Skip to main content

google_cloud_spanner/
transaction_runner.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::database_client::DatabaseClient;
16use crate::model::CommitResponse;
17use crate::model::request_options::Priority;
18use crate::model::transaction_options::IsolationLevel;
19use crate::model::transaction_options::read_write::ReadLockMode;
20use crate::read_only_transaction::BeginTransactionOption;
21use crate::read_write_transaction::{ReadWriteTransaction, ReadWriteTransactionBuilder};
22use crate::transaction_retry_policy::{
23    BasicTransactionRetryPolicy, TransactionRetryPolicy, backoff_if_aborted, is_aborted,
24};
25use google_cloud_gax::backoff_policy::BackoffPolicyArg;
26use google_cloud_gax::retry_policy::RetryPolicyArg;
27
28use std::time::Duration as StdDuration;
29use tokio::time::Instant;
30use wkt::Duration;
31
32/// A builder for a [TransactionRunner] for a read/write transaction.
33///
34/// # Example
35/// ```
36/// # use google_cloud_spanner::client::Spanner;
37/// # use google_cloud_spanner::statement::Statement;
38/// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
39/// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
40/// let runner = db_client.read_write_transaction().build().await?;
41///
42/// let result = runner.run(async |transaction| {
43///     let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
44///     transaction.execute_update(statement).await?;
45///     Ok(42)
46/// }).await?;
47/// # Ok(())
48/// # }
49/// ```
50///
51/// Spanner can abort any read/write transaction at any time. A [TransactionRunner]
52/// automatically retries aborted transactions according to the configured retry policy.
53pub struct TransactionRunnerBuilder {
54    builder: ReadWriteTransactionBuilder,
55    retry_policy: Box<dyn TransactionRetryPolicy>,
56    timeout: Option<StdDuration>,
57    begin_gax_options: Option<crate::RequestOptions>,
58    commit_gax_options: Option<crate::RequestOptions>,
59}
60
61impl TransactionRunnerBuilder {
62    pub(crate) fn new(client: DatabaseClient) -> Self {
63        Self {
64            builder: ReadWriteTransactionBuilder::new(client),
65            retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
66            timeout: None,
67            begin_gax_options: None,
68            commit_gax_options: None,
69        }
70    }
71
72    /// Sets the timeout for the entire transaction.
73    ///
74    /// # Example
75    /// ```
76    /// # use google_cloud_spanner::client::Spanner;
77    /// # use std::time::Duration;
78    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
79    /// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
80    /// let runner = db_client.read_write_transaction()
81    ///     .with_transaction_timeout(Duration::from_secs(5))
82    ///     .build()
83    ///     .await?;
84    /// # Ok(())
85    /// # }
86    /// ```
87    ///
88    /// This timeout applies to the total time spent executing the transaction, including
89    /// all statements and automatic retries. Each individual RPC within the transaction
90    /// is automatically assigned a deadline derived from the remaining time of this
91    /// overall timeout.
92    pub fn with_transaction_timeout(mut self, timeout: StdDuration) -> Self {
93        self.timeout = Some(timeout);
94        self
95    }
96
97    /// Sets the per-attempt timeout for the BeginTransaction RPC.
98    ///
99    /// # Example
100    /// ```
101    /// # use google_cloud_spanner::client::Spanner;
102    /// # use std::time::Duration;
103    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
104    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
105    /// let runner = db_client.read_write_transaction()
106    ///     .with_begin_attempt_timeout(Duration::from_secs(5))
107    ///     .build()
108    ///     .await?;
109    /// # Ok(())
110    /// # }
111    /// ```
112    ///
113    /// Note: This timeout is only used if the transaction uses the `ExplicitBegin` transaction option.
114    pub fn with_begin_attempt_timeout(mut self, timeout: StdDuration) -> Self {
115        self.begin_gax_options
116            .get_or_insert_with(crate::RequestOptions::default)
117            .set_attempt_timeout(timeout);
118        self
119    }
120
121    /// Sets the retry policy for the BeginTransaction RPC.
122    ///
123    /// # Example
124    /// ```
125    /// # use google_cloud_spanner::client::Spanner;
126    /// # use google_cloud_gax::retry_policy::NeverRetry;
127    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
128    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
129    /// let runner = db_client.read_write_transaction()
130    ///     .with_begin_retry_policy(NeverRetry)
131    ///     .build()
132    ///     .await?;
133    /// # Ok(())
134    /// # }
135    /// ```
136    ///
137    /// Note: This policy is only used if the transaction uses the `ExplicitBegin` transaction option.
138    pub fn with_begin_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
139        self.begin_gax_options
140            .get_or_insert_with(crate::RequestOptions::default)
141            .set_retry_policy(policy);
142        self
143    }
144
145    /// Sets the backoff policy for the BeginTransaction RPC.
146    ///
147    /// # Example
148    /// ```
149    /// # use google_cloud_spanner::client::Spanner;
150    /// # use google_cloud_gax::exponential_backoff::ExponentialBackoff;
151    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
152    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
153    /// let runner = db_client.read_write_transaction()
154    ///     .with_begin_backoff_policy(ExponentialBackoff::default())
155    ///     .build()
156    ///     .await?;
157    /// # Ok(())
158    /// # }
159    /// ```
160    ///
161    /// Note: This policy is only used if the transaction uses the `ExplicitBegin` transaction option.
162    pub fn with_begin_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
163        self.begin_gax_options
164            .get_or_insert_with(crate::RequestOptions::default)
165            .set_backoff_policy(policy);
166        self
167    }
168
169    /// Sets the per-attempt timeout for the Commit RPC.
170    ///
171    /// # Example
172    /// ```
173    /// # use google_cloud_spanner::client::Spanner;
174    /// # use std::time::Duration;
175    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
176    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
177    /// let runner = db_client.read_write_transaction()
178    ///     .with_commit_attempt_timeout(Duration::from_secs(5))
179    ///     .build()
180    ///     .await?;
181    /// # Ok(())
182    /// # }
183    /// ```
184    pub fn with_commit_attempt_timeout(mut self, timeout: StdDuration) -> Self {
185        self.commit_gax_options
186            .get_or_insert_with(crate::RequestOptions::default)
187            .set_attempt_timeout(timeout);
188        self
189    }
190
191    /// Sets the retry policy for the Commit RPC.
192    ///
193    /// # Example
194    /// ```
195    /// # use google_cloud_spanner::client::Spanner;
196    /// # use google_cloud_gax::retry_policy::NeverRetry;
197    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
198    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
199    /// let runner = db_client.read_write_transaction()
200    ///     .with_commit_retry_policy(NeverRetry)
201    ///     .build()
202    ///     .await?;
203    /// # Ok(())
204    /// # }
205    /// ```
206    pub fn with_commit_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
207        self.commit_gax_options
208            .get_or_insert_with(crate::RequestOptions::default)
209            .set_retry_policy(policy);
210        self
211    }
212
213    /// Sets the backoff policy for the Commit RPC.
214    ///
215    /// # Example
216    /// ```
217    /// # use google_cloud_spanner::client::Spanner;
218    /// # use google_cloud_gax::exponential_backoff::ExponentialBackoff;
219    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
220    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
221    /// let runner = db_client.read_write_transaction()
222    ///     .with_commit_backoff_policy(ExponentialBackoff::default())
223    ///     .build()
224    ///     .await?;
225    /// # Ok(())
226    /// # }
227    /// ```
228    pub fn with_commit_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
229        self.commit_gax_options
230            .get_or_insert_with(crate::RequestOptions::default)
231            .set_backoff_policy(policy);
232        self
233    }
234
235    /// Sets the isolation level for the transaction.
236    ///
237    /// # Example
238    /// ```
239    /// # use google_cloud_spanner::client::Spanner;
240    /// # use google_cloud_spanner::model::transaction_options::IsolationLevel;
241    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
242    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
243    /// let runner = db_client
244    ///     .read_write_transaction()
245    ///     .set_isolation_level(IsolationLevel::Serializable)
246    ///     .build()
247    ///     .await?;
248    /// # Ok(())
249    /// # }
250    /// ```
251    ///
252    /// See also: <https://docs.cloud.google.com/spanner/docs/isolation-levels>
253    pub fn set_isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
254        self.builder = self.builder.set_isolation_level(isolation_level);
255        self
256    }
257
258    /// Sets the read lock mode for the transaction.
259    ///
260    /// # Example
261    /// ```
262    /// # use google_cloud_spanner::client::Spanner;
263    /// # use google_cloud_spanner::model::transaction_options::read_write::ReadLockMode;
264    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
265    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
266    /// let runner = db_client
267    ///     .read_write_transaction()
268    ///     .set_read_lock_mode(ReadLockMode::Pessimistic)
269    ///     .build()
270    ///     .await?;
271    /// # Ok(())
272    /// # }
273    /// ```
274    ///
275    /// See also: <https://docs.cloud.google.com/spanner/docs/concurrency-control>
276    pub fn set_read_lock_mode(mut self, read_lock_mode: ReadLockMode) -> Self {
277        self.builder = self.builder.set_read_lock_mode(read_lock_mode);
278        self
279    }
280
281    /// Sets the transaction tag for the transaction.
282    ///
283    /// # Example
284    /// ```
285    /// # use google_cloud_spanner::client::Spanner;
286    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
287    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
288    /// let runner = db_client.read_write_transaction()
289    ///     .set_transaction_tag("my-tag")
290    ///     .build()
291    ///     .await?;
292    /// # Ok(())
293    /// # }
294    /// ```
295    ///
296    /// The tag is applied to all statements executed within the transaction.
297    ///
298    /// See also: [Troubleshooting with tags](https://docs.cloud.google.com/spanner/docs/introspection/troubleshooting-with-tags)
299    pub fn set_transaction_tag(mut self, tag: impl Into<String>) -> Self {
300        self.builder = self.builder.set_transaction_tag(tag);
301        self
302    }
303
304    /// Sets the option for how to start a transaction.
305    ///
306    /// # Example
307    /// ```
308    /// # use google_cloud_spanner::client::Spanner;
309    /// # use google_cloud_spanner::transaction::BeginTransactionOption;
310    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
311    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
312    /// let runner = db_client
313    ///     .read_write_transaction()
314    ///     .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
315    ///     .build()
316    ///     .await?;
317    /// # Ok(())
318    /// # }
319    /// ```
320    ///
321    /// By default, the Spanner client will inline the `BeginTransaction` call with the first query
322    /// or DML statement in the transaction. This reduces the number of round-trips to Spanner that
323    /// are needed for a transaction. Setting this option to `ExplicitBegin` can be beneficial for
324    /// specific transaction shapes:
325    ///
326    /// 1. When the transaction executes multiple parallel queries at the start of the transaction.
327    ///    Only one query can include a `BeginTransaction` option, and all other queries must wait for
328    ///    the first query to return the first result before they can proceed to execute. A
329    ///    `BeginTransaction` RPC will quickly return a transaction ID and allow all queries to start
330    ///    execution in parallel once the transaction ID has been returned.
331    /// 2. When the first statement in the transaction could fail. If the statement fails, then it
332    ///    will also not start a transaction and return a transaction ID. The transaction will then
333    ///    fall back to executing a `BeginTransaction` RPC and retry the first statement.
334    ///
335    /// Default is `BeginTransactionOption::InlineBegin`.
336    pub fn with_begin_transaction_option(mut self, option: BeginTransactionOption) -> Self {
337        self.builder = self.builder.with_begin_transaction_option(option);
338        self
339    }
340
341    /// Sets the RPC priority to use for the commit of this transaction.
342    ///
343    /// # Example
344    /// ```
345    /// # use google_cloud_spanner::client::Spanner;
346    /// # use google_cloud_spanner::model::request_options::Priority;
347    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
348    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
349    /// let runner = db_client
350    ///     .read_write_transaction()
351    ///     .set_commit_priority(Priority::Low)
352    ///     .build()
353    ///     .await?;
354    /// # Ok(())
355    /// # }
356    /// ```
357    pub fn set_commit_priority(mut self, priority: Priority) -> Self {
358        self.builder = self.builder.set_commit_priority(priority);
359        self
360    }
361
362    /// Sets the maximum commit delay for the transaction.
363    ///
364    /// # Example
365    /// ```
366    /// # use google_cloud_spanner::client::Spanner;
367    /// # use wkt::Duration;
368    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
369    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
370    /// let runner = db_client
371    ///     .read_write_transaction()
372    ///     .set_max_commit_delay(Duration::try_from("0.2s").unwrap())
373    ///     .build()
374    ///     .await?;
375    /// # Ok(())
376    /// # }
377    /// ```
378    ///
379    /// This option allows you to specify the maximum amount of time Spanner can
380    /// adjust the commit timestamp of the transaction to allow for commit batching.
381    /// Increasing this value can increase throughput at the expense of latency.
382    /// The value must be between 0 and 500 milliseconds. If not set, or set to 0,
383    /// Spanner does not delay the commit.
384    pub fn set_max_commit_delay(mut self, delay: Duration) -> Self {
385        self.builder = self.builder.set_max_commit_delay(delay);
386        self
387    }
388
389    /// Sets whether to exclude the transaction from change streams.
390    ///
391    /// # Example
392    /// ```
393    /// # use google_cloud_spanner::client::Spanner;
394    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
395    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
396    /// let runner = db_client.read_write_transaction()
397    ///     .set_exclude_txn_from_change_streams(true)
398    ///     .build()
399    ///     .await?;
400    /// # Ok(())
401    /// # }
402    /// ```
403    ///
404    /// When set to `true`, it prevents modifications from this transaction from being tracked in change streams.
405    /// Note that this only affects change streams that have been created with the DDL option `allow_txn_exclusion = true`.
406    /// If `allow_txn_exclusion` is not set or set to `false` for a change stream, updates made within this transaction
407    /// are recorded in that change stream regardless of this setting.
408    ///
409    /// When set to `false` or not specified, modifications from this transaction are recorded in all change streams
410    /// tracking columns modified by this transaction.
411    pub fn set_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
412        self.builder = self.builder.set_exclude_txn_from_change_streams(exclude);
413        self
414    }
415
416    /// Sets whether to return commit stats for the transaction.
417    ///
418    /// # Example
419    /// ```
420    /// # use google_cloud_spanner::client::Spanner;
421    /// # use google_cloud_spanner::statement::Statement;
422    /// # async fn run_tx(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
423    /// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
424    /// let runner = db_client.read_write_transaction()
425    ///     .set_return_commit_stats(true)
426    ///     .build()
427    ///     .await?;
428    ///
429    /// let result = runner.run(async |transaction| {
430    ///     let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
431    ///     transaction.execute_update(statement).await?;
432    ///     Ok(42)
433    /// }).await?;
434    ///
435    /// if let Some(stats) = result.commit_response.commit_stats {
436    ///     println!("Mutation count: {}", stats.mutation_count);
437    /// }
438    /// # Ok(())
439    /// # }
440    /// ```
441    ///
442    /// See also: <https://docs.cloud.google.com/spanner/docs/commit-statistics>
443    pub fn set_return_commit_stats(mut self, return_stats: bool) -> Self {
444        self.builder = self.builder.set_return_commit_stats(return_stats);
445        self
446    }
447
448    /// Sets the retry policy for the transaction.
449    ///
450    /// # Example
451    /// ```
452    /// # use std::time::Duration;
453    /// # use google_cloud_spanner::client::Spanner;
454    /// # use google_cloud_spanner::transaction::BasicTransactionRetryPolicy;
455    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
456    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
457    ///
458    /// let retry_policy = BasicTransactionRetryPolicy::new()
459    ///     .with_max_attempts(5)
460    ///     .with_total_timeout(Duration::from_secs(60));
461    ///
462    /// let runner = db_client
463    ///     .read_write_transaction()
464    ///     .with_retry_policy(retry_policy)
465    ///     .build()
466    ///     .await?;
467    /// # Ok(())
468    /// # }
469    /// ```
470    pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
471        self.retry_policy = Box::new(policy);
472        self
473    }
474
475    /// Builds a [TransactionRunner] for a read/write transaction.
476    ///
477    /// # Example
478    /// ```
479    /// # use google_cloud_spanner::client::Spanner;
480    /// # use google_cloud_spanner::statement::Statement;
481    /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
482    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
483    /// let runner = db_client.read_write_transaction().build().await?;
484    ///
485    /// let result = runner.run(async |transaction| {
486    ///     let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
487    ///     transaction.execute_update(statement).await?;
488    ///     Ok(42)
489    /// }).await?;
490    /// # Ok(())
491    /// # }
492    /// ```
493    pub async fn build(self) -> crate::Result<TransactionRunner> {
494        Ok(TransactionRunner {
495            builder: self
496                .builder
497                .with_begin_transaction_request_options(self.begin_gax_options)
498                .with_commit_request_options(self.commit_gax_options),
499            retry_policy: self.retry_policy,
500            timeout: self.timeout,
501        })
502    }
503}
504
505/// Result of a read/write transaction executed by a [TransactionRunner].
506#[derive(Debug)]
507#[non_exhaustive]
508pub struct TransactionResult<T> {
509    /// The result returned by the closure executed within the transaction.
510    pub result: T,
511    /// The response from the commit RPC.
512    pub commit_response: CommitResponse,
513}
514
515/// A runner for read/write transactions. Aborted transactions are automatically retried.
516pub struct TransactionRunner {
517    builder: ReadWriteTransactionBuilder,
518    retry_policy: Box<dyn TransactionRetryPolicy>,
519    timeout: Option<StdDuration>,
520}
521
522impl TransactionRunner {
523    /// Runs the provided closure within the context of a read/write transaction.
524    ///
525    /// # Example
526    /// ```
527    /// # use google_cloud_spanner::client::Spanner;
528    /// # use google_cloud_spanner::statement::Statement;
529    /// # async fn run_tx(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
530    /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
531    /// let runner = db_client.read_write_transaction().build().await?;
532    ///
533    /// let result = runner.run(async |transaction| {
534    ///     let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
535    ///     transaction.execute_update(statement).await?;
536    ///     Ok(42)
537    /// }).await?;
538    /// # Ok(())
539    /// # }
540    /// ```
541    ///
542    /// If the transaction is aborted by Spanner, the closure will be retried
543    /// automatically according to the configured `TransactionRetryPolicy`.
544    ///
545    /// The transaction is automatically committed if the closure returns `Ok`.
546    /// If the closure returns `Err`, the transaction will be rolled back and
547    /// the error will be propagated.
548    pub async fn run<T, F>(mut self, mut work: F) -> crate::Result<TransactionResult<T>>
549    where
550        F: std::ops::AsyncFnMut(ReadWriteTransaction) -> crate::Result<T>,
551    {
552        let start_time = Instant::now();
553        let mut attempts: u32 = 0;
554        let backoff = crate::transaction_retry_policy::default_retry_backoff();
555        let deadline = self.timeout.map(|t| start_time + t);
556
557        loop {
558            attempts += 1;
559
560            let mut current_tx_id = None;
561            let attempt_result = async {
562                let transaction = self.builder.clone().build(deadline).await?;
563
564                let result = match work(transaction.clone()).await {
565                    Ok(res) => res,
566                    Err(e) => {
567                        // We call `get_id_no_wait` here to retrieve the transaction ID without waiting.
568                        // We do not require the transaction ID to be unconditionally available here;
569                        // we only wish to capture it if the transaction successfully started prior to
570                        // failing, so it can be used as the previous transaction ID if the transaction
571                        // was aborted.
572                        let id = transaction
573                            .context
574                            .transaction_selector
575                            .get_id_no_wait()
576                            .ok()
577                            .flatten();
578                        // Rollback if the closure failed and it was not an Aborted error.
579                        if !is_aborted(&e) {
580                            let _ = transaction.rollback().await;
581                        }
582                        current_tx_id = id;
583                        return Err(e);
584                    }
585                };
586
587                // `commit()` consumes `transaction`. If the commit RPC fails with an Aborted error,
588                // we still need access to the transaction ID so we can provide it as `previous_transaction_id`
589                // on retry. Cloning only `transaction_selector` preserves access to the internal state efficiently.
590                let selector = transaction.context.transaction_selector.clone();
591                let commit_result = transaction.commit().await;
592                current_tx_id = selector.get_id_no_wait().ok().flatten();
593                let commit_response = commit_result?;
594                Ok::<TransactionResult<T>, crate::Error>(TransactionResult {
595                    result,
596                    commit_response,
597                })
598            }
599            .await;
600
601            match attempt_result {
602                Ok(res) => return Ok(res),
603                Err(e) => {
604                    if is_aborted(&e) {
605                        let current_tx_id = current_tx_id.clone();
606                        self.builder = self.builder.set_previous_transaction_id(current_tx_id);
607                    }
608
609                    backoff_if_aborted(
610                        e,
611                        attempts,
612                        start_time.elapsed(),
613                        self.retry_policy.as_ref(),
614                        &backoff,
615                        self.builder.client.is_emulator(),
616                    )
617                    .await?;
618                }
619            }
620        }
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627    use crate::mutation::Mutation;
628    use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
629    use crate::transaction_retry_policy::tests::create_aborted_status;
630    use gaxi::grpc::tonic;
631    use google_cloud_gax::exponential_backoff::ExponentialBackoff;
632    use google_cloud_gax::retry_policy::NeverRetry;
633    use google_cloud_test_macros::tokio_test_no_panics;
634    use prost_types::value::Kind;
635    use spanner_grpc_mock::google::spanner::v1;
636    use spanner_grpc_mock::google::spanner::v1::CommitResponse;
637    use spanner_grpc_mock::google::spanner::v1::commit_request::Transaction as CommitTransaction;
638    use spanner_grpc_mock::google::spanner::v1::commit_response::CommitStats;
639    use spanner_grpc_mock::google::spanner::v1::mutation::Operation;
640    use spanner_grpc_mock::google::spanner::v1::transaction_options::Mode;
641    use spanner_grpc_mock::google::spanner::v1::transaction_selector::Selector as ProtoSelector;
642    use std::sync::Mutex;
643    use std::sync::mpsc::channel as std_channel;
644    use std::time::Duration as StdDuration;
645    use std::time::Duration as StdTimeDuration;
646    use tokio::sync::oneshot::channel as oneshot_channel;
647
648    fn expect_begin_transaction(
649        mock: &mut spanner_grpc_mock::MockSpanner,
650        times: usize,
651        transaction_id: Vec<u8>,
652    ) {
653        mock.expect_begin_transaction()
654            .times(times)
655            .returning(move |req| {
656                let req = req.into_inner();
657                assert_eq!(
658                    req.session,
659                    "projects/p/instances/i/databases/d/sessions/123"
660                );
661                Ok(tonic::Response::new(v1::Transaction {
662                    id: transaction_id.clone(),
663                    ..Default::default()
664                }))
665            });
666    }
667
668    async fn execute_test_runner(
669        mock: spanner_grpc_mock::MockSpanner,
670        begin_transaction_option: BeginTransactionOption,
671    ) -> Result<i64, crate::Error> {
672        let (db_client, server) = setup_db_client(mock).await;
673        let runner = TransactionRunnerBuilder::new(db_client)
674            .with_begin_transaction_option(begin_transaction_option)
675            .build()
676            .await
677            .unwrap();
678        tokio::select! {
679            res = runner.run(async |tx| {
680                let count = tx.execute_update("UPDATE Users SET active = true").await?;
681                Ok(count)
682            }) => res.map(|r| r.result),
683            err = server => panic!("Mock server panicked or terminated unexpectedly: {:?}", err),
684        }
685    }
686
687    fn commit_response() -> Result<tonic::Response<v1::CommitResponse>, tonic::Status> {
688        Ok(tonic::Response::new(v1::CommitResponse {
689            commit_timestamp: Some(prost_types::Timestamp {
690                seconds: 123456789,
691                nanos: 0,
692            }),
693            ..Default::default()
694        }))
695    }
696
697    fn row_count_exact_response(
698        count: i64,
699    ) -> Result<tonic::Response<v1::ResultSet>, tonic::Status> {
700        Ok(tonic::Response::new(v1::ResultSet {
701            stats: Some(v1::ResultSetStats {
702                row_count: Some(v1::result_set_stats::RowCount::RowCountExact(count)),
703                ..Default::default()
704            }),
705            ..Default::default()
706        }))
707    }
708
709    #[test]
710    fn auto_traits() {
711        static_assertions::assert_impl_all!(TransactionRunnerBuilder: Send, Sync);
712        static_assertions::assert_impl_all!(TransactionRunner: Send, Sync);
713    }
714
715    #[tokio_test_no_panics]
716    async fn execute_run_success_explicit() {
717        run_success(BeginTransactionOption::ExplicitBegin).await;
718    }
719
720    #[tokio_test_no_panics]
721    async fn execute_run_success_inline() {
722        run_success(BeginTransactionOption::InlineBegin).await;
723    }
724
725    async fn run_success(begin_transaction_option: BeginTransactionOption) {
726        let mut mock = create_session_mock();
727
728        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
729            expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);
730        }
731
732        mock.expect_execute_sql().once().returning(move |req| {
733            let req = req.into_inner();
734            assert_eq!(req.sql, "UPDATE Users SET active = true");
735            assert_eq!(req.seqno, 1);
736
737            if begin_transaction_option == BeginTransactionOption::InlineBegin {
738                let transaction = req
739                    .transaction
740                    .as_ref()
741                    .expect("transaction options required for inline begin");
742                let selector = transaction.selector.as_ref().expect("selector required");
743                assert!(matches!(
744                    selector,
745                    v1::transaction_selector::Selector::Begin(_)
746                ));
747            }
748
749            let mut metadata = v1::ResultSetMetadata {
750                ..Default::default()
751            };
752            if begin_transaction_option == BeginTransactionOption::InlineBegin {
753                metadata.transaction = Some(v1::Transaction {
754                    id: vec![1, 2, 3],
755                    ..Default::default()
756                });
757            }
758
759            Ok(tonic::Response::new(v1::ResultSet {
760                metadata: Some(metadata),
761                stats: Some(v1::ResultSetStats {
762                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
763                    ..Default::default()
764                }),
765                ..Default::default()
766            }))
767        });
768
769        mock.expect_commit().once().returning(|req| {
770            let req = req.into_inner();
771            assert_eq!(
772                req.transaction,
773                Some(v1::commit_request::Transaction::TransactionId(vec![
774                    1, 2, 3
775                ]))
776            );
777            commit_response()
778        });
779
780        let res = execute_test_runner(mock, begin_transaction_option)
781            .await
782            .unwrap();
783        assert_eq!(res, 1);
784    }
785
786    #[tokio_test_no_panics]
787    async fn execute_run_success_with_commit_stats_explicit() {
788        run_success_with_commit_stats(BeginTransactionOption::ExplicitBegin).await;
789    }
790
791    #[tokio_test_no_panics]
792    async fn execute_run_success_with_commit_stats_inline() {
793        run_success_with_commit_stats(BeginTransactionOption::InlineBegin).await;
794    }
795
796    async fn run_success_with_commit_stats(begin_transaction_option: BeginTransactionOption) {
797        let mut mock = create_session_mock();
798
799        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
800            expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);
801        }
802
803        mock.expect_execute_sql().once().returning(move |req| {
804            let req = req.into_inner();
805            assert_eq!(req.sql, "UPDATE Users SET active = true");
806
807            if begin_transaction_option == BeginTransactionOption::InlineBegin {
808                let transaction = req
809                    .transaction
810                    .as_ref()
811                    .expect("transaction options required for inline begin");
812                let selector = transaction.selector.as_ref().expect("selector required");
813                assert!(matches!(
814                    selector,
815                    v1::transaction_selector::Selector::Begin(_)
816                ));
817            }
818
819            let mut metadata = v1::ResultSetMetadata {
820                ..Default::default()
821            };
822            if begin_transaction_option == BeginTransactionOption::InlineBegin {
823                metadata.transaction = Some(v1::Transaction {
824                    id: vec![1, 2, 3],
825                    ..Default::default()
826                });
827            }
828
829            Ok(tonic::Response::new(v1::ResultSet {
830                metadata: Some(metadata),
831                stats: Some(v1::ResultSetStats {
832                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
833                    ..Default::default()
834                }),
835                ..Default::default()
836            }))
837        });
838
839        mock.expect_commit().once().returning(|req| {
840            let req = req.into_inner();
841            assert!(req.return_commit_stats);
842            Ok(tonic::Response::new(CommitResponse {
843                commit_timestamp: Some(prost_types::Timestamp {
844                    seconds: 123456789,
845                    nanos: 0,
846                }),
847                commit_stats: Some(CommitStats { mutation_count: 5 }),
848                ..Default::default()
849            }))
850        });
851
852        let (db_client, _server) = setup_db_client(mock).await;
853        let runner = TransactionRunnerBuilder::new(db_client)
854            .set_return_commit_stats(true)
855            .with_begin_transaction_option(begin_transaction_option)
856            .build()
857            .await
858            .unwrap();
859
860        let res = runner
861            .run(async |tx| {
862                let count = tx.execute_update("UPDATE Users SET active = true").await?;
863                Ok(count)
864            })
865            .await
866            .unwrap();
867
868        assert_eq!(res.result, 1);
869        assert!(res.commit_response.commit_stats.is_some());
870        assert_eq!(
871            res.commit_response
872                .commit_stats
873                .expect("Commit stats should be present")
874                .mutation_count,
875            5
876        );
877    }
878
879    #[tokio_test_no_panics]
880    async fn execute_run_with_aborted_retry_explicit() -> anyhow::Result<()> {
881        run_with_aborted_retry(BeginTransactionOption::ExplicitBegin).await
882    }
883
884    #[tokio_test_no_panics]
885    async fn execute_run_with_aborted_retry_inline() -> anyhow::Result<()> {
886        run_with_aborted_retry(BeginTransactionOption::InlineBegin).await
887    }
888
889    async fn run_with_aborted_retry(
890        begin_transaction_option: BeginTransactionOption,
891    ) -> anyhow::Result<()> {
892        let mut mock = create_session_mock();
893        let mut seq = mockall::Sequence::new();
894
895        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
896            mock.expect_begin_transaction()
897                .once()
898                .in_sequence(&mut seq)
899                .returning(move |req| {
900                    let req = req.into_inner();
901                    assert_eq!(
902                        req.session,
903                        "projects/p/instances/i/databases/d/sessions/123"
904                    );
905                    Ok(tonic::Response::new(v1::Transaction {
906                        id: vec![9, 9, 9],
907                        ..Default::default()
908                    }))
909                });
910        }
911
912        if begin_transaction_option == BeginTransactionOption::InlineBegin {
913            // Attempt 1: execute_sql fails with Aborted
914            mock.expect_execute_sql()
915                .once()
916                .in_sequence(&mut seq)
917                .returning(move |req| {
918                    let req = req.into_inner();
919                    let transaction = req
920                        .transaction
921                        .as_ref()
922                        .expect("transaction options required for inline begin");
923                    let selector = transaction.selector.as_ref().expect("selector required");
924                    assert!(matches!(
925                        selector,
926                        v1::transaction_selector::Selector::Begin(_)
927                    ));
928
929                    Err(create_aborted_status(std::time::Duration::from_nanos(1)))
930                });
931        } else {
932            mock.expect_execute_sql()
933                .once()
934                .in_sequence(&mut seq)
935                .returning(move |_req| {
936                    Err(create_aborted_status(std::time::Duration::from_nanos(1)))
937                });
938        }
939
940        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
941            mock.expect_begin_transaction()
942                .once()
943                .in_sequence(&mut seq)
944                .returning(move |req| {
945                    let req = req.into_inner();
946                    assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
947
948                    let options = req.options.as_ref().expect("options required on retry");
949                    let read_write = options.mode.as_ref().expect("mode required on retry");
950                    match read_write {
951                        Mode::ReadWrite(rw) => {
952                            assert_eq!(rw.multiplexed_session_previous_transaction_id, vec![9, 9, 9], "previous_transaction_id should be set to the ID of the aborted transaction");
953                        }
954                        _ => panic!("Expected ReadWrite mode"),
955                    }
956
957                    Ok(tonic::Response::new(v1::Transaction {
958                        id: vec![8, 8, 8],
959                        ..Default::default()
960                    }))
961                });
962        }
963
964        // Attempt 2 (retry of closure)
965        mock.expect_execute_sql()
966            .once()
967            .in_sequence(&mut seq)
968            .returning(move |req| {
969                if begin_transaction_option == BeginTransactionOption::InlineBegin {
970                    let req = req.into_inner();
971                    let transaction = req
972                        .transaction
973                        .as_ref()
974                        .expect("transaction options required for inline begin");
975                    let selector = transaction.selector.as_ref().expect("selector required");
976                    assert!(matches!(
977                        selector,
978                        v1::transaction_selector::Selector::Begin(_)
979                    ));
980
981                    let options = match selector {
982                        v1::transaction_selector::Selector::Begin(o) => o,
983                        _ => panic!("Expected Begin"),
984                    };
985                    let read_write = options.mode.as_ref().expect("mode required");
986                    match read_write {
987                        Mode::ReadWrite(rw) => {
988                            assert!(rw.multiplexed_session_previous_transaction_id.is_empty());
989                        }
990                        _ => panic!("Expected ReadWrite"),
991                    }
992                }
993
994                let mut metadata = v1::ResultSetMetadata {
995                    ..Default::default()
996                };
997                if begin_transaction_option == BeginTransactionOption::InlineBegin {
998                    metadata.transaction = Some(v1::Transaction {
999                        id: vec![8, 8, 8],
1000                        ..Default::default()
1001                    });
1002                }
1003
1004                Ok(tonic::Response::new(v1::ResultSet {
1005                    metadata: Some(metadata),
1006                    stats: Some(v1::ResultSetStats {
1007                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)),
1008                        ..Default::default()
1009                    }),
1010                    ..Default::default()
1011                }))
1012            });
1013
1014        mock.expect_commit()
1015            .once()
1016            .returning(|_req| commit_response());
1017
1018        let res = execute_test_runner(mock, begin_transaction_option)
1019            .await
1020            .expect("runner should succeed");
1021        assert_eq!(res, 5);
1022        Ok(())
1023    }
1024
1025    #[tokio_test_no_panics]
1026    async fn execute_run_query_stream_with_aborted_retry_explicit() -> anyhow::Result<()> {
1027        run_query_stream_with_aborted_retry(BeginTransactionOption::ExplicitBegin).await
1028    }
1029
1030    #[tokio_test_no_panics]
1031    async fn execute_run_query_stream_with_aborted_retry_inline() -> anyhow::Result<()> {
1032        run_query_stream_with_aborted_retry(BeginTransactionOption::InlineBegin).await
1033    }
1034
1035    async fn run_query_stream_with_aborted_retry(
1036        begin_transaction_option: BeginTransactionOption,
1037    ) -> anyhow::Result<()> {
1038        let mut mock = create_session_mock();
1039        let mut seq = mockall::Sequence::new();
1040
1041        let tx_id_1 = vec![9, 9, 9];
1042        let tx_id_2 = vec![8, 8, 8];
1043
1044        let tx_id_1_c1 = tx_id_1.clone();
1045        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1046            mock.expect_begin_transaction()
1047                .once()
1048                .in_sequence(&mut seq)
1049                .returning(move |_| {
1050                    Ok(tonic::Response::new(v1::Transaction {
1051                        id: tx_id_1_c1.clone(),
1052                        ..Default::default()
1053                    }))
1054                });
1055        }
1056
1057        let tx_id_1_c2 = tx_id_1.clone();
1058        mock.expect_execute_streaming_sql()
1059            .once()
1060            .in_sequence(&mut seq)
1061            .returning(move |req| {
1062                let req = req.into_inner();
1063                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1064                    let transaction = req
1065                        .transaction
1066                        .as_ref()
1067                        .expect("transaction options required for inline begin");
1068                    let selector = transaction.selector.as_ref().expect("selector required");
1069                    assert!(matches!(
1070                        selector,
1071                        v1::transaction_selector::Selector::Begin(_)
1072                    ));
1073                }
1074
1075                let mut rs = v1::PartialResultSet {
1076                    metadata: Some(v1::ResultSetMetadata {
1077                        row_type: Some(v1::StructType {
1078                            fields: vec![Default::default()],
1079                        }),
1080                        ..Default::default()
1081                    }),
1082                    values: vec![prost_types::Value {
1083                        kind: Some(prost_types::value::Kind::StringValue("1".to_string())),
1084                    }],
1085                    resume_token: b"token1".to_vec(),
1086                    ..Default::default()
1087                };
1088
1089                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1090                    rs.metadata.as_mut().unwrap().transaction = Some(v1::Transaction {
1091                        id: tx_id_1_c2.clone(),
1092                        ..Default::default()
1093                    });
1094                }
1095
1096                let (tx, rx) = tokio::sync::mpsc::channel(2);
1097                tx.try_send(Ok(rs)).unwrap();
1098                tx.try_send(Err(tonic::Status::new(tonic::Code::Aborted, "aborted")))
1099                    .unwrap();
1100                Ok(tonic::Response::from(rx))
1101            });
1102
1103        let tx_id_2_c1 = tx_id_2.clone();
1104        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1105            mock.expect_begin_transaction()
1106                .once()
1107                .in_sequence(&mut seq)
1108                .returning(move |req| {
1109                    let req = req.into_inner();
1110                    let options = req.options.as_ref().expect("options required on retry");
1111                    let read_write = options.mode.as_ref().expect("mode required on retry");
1112                    match read_write {
1113                        Mode::ReadWrite(rw) => {
1114                            assert_eq!(
1115                                rw.multiplexed_session_previous_transaction_id,
1116                                vec![9, 9, 9]
1117                            );
1118                        }
1119                        _ => panic!("Expected ReadWrite mode"),
1120                    }
1121
1122                    Ok(tonic::Response::new(v1::Transaction {
1123                        id: tx_id_2_c1.clone(),
1124                        ..Default::default()
1125                    }))
1126                });
1127        }
1128
1129        let tx_id_2_c2 = tx_id_2.clone();
1130        mock.expect_execute_streaming_sql()
1131            .once()
1132            .in_sequence(&mut seq)
1133            .returning(move |req| {
1134                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1135                    let req = req.into_inner();
1136                    let transaction = req
1137                        .transaction
1138                        .as_ref()
1139                        .expect("transaction options required for inline begin");
1140                    let selector = transaction.selector.as_ref().expect("selector required");
1141                    assert!(matches!(
1142                        selector,
1143                        v1::transaction_selector::Selector::Begin(_)
1144                    ));
1145
1146                    let options = match selector {
1147                        v1::transaction_selector::Selector::Begin(o) => o,
1148                        _ => panic!("Expected Begin"),
1149                    };
1150                    let read_write = options.mode.as_ref().expect("mode required");
1151                    match read_write {
1152                        Mode::ReadWrite(rw) => {
1153                            assert_eq!(
1154                                rw.multiplexed_session_previous_transaction_id,
1155                                vec![9, 9, 9]
1156                            );
1157                        }
1158                        _ => panic!("Expected ReadWrite"),
1159                    }
1160                }
1161
1162                let mut rs = v1::PartialResultSet {
1163                    metadata: Some(v1::ResultSetMetadata {
1164                        row_type: Some(v1::StructType {
1165                            fields: vec![Default::default()],
1166                        }),
1167                        ..Default::default()
1168                    }),
1169                    values: vec![prost_types::Value {
1170                        kind: Some(prost_types::value::Kind::StringValue("1".to_string())),
1171                    }],
1172                    last: true,
1173                    ..Default::default()
1174                };
1175
1176                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1177                    rs.metadata.as_mut().unwrap().transaction = Some(v1::Transaction {
1178                        id: tx_id_2_c2.clone(),
1179                        ..Default::default()
1180                    });
1181                }
1182
1183                let (tx, rx) = tokio::sync::mpsc::channel(2);
1184                tx.try_send(Ok(rs)).unwrap();
1185                Ok(tonic::Response::from(rx))
1186            });
1187
1188        mock.expect_commit()
1189            .once()
1190            .returning(|_req| commit_response());
1191
1192        let (db_client, _server) = setup_db_client(mock).await;
1193        let runner = TransactionRunnerBuilder::new(db_client)
1194            .with_begin_transaction_option(begin_transaction_option)
1195            .build()
1196            .await?;
1197
1198        let mut attempt_counter = 0;
1199        let res = runner
1200            .run(async |tx| {
1201                attempt_counter += 1;
1202                let mut rs = tx.execute_query("SELECT 1").await?;
1203                let mut last_val = None;
1204                while let Some(row_res) = rs.next().await {
1205                    let row = row_res?;
1206                    last_val = Some(row.raw_values()[0].as_string().to_string());
1207                }
1208                Ok(last_val.unwrap())
1209            })
1210            .await?;
1211
1212        assert_eq!(res.result, "1");
1213        assert_eq!(attempt_counter, 2);
1214        Ok(())
1215    }
1216
1217    #[tokio_test_no_panics]
1218    async fn execute_run_with_non_aborted_error_explicit() {
1219        run_with_non_aborted_error(BeginTransactionOption::ExplicitBegin).await;
1220    }
1221
1222    #[tokio_test_no_panics]
1223    async fn execute_run_with_non_aborted_error_inline() {
1224        run_with_non_aborted_error(BeginTransactionOption::InlineBegin).await;
1225    }
1226
1227    async fn run_with_non_aborted_error(begin_transaction_option: BeginTransactionOption) {
1228        let mut mock = create_session_mock();
1229
1230        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1231            expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]);
1232        }
1233
1234        // Let execute_sql return an error to trigger a rollback.
1235        mock.expect_execute_sql().once().returning(move |_req| {
1236            Err(tonic::Status::new(
1237                tonic::Code::PermissionDenied,
1238                "permission denied",
1239            ))
1240        });
1241
1242        if begin_transaction_option == BeginTransactionOption::InlineBegin {
1243            expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]);
1244            mock.expect_execute_sql().once().returning(move |_req| {
1245                Err(tonic::Status::new(
1246                    tonic::Code::PermissionDenied,
1247                    "permission denied",
1248                ))
1249            });
1250        }
1251
1252        // Must explicitly trigger rollback
1253        mock.expect_rollback()
1254            .once()
1255            .returning(|_req| Ok(tonic::Response::new(())));
1256
1257        let res = execute_test_runner(mock, begin_transaction_option).await;
1258
1259        assert!(res.is_err());
1260        let err = res.unwrap_err();
1261        if let Some(status) = err.status() {
1262            assert_eq!(
1263                status.code,
1264                google_cloud_gax::error::rpc::Code::PermissionDenied
1265            );
1266        } else {
1267            panic!("Expected GRPC error");
1268        }
1269    }
1270
1271    #[tokio_test_no_panics]
1272    async fn execute_run_with_non_aborted_error_and_rollback_fails_explicit() {
1273        run_with_non_aborted_error_and_rollback_fails(BeginTransactionOption::ExplicitBegin).await;
1274    }
1275
1276    #[tokio_test_no_panics]
1277    async fn execute_run_with_non_aborted_error_and_rollback_fails_inline() {
1278        run_with_non_aborted_error_and_rollback_fails(BeginTransactionOption::InlineBegin).await;
1279    }
1280
1281    async fn run_with_non_aborted_error_and_rollback_fails(
1282        begin_transaction_option: BeginTransactionOption,
1283    ) {
1284        let mut mock = create_session_mock();
1285
1286        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1287            expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]);
1288        }
1289
1290        // Let execute_sql return an error to trigger a rollback.
1291        mock.expect_execute_sql().once().returning(move |_req| {
1292            Err(tonic::Status::new(
1293                tonic::Code::PermissionDenied,
1294                "permission denied",
1295            ))
1296        });
1297
1298        if begin_transaction_option == BeginTransactionOption::InlineBegin {
1299            expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]);
1300            mock.expect_execute_sql().once().returning(move |_req| {
1301                Err(tonic::Status::new(
1302                    tonic::Code::PermissionDenied,
1303                    "permission denied",
1304                ))
1305            });
1306        }
1307
1308        // Force the rollback itself to fail as well
1309        mock.expect_rollback()
1310            .once()
1311            .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "rollback failed")));
1312
1313        let res = execute_test_runner(mock, begin_transaction_option).await;
1314
1315        // Verify the user unequivocally receives the PRIMARY original error
1316        assert!(res.is_err());
1317        let err = res.unwrap_err();
1318        if let Some(status) = err.status() {
1319            assert_eq!(
1320                status.code,
1321                google_cloud_gax::error::rpc::Code::PermissionDenied
1322            );
1323        } else {
1324            panic!("Expected GRPC error");
1325        }
1326    }
1327
1328    #[tokio_test_no_panics]
1329    async fn execute_run_commit_aborted_retry_explicit() {
1330        run_commit_aborted_retry(BeginTransactionOption::ExplicitBegin).await;
1331    }
1332
1333    #[tokio_test_no_panics]
1334    async fn execute_run_commit_aborted_retry_inline() {
1335        run_commit_aborted_retry(BeginTransactionOption::InlineBegin).await;
1336    }
1337
1338    async fn run_commit_aborted_retry(begin_transaction_option: BeginTransactionOption) {
1339        let mut mock = create_session_mock();
1340
1341        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1342            expect_begin_transaction(&mut mock, 2, vec![9, 9, 9]);
1343        }
1344
1345        let mut attempt = 0;
1346        mock.expect_execute_sql().times(2).returning(move |req| {
1347            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1348                let req = req.into_inner();
1349                let transaction = req
1350                    .transaction
1351                    .as_ref()
1352                    .expect("transaction options required for inline begin");
1353                let selector = transaction.selector.as_ref().expect("selector required");
1354                assert!(matches!(
1355                    selector,
1356                    v1::transaction_selector::Selector::Begin(_)
1357                ));
1358
1359                attempt += 1;
1360                if attempt == 2 {
1361                    let options = match selector {
1362                        v1::transaction_selector::Selector::Begin(o) => o,
1363                        _ => panic!("Expected Begin"),
1364                    };
1365                    let read_write = options.mode.as_ref().expect("mode required");
1366                    match read_write {
1367                        Mode::ReadWrite(rw) => {
1368                            assert_eq!(
1369                                rw.multiplexed_session_previous_transaction_id,
1370                                vec![9, 9, 9]
1371                            );
1372                        }
1373                        _ => panic!("Expected ReadWrite"),
1374                    }
1375                }
1376
1377                let mut metadata = v1::ResultSetMetadata {
1378                    ..Default::default()
1379                };
1380                metadata.transaction = Some(v1::Transaction {
1381                    id: vec![9, 9, 9],
1382                    ..Default::default()
1383                });
1384
1385                return Ok(tonic::Response::new(v1::ResultSet {
1386                    metadata: Some(metadata),
1387                    stats: Some(v1::ResultSetStats {
1388                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)),
1389                        ..Default::default()
1390                    }),
1391                    ..Default::default()
1392                }));
1393            }
1394            row_count_exact_response(5)
1395        });
1396
1397        let mut commit_attempt = 0;
1398        mock.expect_commit().times(2).returning(move |_req| {
1399            commit_attempt += 1;
1400            if commit_attempt == 1 {
1401                Err(create_aborted_status(std::time::Duration::from_nanos(1)))
1402            } else {
1403                commit_response()
1404            }
1405        });
1406
1407        let res = execute_test_runner(mock, begin_transaction_option)
1408            .await
1409            .unwrap();
1410        assert_eq!(res, 5);
1411    }
1412
1413    #[tokio_test_no_panics]
1414    async fn execute_run_begin_transaction_fails_explicit() {
1415        run_begin_transaction_fails(BeginTransactionOption::ExplicitBegin).await;
1416    }
1417
1418    #[tokio_test_no_panics]
1419    async fn execute_run_begin_transaction_fails_inline() {
1420        run_begin_transaction_fails(BeginTransactionOption::InlineBegin).await;
1421    }
1422
1423    async fn run_begin_transaction_fails(begin_transaction_option: BeginTransactionOption) {
1424        let mut mock = create_session_mock();
1425        let mut seq = mockall::Sequence::new();
1426
1427        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1428            mock.expect_begin_transaction()
1429                .once()
1430                .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "internal error")));
1431        } else {
1432            mock.expect_execute_sql()
1433                .once()
1434                .in_sequence(&mut seq)
1435                .returning(move |req| {
1436                    let req = req.into_inner();
1437                    let transaction = req
1438                        .transaction
1439                        .as_ref()
1440                        .expect("transaction options required for inline begin");
1441                    let selector = transaction.selector.as_ref().expect("selector required");
1442                    assert!(matches!(
1443                        selector,
1444                        v1::transaction_selector::Selector::Begin(_)
1445                    ));
1446
1447                    Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
1448                });
1449
1450            mock.expect_begin_transaction()
1451                .once()
1452                .in_sequence(&mut seq)
1453                .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "internal error")));
1454        }
1455
1456        let res = execute_test_runner(mock, begin_transaction_option).await;
1457
1458        assert!(res.is_err());
1459        let err = res.unwrap_err();
1460        if let Some(status) = err.status() {
1461            assert_eq!(status.code, google_cloud_gax::error::rpc::Code::Internal);
1462        } else {
1463            panic!("Expected GRPC error");
1464        }
1465    }
1466
1467    #[tokio_test_no_panics]
1468    async fn builder_options() {
1469        use crate::transaction_retry_policy::BasicTransactionRetryPolicy;
1470
1471        let mock = create_session_mock();
1472        let (db_client, _server) = setup_db_client(mock).await;
1473
1474        let retry_policy = BasicTransactionRetryPolicy::new()
1475            .with_max_attempts(1)
1476            .with_total_timeout(std::time::Duration::from_secs(10));
1477
1478        // Validate builder chaining safely accepts and compiles options dynamically
1479        let _runner = TransactionRunnerBuilder::new(db_client)
1480            .set_isolation_level(IsolationLevel::Serializable)
1481            .set_read_lock_mode(ReadLockMode::Pessimistic)
1482            .with_retry_policy(retry_policy)
1483            .build()
1484            .await
1485            .unwrap();
1486    }
1487
1488    #[tokio_test_no_panics]
1489    async fn execute_run_batch_dml_aborted_retry_explicit() {
1490        run_batch_dml_aborted_retry(BeginTransactionOption::ExplicitBegin).await;
1491    }
1492
1493    #[tokio_test_no_panics]
1494    async fn execute_run_batch_dml_aborted_retry_inline() {
1495        run_batch_dml_aborted_retry(BeginTransactionOption::InlineBegin).await;
1496    }
1497
1498    async fn run_batch_dml_aborted_retry(begin_transaction_option: BeginTransactionOption) {
1499        use crate::batch_dml::BatchDml;
1500        use crate::statement::Statement;
1501        use gaxi::grpc::tonic::Code;
1502        use spanner_grpc_mock::google::rpc::Status;
1503        use spanner_grpc_mock::google::spanner::v1::result_set_stats::RowCount;
1504
1505        let mut mock = create_session_mock();
1506
1507        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1508            expect_begin_transaction(&mut mock, 2, vec![9, 9, 9]);
1509        }
1510
1511        let mut seq = mockall::Sequence::new();
1512        mock.expect_execute_batch_dml()
1513            .once()
1514            .in_sequence(&mut seq)
1515            .returning(move |req| {
1516                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1517                    let req = req.into_inner();
1518                    let selector = req
1519                        .transaction
1520                        .expect("missing transaction selector")
1521                        .selector
1522                        .expect("missing selector");
1523                    assert!(matches!(
1524                        selector,
1525                        v1::transaction_selector::Selector::Begin(_)
1526                    ));
1527                }
1528
1529                // Return a successful response but with an embedded aborted status.
1530                let status = Status {
1531                    code: Code::Aborted as i32,
1532                    message: "transaction aborted".to_string(),
1533                    ..Default::default()
1534                };
1535
1536                let mut metadata = v1::ResultSetMetadata {
1537                    ..Default::default()
1538                };
1539                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1540                    metadata.transaction = Some(v1::Transaction {
1541                        id: vec![9, 9, 9],
1542                        ..Default::default()
1543                    });
1544                }
1545
1546                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
1547                    result_sets: vec![v1::ResultSet {
1548                        metadata: Some(metadata),
1549                        stats: Some(v1::ResultSetStats {
1550                            row_count: Some(RowCount::RowCountExact(1)),
1551                            ..Default::default()
1552                        }),
1553                        ..Default::default()
1554                    }],
1555                    status: Some(status),
1556                    ..Default::default()
1557                }))
1558            });
1559        mock.expect_execute_batch_dml()
1560            .once()
1561            .in_sequence(&mut seq)
1562            .returning(move |req| {
1563                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1564                    let req = req.into_inner();
1565                    let selector = req
1566                        .transaction
1567                        .expect("missing transaction selector")
1568                        .selector
1569                        .expect("missing selector");
1570                    assert!(matches!(
1571                        selector,
1572                        v1::transaction_selector::Selector::Begin(_)
1573                    ));
1574                }
1575
1576                let mut metadata = v1::ResultSetMetadata {
1577                    ..Default::default()
1578                };
1579                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1580                    metadata.transaction = Some(v1::Transaction {
1581                        id: vec![9, 9, 9],
1582                        ..Default::default()
1583                    });
1584                }
1585
1586                // Return success after the retry.
1587                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
1588                    result_sets: vec![v1::ResultSet {
1589                        metadata: Some(metadata),
1590                        stats: Some(v1::ResultSetStats {
1591                            row_count: Some(RowCount::RowCountExact(5)),
1592                            ..Default::default()
1593                        }),
1594                        ..Default::default()
1595                    }],
1596                    ..Default::default()
1597                }))
1598            });
1599
1600        mock.expect_commit()
1601            .once()
1602            .returning(move |_| commit_response());
1603
1604        let (db_client, _) = setup_db_client(mock).await;
1605        let runner = TransactionRunnerBuilder::new(db_client)
1606            .with_begin_transaction_option(begin_transaction_option)
1607            .build()
1608            .await
1609            .expect("failed to build TransactionRunner");
1610
1611        let mut attempt_counter = 0;
1612
1613        // TransactionRunner retries the closure on transaction aborts
1614        let res = runner
1615            .run(async |tx| {
1616                attempt_counter += 1;
1617                let stmt = Statement::builder("UPDATE t SET c = 1").build();
1618                let batch = BatchDml::builder().add_statement(stmt).build();
1619                let counts = tx.execute_batch_update(batch).await?;
1620                Ok(counts)
1621            })
1622            .await
1623            .expect("transaction failed");
1624
1625        assert_eq!(res.result, vec![5]);
1626        assert_eq!(attempt_counter, 2);
1627    }
1628
1629    #[tokio_test_no_panics]
1630    async fn execute_run_with_transaction_tag_explicit() -> anyhow::Result<()> {
1631        run_with_transaction_tag(BeginTransactionOption::ExplicitBegin).await
1632    }
1633
1634    #[tokio_test_no_panics]
1635    async fn execute_run_with_transaction_tag_inline() -> anyhow::Result<()> {
1636        run_with_transaction_tag(BeginTransactionOption::InlineBegin).await
1637    }
1638
1639    async fn run_with_transaction_tag(
1640        begin_transaction_option: BeginTransactionOption,
1641    ) -> anyhow::Result<()> {
1642        let mut mock = create_session_mock();
1643
1644        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1645            mock.expect_begin_transaction().once().returning(|req| {
1646                let req = req.into_inner();
1647                // Check if the transaction tag is correctly propagated.
1648                assert_eq!(
1649                    req.request_options
1650                        .expect("Missing request_options")
1651                        .transaction_tag,
1652                    "my-test-tag"
1653                );
1654
1655                Ok(tonic::Response::new(v1::Transaction {
1656                    id: vec![9, 9, 9],
1657                    ..Default::default()
1658                }))
1659            });
1660        }
1661
1662        mock.expect_execute_sql().once().returning(move |req| {
1663            let req = req.into_inner();
1664            assert_eq!(
1665                req.request_options
1666                    .expect("Missing request_options")
1667                    .transaction_tag,
1668                "my-test-tag"
1669            );
1670
1671            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1672                let transaction = req
1673                    .transaction
1674                    .as_ref()
1675                    .expect("transaction options required for inline begin");
1676                let selector = transaction.selector.as_ref().expect("selector required");
1677                assert!(matches!(
1678                    selector,
1679                    v1::transaction_selector::Selector::Begin(_)
1680                ));
1681            }
1682
1683            let mut metadata = v1::ResultSetMetadata {
1684                ..Default::default()
1685            };
1686            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1687                metadata.transaction = Some(v1::Transaction {
1688                    id: vec![9, 9, 9],
1689                    ..Default::default()
1690                });
1691            }
1692
1693            Ok(tonic::Response::new(v1::ResultSet {
1694                metadata: Some(metadata),
1695                stats: Some(v1::ResultSetStats {
1696                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)),
1697                    ..Default::default()
1698                }),
1699                ..Default::default()
1700            }))
1701        });
1702
1703        mock.expect_commit().once().returning(|req| {
1704            let req = req.into_inner();
1705            assert_eq!(
1706                req.request_options
1707                    .expect("Missing request_options")
1708                    .transaction_tag,
1709                "my-test-tag"
1710            );
1711            commit_response()
1712        });
1713
1714        let (db_client, _server) = setup_db_client(mock).await;
1715
1716        let runner = TransactionRunnerBuilder::new(db_client)
1717            .with_begin_transaction_option(begin_transaction_option)
1718            .set_transaction_tag("my-test-tag")
1719            .build()
1720            .await?;
1721
1722        let res = runner
1723            .run(async |tx| {
1724                let count = tx.execute_update("UPDATE Users SET active = true").await?;
1725                Ok(count)
1726            })
1727            .await?;
1728
1729        assert_eq!(res.result, 5);
1730
1731        Ok(())
1732    }
1733
1734    #[tokio_test_no_panics]
1735    async fn execute_run_with_exclude_txn_from_change_streams_explicit() -> anyhow::Result<()> {
1736        run_with_exclude_txn_from_change_streams(BeginTransactionOption::ExplicitBegin).await
1737    }
1738
1739    #[tokio_test_no_panics]
1740    async fn execute_run_with_exclude_txn_from_change_streams_inline() -> anyhow::Result<()> {
1741        run_with_exclude_txn_from_change_streams(BeginTransactionOption::InlineBegin).await
1742    }
1743
1744    async fn run_with_exclude_txn_from_change_streams(
1745        begin_transaction_option: BeginTransactionOption,
1746    ) -> anyhow::Result<()> {
1747        let mut mock = create_session_mock();
1748
1749        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1750            mock.expect_begin_transaction().once().returning(|req| {
1751                let req = req.into_inner();
1752                let options = req.options.expect("Missing transaction options");
1753                assert!(options.exclude_txn_from_change_streams);
1754
1755                Ok(tonic::Response::new(v1::Transaction {
1756                    id: vec![9, 9, 9],
1757                    ..Default::default()
1758                }))
1759            });
1760        }
1761
1762        mock.expect_execute_sql().once().returning(move |req| {
1763            let req = req.into_inner();
1764            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1765                let transaction = req
1766                    .transaction
1767                    .as_ref()
1768                    .expect("transaction options required for inline begin");
1769                let selector = transaction.selector.as_ref().expect("selector required");
1770                assert!(matches!(
1771                    selector,
1772                    v1::transaction_selector::Selector::Begin(_)
1773                ));
1774            }
1775
1776            let mut metadata = v1::ResultSetMetadata {
1777                ..Default::default()
1778            };
1779            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1780                metadata.transaction = Some(v1::Transaction {
1781                    id: vec![9, 9, 9],
1782                    ..Default::default()
1783                });
1784            }
1785
1786            Ok(tonic::Response::new(v1::ResultSet {
1787                metadata: Some(metadata),
1788                stats: Some(v1::ResultSetStats {
1789                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)),
1790                    ..Default::default()
1791                }),
1792                ..Default::default()
1793            }))
1794        });
1795
1796        mock.expect_commit()
1797            .once()
1798            .returning(|_req| commit_response());
1799
1800        let (db_client, _server) = setup_db_client(mock).await;
1801
1802        let runner = TransactionRunnerBuilder::new(db_client)
1803            .set_exclude_txn_from_change_streams(true)
1804            .with_begin_transaction_option(begin_transaction_option)
1805            .build()
1806            .await?;
1807
1808        let res = runner
1809            .run(async |tx| {
1810                let count = tx.execute_update("UPDATE Users SET active = true").await?;
1811                Ok(count)
1812            })
1813            .await?;
1814
1815        assert_eq!(res.result, 5);
1816
1817        Ok(())
1818    }
1819
1820    #[tokio_test_no_panics]
1821    async fn execute_run_with_max_commit_delay_explicit() -> anyhow::Result<()> {
1822        run_with_max_commit_delay(BeginTransactionOption::ExplicitBegin).await
1823    }
1824
1825    #[tokio_test_no_panics]
1826    async fn execute_run_with_max_commit_delay_inline() -> anyhow::Result<()> {
1827        run_with_max_commit_delay(BeginTransactionOption::InlineBegin).await
1828    }
1829
1830    async fn run_with_max_commit_delay(
1831        begin_transaction_option: BeginTransactionOption,
1832    ) -> anyhow::Result<()> {
1833        let mut mock = create_session_mock();
1834
1835        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1836            expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);
1837        }
1838
1839        mock.expect_execute_sql().once().returning(move |req| {
1840            let req = req.into_inner();
1841            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1842                let transaction = req
1843                    .transaction
1844                    .as_ref()
1845                    .expect("transaction options required for inline begin");
1846                let selector = transaction.selector.as_ref().expect("selector required");
1847                assert!(matches!(
1848                    selector,
1849                    v1::transaction_selector::Selector::Begin(_)
1850                ));
1851            }
1852
1853            let mut metadata = v1::ResultSetMetadata {
1854                ..Default::default()
1855            };
1856            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1857                metadata.transaction = Some(v1::Transaction {
1858                    id: vec![1, 2, 3],
1859                    ..Default::default()
1860                });
1861            }
1862
1863            Ok(tonic::Response::new(v1::ResultSet {
1864                metadata: Some(metadata),
1865                stats: Some(v1::ResultSetStats {
1866                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1867                    ..Default::default()
1868                }),
1869                ..Default::default()
1870            }))
1871        });
1872
1873        mock.expect_commit().once().returning(|req| {
1874            let req = req.into_inner();
1875            assert_eq!(
1876                req.max_commit_delay,
1877                Some(::prost_types::Duration {
1878                    seconds: 0,
1879                    nanos: 200_000_000, // 200ms
1880                })
1881            );
1882            commit_response()
1883        });
1884
1885        let (db_client, _server) = setup_db_client(mock).await;
1886        let runner = TransactionRunnerBuilder::new(db_client)
1887            .set_max_commit_delay(Duration::try_from("0.2s").unwrap())
1888            .with_begin_transaction_option(begin_transaction_option)
1889            .build()
1890            .await?;
1891
1892        let res = runner
1893            .run(async |tx| {
1894                let count = tx.execute_update("UPDATE Users SET active = true").await?;
1895                Ok(count)
1896            })
1897            .await?;
1898        assert_eq!(res.result, 1);
1899        Ok(())
1900    }
1901
1902    #[tokio_test_no_panics]
1903    async fn execute_run_empty_closure_inline() {
1904        let mut mock = create_session_mock();
1905        expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);
1906        mock.expect_commit().once().returning(|req| {
1907            let req = req.into_inner();
1908            assert_eq!(
1909                req.transaction,
1910                Some(v1::commit_request::Transaction::TransactionId(vec![
1911                    1, 2, 3
1912                ]))
1913            );
1914            commit_response()
1915        });
1916
1917        let (db_client, _server) = setup_db_client(mock).await;
1918
1919        let runner = TransactionRunnerBuilder::new(db_client)
1920            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
1921            .build()
1922            .await
1923            .unwrap();
1924
1925        let res = runner.run(async |_tx| Ok(42)).await.unwrap();
1926        assert_eq!(res.result, 42);
1927    }
1928
1929    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1930    async fn execute_run_async_statement_still_starting() {
1931        let (tx_rpc, rx_rpc) = std_channel();
1932        let (tx_started, rx_started) = oneshot_channel();
1933        let tx_started_mutex = Mutex::new(Some(tx_started));
1934
1935        let mut mock = create_session_mock();
1936
1937        mock.expect_execute_sql().once().returning(move |_req| {
1938            if let Some(tx) = tx_started_mutex.lock().unwrap().take() {
1939                let _ = tx.send(());
1940            }
1941            rx_rpc.recv().unwrap();
1942            row_count_exact_response(1)
1943        });
1944
1945        let (db_client, _server) = setup_db_client(mock).await;
1946
1947        let runner = TransactionRunnerBuilder::new(db_client)
1948            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
1949            .build()
1950            .await
1951            .unwrap();
1952
1953        let mut rx_started_opt = Some(rx_started);
1954        let res = runner
1955            .run(async |tx| {
1956                tokio::spawn(async move {
1957                    let _ = tx.execute_update("UPDATE Users SET active = true").await;
1958                });
1959                if let Some(rx) = rx_started_opt.take() {
1960                    rx.await.unwrap();
1961                }
1962                Ok(42)
1963            })
1964            .await;
1965
1966        tx_rpc.send(()).unwrap();
1967
1968        assert!(res.is_err());
1969        assert!(
1970            format!("{:?}", res.unwrap_err())
1971                .contains("asynchronous statement is still starting the transaction")
1972        );
1973    }
1974
1975    #[tokio_test_no_panics]
1976    async fn execute_run_with_mutations_happy_flow() {
1977        let mut mock = create_session_mock();
1978
1979        mock.expect_execute_sql().once().returning(move |req| {
1980            let req = req.into_inner();
1981            assert_eq!(req.sql, "UPDATE Users SET active = true");
1982            let transaction = req
1983                .transaction
1984                .as_ref()
1985                .expect("transaction options required for inline begin");
1986            let selector = transaction.selector.as_ref().expect("selector required");
1987            assert!(matches!(selector, ProtoSelector::Begin(_)));
1988
1989            Ok(tonic::Response::new(v1::ResultSet {
1990                metadata: Some(v1::ResultSetMetadata {
1991                    transaction: Some(v1::Transaction {
1992                        id: vec![1, 1, 1],
1993                        ..Default::default()
1994                    }),
1995                    ..Default::default()
1996                }),
1997                stats: Some(v1::ResultSetStats {
1998                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1999                    ..Default::default()
2000                }),
2001                ..Default::default()
2002            }))
2003        });
2004
2005        mock.expect_commit().once().returning(|req| {
2006            let req = req.into_inner();
2007            assert_eq!(
2008                req.transaction,
2009                Some(CommitTransaction::TransactionId(vec![1, 1, 1]))
2010            );
2011            assert_eq!(req.mutations.len(), 1);
2012            commit_response()
2013        });
2014
2015        let (db_client, _server) = setup_db_client(mock).await;
2016        let runner = TransactionRunnerBuilder::new(db_client)
2017            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2018            .build()
2019            .await
2020            .expect("Failed to build transaction runner");
2021
2022        let res = runner
2023            .run(async |tx| {
2024                let count = tx.execute_update("UPDATE Users SET active = true").await?;
2025                let mutation = Mutation::new_insert_builder("Audits")
2026                    .set("AuditId")
2027                    .to(&1)
2028                    .build();
2029                tx.buffer([mutation])?;
2030                Ok(count)
2031            })
2032            .await
2033            .expect("Transaction runner failed");
2034
2035        assert_eq!(res.result, 1);
2036    }
2037
2038    #[tokio_test_no_panics]
2039    async fn execute_run_with_mutations_aborted_retry() {
2040        let mut mock = create_session_mock();
2041        let mut sequence = mockall::Sequence::new();
2042
2043        // Initial attempt: statement succeeds, returns tx id [10, 20, 30]
2044        mock.expect_execute_sql()
2045            .once()
2046            .in_sequence(&mut sequence)
2047            .returning(move |_req| {
2048                Ok(tonic::Response::new(v1::ResultSet {
2049                    metadata: Some(v1::ResultSetMetadata {
2050                        transaction: Some(v1::Transaction {
2051                            id: vec![10, 20, 30],
2052                            ..Default::default()
2053                        }),
2054                        ..Default::default()
2055                    }),
2056                    stats: Some(v1::ResultSetStats {
2057                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2058                        ..Default::default()
2059                    }),
2060                    ..Default::default()
2061                }))
2062            });
2063
2064        // Initial commit fails with Aborted
2065        mock.expect_commit()
2066            .once()
2067            .in_sequence(&mut sequence)
2068            .returning(|req| {
2069                let req = req.into_inner();
2070                assert_eq!(req.mutations.len(), 1);
2071                // Verify initial attempt added mutation for UserId 100
2072                let write = req.mutations[0]
2073                    .operation
2074                    .as_ref()
2075                    .expect("Operation required");
2076                match write {
2077                    Operation::Insert(w) => {
2078                        assert_eq!(
2079                            w.values[0].values[0].kind,
2080                            Some(Kind::StringValue("100".to_string()))
2081                        );
2082                    }
2083                    _ => panic!("Expected insert mutation"),
2084                }
2085                Err(create_aborted_status(StdTimeDuration::from_nanos(1)))
2086            });
2087
2088        // Retry attempt: execute_sql sends inline BeginTransaction with previous_transaction_id
2089        mock.expect_execute_sql()
2090            .once()
2091            .in_sequence(&mut sequence)
2092            .returning(move |req| {
2093                let req = req.into_inner();
2094                let transaction = req
2095                    .transaction
2096                    .as_ref()
2097                    .expect("transaction options required for inline begin");
2098                let selector = transaction.selector.as_ref().expect("selector required");
2099                let options = match selector {
2100                    ProtoSelector::Begin(o) => o,
2101                    _ => panic!("Expected Begin"),
2102                };
2103                let read_write = options.mode.as_ref().expect("mode required");
2104                match read_write {
2105                    Mode::ReadWrite(rw) => {
2106                        assert_eq!(
2107                            rw.multiplexed_session_previous_transaction_id,
2108                            vec![10, 20, 30]
2109                        );
2110                    }
2111                    _ => panic!("Expected ReadWrite"),
2112                }
2113
2114                Ok(tonic::Response::new(v1::ResultSet {
2115                    metadata: Some(v1::ResultSetMetadata {
2116                        transaction: Some(v1::Transaction {
2117                            id: vec![99, 99, 99],
2118                            ..Default::default()
2119                        }),
2120                        ..Default::default()
2121                    }),
2122                    stats: Some(v1::ResultSetStats {
2123                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2124                        ..Default::default()
2125                    }),
2126                    ..Default::default()
2127                }))
2128            });
2129
2130        // Second commit succeeds with the new mutation
2131        mock.expect_commit()
2132            .once()
2133            .in_sequence(&mut sequence)
2134            .returning(|req| {
2135                let req = req.into_inner();
2136                assert_eq!(
2137                    req.transaction,
2138                    Some(CommitTransaction::TransactionId(vec![99, 99, 99]))
2139                );
2140                assert_eq!(req.mutations.len(), 1);
2141                // Verify retry attempt added mutation for UserId 200
2142                let write = req.mutations[0]
2143                    .operation
2144                    .as_ref()
2145                    .expect("Operation required");
2146                match write {
2147                    Operation::Insert(w) => {
2148                        assert_eq!(
2149                            w.values[0].values[0].kind,
2150                            Some(Kind::StringValue("200".to_string()))
2151                        );
2152                    }
2153                    _ => panic!("Expected insert mutation"),
2154                }
2155                commit_response()
2156            });
2157
2158        let (db_client, _server) = setup_db_client(mock).await;
2159        let runner = TransactionRunnerBuilder::new(db_client)
2160            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2161            .build()
2162            .await
2163            .expect("Failed to build transaction runner");
2164
2165        let mut attempt = 0;
2166        let res = runner
2167            .run(async |tx| {
2168                attempt += 1;
2169                let count = tx.execute_update("UPDATE Users SET active = true").await?;
2170                let mutation_value = if attempt == 1 { 100 } else { 200 };
2171                let mutation = Mutation::new_insert_builder("Users")
2172                    .set("UserId")
2173                    .to(&mutation_value)
2174                    .build();
2175                tx.buffer([mutation])?;
2176                Ok(count)
2177            })
2178            .await
2179            .expect("Transaction runner failed");
2180
2181        assert_eq!(res.result, 1);
2182    }
2183
2184    #[tokio_test_no_panics]
2185    async fn execute_run_mutation_only_explicit_begin_fallback() {
2186        let mut mock = create_session_mock();
2187
2188        // Since the user closure executes no statements, commit() calls explicit BeginTransaction
2189        mock.expect_begin_transaction().once().returning(|req| {
2190            let req = req.into_inner();
2191            assert_eq!(
2192                req.session,
2193                "projects/p/instances/i/databases/d/sessions/123"
2194            );
2195            Ok(tonic::Response::new(v1::Transaction {
2196                id: vec![77, 88, 99],
2197                ..Default::default()
2198            }))
2199        });
2200
2201        mock.expect_commit().once().returning(|req| {
2202            let req = req.into_inner();
2203            assert_eq!(
2204                req.transaction,
2205                Some(CommitTransaction::TransactionId(vec![77, 88, 99]))
2206            );
2207            assert_eq!(req.mutations.len(), 2);
2208            commit_response()
2209        });
2210
2211        let (db_client, _server) = setup_db_client(mock).await;
2212        let runner = TransactionRunnerBuilder::new(db_client)
2213            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2214            .build()
2215            .await
2216            .expect("Failed to build transaction runner");
2217
2218        let res = runner
2219            .run(async |tx| {
2220                let m1 = Mutation::new_insert_builder("Orders")
2221                    .set("OrderId")
2222                    .to(&1)
2223                    .build();
2224                let m2 = Mutation::new_insert_builder("Orders")
2225                    .set("OrderId")
2226                    .to(&2)
2227                    .build();
2228                tx.buffer([m1, m2])?;
2229                Ok(())
2230            })
2231            .await
2232            .expect("Transaction runner failed");
2233
2234        assert_eq!(
2235            res.commit_response
2236                .commit_timestamp
2237                .expect("Timestamp required")
2238                .seconds(),
2239            123456789
2240        );
2241    }
2242
2243    #[tokio_test_no_panics]
2244    async fn read_write_transaction_builder_sets_gax_options() -> anyhow::Result<()> {
2245        let mock = create_session_mock();
2246        let (db_client, _server) = setup_db_client(mock).await;
2247
2248        let runner = TransactionRunnerBuilder::new(db_client)
2249            .with_begin_attempt_timeout(StdDuration::from_secs(5))
2250            .with_begin_retry_policy(NeverRetry)
2251            .with_begin_backoff_policy(ExponentialBackoff::default())
2252            .with_commit_attempt_timeout(StdDuration::from_secs(10))
2253            .with_commit_retry_policy(NeverRetry)
2254            .with_commit_backoff_policy(ExponentialBackoff::default());
2255
2256        let begin_gax = runner
2257            .begin_gax_options
2258            .as_ref()
2259            .expect("begin_gax_options missing");
2260        assert_eq!(
2261            *begin_gax.attempt_timeout(),
2262            Some(StdDuration::from_secs(5))
2263        );
2264        assert!(begin_gax.retry_policy().is_some());
2265        assert!(begin_gax.backoff_policy().is_some());
2266
2267        let commit_gax = runner
2268            .commit_gax_options
2269            .as_ref()
2270            .expect("commit_gax_options missing");
2271        assert_eq!(
2272            *commit_gax.attempt_timeout(),
2273            Some(StdDuration::from_secs(10))
2274        );
2275        assert!(commit_gax.retry_policy().is_some());
2276        assert!(commit_gax.backoff_policy().is_some());
2277
2278        Ok(())
2279    }
2280}