Skip to main content

google_cloud_spanner/
transaction_runner.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::database_client::DatabaseClient;
16use crate::model::CommitResponse;
17use crate::model::request_options::Priority;
18use crate::model::transaction_options::IsolationLevel;
19use crate::model::transaction_options::read_write::ReadLockMode;
20use crate::read_only_transaction::BeginTransactionOption;
21use crate::read_write_transaction::{ReadWriteTransaction, ReadWriteTransactionBuilder};
22use crate::transaction_retry_policy::{
23    BasicTransactionRetryPolicy, TransactionRetryPolicy, backoff_if_aborted, is_aborted,
24};
25use google_cloud_gax::backoff_policy::BackoffPolicyArg;
26use google_cloud_gax::retry_policy::RetryPolicyArg;
27
28use std::time::Duration as StdDuration;
29use tokio::time::Instant;
30use wkt::Duration;
31
32/// A builder for a [TransactionRunner] for a read/write transaction.
33///
34/// # Example
35/// ```
36/// # use google_cloud_spanner::client::Spanner;
37/// # use google_cloud_spanner::statement::Statement;
38/// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
39/// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
40/// let runner = db_client.read_write_transaction().build().await?;
41///
42/// let result = runner.run(async |transaction| {
43///     let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
44///     transaction.execute_update(statement).await?;
45///     Ok(42)
46/// }).await?;
47/// # Ok(())
48/// # }
49/// ```
50///
51/// Spanner can abort any read/write transaction at any time. A [TransactionRunner]
52/// automatically retries aborted transactions according to the configured retry policy.
53pub struct TransactionRunnerBuilder {
54    builder: ReadWriteTransactionBuilder,
55    retry_policy: Box<dyn TransactionRetryPolicy>,
56    timeout: Option<StdDuration>,
57    begin_gax_options: Option<crate::RequestOptions>,
58    commit_gax_options: Option<crate::RequestOptions>,
59}
60
61impl TransactionRunnerBuilder {
62    pub(crate) fn new(client: DatabaseClient) -> Self {
63        Self {
64            builder: ReadWriteTransactionBuilder::new(client),
65            retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
66            timeout: None,
67            begin_gax_options: None,
68            commit_gax_options: None,
69        }
70    }
71
72    /// Sets the timeout for the entire transaction.
73    ///
74    /// # Example
75    /// ```
76    /// # use google_cloud_spanner::client::Spanner;
77    /// # use std::time::Duration;
78    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
79    /// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
80    /// let runner = db_client.read_write_transaction()
81    ///     .with_transaction_timeout(Duration::from_secs(5))
82    ///     .build()
83    ///     .await?;
84    /// # Ok(())
85    /// # }
86    /// ```
87    ///
88    /// This timeout applies to the total time spent executing the transaction, including
89    /// all statements and automatic retries. Each individual RPC within the transaction
90    /// is automatically assigned a deadline derived from the remaining time of this
91    /// overall timeout.
92    pub fn with_transaction_timeout(mut self, timeout: StdDuration) -> Self {
93        self.timeout = Some(timeout);
94        self
95    }
96
97    /// Sets the per-attempt timeout for the BeginTransaction RPC.
98    ///
99    /// # Example
100    /// ```
101    /// # use google_cloud_spanner::client::Spanner;
102    /// # use std::time::Duration;
103    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
104    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
105    /// let runner = db_client.read_write_transaction()
106    ///     .with_begin_attempt_timeout(Duration::from_secs(5))
107    ///     .build()
108    ///     .await?;
109    /// # Ok(())
110    /// # }
111    /// ```
112    ///
113    /// Note: This timeout is only used if the transaction uses the `ExplicitBegin` transaction option.
114    pub fn with_begin_attempt_timeout(mut self, timeout: StdDuration) -> Self {
115        self.begin_gax_options
116            .get_or_insert_with(crate::RequestOptions::default)
117            .set_attempt_timeout(timeout);
118        self
119    }
120
121    /// Sets the retry policy for the BeginTransaction RPC.
122    ///
123    /// # Example
124    /// ```
125    /// # use google_cloud_spanner::client::Spanner;
126    /// # use google_cloud_gax::retry_policy::NeverRetry;
127    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
128    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
129    /// let runner = db_client.read_write_transaction()
130    ///     .with_begin_retry_policy(NeverRetry)
131    ///     .build()
132    ///     .await?;
133    /// # Ok(())
134    /// # }
135    /// ```
136    ///
137    /// Note: This policy is only used if the transaction uses the `ExplicitBegin` transaction option.
138    pub fn with_begin_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
139        self.begin_gax_options
140            .get_or_insert_with(crate::RequestOptions::default)
141            .set_retry_policy(policy);
142        self
143    }
144
145    /// Sets the backoff policy for the BeginTransaction RPC.
146    ///
147    /// # Example
148    /// ```
149    /// # use google_cloud_spanner::client::Spanner;
150    /// # use google_cloud_gax::exponential_backoff::ExponentialBackoff;
151    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
152    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
153    /// let runner = db_client.read_write_transaction()
154    ///     .with_begin_backoff_policy(ExponentialBackoff::default())
155    ///     .build()
156    ///     .await?;
157    /// # Ok(())
158    /// # }
159    /// ```
160    ///
161    /// Note: This policy is only used if the transaction uses the `ExplicitBegin` transaction option.
162    pub fn with_begin_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
163        self.begin_gax_options
164            .get_or_insert_with(crate::RequestOptions::default)
165            .set_backoff_policy(policy);
166        self
167    }
168
169    /// Sets the per-attempt timeout for the Commit RPC.
170    ///
171    /// # Example
172    /// ```
173    /// # use google_cloud_spanner::client::Spanner;
174    /// # use std::time::Duration;
175    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
176    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
177    /// let runner = db_client.read_write_transaction()
178    ///     .with_commit_attempt_timeout(Duration::from_secs(5))
179    ///     .build()
180    ///     .await?;
181    /// # Ok(())
182    /// # }
183    /// ```
184    pub fn with_commit_attempt_timeout(mut self, timeout: StdDuration) -> Self {
185        self.commit_gax_options
186            .get_or_insert_with(crate::RequestOptions::default)
187            .set_attempt_timeout(timeout);
188        self
189    }
190
191    /// Sets the retry policy for the Commit RPC.
192    ///
193    /// # Example
194    /// ```
195    /// # use google_cloud_spanner::client::Spanner;
196    /// # use google_cloud_gax::retry_policy::NeverRetry;
197    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
198    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
199    /// let runner = db_client.read_write_transaction()
200    ///     .with_commit_retry_policy(NeverRetry)
201    ///     .build()
202    ///     .await?;
203    /// # Ok(())
204    /// # }
205    /// ```
206    pub fn with_commit_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
207        self.commit_gax_options
208            .get_or_insert_with(crate::RequestOptions::default)
209            .set_retry_policy(policy);
210        self
211    }
212
213    /// Sets the backoff policy for the Commit RPC.
214    ///
215    /// # Example
216    /// ```
217    /// # use google_cloud_spanner::client::Spanner;
218    /// # use google_cloud_gax::exponential_backoff::ExponentialBackoff;
219    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
220    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
221    /// let runner = db_client.read_write_transaction()
222    ///     .with_commit_backoff_policy(ExponentialBackoff::default())
223    ///     .build()
224    ///     .await?;
225    /// # Ok(())
226    /// # }
227    /// ```
228    pub fn with_commit_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
229        self.commit_gax_options
230            .get_or_insert_with(crate::RequestOptions::default)
231            .set_backoff_policy(policy);
232        self
233    }
234
235    /// Sets the isolation level for the transaction.
236    ///
237    /// # Example
238    /// ```
239    /// # use google_cloud_spanner::client::Spanner;
240    /// # use google_cloud_spanner::model::transaction_options::IsolationLevel;
241    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
242    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
243    /// let runner = db_client
244    ///     .read_write_transaction()
245    ///     .set_isolation_level(IsolationLevel::Serializable)
246    ///     .build()
247    ///     .await?;
248    /// # Ok(())
249    /// # }
250    /// ```
251    ///
252    /// See also: <https://docs.cloud.google.com/spanner/docs/isolation-levels>
253    pub fn set_isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
254        self.builder = self.builder.set_isolation_level(isolation_level);
255        self
256    }
257
258    /// Sets the read lock mode for the transaction.
259    ///
260    /// # Example
261    /// ```
262    /// # use google_cloud_spanner::client::Spanner;
263    /// # use google_cloud_spanner::model::transaction_options::read_write::ReadLockMode;
264    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
265    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
266    /// let runner = db_client
267    ///     .read_write_transaction()
268    ///     .set_read_lock_mode(ReadLockMode::Pessimistic)
269    ///     .build()
270    ///     .await?;
271    /// # Ok(())
272    /// # }
273    /// ```
274    ///
275    /// See also: <https://docs.cloud.google.com/spanner/docs/concurrency-control>
276    pub fn set_read_lock_mode(mut self, read_lock_mode: ReadLockMode) -> Self {
277        self.builder = self.builder.set_read_lock_mode(read_lock_mode);
278        self
279    }
280
281    /// Sets the transaction tag for the transaction.
282    ///
283    /// # Example
284    /// ```
285    /// # use google_cloud_spanner::client::Spanner;
286    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
287    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
288    /// let runner = db_client.read_write_transaction()
289    ///     .set_transaction_tag("my-tag")
290    ///     .build()
291    ///     .await?;
292    /// # Ok(())
293    /// # }
294    /// ```
295    ///
296    /// The tag is applied to all statements executed within the transaction.
297    ///
298    /// See also: [Troubleshooting with tags](https://docs.cloud.google.com/spanner/docs/introspection/troubleshooting-with-tags)
299    pub fn set_transaction_tag(mut self, tag: impl Into<String>) -> Self {
300        self.builder = self.builder.set_transaction_tag(tag);
301        self
302    }
303
304    /// Sets the option for how to start a transaction.
305    ///
306    /// # Example
307    /// ```
308    /// # use google_cloud_spanner::client::Spanner;
309    /// # use google_cloud_spanner::transaction::BeginTransactionOption;
310    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
311    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
312    /// let runner = db_client
313    ///     .read_write_transaction()
314    ///     .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
315    ///     .build()
316    ///     .await?;
317    /// # Ok(())
318    /// # }
319    /// ```
320    ///
321    /// By default, the Spanner client will inline the `BeginTransaction` call with the first query
322    /// or DML statement in the transaction. This reduces the number of round-trips to Spanner that
323    /// are needed for a transaction. Setting this option to `ExplicitBegin` can be beneficial for
324    /// specific transaction shapes:
325    ///
326    /// 1. When the transaction executes multiple parallel queries at the start of the transaction.
327    ///    Only one query can include a `BeginTransaction` option, and all other queries must wait for
328    ///    the first query to return the first result before they can proceed to execute. A
329    ///    `BeginTransaction` RPC will quickly return a transaction ID and allow all queries to start
330    ///    execution in parallel once the transaction ID has been returned.
331    /// 2. When the first statement in the transaction could fail. If the statement fails, then it
332    ///    will also not start a transaction and return a transaction ID. The transaction will then
333    ///    fall back to executing a `BeginTransaction` RPC and retry the first statement.
334    ///
335    /// Default is `BeginTransactionOption::InlineBegin`.
336    pub fn with_begin_transaction_option(mut self, option: BeginTransactionOption) -> Self {
337        self.builder = self.builder.with_begin_transaction_option(option);
338        self
339    }
340
341    /// Sets the RPC priority to use for the commit of this transaction.
342    ///
343    /// # Example
344    /// ```
345    /// # use google_cloud_spanner::client::Spanner;
346    /// # use google_cloud_spanner::model::request_options::Priority;
347    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
348    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
349    /// let runner = db_client
350    ///     .read_write_transaction()
351    ///     .set_commit_priority(Priority::Low)
352    ///     .build()
353    ///     .await?;
354    /// # Ok(())
355    /// # }
356    /// ```
357    pub fn set_commit_priority(mut self, priority: Priority) -> Self {
358        self.builder = self.builder.set_commit_priority(priority);
359        self
360    }
361
362    /// Sets the maximum commit delay for the transaction.
363    ///
364    /// # Example
365    /// ```
366    /// # use google_cloud_spanner::client::Spanner;
367    /// # use wkt::Duration;
368    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
369    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
370    /// let runner = db_client
371    ///     .read_write_transaction()
372    ///     .set_max_commit_delay(Duration::try_from("0.2s").unwrap())
373    ///     .build()
374    ///     .await?;
375    /// # Ok(())
376    /// # }
377    /// ```
378    ///
379    /// This option allows you to specify the maximum amount of time Spanner can
380    /// adjust the commit timestamp of the transaction to allow for commit batching.
381    /// Increasing this value can increase throughput at the expense of latency.
382    /// The value must be between 0 and 500 milliseconds. If not set, or set to 0,
383    /// Spanner does not delay the commit.
384    pub fn set_max_commit_delay(mut self, delay: Duration) -> Self {
385        self.builder = self.builder.set_max_commit_delay(delay);
386        self
387    }
388
389    /// Sets whether to exclude the transaction from change streams.
390    ///
391    /// # Example
392    /// ```
393    /// # use google_cloud_spanner::client::Spanner;
394    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
395    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
396    /// let runner = db_client.read_write_transaction()
397    ///     .set_exclude_txn_from_change_streams(true)
398    ///     .build()
399    ///     .await?;
400    /// # Ok(())
401    /// # }
402    /// ```
403    ///
404    /// When set to `true`, it prevents modifications from this transaction from being tracked in change streams.
405    /// Note that this only affects change streams that have been created with the DDL option `allow_txn_exclusion = true`.
406    /// If `allow_txn_exclusion` is not set or set to `false` for a change stream, updates made within this transaction
407    /// are recorded in that change stream regardless of this setting.
408    ///
409    /// When set to `false` or not specified, modifications from this transaction are recorded in all change streams
410    /// tracking columns modified by this transaction.
411    pub fn set_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
412        self.builder = self.builder.set_exclude_txn_from_change_streams(exclude);
413        self
414    }
415
416    /// Sets whether to return commit stats for the transaction.
417    ///
418    /// # Example
419    /// ```
420    /// # use google_cloud_spanner::client::Spanner;
421    /// # use google_cloud_spanner::statement::Statement;
422    /// # async fn run_tx(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
423    /// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
424    /// let runner = db_client.read_write_transaction()
425    ///     .set_return_commit_stats(true)
426    ///     .build()
427    ///     .await?;
428    ///
429    /// let result = runner.run(async |transaction| {
430    ///     let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
431    ///     transaction.execute_update(statement).await?;
432    ///     Ok(42)
433    /// }).await?;
434    ///
435    /// if let Some(stats) = result.commit_response.commit_stats {
436    ///     println!("Mutation count: {}", stats.mutation_count);
437    /// }
438    /// # Ok(())
439    /// # }
440    /// ```
441    ///
442    /// See also: <https://docs.cloud.google.com/spanner/docs/commit-statistics>
443    pub fn set_return_commit_stats(mut self, return_stats: bool) -> Self {
444        self.builder = self.builder.set_return_commit_stats(return_stats);
445        self
446    }
447
448    /// Sets the retry policy for the transaction.
449    ///
450    /// # Example
451    /// ```
452    /// # use std::time::Duration;
453    /// # use google_cloud_spanner::client::Spanner;
454    /// # use google_cloud_spanner::transaction::BasicTransactionRetryPolicy;
455    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
456    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
457    ///
458    /// let retry_policy = BasicTransactionRetryPolicy::new()
459    ///     .with_max_attempts(5)
460    ///     .with_total_timeout(Duration::from_secs(60));
461    ///
462    /// let runner = db_client
463    ///     .read_write_transaction()
464    ///     .with_retry_policy(retry_policy)
465    ///     .build()
466    ///     .await?;
467    /// # Ok(())
468    /// # }
469    /// ```
470    pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
471        self.retry_policy = Box::new(policy);
472        self
473    }
474
475    /// Builds a [TransactionRunner] for a read/write transaction.
476    ///
477    /// # Example
478    /// ```
479    /// # use google_cloud_spanner::client::Spanner;
480    /// # use google_cloud_spanner::statement::Statement;
481    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
482    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
483    /// let runner = db_client.read_write_transaction().build().await?;
484    ///
485    /// let result = runner.run(async |transaction| {
486    ///     let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
487    ///     transaction.execute_update(statement).await?;
488    ///     Ok(42)
489    /// }).await?;
490    /// # Ok(())
491    /// # }
492    /// ```
493    pub async fn build(self) -> crate::Result<TransactionRunner> {
494        Ok(TransactionRunner {
495            builder: self
496                .builder
497                .with_begin_transaction_request_options(self.begin_gax_options)
498                .with_commit_request_options(self.commit_gax_options),
499            retry_policy: self.retry_policy,
500            timeout: self.timeout,
501        })
502    }
503}
504
505/// Result of a read/write transaction executed by a [TransactionRunner].
506#[derive(Debug)]
507#[non_exhaustive]
508pub struct TransactionResult<T> {
509    /// The result returned by the closure executed within the transaction.
510    pub result: T,
511    /// The response from the commit RPC.
512    pub commit_response: CommitResponse,
513}
514
515/// A runner for read/write transactions. Aborted transactions are automatically retried.
516pub struct TransactionRunner {
517    builder: ReadWriteTransactionBuilder,
518    retry_policy: Box<dyn TransactionRetryPolicy>,
519    timeout: Option<StdDuration>,
520}
521
522impl TransactionRunner {
523    /// Runs the provided closure within the context of a read/write transaction.
524    ///
525    /// # Example
526    /// ```
527    /// # use google_cloud_spanner::client::Spanner;
528    /// # use google_cloud_spanner::statement::Statement;
529    /// # async fn run_tx(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
530    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
531    /// let runner = db_client.read_write_transaction().build().await?;
532    ///
533    /// let result = runner.run(async |transaction| {
534    ///     let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
535    ///     transaction.execute_update(statement).await?;
536    ///     Ok(42)
537    /// }).await?;
538    /// # Ok(())
539    /// # }
540    /// ```
541    ///
542    /// If the transaction is aborted by Spanner, the closure will be retried
543    /// automatically according to the configured `TransactionRetryPolicy`.
544    ///
545    /// The transaction is automatically committed if the closure returns `Ok`.
546    /// If the closure returns `Err`, the transaction will be rolled back and
547    /// the error will be propagated.
548    pub async fn run<T, F>(mut self, mut work: F) -> crate::Result<TransactionResult<T>>
549    where
550        F: std::ops::AsyncFnMut(ReadWriteTransaction) -> crate::Result<T>,
551    {
552        let start_time = Instant::now();
553        let mut attempts: u32 = 0;
554        let backoff = crate::transaction_retry_policy::default_retry_backoff();
555        let deadline = self.timeout.map(|t| start_time + t);
556
557        loop {
558            attempts += 1;
559
560            let mut current_tx_id = None;
561            let attempt_result = async {
562                let transaction = self.builder.clone().build(deadline).await?;
563
564                let result = match work(transaction.clone()).await {
565                    Ok(res) => res,
566                    Err(e) => {
567                        // We call `get_id_no_wait` here to retrieve the transaction ID without waiting.
568                        // We do not require the transaction ID to be unconditionally available here;
569                        // we only wish to capture it if the transaction successfully started prior to
570                        // failing, so it can be used as the previous transaction ID if the transaction
571                        // was aborted.
572                        let id = transaction
573                            .context
574                            .transaction_selector
575                            .get_id_no_wait()
576                            .ok()
577                            .flatten();
578                        // Rollback if the closure failed and it was not an Aborted error.
579                        if !is_aborted(&e) {
580                            let _ = transaction.rollback().await;
581                        }
582                        current_tx_id = id;
583                        return Err(e);
584                    }
585                };
586
587                // `commit()` consumes `transaction`. If the commit RPC fails with an Aborted error,
588                // we still need access to the transaction ID so we can provide it as `previous_transaction_id`
589                // on retry. Cloning only `transaction_selector` preserves access to the internal state efficiently.
590                let selector = transaction.context.transaction_selector.clone();
591                let commit_result = transaction.commit().await;
592                current_tx_id = selector.get_id_no_wait().ok().flatten();
593                let commit_response = commit_result?;
594                Ok::<TransactionResult<T>, crate::Error>(TransactionResult {
595                    result,
596                    commit_response,
597                })
598            }
599            .await;
600
601            match attempt_result {
602                Ok(res) => return Ok(res),
603                Err(e) => {
604                    if is_aborted(&e) {
605                        let current_tx_id = current_tx_id.clone();
606                        self.builder = self.builder.set_previous_transaction_id(current_tx_id);
607                    }
608
609                    backoff_if_aborted(
610                        e,
611                        attempts,
612                        start_time.elapsed(),
613                        self.retry_policy.as_ref(),
614                        &backoff,
615                    )
616                    .await?;
617                }
618            }
619        }
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use crate::mutation::Mutation;
627    use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
628    use crate::transaction_retry_policy::tests::create_aborted_status;
629    use gaxi::grpc::tonic;
630    use google_cloud_gax::exponential_backoff::ExponentialBackoff;
631    use google_cloud_gax::retry_policy::NeverRetry;
632    use google_cloud_test_macros::tokio_test_no_panics;
633    use prost_types::value::Kind;
634    use spanner_grpc_mock::google::spanner::v1;
635    use spanner_grpc_mock::google::spanner::v1::CommitResponse;
636    use spanner_grpc_mock::google::spanner::v1::commit_request::Transaction as CommitTransaction;
637    use spanner_grpc_mock::google::spanner::v1::commit_response::CommitStats;
638    use spanner_grpc_mock::google::spanner::v1::mutation::Operation;
639    use spanner_grpc_mock::google::spanner::v1::transaction_options::Mode;
640    use spanner_grpc_mock::google::spanner::v1::transaction_selector::Selector as ProtoSelector;
641    use std::sync::Mutex;
642    use std::sync::mpsc::channel as std_channel;
643    use std::time::Duration as StdDuration;
644    use std::time::Duration as StdTimeDuration;
645    use tokio::sync::oneshot::channel as oneshot_channel;
646
647    fn expect_begin_transaction(
648        mock: &mut spanner_grpc_mock::MockSpanner,
649        times: usize,
650        transaction_id: Vec<u8>,
651    ) {
652        mock.expect_begin_transaction()
653            .times(times)
654            .returning(move |req| {
655                let req = req.into_inner();
656                assert_eq!(
657                    req.session,
658                    "projects/p/instances/i/databases/d/sessions/123"
659                );
660                Ok(tonic::Response::new(v1::Transaction {
661                    id: transaction_id.clone(),
662                    ..Default::default()
663                }))
664            });
665    }
666
667    async fn execute_test_runner(
668        mock: spanner_grpc_mock::MockSpanner,
669        begin_transaction_option: BeginTransactionOption,
670    ) -> Result<i64, crate::Error> {
671        let (db_client, server) = setup_db_client(mock).await;
672        let runner = TransactionRunnerBuilder::new(db_client)
673            .with_begin_transaction_option(begin_transaction_option)
674            .build()
675            .await
676            .unwrap();
677        tokio::select! {
678            res = runner.run(async |tx| {
679                let count = tx.execute_update("UPDATE Users SET active = true").await?;
680                Ok(count)
681            }) => res.map(|r| r.result),
682            err = server => panic!("Mock server panicked or terminated unexpectedly: {:?}", err),
683        }
684    }
685
686    fn commit_response() -> Result<tonic::Response<v1::CommitResponse>, tonic::Status> {
687        Ok(tonic::Response::new(v1::CommitResponse {
688            commit_timestamp: Some(prost_types::Timestamp {
689                seconds: 123456789,
690                nanos: 0,
691            }),
692            ..Default::default()
693        }))
694    }
695
696    fn row_count_exact_response(
697        count: i64,
698    ) -> Result<tonic::Response<v1::ResultSet>, tonic::Status> {
699        Ok(tonic::Response::new(v1::ResultSet {
700            stats: Some(v1::ResultSetStats {
701                row_count: Some(v1::result_set_stats::RowCount::RowCountExact(count)),
702                ..Default::default()
703            }),
704            ..Default::default()
705        }))
706    }
707
708    #[test]
709    fn auto_traits() {
710        static_assertions::assert_impl_all!(TransactionRunnerBuilder: Send, Sync);
711        static_assertions::assert_impl_all!(TransactionRunner: Send, Sync);
712    }
713
714    #[tokio_test_no_panics]
715    async fn execute_run_success_explicit() {
716        run_success(BeginTransactionOption::ExplicitBegin).await;
717    }
718
719    #[tokio_test_no_panics]
720    async fn execute_run_success_inline() {
721        run_success(BeginTransactionOption::InlineBegin).await;
722    }
723
724    async fn run_success(begin_transaction_option: BeginTransactionOption) {
725        let mut mock = create_session_mock();
726
727        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
728            expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);
729        }
730
731        mock.expect_execute_sql().once().returning(move |req| {
732            let req = req.into_inner();
733            assert_eq!(req.sql, "UPDATE Users SET active = true");
734            assert_eq!(req.seqno, 1);
735
736            if begin_transaction_option == BeginTransactionOption::InlineBegin {
737                let transaction = req
738                    .transaction
739                    .as_ref()
740                    .expect("transaction options required for inline begin");
741                let selector = transaction.selector.as_ref().expect("selector required");
742                assert!(matches!(
743                    selector,
744                    v1::transaction_selector::Selector::Begin(_)
745                ));
746            }
747
748            let mut metadata = v1::ResultSetMetadata {
749                ..Default::default()
750            };
751            if begin_transaction_option == BeginTransactionOption::InlineBegin {
752                metadata.transaction = Some(v1::Transaction {
753                    id: vec![1, 2, 3],
754                    ..Default::default()
755                });
756            }
757
758            Ok(tonic::Response::new(v1::ResultSet {
759                metadata: Some(metadata),
760                stats: Some(v1::ResultSetStats {
761                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
762                    ..Default::default()
763                }),
764                ..Default::default()
765            }))
766        });
767
768        mock.expect_commit().once().returning(|req| {
769            let req = req.into_inner();
770            assert_eq!(
771                req.transaction,
772                Some(v1::commit_request::Transaction::TransactionId(vec![
773                    1, 2, 3
774                ]))
775            );
776            commit_response()
777        });
778
779        let res = execute_test_runner(mock, begin_transaction_option)
780            .await
781            .unwrap();
782        assert_eq!(res, 1);
783    }
784
785    #[tokio_test_no_panics]
786    async fn execute_run_success_with_commit_stats_explicit() {
787        run_success_with_commit_stats(BeginTransactionOption::ExplicitBegin).await;
788    }
789
790    #[tokio_test_no_panics]
791    async fn execute_run_success_with_commit_stats_inline() {
792        run_success_with_commit_stats(BeginTransactionOption::InlineBegin).await;
793    }
794
795    async fn run_success_with_commit_stats(begin_transaction_option: BeginTransactionOption) {
796        let mut mock = create_session_mock();
797
798        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
799            expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);
800        }
801
802        mock.expect_execute_sql().once().returning(move |req| {
803            let req = req.into_inner();
804            assert_eq!(req.sql, "UPDATE Users SET active = true");
805
806            if begin_transaction_option == BeginTransactionOption::InlineBegin {
807                let transaction = req
808                    .transaction
809                    .as_ref()
810                    .expect("transaction options required for inline begin");
811                let selector = transaction.selector.as_ref().expect("selector required");
812                assert!(matches!(
813                    selector,
814                    v1::transaction_selector::Selector::Begin(_)
815                ));
816            }
817
818            let mut metadata = v1::ResultSetMetadata {
819                ..Default::default()
820            };
821            if begin_transaction_option == BeginTransactionOption::InlineBegin {
822                metadata.transaction = Some(v1::Transaction {
823                    id: vec![1, 2, 3],
824                    ..Default::default()
825                });
826            }
827
828            Ok(tonic::Response::new(v1::ResultSet {
829                metadata: Some(metadata),
830                stats: Some(v1::ResultSetStats {
831                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
832                    ..Default::default()
833                }),
834                ..Default::default()
835            }))
836        });
837
838        mock.expect_commit().once().returning(|req| {
839            let req = req.into_inner();
840            assert!(req.return_commit_stats);
841            Ok(tonic::Response::new(CommitResponse {
842                commit_timestamp: Some(prost_types::Timestamp {
843                    seconds: 123456789,
844                    nanos: 0,
845                }),
846                commit_stats: Some(CommitStats { mutation_count: 5 }),
847                ..Default::default()
848            }))
849        });
850
851        let (db_client, _server) = setup_db_client(mock).await;
852        let runner = TransactionRunnerBuilder::new(db_client)
853            .set_return_commit_stats(true)
854            .with_begin_transaction_option(begin_transaction_option)
855            .build()
856            .await
857            .unwrap();
858
859        let res = runner
860            .run(async |tx| {
861                let count = tx.execute_update("UPDATE Users SET active = true").await?;
862                Ok(count)
863            })
864            .await
865            .unwrap();
866
867        assert_eq!(res.result, 1);
868        assert!(res.commit_response.commit_stats.is_some());
869        assert_eq!(
870            res.commit_response
871                .commit_stats
872                .expect("Commit stats should be present")
873                .mutation_count,
874            5
875        );
876    }
877
878    #[tokio_test_no_panics]
879    async fn execute_run_with_aborted_retry_explicit() -> anyhow::Result<()> {
880        run_with_aborted_retry(BeginTransactionOption::ExplicitBegin).await
881    }
882
883    #[tokio_test_no_panics]
884    async fn execute_run_with_aborted_retry_inline() -> anyhow::Result<()> {
885        run_with_aborted_retry(BeginTransactionOption::InlineBegin).await
886    }
887
888    async fn run_with_aborted_retry(
889        begin_transaction_option: BeginTransactionOption,
890    ) -> anyhow::Result<()> {
891        let mut mock = create_session_mock();
892        let mut seq = mockall::Sequence::new();
893
894        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
895            mock.expect_begin_transaction()
896                .once()
897                .in_sequence(&mut seq)
898                .returning(move |req| {
899                    let req = req.into_inner();
900                    assert_eq!(
901                        req.session,
902                        "projects/p/instances/i/databases/d/sessions/123"
903                    );
904                    Ok(tonic::Response::new(v1::Transaction {
905                        id: vec![9, 9, 9],
906                        ..Default::default()
907                    }))
908                });
909        }
910
911        if begin_transaction_option == BeginTransactionOption::InlineBegin {
912            // Attempt 1: execute_sql fails with Aborted
913            mock.expect_execute_sql()
914                .once()
915                .in_sequence(&mut seq)
916                .returning(move |req| {
917                    let req = req.into_inner();
918                    let transaction = req
919                        .transaction
920                        .as_ref()
921                        .expect("transaction options required for inline begin");
922                    let selector = transaction.selector.as_ref().expect("selector required");
923                    assert!(matches!(
924                        selector,
925                        v1::transaction_selector::Selector::Begin(_)
926                    ));
927
928                    Err(create_aborted_status(std::time::Duration::from_nanos(1)))
929                });
930        } else {
931            mock.expect_execute_sql()
932                .once()
933                .in_sequence(&mut seq)
934                .returning(move |_req| {
935                    Err(create_aborted_status(std::time::Duration::from_nanos(1)))
936                });
937        }
938
939        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
940            mock.expect_begin_transaction()
941                .once()
942                .in_sequence(&mut seq)
943                .returning(move |req| {
944                    let req = req.into_inner();
945                    assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
946
947                    let options = req.options.as_ref().expect("options required on retry");
948                    let read_write = options.mode.as_ref().expect("mode required on retry");
949                    match read_write {
950                        Mode::ReadWrite(rw) => {
951                            assert_eq!(rw.multiplexed_session_previous_transaction_id, vec![9, 9, 9], "previous_transaction_id should be set to the ID of the aborted transaction");
952                        }
953                        _ => panic!("Expected ReadWrite mode"),
954                    }
955
956                    Ok(tonic::Response::new(v1::Transaction {
957                        id: vec![8, 8, 8],
958                        ..Default::default()
959                    }))
960                });
961        }
962
963        // Attempt 2 (retry of closure)
964        mock.expect_execute_sql()
965            .once()
966            .in_sequence(&mut seq)
967            .returning(move |req| {
968                if begin_transaction_option == BeginTransactionOption::InlineBegin {
969                    let req = req.into_inner();
970                    let transaction = req
971                        .transaction
972                        .as_ref()
973                        .expect("transaction options required for inline begin");
974                    let selector = transaction.selector.as_ref().expect("selector required");
975                    assert!(matches!(
976                        selector,
977                        v1::transaction_selector::Selector::Begin(_)
978                    ));
979
980                    let options = match selector {
981                        v1::transaction_selector::Selector::Begin(o) => o,
982                        _ => panic!("Expected Begin"),
983                    };
984                    let read_write = options.mode.as_ref().expect("mode required");
985                    match read_write {
986                        Mode::ReadWrite(rw) => {
987                            assert!(rw.multiplexed_session_previous_transaction_id.is_empty());
988                        }
989                        _ => panic!("Expected ReadWrite"),
990                    }
991                }
992
993                let mut metadata = v1::ResultSetMetadata {
994                    ..Default::default()
995                };
996                if begin_transaction_option == BeginTransactionOption::InlineBegin {
997                    metadata.transaction = Some(v1::Transaction {
998                        id: vec![8, 8, 8],
999                        ..Default::default()
1000                    });
1001                }
1002
1003                Ok(tonic::Response::new(v1::ResultSet {
1004                    metadata: Some(metadata),
1005                    stats: Some(v1::ResultSetStats {
1006                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)),
1007                        ..Default::default()
1008                    }),
1009                    ..Default::default()
1010                }))
1011            });
1012
1013        mock.expect_commit()
1014            .once()
1015            .returning(|_req| commit_response());
1016
1017        let res = execute_test_runner(mock, begin_transaction_option)
1018            .await
1019            .expect("runner should succeed");
1020        assert_eq!(res, 5);
1021        Ok(())
1022    }
1023
1024    #[tokio_test_no_panics]
1025    async fn execute_run_query_stream_with_aborted_retry_explicit() -> anyhow::Result<()> {
1026        run_query_stream_with_aborted_retry(BeginTransactionOption::ExplicitBegin).await
1027    }
1028
1029    #[tokio_test_no_panics]
1030    async fn execute_run_query_stream_with_aborted_retry_inline() -> anyhow::Result<()> {
1031        run_query_stream_with_aborted_retry(BeginTransactionOption::InlineBegin).await
1032    }
1033
1034    async fn run_query_stream_with_aborted_retry(
1035        begin_transaction_option: BeginTransactionOption,
1036    ) -> anyhow::Result<()> {
1037        let mut mock = create_session_mock();
1038        let mut seq = mockall::Sequence::new();
1039
1040        let tx_id_1 = vec![9, 9, 9];
1041        let tx_id_2 = vec![8, 8, 8];
1042
1043        let tx_id_1_c1 = tx_id_1.clone();
1044        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1045            mock.expect_begin_transaction()
1046                .once()
1047                .in_sequence(&mut seq)
1048                .returning(move |_| {
1049                    Ok(tonic::Response::new(v1::Transaction {
1050                        id: tx_id_1_c1.clone(),
1051                        ..Default::default()
1052                    }))
1053                });
1054        }
1055
1056        let tx_id_1_c2 = tx_id_1.clone();
1057        mock.expect_execute_streaming_sql()
1058            .once()
1059            .in_sequence(&mut seq)
1060            .returning(move |req| {
1061                let req = req.into_inner();
1062                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1063                    let transaction = req
1064                        .transaction
1065                        .as_ref()
1066                        .expect("transaction options required for inline begin");
1067                    let selector = transaction.selector.as_ref().expect("selector required");
1068                    assert!(matches!(
1069                        selector,
1070                        v1::transaction_selector::Selector::Begin(_)
1071                    ));
1072                }
1073
1074                let mut rs = v1::PartialResultSet {
1075                    metadata: Some(v1::ResultSetMetadata {
1076                        row_type: Some(v1::StructType {
1077                            fields: vec![Default::default()],
1078                        }),
1079                        ..Default::default()
1080                    }),
1081                    values: vec![prost_types::Value {
1082                        kind: Some(prost_types::value::Kind::StringValue("1".to_string())),
1083                    }],
1084                    resume_token: b"token1".to_vec(),
1085                    ..Default::default()
1086                };
1087
1088                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1089                    rs.metadata.as_mut().unwrap().transaction = Some(v1::Transaction {
1090                        id: tx_id_1_c2.clone(),
1091                        ..Default::default()
1092                    });
1093                }
1094
1095                let (tx, rx) = tokio::sync::mpsc::channel(2);
1096                tx.try_send(Ok(rs)).unwrap();
1097                tx.try_send(Err(tonic::Status::new(tonic::Code::Aborted, "aborted")))
1098                    .unwrap();
1099                Ok(tonic::Response::from(rx))
1100            });
1101
1102        let tx_id_2_c1 = tx_id_2.clone();
1103        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1104            mock.expect_begin_transaction()
1105                .once()
1106                .in_sequence(&mut seq)
1107                .returning(move |req| {
1108                    let req = req.into_inner();
1109                    let options = req.options.as_ref().expect("options required on retry");
1110                    let read_write = options.mode.as_ref().expect("mode required on retry");
1111                    match read_write {
1112                        Mode::ReadWrite(rw) => {
1113                            assert_eq!(
1114                                rw.multiplexed_session_previous_transaction_id,
1115                                vec![9, 9, 9]
1116                            );
1117                        }
1118                        _ => panic!("Expected ReadWrite mode"),
1119                    }
1120
1121                    Ok(tonic::Response::new(v1::Transaction {
1122                        id: tx_id_2_c1.clone(),
1123                        ..Default::default()
1124                    }))
1125                });
1126        }
1127
1128        let tx_id_2_c2 = tx_id_2.clone();
1129        mock.expect_execute_streaming_sql()
1130            .once()
1131            .in_sequence(&mut seq)
1132            .returning(move |req| {
1133                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1134                    let req = req.into_inner();
1135                    let transaction = req
1136                        .transaction
1137                        .as_ref()
1138                        .expect("transaction options required for inline begin");
1139                    let selector = transaction.selector.as_ref().expect("selector required");
1140                    assert!(matches!(
1141                        selector,
1142                        v1::transaction_selector::Selector::Begin(_)
1143                    ));
1144
1145                    let options = match selector {
1146                        v1::transaction_selector::Selector::Begin(o) => o,
1147                        _ => panic!("Expected Begin"),
1148                    };
1149                    let read_write = options.mode.as_ref().expect("mode required");
1150                    match read_write {
1151                        Mode::ReadWrite(rw) => {
1152                            assert_eq!(
1153                                rw.multiplexed_session_previous_transaction_id,
1154                                vec![9, 9, 9]
1155                            );
1156                        }
1157                        _ => panic!("Expected ReadWrite"),
1158                    }
1159                }
1160
1161                let mut rs = v1::PartialResultSet {
1162                    metadata: Some(v1::ResultSetMetadata {
1163                        row_type: Some(v1::StructType {
1164                            fields: vec![Default::default()],
1165                        }),
1166                        ..Default::default()
1167                    }),
1168                    values: vec![prost_types::Value {
1169                        kind: Some(prost_types::value::Kind::StringValue("1".to_string())),
1170                    }],
1171                    last: true,
1172                    ..Default::default()
1173                };
1174
1175                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1176                    rs.metadata.as_mut().unwrap().transaction = Some(v1::Transaction {
1177                        id: tx_id_2_c2.clone(),
1178                        ..Default::default()
1179                    });
1180                }
1181
1182                let (tx, rx) = tokio::sync::mpsc::channel(2);
1183                tx.try_send(Ok(rs)).unwrap();
1184                Ok(tonic::Response::from(rx))
1185            });
1186
1187        mock.expect_commit()
1188            .once()
1189            .returning(|_req| commit_response());
1190
1191        let (db_client, _server) = setup_db_client(mock).await;
1192        let runner = TransactionRunnerBuilder::new(db_client)
1193            .with_begin_transaction_option(begin_transaction_option)
1194            .build()
1195            .await?;
1196
1197        let mut attempt_counter = 0;
1198        let res = runner
1199            .run(async |tx| {
1200                attempt_counter += 1;
1201                let mut rs = tx.execute_query("SELECT 1").await?;
1202                let mut last_val = None;
1203                while let Some(row_res) = rs.next().await {
1204                    let row = row_res?;
1205                    last_val = Some(row.raw_values()[0].as_string().to_string());
1206                }
1207                Ok(last_val.unwrap())
1208            })
1209            .await?;
1210
1211        assert_eq!(res.result, "1");
1212        assert_eq!(attempt_counter, 2);
1213        Ok(())
1214    }
1215
1216    #[tokio_test_no_panics]
1217    async fn execute_run_with_non_aborted_error_explicit() {
1218        run_with_non_aborted_error(BeginTransactionOption::ExplicitBegin).await;
1219    }
1220
1221    #[tokio_test_no_panics]
1222    async fn execute_run_with_non_aborted_error_inline() {
1223        run_with_non_aborted_error(BeginTransactionOption::InlineBegin).await;
1224    }
1225
1226    async fn run_with_non_aborted_error(begin_transaction_option: BeginTransactionOption) {
1227        let mut mock = create_session_mock();
1228
1229        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1230            expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]);
1231        }
1232
1233        // Let execute_sql return an error to trigger a rollback.
1234        mock.expect_execute_sql().once().returning(move |_req| {
1235            Err(tonic::Status::new(
1236                tonic::Code::PermissionDenied,
1237                "permission denied",
1238            ))
1239        });
1240
1241        if begin_transaction_option == BeginTransactionOption::InlineBegin {
1242            expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]);
1243            mock.expect_execute_sql().once().returning(move |_req| {
1244                Err(tonic::Status::new(
1245                    tonic::Code::PermissionDenied,
1246                    "permission denied",
1247                ))
1248            });
1249        }
1250
1251        // Must explicitly trigger rollback
1252        mock.expect_rollback()
1253            .once()
1254            .returning(|_req| Ok(tonic::Response::new(())));
1255
1256        let res = execute_test_runner(mock, begin_transaction_option).await;
1257
1258        assert!(res.is_err());
1259        let err = res.unwrap_err();
1260        if let Some(status) = err.status() {
1261            assert_eq!(
1262                status.code,
1263                google_cloud_gax::error::rpc::Code::PermissionDenied
1264            );
1265        } else {
1266            panic!("Expected GRPC error");
1267        }
1268    }
1269
1270    #[tokio_test_no_panics]
1271    async fn execute_run_with_non_aborted_error_and_rollback_fails_explicit() {
1272        run_with_non_aborted_error_and_rollback_fails(BeginTransactionOption::ExplicitBegin).await;
1273    }
1274
1275    #[tokio_test_no_panics]
1276    async fn execute_run_with_non_aborted_error_and_rollback_fails_inline() {
1277        run_with_non_aborted_error_and_rollback_fails(BeginTransactionOption::InlineBegin).await;
1278    }
1279
1280    async fn run_with_non_aborted_error_and_rollback_fails(
1281        begin_transaction_option: BeginTransactionOption,
1282    ) {
1283        let mut mock = create_session_mock();
1284
1285        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1286            expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]);
1287        }
1288
1289        // Let execute_sql return an error to trigger a rollback.
1290        mock.expect_execute_sql().once().returning(move |_req| {
1291            Err(tonic::Status::new(
1292                tonic::Code::PermissionDenied,
1293                "permission denied",
1294            ))
1295        });
1296
1297        if begin_transaction_option == BeginTransactionOption::InlineBegin {
1298            expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]);
1299            mock.expect_execute_sql().once().returning(move |_req| {
1300                Err(tonic::Status::new(
1301                    tonic::Code::PermissionDenied,
1302                    "permission denied",
1303                ))
1304            });
1305        }
1306
1307        // Force the rollback itself to fail as well
1308        mock.expect_rollback()
1309            .once()
1310            .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "rollback failed")));
1311
1312        let res = execute_test_runner(mock, begin_transaction_option).await;
1313
1314        // Verify the user unequivocally receives the PRIMARY original error
1315        assert!(res.is_err());
1316        let err = res.unwrap_err();
1317        if let Some(status) = err.status() {
1318            assert_eq!(
1319                status.code,
1320                google_cloud_gax::error::rpc::Code::PermissionDenied
1321            );
1322        } else {
1323            panic!("Expected GRPC error");
1324        }
1325    }
1326
1327    #[tokio_test_no_panics]
1328    async fn execute_run_commit_aborted_retry_explicit() {
1329        run_commit_aborted_retry(BeginTransactionOption::ExplicitBegin).await;
1330    }
1331
1332    #[tokio_test_no_panics]
1333    async fn execute_run_commit_aborted_retry_inline() {
1334        run_commit_aborted_retry(BeginTransactionOption::InlineBegin).await;
1335    }
1336
1337    async fn run_commit_aborted_retry(begin_transaction_option: BeginTransactionOption) {
1338        let mut mock = create_session_mock();
1339
1340        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1341            expect_begin_transaction(&mut mock, 2, vec![9, 9, 9]);
1342        }
1343
1344        let mut attempt = 0;
1345        mock.expect_execute_sql().times(2).returning(move |req| {
1346            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1347                let req = req.into_inner();
1348                let transaction = req
1349                    .transaction
1350                    .as_ref()
1351                    .expect("transaction options required for inline begin");
1352                let selector = transaction.selector.as_ref().expect("selector required");
1353                assert!(matches!(
1354                    selector,
1355                    v1::transaction_selector::Selector::Begin(_)
1356                ));
1357
1358                attempt += 1;
1359                if attempt == 2 {
1360                    let options = match selector {
1361                        v1::transaction_selector::Selector::Begin(o) => o,
1362                        _ => panic!("Expected Begin"),
1363                    };
1364                    let read_write = options.mode.as_ref().expect("mode required");
1365                    match read_write {
1366                        Mode::ReadWrite(rw) => {
1367                            assert_eq!(
1368                                rw.multiplexed_session_previous_transaction_id,
1369                                vec![9, 9, 9]
1370                            );
1371                        }
1372                        _ => panic!("Expected ReadWrite"),
1373                    }
1374                }
1375
1376                let mut metadata = v1::ResultSetMetadata {
1377                    ..Default::default()
1378                };
1379                metadata.transaction = Some(v1::Transaction {
1380                    id: vec![9, 9, 9],
1381                    ..Default::default()
1382                });
1383
1384                return Ok(tonic::Response::new(v1::ResultSet {
1385                    metadata: Some(metadata),
1386                    stats: Some(v1::ResultSetStats {
1387                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)),
1388                        ..Default::default()
1389                    }),
1390                    ..Default::default()
1391                }));
1392            }
1393            row_count_exact_response(5)
1394        });
1395
1396        let mut commit_attempt = 0;
1397        mock.expect_commit().times(2).returning(move |_req| {
1398            commit_attempt += 1;
1399            if commit_attempt == 1 {
1400                Err(create_aborted_status(std::time::Duration::from_nanos(1)))
1401            } else {
1402                commit_response()
1403            }
1404        });
1405
1406        let res = execute_test_runner(mock, begin_transaction_option)
1407            .await
1408            .unwrap();
1409        assert_eq!(res, 5);
1410    }
1411
1412    #[tokio_test_no_panics]
1413    async fn execute_run_begin_transaction_fails_explicit() {
1414        run_begin_transaction_fails(BeginTransactionOption::ExplicitBegin).await;
1415    }
1416
1417    #[tokio_test_no_panics]
1418    async fn execute_run_begin_transaction_fails_inline() {
1419        run_begin_transaction_fails(BeginTransactionOption::InlineBegin).await;
1420    }
1421
1422    async fn run_begin_transaction_fails(begin_transaction_option: BeginTransactionOption) {
1423        let mut mock = create_session_mock();
1424        let mut seq = mockall::Sequence::new();
1425
1426        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1427            mock.expect_begin_transaction()
1428                .once()
1429                .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "internal error")));
1430        } else {
1431            mock.expect_execute_sql()
1432                .once()
1433                .in_sequence(&mut seq)
1434                .returning(move |req| {
1435                    let req = req.into_inner();
1436                    let transaction = req
1437                        .transaction
1438                        .as_ref()
1439                        .expect("transaction options required for inline begin");
1440                    let selector = transaction.selector.as_ref().expect("selector required");
1441                    assert!(matches!(
1442                        selector,
1443                        v1::transaction_selector::Selector::Begin(_)
1444                    ));
1445
1446                    Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
1447                });
1448
1449            mock.expect_begin_transaction()
1450                .once()
1451                .in_sequence(&mut seq)
1452                .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "internal error")));
1453        }
1454
1455        let res = execute_test_runner(mock, begin_transaction_option).await;
1456
1457        assert!(res.is_err());
1458        let err = res.unwrap_err();
1459        if let Some(status) = err.status() {
1460            assert_eq!(status.code, google_cloud_gax::error::rpc::Code::Internal);
1461        } else {
1462            panic!("Expected GRPC error");
1463        }
1464    }
1465
1466    #[tokio_test_no_panics]
1467    async fn builder_options() {
1468        use crate::transaction_retry_policy::BasicTransactionRetryPolicy;
1469
1470        let mock = create_session_mock();
1471        let (db_client, _server) = setup_db_client(mock).await;
1472
1473        let retry_policy = BasicTransactionRetryPolicy::new()
1474            .with_max_attempts(1)
1475            .with_total_timeout(std::time::Duration::from_secs(10));
1476
1477        // Validate builder chaining safely accepts and compiles options dynamically
1478        let _runner = TransactionRunnerBuilder::new(db_client)
1479            .set_isolation_level(IsolationLevel::Serializable)
1480            .set_read_lock_mode(ReadLockMode::Pessimistic)
1481            .with_retry_policy(retry_policy)
1482            .build()
1483            .await
1484            .unwrap();
1485    }
1486
1487    #[tokio_test_no_panics]
1488    async fn execute_run_batch_dml_aborted_retry_explicit() {
1489        run_batch_dml_aborted_retry(BeginTransactionOption::ExplicitBegin).await;
1490    }
1491
1492    #[tokio_test_no_panics]
1493    async fn execute_run_batch_dml_aborted_retry_inline() {
1494        run_batch_dml_aborted_retry(BeginTransactionOption::InlineBegin).await;
1495    }
1496
1497    async fn run_batch_dml_aborted_retry(begin_transaction_option: BeginTransactionOption) {
1498        use crate::batch_dml::BatchDml;
1499        use crate::statement::Statement;
1500        use gaxi::grpc::tonic::Code;
1501        use spanner_grpc_mock::google::rpc::Status;
1502        use spanner_grpc_mock::google::spanner::v1::result_set_stats::RowCount;
1503
1504        let mut mock = create_session_mock();
1505
1506        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1507            expect_begin_transaction(&mut mock, 2, vec![9, 9, 9]);
1508        }
1509
1510        let mut seq = mockall::Sequence::new();
1511        mock.expect_execute_batch_dml()
1512            .once()
1513            .in_sequence(&mut seq)
1514            .returning(move |req| {
1515                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1516                    let req = req.into_inner();
1517                    let selector = req
1518                        .transaction
1519                        .expect("missing transaction selector")
1520                        .selector
1521                        .expect("missing selector");
1522                    assert!(matches!(
1523                        selector,
1524                        v1::transaction_selector::Selector::Begin(_)
1525                    ));
1526                }
1527
1528                // Return a successful response but with an embedded aborted status.
1529                let status = Status {
1530                    code: Code::Aborted as i32,
1531                    message: "transaction aborted".to_string(),
1532                    ..Default::default()
1533                };
1534
1535                let mut metadata = v1::ResultSetMetadata {
1536                    ..Default::default()
1537                };
1538                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1539                    metadata.transaction = Some(v1::Transaction {
1540                        id: vec![9, 9, 9],
1541                        ..Default::default()
1542                    });
1543                }
1544
1545                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
1546                    result_sets: vec![v1::ResultSet {
1547                        metadata: Some(metadata),
1548                        stats: Some(v1::ResultSetStats {
1549                            row_count: Some(RowCount::RowCountExact(1)),
1550                            ..Default::default()
1551                        }),
1552                        ..Default::default()
1553                    }],
1554                    status: Some(status),
1555                    ..Default::default()
1556                }))
1557            });
1558        mock.expect_execute_batch_dml()
1559            .once()
1560            .in_sequence(&mut seq)
1561            .returning(move |req| {
1562                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1563                    let req = req.into_inner();
1564                    let selector = req
1565                        .transaction
1566                        .expect("missing transaction selector")
1567                        .selector
1568                        .expect("missing selector");
1569                    assert!(matches!(
1570                        selector,
1571                        v1::transaction_selector::Selector::Begin(_)
1572                    ));
1573                }
1574
1575                let mut metadata = v1::ResultSetMetadata {
1576                    ..Default::default()
1577                };
1578                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1579                    metadata.transaction = Some(v1::Transaction {
1580                        id: vec![9, 9, 9],
1581                        ..Default::default()
1582                    });
1583                }
1584
1585                // Return success after the retry.
1586                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
1587                    result_sets: vec![v1::ResultSet {
1588                        metadata: Some(metadata),
1589                        stats: Some(v1::ResultSetStats {
1590                            row_count: Some(RowCount::RowCountExact(5)),
1591                            ..Default::default()
1592                        }),
1593                        ..Default::default()
1594                    }],
1595                    ..Default::default()
1596                }))
1597            });
1598
1599        mock.expect_commit()
1600            .once()
1601            .returning(move |_| commit_response());
1602
1603        let (db_client, _) = setup_db_client(mock).await;
1604        let runner = TransactionRunnerBuilder::new(db_client)
1605            .with_begin_transaction_option(begin_transaction_option)
1606            .build()
1607            .await
1608            .expect("failed to build TransactionRunner");
1609
1610        let mut attempt_counter = 0;
1611
1612        // TransactionRunner retries the closure on transaction aborts
1613        let res = runner
1614            .run(async |tx| {
1615                attempt_counter += 1;
1616                let stmt = Statement::builder("UPDATE t SET c = 1").build();
1617                let batch = BatchDml::builder().add_statement(stmt).build();
1618                let counts = tx.execute_batch_update(batch).await?;
1619                Ok(counts)
1620            })
1621            .await
1622            .expect("transaction failed");
1623
1624        assert_eq!(res.result, vec![5]);
1625        assert_eq!(attempt_counter, 2);
1626    }
1627
1628    #[tokio_test_no_panics]
1629    async fn execute_run_with_transaction_tag_explicit() -> anyhow::Result<()> {
1630        run_with_transaction_tag(BeginTransactionOption::ExplicitBegin).await
1631    }
1632
1633    #[tokio_test_no_panics]
1634    async fn execute_run_with_transaction_tag_inline() -> anyhow::Result<()> {
1635        run_with_transaction_tag(BeginTransactionOption::InlineBegin).await
1636    }
1637
1638    async fn run_with_transaction_tag(
1639        begin_transaction_option: BeginTransactionOption,
1640    ) -> anyhow::Result<()> {
1641        let mut mock = create_session_mock();
1642
1643        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1644            mock.expect_begin_transaction().once().returning(|req| {
1645                let req = req.into_inner();
1646                // Check if the transaction tag is correctly propagated.
1647                assert_eq!(
1648                    req.request_options
1649                        .expect("Missing request_options")
1650                        .transaction_tag,
1651                    "my-test-tag"
1652                );
1653
1654                Ok(tonic::Response::new(v1::Transaction {
1655                    id: vec![9, 9, 9],
1656                    ..Default::default()
1657                }))
1658            });
1659        }
1660
1661        mock.expect_execute_sql().once().returning(move |req| {
1662            let req = req.into_inner();
1663            assert_eq!(
1664                req.request_options
1665                    .expect("Missing request_options")
1666                    .transaction_tag,
1667                "my-test-tag"
1668            );
1669
1670            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1671                let transaction = req
1672                    .transaction
1673                    .as_ref()
1674                    .expect("transaction options required for inline begin");
1675                let selector = transaction.selector.as_ref().expect("selector required");
1676                assert!(matches!(
1677                    selector,
1678                    v1::transaction_selector::Selector::Begin(_)
1679                ));
1680            }
1681
1682            let mut metadata = v1::ResultSetMetadata {
1683                ..Default::default()
1684            };
1685            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1686                metadata.transaction = Some(v1::Transaction {
1687                    id: vec![9, 9, 9],
1688                    ..Default::default()
1689                });
1690            }
1691
1692            Ok(tonic::Response::new(v1::ResultSet {
1693                metadata: Some(metadata),
1694                stats: Some(v1::ResultSetStats {
1695                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)),
1696                    ..Default::default()
1697                }),
1698                ..Default::default()
1699            }))
1700        });
1701
1702        mock.expect_commit().once().returning(|req| {
1703            let req = req.into_inner();
1704            assert_eq!(
1705                req.request_options
1706                    .expect("Missing request_options")
1707                    .transaction_tag,
1708                "my-test-tag"
1709            );
1710            commit_response()
1711        });
1712
1713        let (db_client, _server) = setup_db_client(mock).await;
1714
1715        let runner = TransactionRunnerBuilder::new(db_client)
1716            .with_begin_transaction_option(begin_transaction_option)
1717            .set_transaction_tag("my-test-tag")
1718            .build()
1719            .await?;
1720
1721        let res = runner
1722            .run(async |tx| {
1723                let count = tx.execute_update("UPDATE Users SET active = true").await?;
1724                Ok(count)
1725            })
1726            .await?;
1727
1728        assert_eq!(res.result, 5);
1729
1730        Ok(())
1731    }
1732
1733    #[tokio_test_no_panics]
1734    async fn execute_run_with_exclude_txn_from_change_streams_explicit() -> anyhow::Result<()> {
1735        run_with_exclude_txn_from_change_streams(BeginTransactionOption::ExplicitBegin).await
1736    }
1737
1738    #[tokio_test_no_panics]
1739    async fn execute_run_with_exclude_txn_from_change_streams_inline() -> anyhow::Result<()> {
1740        run_with_exclude_txn_from_change_streams(BeginTransactionOption::InlineBegin).await
1741    }
1742
1743    async fn run_with_exclude_txn_from_change_streams(
1744        begin_transaction_option: BeginTransactionOption,
1745    ) -> anyhow::Result<()> {
1746        let mut mock = create_session_mock();
1747
1748        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1749            mock.expect_begin_transaction().once().returning(|req| {
1750                let req = req.into_inner();
1751                let options = req.options.expect("Missing transaction options");
1752                assert!(options.exclude_txn_from_change_streams);
1753
1754                Ok(tonic::Response::new(v1::Transaction {
1755                    id: vec![9, 9, 9],
1756                    ..Default::default()
1757                }))
1758            });
1759        }
1760
1761        mock.expect_execute_sql().once().returning(move |req| {
1762            let req = req.into_inner();
1763            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1764                let transaction = req
1765                    .transaction
1766                    .as_ref()
1767                    .expect("transaction options required for inline begin");
1768                let selector = transaction.selector.as_ref().expect("selector required");
1769                assert!(matches!(
1770                    selector,
1771                    v1::transaction_selector::Selector::Begin(_)
1772                ));
1773            }
1774
1775            let mut metadata = v1::ResultSetMetadata {
1776                ..Default::default()
1777            };
1778            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1779                metadata.transaction = Some(v1::Transaction {
1780                    id: vec![9, 9, 9],
1781                    ..Default::default()
1782                });
1783            }
1784
1785            Ok(tonic::Response::new(v1::ResultSet {
1786                metadata: Some(metadata),
1787                stats: Some(v1::ResultSetStats {
1788                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)),
1789                    ..Default::default()
1790                }),
1791                ..Default::default()
1792            }))
1793        });
1794
1795        mock.expect_commit()
1796            .once()
1797            .returning(|_req| commit_response());
1798
1799        let (db_client, _server) = setup_db_client(mock).await;
1800
1801        let runner = TransactionRunnerBuilder::new(db_client)
1802            .set_exclude_txn_from_change_streams(true)
1803            .with_begin_transaction_option(begin_transaction_option)
1804            .build()
1805            .await?;
1806
1807        let res = runner
1808            .run(async |tx| {
1809                let count = tx.execute_update("UPDATE Users SET active = true").await?;
1810                Ok(count)
1811            })
1812            .await?;
1813
1814        assert_eq!(res.result, 5);
1815
1816        Ok(())
1817    }
1818
1819    #[tokio_test_no_panics]
1820    async fn execute_run_with_max_commit_delay_explicit() -> anyhow::Result<()> {
1821        run_with_max_commit_delay(BeginTransactionOption::ExplicitBegin).await
1822    }
1823
1824    #[tokio_test_no_panics]
1825    async fn execute_run_with_max_commit_delay_inline() -> anyhow::Result<()> {
1826        run_with_max_commit_delay(BeginTransactionOption::InlineBegin).await
1827    }
1828
1829    async fn run_with_max_commit_delay(
1830        begin_transaction_option: BeginTransactionOption,
1831    ) -> anyhow::Result<()> {
1832        let mut mock = create_session_mock();
1833
1834        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1835            expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);
1836        }
1837
1838        mock.expect_execute_sql().once().returning(move |req| {
1839            let req = req.into_inner();
1840            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1841                let transaction = req
1842                    .transaction
1843                    .as_ref()
1844                    .expect("transaction options required for inline begin");
1845                let selector = transaction.selector.as_ref().expect("selector required");
1846                assert!(matches!(
1847                    selector,
1848                    v1::transaction_selector::Selector::Begin(_)
1849                ));
1850            }
1851
1852            let mut metadata = v1::ResultSetMetadata {
1853                ..Default::default()
1854            };
1855            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1856                metadata.transaction = Some(v1::Transaction {
1857                    id: vec![1, 2, 3],
1858                    ..Default::default()
1859                });
1860            }
1861
1862            Ok(tonic::Response::new(v1::ResultSet {
1863                metadata: Some(metadata),
1864                stats: Some(v1::ResultSetStats {
1865                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1866                    ..Default::default()
1867                }),
1868                ..Default::default()
1869            }))
1870        });
1871
1872        mock.expect_commit().once().returning(|req| {
1873            let req = req.into_inner();
1874            assert_eq!(
1875                req.max_commit_delay,
1876                Some(::prost_types::Duration {
1877                    seconds: 0,
1878                    nanos: 200_000_000, // 200ms
1879                })
1880            );
1881            commit_response()
1882        });
1883
1884        let (db_client, _server) = setup_db_client(mock).await;
1885        let runner = TransactionRunnerBuilder::new(db_client)
1886            .set_max_commit_delay(Duration::try_from("0.2s").unwrap())
1887            .with_begin_transaction_option(begin_transaction_option)
1888            .build()
1889            .await?;
1890
1891        let res = runner
1892            .run(async |tx| {
1893                let count = tx.execute_update("UPDATE Users SET active = true").await?;
1894                Ok(count)
1895            })
1896            .await?;
1897        assert_eq!(res.result, 1);
1898        Ok(())
1899    }
1900
1901    #[tokio_test_no_panics]
1902    async fn execute_run_empty_closure_inline() {
1903        let mut mock = create_session_mock();
1904        expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);
1905        mock.expect_commit().once().returning(|req| {
1906            let req = req.into_inner();
1907            assert_eq!(
1908                req.transaction,
1909                Some(v1::commit_request::Transaction::TransactionId(vec![
1910                    1, 2, 3
1911                ]))
1912            );
1913            commit_response()
1914        });
1915
1916        let (db_client, _server) = setup_db_client(mock).await;
1917
1918        let runner = TransactionRunnerBuilder::new(db_client)
1919            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
1920            .build()
1921            .await
1922            .unwrap();
1923
1924        let res = runner.run(async |_tx| Ok(42)).await.unwrap();
1925        assert_eq!(res.result, 42);
1926    }
1927
1928    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1929    async fn execute_run_async_statement_still_starting() {
1930        let (tx_rpc, rx_rpc) = std_channel();
1931        let (tx_started, rx_started) = oneshot_channel();
1932        let tx_started_mutex = Mutex::new(Some(tx_started));
1933
1934        let mut mock = create_session_mock();
1935
1936        mock.expect_execute_sql().once().returning(move |_req| {
1937            if let Some(tx) = tx_started_mutex.lock().unwrap().take() {
1938                let _ = tx.send(());
1939            }
1940            rx_rpc.recv().unwrap();
1941            row_count_exact_response(1)
1942        });
1943
1944        let (db_client, _server) = setup_db_client(mock).await;
1945
1946        let runner = TransactionRunnerBuilder::new(db_client)
1947            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
1948            .build()
1949            .await
1950            .unwrap();
1951
1952        let mut rx_started_opt = Some(rx_started);
1953        let res = runner
1954            .run(async |tx| {
1955                tokio::spawn(async move {
1956                    let _ = tx.execute_update("UPDATE Users SET active = true").await;
1957                });
1958                if let Some(rx) = rx_started_opt.take() {
1959                    rx.await.unwrap();
1960                }
1961                Ok(42)
1962            })
1963            .await;
1964
1965        tx_rpc.send(()).unwrap();
1966
1967        assert!(res.is_err());
1968        assert!(
1969            format!("{:?}", res.unwrap_err())
1970                .contains("asynchronous statement is still starting the transaction")
1971        );
1972    }
1973
1974    #[tokio_test_no_panics]
1975    async fn execute_run_with_mutations_happy_flow() {
1976        let mut mock = create_session_mock();
1977
1978        mock.expect_execute_sql().once().returning(move |req| {
1979            let req = req.into_inner();
1980            assert_eq!(req.sql, "UPDATE Users SET active = true");
1981            let transaction = req
1982                .transaction
1983                .as_ref()
1984                .expect("transaction options required for inline begin");
1985            let selector = transaction.selector.as_ref().expect("selector required");
1986            assert!(matches!(selector, ProtoSelector::Begin(_)));
1987
1988            Ok(tonic::Response::new(v1::ResultSet {
1989                metadata: Some(v1::ResultSetMetadata {
1990                    transaction: Some(v1::Transaction {
1991                        id: vec![1, 1, 1],
1992                        ..Default::default()
1993                    }),
1994                    ..Default::default()
1995                }),
1996                stats: Some(v1::ResultSetStats {
1997                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1998                    ..Default::default()
1999                }),
2000                ..Default::default()
2001            }))
2002        });
2003
2004        mock.expect_commit().once().returning(|req| {
2005            let req = req.into_inner();
2006            assert_eq!(
2007                req.transaction,
2008                Some(CommitTransaction::TransactionId(vec![1, 1, 1]))
2009            );
2010            assert_eq!(req.mutations.len(), 1);
2011            commit_response()
2012        });
2013
2014        let (db_client, _server) = setup_db_client(mock).await;
2015        let runner = TransactionRunnerBuilder::new(db_client)
2016            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2017            .build()
2018            .await
2019            .expect("Failed to build transaction runner");
2020
2021        let res = runner
2022            .run(async |tx| {
2023                let count = tx.execute_update("UPDATE Users SET active = true").await?;
2024                let mutation = Mutation::new_insert_builder("Audits")
2025                    .set("AuditId")
2026                    .to(&1)
2027                    .build();
2028                tx.buffer([mutation])?;
2029                Ok(count)
2030            })
2031            .await
2032            .expect("Transaction runner failed");
2033
2034        assert_eq!(res.result, 1);
2035    }
2036
2037    #[tokio_test_no_panics]
2038    async fn execute_run_with_mutations_aborted_retry() {
2039        let mut mock = create_session_mock();
2040        let mut sequence = mockall::Sequence::new();
2041
2042        // Initial attempt: statement succeeds, returns tx id [10, 20, 30]
2043        mock.expect_execute_sql()
2044            .once()
2045            .in_sequence(&mut sequence)
2046            .returning(move |_req| {
2047                Ok(tonic::Response::new(v1::ResultSet {
2048                    metadata: Some(v1::ResultSetMetadata {
2049                        transaction: Some(v1::Transaction {
2050                            id: vec![10, 20, 30],
2051                            ..Default::default()
2052                        }),
2053                        ..Default::default()
2054                    }),
2055                    stats: Some(v1::ResultSetStats {
2056                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2057                        ..Default::default()
2058                    }),
2059                    ..Default::default()
2060                }))
2061            });
2062
2063        // Initial commit fails with Aborted
2064        mock.expect_commit()
2065            .once()
2066            .in_sequence(&mut sequence)
2067            .returning(|req| {
2068                let req = req.into_inner();
2069                assert_eq!(req.mutations.len(), 1);
2070                // Verify initial attempt added mutation for UserId 100
2071                let write = req.mutations[0]
2072                    .operation
2073                    .as_ref()
2074                    .expect("Operation required");
2075                match write {
2076                    Operation::Insert(w) => {
2077                        assert_eq!(
2078                            w.values[0].values[0].kind,
2079                            Some(Kind::StringValue("100".to_string()))
2080                        );
2081                    }
2082                    _ => panic!("Expected insert mutation"),
2083                }
2084                Err(create_aborted_status(StdTimeDuration::from_nanos(1)))
2085            });
2086
2087        // Retry attempt: execute_sql sends inline BeginTransaction with previous_transaction_id
2088        mock.expect_execute_sql()
2089            .once()
2090            .in_sequence(&mut sequence)
2091            .returning(move |req| {
2092                let req = req.into_inner();
2093                let transaction = req
2094                    .transaction
2095                    .as_ref()
2096                    .expect("transaction options required for inline begin");
2097                let selector = transaction.selector.as_ref().expect("selector required");
2098                let options = match selector {
2099                    ProtoSelector::Begin(o) => o,
2100                    _ => panic!("Expected Begin"),
2101                };
2102                let read_write = options.mode.as_ref().expect("mode required");
2103                match read_write {
2104                    Mode::ReadWrite(rw) => {
2105                        assert_eq!(
2106                            rw.multiplexed_session_previous_transaction_id,
2107                            vec![10, 20, 30]
2108                        );
2109                    }
2110                    _ => panic!("Expected ReadWrite"),
2111                }
2112
2113                Ok(tonic::Response::new(v1::ResultSet {
2114                    metadata: Some(v1::ResultSetMetadata {
2115                        transaction: Some(v1::Transaction {
2116                            id: vec![99, 99, 99],
2117                            ..Default::default()
2118                        }),
2119                        ..Default::default()
2120                    }),
2121                    stats: Some(v1::ResultSetStats {
2122                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2123                        ..Default::default()
2124                    }),
2125                    ..Default::default()
2126                }))
2127            });
2128
2129        // Second commit succeeds with the new mutation
2130        mock.expect_commit()
2131            .once()
2132            .in_sequence(&mut sequence)
2133            .returning(|req| {
2134                let req = req.into_inner();
2135                assert_eq!(
2136                    req.transaction,
2137                    Some(CommitTransaction::TransactionId(vec![99, 99, 99]))
2138                );
2139                assert_eq!(req.mutations.len(), 1);
2140                // Verify retry attempt added mutation for UserId 200
2141                let write = req.mutations[0]
2142                    .operation
2143                    .as_ref()
2144                    .expect("Operation required");
2145                match write {
2146                    Operation::Insert(w) => {
2147                        assert_eq!(
2148                            w.values[0].values[0].kind,
2149                            Some(Kind::StringValue("200".to_string()))
2150                        );
2151                    }
2152                    _ => panic!("Expected insert mutation"),
2153                }
2154                commit_response()
2155            });
2156
2157        let (db_client, _server) = setup_db_client(mock).await;
2158        let runner = TransactionRunnerBuilder::new(db_client)
2159            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2160            .build()
2161            .await
2162            .expect("Failed to build transaction runner");
2163
2164        let mut attempt = 0;
2165        let res = runner
2166            .run(async |tx| {
2167                attempt += 1;
2168                let count = tx.execute_update("UPDATE Users SET active = true").await?;
2169                let mutation_value = if attempt == 1 { 100 } else { 200 };
2170                let mutation = Mutation::new_insert_builder("Users")
2171                    .set("UserId")
2172                    .to(&mutation_value)
2173                    .build();
2174                tx.buffer([mutation])?;
2175                Ok(count)
2176            })
2177            .await
2178            .expect("Transaction runner failed");
2179
2180        assert_eq!(res.result, 1);
2181    }
2182
2183    #[tokio_test_no_panics]
2184    async fn execute_run_mutation_only_explicit_begin_fallback() {
2185        let mut mock = create_session_mock();
2186
2187        // Since the user closure executes no statements, commit() calls explicit BeginTransaction
2188        mock.expect_begin_transaction().once().returning(|req| {
2189            let req = req.into_inner();
2190            assert_eq!(
2191                req.session,
2192                "projects/p/instances/i/databases/d/sessions/123"
2193            );
2194            Ok(tonic::Response::new(v1::Transaction {
2195                id: vec![77, 88, 99],
2196                ..Default::default()
2197            }))
2198        });
2199
2200        mock.expect_commit().once().returning(|req| {
2201            let req = req.into_inner();
2202            assert_eq!(
2203                req.transaction,
2204                Some(CommitTransaction::TransactionId(vec![77, 88, 99]))
2205            );
2206            assert_eq!(req.mutations.len(), 2);
2207            commit_response()
2208        });
2209
2210        let (db_client, _server) = setup_db_client(mock).await;
2211        let runner = TransactionRunnerBuilder::new(db_client)
2212            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2213            .build()
2214            .await
2215            .expect("Failed to build transaction runner");
2216
2217        let res = runner
2218            .run(async |tx| {
2219                let m1 = Mutation::new_insert_builder("Orders")
2220                    .set("OrderId")
2221                    .to(&1)
2222                    .build();
2223                let m2 = Mutation::new_insert_builder("Orders")
2224                    .set("OrderId")
2225                    .to(&2)
2226                    .build();
2227                tx.buffer([m1, m2])?;
2228                Ok(())
2229            })
2230            .await
2231            .expect("Transaction runner failed");
2232
2233        assert_eq!(
2234            res.commit_response
2235                .commit_timestamp
2236                .expect("Timestamp required")
2237                .seconds(),
2238            123456789
2239        );
2240    }
2241
2242    #[tokio_test_no_panics]
2243    async fn read_write_transaction_builder_sets_gax_options() -> anyhow::Result<()> {
2244        let mock = create_session_mock();
2245        let (db_client, _server) = setup_db_client(mock).await;
2246
2247        let runner = TransactionRunnerBuilder::new(db_client)
2248            .with_begin_attempt_timeout(StdDuration::from_secs(5))
2249            .with_begin_retry_policy(NeverRetry)
2250            .with_begin_backoff_policy(ExponentialBackoff::default())
2251            .with_commit_attempt_timeout(StdDuration::from_secs(10))
2252            .with_commit_retry_policy(NeverRetry)
2253            .with_commit_backoff_policy(ExponentialBackoff::default());
2254
2255        let begin_gax = runner
2256            .begin_gax_options
2257            .as_ref()
2258            .expect("begin_gax_options missing");
2259        assert_eq!(
2260            *begin_gax.attempt_timeout(),
2261            Some(StdDuration::from_secs(5))
2262        );
2263        assert!(begin_gax.retry_policy().is_some());
2264        assert!(begin_gax.backoff_policy().is_some());
2265
2266        let commit_gax = runner
2267            .commit_gax_options
2268            .as_ref()
2269            .expect("commit_gax_options missing");
2270        assert_eq!(
2271            *commit_gax.attempt_timeout(),
2272            Some(StdDuration::from_secs(10))
2273        );
2274        assert!(commit_gax.retry_policy().is_some());
2275        assert!(commit_gax.backoff_policy().is_some());
2276
2277        Ok(())
2278    }
2279}