Skip to main content

google_cloud_spanner/
read_write_transaction.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::Error;
16use crate::RequestOptions;
17use crate::batch::BatchDml;
18use crate::client::amend_request_options_for_lar;
19use crate::database_client::DatabaseClient;
20use crate::error::internal_error;
21use crate::model::CommitRequest;
22use crate::model::CommitResponse;
23use crate::model::ExecuteBatchDmlRequest;
24use crate::model::ExecuteBatchDmlResponse;
25use crate::model::Mutation as ProtoMutation;
26use crate::model::ResultSet as ProtoResultSet;
27use crate::model::RollbackRequest;
28use crate::model::TransactionOptions;
29use crate::model::TransactionSelector;
30use crate::model::execute_batch_dml_request::Statement as ExecuteBatchDmlStatement;
31use crate::model::request_options::Priority;
32use crate::model::result_set_stats::RowCount;
33use crate::model::transaction_options::IsolationLevel;
34use crate::model::transaction_options::Mode;
35use crate::model::transaction_options::ReadWrite;
36use crate::model::transaction_options::read_write::ReadLockMode;
37use crate::model::transaction_selector::Selector;
38use crate::mutation::Mutation;
39use crate::precommit::PrecommitTokenTracker;
40use crate::read_only_transaction::{
41    BeginTransactionOption, ReadContext, ReadContextTransactionSelector, TransactionState,
42};
43use crate::result_set::ResultSet;
44use crate::retry_policy::SpannerRetryPolicy;
45use crate::statement::Statement;
46use crate::transaction_retry_policy::is_aborted;
47use crate::write_only_transaction::create_commit_request;
48use google_cloud_gax::error::Error as GaxError;
49use google_cloud_gax::error::rpc::{Code, Status};
50use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
51use google_cloud_gax::retry_policy::RetryPolicy;
52use google_cloud_gax::retry_result::RetryResult;
53use google_cloud_gax::retry_state::RetryState;
54use google_cloud_gax::throttle_result::ThrottleResult;
55use std::cmp::min;
56use std::mem::take;
57use std::sync::Arc;
58use std::sync::Mutex;
59use std::sync::atomic::{AtomicI64, Ordering};
60use std::time::Duration as StdDuration;
61use tokio::time::Instant;
62use wkt::Duration;
63
64/// A builder for [ReadWriteTransaction].
65#[derive(Clone, Debug)]
66pub(crate) struct ReadWriteTransactionBuilder {
67    pub(crate) client: DatabaseClient,
68    options: TransactionOptions,
69    transaction_tag: Option<String>,
70    max_commit_delay: Option<Duration>,
71    pub(crate) session_name: String,
72    return_commit_stats: bool,
73    commit_priority: Priority,
74    begin_transaction_option: BeginTransactionOption,
75    begin_gax_options: Option<crate::RequestOptions>,
76    commit_gax_options: Option<crate::RequestOptions>,
77}
78
79impl ReadWriteTransactionBuilder {
80    pub(crate) fn new(client: DatabaseClient) -> Self {
81        let session_name = client.session_name();
82        Self {
83            client,
84            options: TransactionOptions::default().set_read_write(ReadWrite::default()),
85            transaction_tag: None,
86            max_commit_delay: None,
87            session_name,
88            return_commit_stats: false,
89            commit_priority: Priority::Unspecified,
90            begin_transaction_option: BeginTransactionOption::InlineBegin,
91            begin_gax_options: None,
92            commit_gax_options: None,
93        }
94    }
95
96    pub(crate) fn set_isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
97        self.options = self.options.set_isolation_level(isolation_level);
98        self
99    }
100
101    pub(crate) fn set_read_lock_mode(mut self, read_lock_mode: ReadLockMode) -> Self {
102        if let Some(Mode::ReadWrite(rw)) = self.options.mode.take() {
103            self.options = self
104                .options
105                .set_read_write(rw.set_read_lock_mode(read_lock_mode));
106        }
107        self
108    }
109
110    pub(crate) fn set_previous_transaction_id(mut self, id: Option<bytes::Bytes>) -> Self {
111        if let Some(id) = id
112            && let Some(Mode::ReadWrite(rw)) = self.options.mode.take()
113        {
114            self.options = self
115                .options
116                .set_read_write(rw.set_multiplexed_session_previous_transaction_id(id));
117        }
118        self
119    }
120
121    pub(crate) fn set_transaction_tag(mut self, tag: impl Into<String>) -> Self {
122        self.transaction_tag = Some(tag.into());
123        self
124    }
125
126    pub(crate) fn set_commit_priority(mut self, priority: Priority) -> Self {
127        self.commit_priority = priority;
128        self
129    }
130
131    pub(crate) fn set_max_commit_delay(mut self, delay: Duration) -> Self {
132        self.max_commit_delay = Some(delay);
133        self
134    }
135
136    pub(crate) fn set_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
137        self.options = self.options.set_exclude_txn_from_change_streams(exclude);
138        self
139    }
140
141    pub(crate) fn set_return_commit_stats(mut self, return_stats: bool) -> Self {
142        self.return_commit_stats = return_stats;
143        self
144    }
145
146    pub fn with_begin_transaction_option(mut self, option: BeginTransactionOption) -> Self {
147        self.begin_transaction_option = option;
148        self
149    }
150
151    pub(crate) fn with_begin_transaction_request_options(
152        mut self,
153        options: Option<crate::RequestOptions>,
154    ) -> Self {
155        self.begin_gax_options = options;
156        self
157    }
158
159    pub(crate) fn with_commit_request_options(
160        mut self,
161        options: Option<crate::RequestOptions>,
162    ) -> Self {
163        self.commit_gax_options = options;
164        self
165    }
166
167    async fn begin(
168        &self,
169        session_name: String,
170        channel_hint: usize,
171        request_options: crate::RequestOptions,
172    ) -> crate::Result<ReadContextTransactionSelector> {
173        let response = crate::read_only_transaction::execute_begin_transaction(
174            &self.client,
175            session_name,
176            self.options.clone(),
177            self.transaction_tag.clone(),
178            channel_hint,
179            request_options,
180            None,
181        )
182        .await?;
183
184        Ok(ReadContextTransactionSelector::Fixed(
185            TransactionSelector::default().set_id(response.id),
186            None,
187        ))
188    }
189
190    pub(crate) async fn build(
191        &self,
192        deadline: Option<Instant>,
193    ) -> crate::Result<ReadWriteTransaction> {
194        let session_name = self.session_name.clone();
195        let channel_hint = self.client.spanner.next_channel_hint();
196        let transaction_selector = match self.begin_transaction_option {
197            BeginTransactionOption::ExplicitBegin => {
198                let mut options = self.begin_gax_options.clone().unwrap_or_default();
199                amend_gax_options(
200                    self.client.leader_aware_routing_enabled,
201                    deadline,
202                    &mut options,
203                );
204
205                self.begin(session_name.clone(), channel_hint, options)
206                    .await?
207            }
208            BeginTransactionOption::InlineBegin => ReadContextTransactionSelector::Lazy(Arc::new(
209                Mutex::new(TransactionState::NotStarted(self.options.clone())),
210            )),
211        };
212
213        Ok(ReadWriteTransaction {
214            context: ReadContext {
215                session_name,
216                client: self.client.clone(),
217                transaction_selector,
218                precommit_token_tracker: PrecommitTokenTracker::new(),
219                transaction_tag: self.transaction_tag.clone(),
220                channel_hint,
221                begin_transaction_request_options: None,
222            },
223            seqno: Arc::new(AtomicI64::new(1)),
224            max_commit_delay: self.max_commit_delay,
225            return_commit_stats: self.return_commit_stats,
226            deadline,
227            commit_priority: self.commit_priority.clone(),
228            mutations: Arc::new(Mutex::new(Vec::new())),
229            begin_gax_options: self.begin_gax_options.clone(),
230            commit_gax_options: self.commit_gax_options.clone(),
231        })
232    }
233}
234
235trait CheckServiceError {
236    fn check_service_error(&self) -> Option<Error>;
237}
238
239impl CheckServiceError for ProtoResultSet {
240    fn check_service_error(&self) -> Option<Error> {
241        None
242    }
243}
244
245/// Normalizes responses from `ExecuteBatchDml`.
246/// If Spanner encounters an error during inline transaction initialization (such as a missing table),
247/// it returns an `Ok(ExecuteBatchDmlResponse)` containing the error status but with empty `result_sets`.
248/// This implementation evaluates that payload so fallback handlers can recover.
249impl CheckServiceError for ExecuteBatchDmlResponse {
250    fn check_service_error(&self) -> Option<Error> {
251        if self.result_sets.is_empty()
252            && let Some(status) = &self.status
253            && status.code != Code::Ok as i32
254        {
255            let rpc_status = Status::default()
256                .set_code(status.code)
257                .set_message(status.message.clone());
258            return Some(Error::service(rpc_status));
259        }
260        None
261    }
262}
263
264/// A scope-bound guard that manages the state of a lazy transaction start attempt.
265///
266/// If the first statement in a transaction is executed using an inline `BeginTransaction` option,
267/// the transaction selector is transitioned to the `Starting` state.
268/// If that initial statement execution fails, or if the transaction ID is not successfully returned,
269/// we must reset the starting state back to `NotStarted` and unlock any concurrent threads waiting
270/// for this transaction to start.
271///
272/// This struct implements the RAII pattern:
273/// - It is initialized with `active = true` when the statement is starting the transaction.
274/// - If the transaction successfully starts and yields a valid ID, the guard is `disarm()`ed.
275/// - If the scope exits early due to an error (e.g., aborted error, protocol error, etc.), the guard
276///   is dropped, and its `Drop` implementation automatically calls `maybe_reset_starting()` to
277///   restore the selector state and notify waiters.
278struct LazyTransactionStartGuard {
279    selector: ReadContextTransactionSelector,
280    active: bool,
281}
282
283impl LazyTransactionStartGuard {
284    fn new(selector: ReadContextTransactionSelector, active: bool) -> Self {
285        Self { selector, active }
286    }
287
288    fn disarm(&mut self) {
289        self.active = false;
290    }
291}
292
293impl Drop for LazyTransactionStartGuard {
294    fn drop(&mut self) {
295        if self.active {
296            self.selector.maybe_reset_starting();
297        }
298    }
299}
300
301/// Helper macro to execute a DML or BatchDML RPC with retry logic if the
302/// request included a BeginTransaction option.
303macro_rules! execute_with_retry {
304    ($self:expr, $request:ident, $gax_options:expr, $rpc_method:ident, $extract_id:expr) => {{
305        let is_starting = matches!(
306            $request
307                .transaction
308                .as_ref()
309                .and_then(|t| t.selector.as_ref()),
310            Some(Selector::Begin(_))
311        );
312
313        let mut guard =
314            LazyTransactionStartGuard::new($self.context.transaction_selector.clone(), is_starting);
315
316        let response_result = $self
317            .context
318            .client
319            .spanner
320            .$rpc_method(
321                $request.clone(),
322                $gax_options.clone(),
323                $self.context.channel_hint,
324            )
325            .await;
326
327        let service_error = response_result
328            .as_ref()
329            .ok()
330            .and_then(|res| res.check_service_error());
331        let err_ref = response_result.as_ref().err().or(service_error.as_ref());
332
333        let response = match err_ref {
334            None => {
335                let response = response_result?;
336                if is_starting {
337                    let id = $extract_id(&response).ok_or_else(|| {
338                        internal_error("Transaction ID was not returned by Spanner")
339                    })?;
340                    $self.context.transaction_selector.update(id, None)?;
341                    guard.disarm();
342                }
343                response
344            }
345            Some(error) => {
346                if !is_starting {
347                    response_result?
348                } else if is_aborted(error) {
349                    response_result?
350                } else {
351                    $self.begin_explicitly_if_not_started(true, None).await?;
352
353                    $request.transaction =
354                        Some($self.context.transaction_selector.selector().await?);
355
356                    let res = $self
357                        .context
358                        .client
359                        .spanner
360                        .$rpc_method($request.clone(), $gax_options, $self.context.channel_hint)
361                        .await?;
362
363                    guard.disarm();
364                    res
365                }
366            }
367        };
368
369        response
370    }};
371}
372
373/// A read-write transaction.
374#[derive(Clone, Debug)]
375pub struct ReadWriteTransaction {
376    pub(crate) context: ReadContext,
377    pub(crate) deadline: Option<Instant>,
378    seqno: Arc<AtomicI64>,
379    max_commit_delay: Option<Duration>,
380    return_commit_stats: bool,
381    commit_priority: Priority,
382    mutations: Arc<Mutex<Vec<ProtoMutation>>>,
383    begin_gax_options: Option<crate::RequestOptions>,
384    commit_gax_options: Option<crate::RequestOptions>,
385}
386
387impl ReadWriteTransaction {
388    /// Buffers one or more mutations to be applied when the transaction commits.
389    ///
390    /// # Example
391    /// ```
392    /// # use google_cloud_spanner::client::Spanner;
393    /// # use google_cloud_spanner::mutation::Mutation;
394    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
395    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
396    /// let runner = db_client.read_write_transaction().build().await?;
397    /// runner.run(async |tx| {
398    ///     let mutation = Mutation::new_insert_builder("users")
399    ///         .set("id").to(&1)
400    ///         .build();
401    ///     tx.buffer([mutation])?;
402    ///     Ok(())
403    /// }).await?;
404    /// # Ok(())
405    /// # }
406    /// ```
407    pub fn buffer<I>(&self, mutations: I) -> crate::Result<()>
408    where
409        I: IntoIterator<Item = Mutation>,
410    {
411        let mut guard = self
412            .mutations
413            .lock()
414            .map_err(|_| crate::error::internal_error("mutations mutex poisoned"))?;
415        for mutation in mutations {
416            guard.push(mutation.build_proto());
417        }
418        Ok(())
419    }
420
421    /// Executes a query using this transaction.
422    pub async fn execute_query<T: Into<Statement>>(
423        &self,
424        statement: T,
425    ) -> crate::Result<ResultSet> {
426        let stmt = statement.into();
427        let mut gax_options = stmt.gax_options().clone();
428        self.amend_gax_options(&mut gax_options);
429        let stmt = stmt.with_gax_options(gax_options);
430        self.context.execute_query(stmt).await
431    }
432
433    /// Reads rows from the database using key lookups and scans, as a simple key/value style alternative to execute_query.
434    pub async fn execute_read<T: Into<crate::read::ReadRequest>>(
435        &self,
436        read: T,
437    ) -> crate::Result<ResultSet> {
438        let mut req = read.into();
439        self.amend_gax_options(&mut req.gax_options);
440        self.context.execute_read(req).await
441    }
442
443    /// Executes an update using this transaction.
444    pub async fn execute_update<T: Into<Statement>>(&self, statement: T) -> crate::Result<i64> {
445        let statement = statement.into();
446        let mut gax_options = statement.gax_options().clone();
447        self.amend_gax_options(&mut gax_options);
448        let seqno = self.seqno.fetch_add(1, Ordering::SeqCst);
449        let mut request = statement
450            .into_request()
451            .set_session(self.context.session_name.clone())
452            .set_transaction(self.context.transaction_selector.selector().await?)
453            .set_seqno(seqno);
454        request.request_options = self.context.amend_request_options(request.request_options);
455
456        let response = execute_with_retry!(
457            self,
458            request,
459            gax_options,
460            execute_sql,
461            |response: &crate::model::ResultSet| {
462                response
463                    .metadata
464                    .as_ref()
465                    .and_then(|md| md.transaction.as_ref())
466                    .map(|t| t.id.clone())
467            }
468        );
469
470        self.context
471            .precommit_token_tracker
472            .update(response.precommit_token);
473
474        let stats = response
475            .stats
476            .ok_or_else(|| internal_error("No stats returned"))?;
477        match stats.row_count {
478            Some(RowCount::RowCountExact(c)) => Ok(c),
479            _ => Err(internal_error(
480                "ExecuteSql returned an invalid or missing row count type for a read/write transaction",
481            )),
482        }
483    }
484
485    /// Executes a batch of DML statements using this transaction.
486    ///
487    /// # Example
488    /// ```
489    /// # use google_cloud_spanner::client::Spanner;
490    /// # use google_cloud_spanner::statement::Statement;
491    /// # use google_cloud_spanner::batch::BatchDml;
492    /// # async fn build(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
493    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
494    /// let runner = db_client.read_write_transaction().build().await?;
495    /// let result = runner.run(async |transaction| {
496    ///     let statement1 = Statement::builder("UPDATE users SET active = true WHERE id = @id")
497    ///         .add_param("id", &1)
498    ///         .build();
499    ///     let statement2 = Statement::builder("UPDATE users SET active = true WHERE id = @id")
500    ///         .add_param("id", &2)
501    ///         .build();
502    ///     let batch = BatchDml::builder()
503    ///         .add_statement(statement1)
504    ///         .add_statement(statement2)
505    ///         .build();
506    ///     let update_counts = transaction.execute_batch_update(batch).await?;
507    ///     Ok(())
508    /// }).await?;
509    /// # Ok(())
510    /// # }
511    /// ```
512    ///
513    /// If a `BatchDml` request fails halfway through execution, `execute_batch_update` will return a
514    /// `BatchUpdateError` indicating exactly which statements succeeded (and their respective update counts)
515    /// before the batch execution failed.
516    ///
517    /// # Error Handling Example
518    /// ```
519    /// # use google_cloud_spanner::client::Spanner;
520    /// # use google_cloud_spanner::statement::Statement;
521    /// # use google_cloud_spanner::batch::BatchDml;
522    /// # use google_cloud_spanner::error::BatchUpdateError;
523    /// # async fn build(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
524    /// # let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
525    /// # let runner = db_client.read_write_transaction().build().await?;
526    /// # let result = runner.run(async |transaction| {
527    /// let statement1 = Statement::builder("UPDATE users SET active = true WHERE id = 1").build();
528    /// let statement2 = Statement::builder("UPDATE non_existent_table SET active = true WHERE id = 2").build();
529    ///
530    /// let batch = BatchDml::builder()
531    ///     .add_statement(statement1)
532    ///     .add_statement(statement2)
533    ///     .build();
534    ///
535    /// match transaction.execute_batch_update(batch).await {
536    ///     Ok(update_counts) => {
537    ///         println!("All statements succeeded. Update counts: {:?}", update_counts);
538    ///     }
539    ///     Err(e) => {
540    ///         if let Some(batch_error) = BatchUpdateError::extract(&e) {
541    ///             println!("Batch execution failed. Successful update counts: {:?}", batch_error.update_counts);
542    ///         } else {
543    ///             println!("RPC failed or internal error occurred: {}", e);
544    ///         }
545    ///     }
546    /// }
547    /// # Ok(())
548    /// # }).await?;
549    /// # Ok(())
550    /// # }
551    /// ```
552    pub async fn execute_batch_update<T: Into<BatchDml>>(
553        &self,
554        batch: T,
555    ) -> crate::Result<Vec<i64>> {
556        let mut batch = batch.into();
557        self.amend_gax_options(&mut batch.gax_options);
558        let seqno = self.seqno.fetch_add(1, Ordering::SeqCst);
559
560        let statements: Vec<ExecuteBatchDmlStatement> = batch
561            .statements
562            .into_iter()
563            .map(|stmt: crate::statement::Statement| stmt.into_batch_statement())
564            .collect();
565
566        let mut request = ExecuteBatchDmlRequest::default()
567            .set_session(self.context.session_name.clone())
568            .set_transaction(self.context.transaction_selector.selector().await?)
569            .set_seqno(seqno)
570            .set_statements(statements)
571            .set_or_clear_request_options(
572                self.context.amend_request_options(batch.request_options),
573            );
574
575        let response = execute_with_retry!(
576            self,
577            request,
578            batch.gax_options,
579            execute_batch_dml,
580            |response: &crate::model::ExecuteBatchDmlResponse| {
581                response
582                    .result_sets
583                    .first()
584                    .and_then(|rs| rs.metadata.as_ref())
585                    .and_then(|md| md.transaction.as_ref())
586                    .map(|t| t.id.clone())
587            }
588        );
589
590        self.context
591            .precommit_token_tracker
592            .update(response.precommit_token.clone());
593        crate::batch_dml::process_response(response)
594    }
595
596    pub(crate) async fn begin_explicitly_if_not_started(
597        &self,
598        is_stream_fallback: bool,
599        mutation_key: Option<crate::model::Mutation>,
600    ) -> crate::Result<bool> {
601        let mut begin_options = self.begin_gax_options.clone().unwrap_or_default();
602        self.amend_gax_options(&mut begin_options);
603        self.context
604            .begin_explicitly_if_not_started(begin_options, is_stream_fallback, mutation_key)
605            .await
606    }
607
608    pub(crate) fn is_starting(&self) -> crate::Result<bool> {
609        self.context.transaction_selector.is_starting()
610    }
611
612    fn commit_request_options(&self) -> Option<crate::model::RequestOptions> {
613        let mut options = self.context.amend_request_options(None);
614        if self.commit_priority != Priority::Unspecified {
615            options
616                .get_or_insert_with(crate::model::RequestOptions::default)
617                .priority = self.commit_priority.clone();
618        }
619        options
620    }
621
622    fn build_commit_request(
623        &self,
624        transaction_id: bytes::Bytes,
625        mutations: Vec<ProtoMutation>,
626        precommit_token: Option<crate::model::MultiplexedSessionPrecommitToken>,
627    ) -> CommitRequest {
628        create_commit_request(
629            self.context.session_name.clone(),
630            transaction_id,
631            mutations,
632            precommit_token,
633            self.commit_request_options(),
634            self.max_commit_delay,
635            self.return_commit_stats,
636        )
637    }
638
639    /// Commits the transaction.
640    pub(crate) async fn commit(self) -> crate::Result<CommitResponse> {
641        let mutations = take(&mut *self.mutations.lock().unwrap());
642        let mut id = self.context.transaction_selector.get_id_no_wait()?;
643        if id.is_none() {
644            if self.is_starting()? {
645                return Err(crate::error::internal_error(
646                    "Commit called while an asynchronous statement is still starting the transaction",
647                ));
648            }
649            let mutation_key = Mutation::select_mutation_key(&mutations);
650            if self
651                .begin_explicitly_if_not_started(false, mutation_key)
652                .await?
653            {
654                id = self.context.transaction_selector.get_id_no_wait()?;
655            }
656        }
657        let transaction_id = id.ok_or_else(|| internal_error("Transaction ID is missing"))?;
658        let precommit_token = self.context.precommit_token_tracker.get();
659
660        let request = self.build_commit_request(transaction_id.clone(), mutations, precommit_token);
661
662        let mut gax_options = self.commit_gax_options.clone().unwrap_or_default();
663        self.amend_gax_options(&mut gax_options);
664
665        let response = self
666            .context
667            .client
668            .spanner
669            .commit(request, gax_options, self.context.channel_hint)
670            .await?;
671
672        let response =
673            if let Some(new_precommit_token) = response.precommit_token().map(|b| (*b).clone()) {
674                let retry_commit_req = self.build_commit_request(
675                    transaction_id,
676                    Vec::new(), // mutations are never re-sent in retry requests
677                    Some(*new_precommit_token),
678                );
679
680                let mut gax_options = self.commit_gax_options.clone().unwrap_or_default();
681                self.amend_gax_options(&mut gax_options);
682
683                self.context
684                    .client
685                    .spanner
686                    .commit(retry_commit_req, gax_options, self.context.channel_hint)
687                    .await?
688            } else {
689                response
690            };
691
692        Ok(response)
693    }
694
695    /// Rolls back the transaction.
696    pub(crate) async fn rollback(self) -> crate::Result<()> {
697        let Some(transaction_id) = self.context.transaction_selector.get_id_no_wait()? else {
698            return Ok(());
699        };
700
701        let request = RollbackRequest::default()
702            .set_session(self.context.session_name.clone())
703            .set_transaction_id(transaction_id);
704
705        let mut gax_options = RequestOptions::default();
706        self.amend_gax_options(&mut gax_options);
707
708        self.context
709            .client
710            .spanner
711            .rollback(request, gax_options, self.context.channel_hint)
712            .await?;
713
714        Ok(())
715    }
716
717    fn amend_gax_options(&self, options: &mut GaxRequestOptions) {
718        amend_gax_options(
719            self.context.client.leader_aware_routing_enabled,
720            self.deadline,
721            options,
722        );
723    }
724}
725
726pub(crate) fn amend_gax_options(
727    leader_aware_routing_enabled: bool,
728    deadline: Option<Instant>,
729    options: &mut GaxRequestOptions,
730) {
731    if let Some(deadline) = deadline {
732        let remaining = deadline.saturating_duration_since(Instant::now());
733        let attempt_timeout = match options.attempt_timeout() {
734            Some(custom_timeout) => std::cmp::min(*custom_timeout, remaining),
735            None => remaining,
736        };
737        options.set_attempt_timeout(attempt_timeout);
738
739        let inner_policy = options
740            .retry_policy()
741            .clone()
742            .unwrap_or_else(|| Arc::new(SpannerRetryPolicy::new()));
743        let bounded_policy = TransactionBoundedRetryPolicy {
744            inner: inner_policy,
745            deadline,
746        };
747        options.set_retry_policy(bounded_policy);
748    }
749    *options = amend_request_options_for_lar(leader_aware_routing_enabled, take(options));
750}
751
752/// A retry policy that wraps another policy and bounds the total execution time
753/// by a specific transaction deadline.
754///
755/// This policy delegates `on_error` to the inner policy but overrides `remaining_time`
756/// to ensure that it never exceeds the time left until the transaction deadline.
757#[derive(Debug)]
758struct TransactionBoundedRetryPolicy {
759    inner: Arc<dyn RetryPolicy>,
760    deadline: Instant,
761}
762
763impl RetryPolicy for TransactionBoundedRetryPolicy {
764    fn on_error(&self, state: &RetryState, error: GaxError) -> RetryResult {
765        self.inner.on_error(state, error)
766    }
767
768    fn on_throttle(&self, state: &RetryState, error: GaxError) -> ThrottleResult {
769        self.inner.on_throttle(state, error)
770    }
771
772    fn remaining_time(&self, state: &RetryState) -> Option<StdDuration> {
773        let remaining = self.deadline.saturating_duration_since(Instant::now());
774        let attempt_timeout = self
775            .inner
776            .remaining_time(state)
777            .map(|inner| min(remaining, inner))
778            .unwrap_or(remaining);
779        Some(attempt_timeout)
780    }
781}
782
783#[cfg(test)]
784mod tests {
785    use super::*;
786    use crate::error::BatchUpdateError;
787    use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
788    use crate::result_set::tests::adapt;
789    use crate::transaction_retry_policy::BasicTransactionRetryPolicy;
790    use gaxi::grpc::tonic;
791    use gaxi::grpc::tonic::MetadataMap;
792    use google_cloud_gax::options::internal::RequestOptionsExt as _;
793    use google_cloud_gax::retry_policy::NeverRetry;
794    use google_cloud_gax::retry_result::RetryResult;
795    use google_cloud_gax::retry_state::RetryState;
796    use google_cloud_test_macros::tokio_test_no_panics;
797    use http::HeaderMap;
798    use prost_types::Timestamp;
799    use spanner_grpc_mock::google::spanner::v1;
800    use std::fmt::Debug;
801    use std::sync::Mutex;
802    use std::time::Duration as StdDuration;
803
804    #[test]
805    fn auto_traits() {
806        static_assertions::assert_impl_all!(ReadWriteTransactionBuilder: Send, Sync, Clone, Debug);
807        static_assertions::assert_impl_all!(ReadWriteTransaction: Send, Sync, Debug);
808    }
809
810    #[tokio_test_no_panics]
811    async fn read_write_transaction_commit_retry_explicit() -> anyhow::Result<()> {
812        run_read_write_transaction_commit_retry(BeginTransactionOption::ExplicitBegin).await
813    }
814
815    #[tokio_test_no_panics]
816    async fn read_write_transaction_commit_retry_inline() -> anyhow::Result<()> {
817        run_read_write_transaction_commit_retry(BeginTransactionOption::InlineBegin).await
818    }
819
820    async fn run_read_write_transaction_commit_retry(
821        begin_transaction_option: BeginTransactionOption,
822    ) -> anyhow::Result<()> {
823        let mut mock = create_session_mock();
824        let remotes = Arc::new(Mutex::new(Vec::new()));
825
826        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
827            let remotes_clone = remotes.clone();
828            mock.expect_begin_transaction()
829                .once()
830                .returning(move |req| {
831                    remotes_clone
832                        .lock()
833                        .unwrap()
834                        .push(req.remote_addr().expect("remote_addr should be available"));
835                    let req = req.into_inner();
836                    assert_eq!(
837                        req.session,
838                        "projects/p/instances/i/databases/d/sessions/123"
839                    );
840                    Ok(tonic::Response::new(v1::Transaction {
841                        id: vec![0, 0, 7],
842                        ..Default::default()
843                    }))
844                });
845        }
846
847        // execute_update returns a precommit token.
848        let remotes_clone = remotes.clone();
849        mock.expect_execute_sql().once().returning(move |req| {
850            remotes_clone
851                .lock()
852                .unwrap()
853                .push(req.remote_addr().expect("remote_addr should be available"));
854            let req = req.into_inner();
855            assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1");
856
857            if begin_transaction_option == BeginTransactionOption::InlineBegin {
858                let transaction = req
859                    .transaction
860                    .as_ref()
861                    .expect("transaction options required for inline begin");
862                let selector = transaction.selector.as_ref().expect("selector required");
863                assert!(matches!(
864                    selector,
865                    v1::transaction_selector::Selector::Begin(_)
866                ));
867            }
868
869            let mut metadata = v1::ResultSetMetadata {
870                row_type: Some(v1::StructType { fields: vec![] }),
871                ..Default::default()
872            };
873            if begin_transaction_option == BeginTransactionOption::InlineBegin {
874                metadata.transaction = Some(v1::Transaction {
875                    id: vec![0, 0, 7],
876                    ..Default::default()
877                });
878            }
879
880            Ok(tonic::Response::new(v1::ResultSet {
881                metadata: Some(metadata),
882                stats: Some(v1::ResultSetStats {
883                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
884                    ..Default::default()
885                }),
886                precommit_token: Some(v1::MultiplexedSessionPrecommitToken {
887                    precommit_token: vec![101],
888                    seq_num: 1,
889                }),
890                ..Default::default()
891            }))
892        });
893
894        // Simulate that commit returns a precommit token in the response.
895        // This would normally not happen, but we test it here to verify
896        // that the commit is retried.
897        let remotes_clone = remotes.clone();
898        mock.expect_commit().once().returning(move |req| {
899            remotes_clone
900                .lock()
901                .unwrap()
902                .push(req.remote_addr().expect("remote_addr should be available"));
903            let req = req.into_inner();
904            assert_eq!(
905                req.precommit_token,
906                Some(v1::MultiplexedSessionPrecommitToken {
907                    precommit_token: vec![101],
908                    seq_num: 1,
909                })
910            );
911            Ok(tonic::Response::new(v1::CommitResponse {
912                commit_timestamp: Some(prost_types::Timestamp {
913                    seconds: 1000,
914                    nanos: 0,
915                }),
916                multiplexed_session_retry: Some(
917                    v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
918                        v1::MultiplexedSessionPrecommitToken {
919                            precommit_token: vec![202],
920                            seq_num: 2,
921                        },
922                    ),
923                ),
924                ..Default::default()
925            }))
926        });
927
928        // Second commit retry is automatically issued with the new token
929        let remotes_clone = remotes.clone();
930        mock.expect_commit().once().returning(move |req| {
931            remotes_clone
932                .lock()
933                .unwrap()
934                .push(req.remote_addr().expect("remote_addr should be available"));
935            let req = req.into_inner();
936            assert_eq!(
937                req.precommit_token,
938                Some(v1::MultiplexedSessionPrecommitToken {
939                    precommit_token: vec![202],
940                    seq_num: 2,
941                })
942            );
943            assert!(
944                req.mutations.is_empty(),
945                "Expected mutations to be empty in retried CommitRequest"
946            );
947            Ok(tonic::Response::new(v1::CommitResponse {
948                commit_timestamp: Some(prost_types::Timestamp {
949                    seconds: 1001,
950                    nanos: 0,
951                }),
952                ..Default::default()
953            }))
954        });
955
956        let (db_client, _server) = setup_db_client(mock).await;
957
958        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
959            .with_begin_transaction_option(begin_transaction_option)
960            .build(None)
961            .await
962            .expect("Failed to build transaction");
963
964        let count = tx
965            .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1")
966            .await?;
967        assert_eq!(count, 1);
968
969        let timestamp = tx.commit().await?;
970        assert_eq!(
971            timestamp
972                .commit_timestamp
973                .as_ref()
974                .expect("timestamp should be present")
975                .seconds(),
976            1001
977        );
978
979        // Verify that all RPCs used the same channel (same remote address)
980        let remotes = remotes.lock().unwrap();
981        let expected_rpcs = if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
982            4
983        } else {
984            3
985        };
986        assert_eq!(
987            remotes.len(),
988            expected_rpcs,
989            "Expected exactly {} RPCs",
990            expected_rpcs
991        );
992        let first = remotes[0];
993        for addr in remotes.iter() {
994            assert_eq!(*addr, first, "All RPCs should use the same gRPC channel");
995        }
996
997        Ok(())
998    }
999
1000    #[tokio_test_no_panics]
1001    async fn read_write_transaction_commit_retry_preserves_options() -> anyhow::Result<()> {
1002        let mut mock = create_session_mock();
1003
1004        // execute_update returns a precommit token.
1005        mock.expect_execute_sql().once().returning(move |req| {
1006            let req = req.into_inner();
1007            assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1");
1008
1009            let mut metadata = v1::ResultSetMetadata {
1010                row_type: Some(v1::StructType { fields: vec![] }),
1011                ..Default::default()
1012            };
1013            metadata.transaction = Some(v1::Transaction {
1014                id: vec![0, 0, 7],
1015                ..Default::default()
1016            });
1017
1018            Ok(tonic::Response::new(v1::ResultSet {
1019                metadata: Some(metadata),
1020                stats: Some(v1::ResultSetStats {
1021                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1022                    ..Default::default()
1023                }),
1024                precommit_token: Some(v1::MultiplexedSessionPrecommitToken {
1025                    precommit_token: vec![101],
1026                    seq_num: 1,
1027                }),
1028                ..Default::default()
1029            }))
1030        });
1031
1032        // Simulate that commit returns a precommit token in the response.
1033        // This would normally not happen, but we test it here to verify
1034        // that the commit is retried and the options are preserved.
1035        let expected_delay = prost_types::Duration {
1036            seconds: 0,
1037            nanos: 200_000_000,
1038        };
1039
1040        let expected_delay_clone = expected_delay;
1041        mock.expect_commit().once().returning(move |req| {
1042            let req = req.into_inner();
1043            assert_eq!(
1044                req.precommit_token,
1045                Some(v1::MultiplexedSessionPrecommitToken {
1046                    precommit_token: vec![101],
1047                    seq_num: 1,
1048                })
1049            );
1050            // Assert original options are present
1051            assert!(
1052                req.return_commit_stats,
1053                "Expected return_commit_stats to be true in first commit"
1054            );
1055            assert_eq!(
1056                req.max_commit_delay.as_ref(),
1057                Some(&expected_delay_clone),
1058                "Expected max_commit_delay to be set in first commit"
1059            );
1060
1061            Ok(tonic::Response::new(v1::CommitResponse {
1062                commit_timestamp: Some(prost_types::Timestamp {
1063                    seconds: 1000,
1064                    nanos: 0,
1065                }),
1066                multiplexed_session_retry: Some(
1067                    v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
1068                        v1::MultiplexedSessionPrecommitToken {
1069                            precommit_token: vec![202],
1070                            seq_num: 2,
1071                        },
1072                    ),
1073                ),
1074                ..Default::default()
1075            }))
1076        });
1077
1078        // Second commit retry is automatically issued with the new token and MUST preserve original options
1079        mock.expect_commit().once().returning(move |req| {
1080            let req = req.into_inner();
1081            assert_eq!(
1082                req.precommit_token,
1083                Some(v1::MultiplexedSessionPrecommitToken {
1084                    precommit_token: vec![202],
1085                    seq_num: 2,
1086                })
1087            );
1088            assert!(
1089                req.return_commit_stats,
1090                "Expected return_commit_stats to be preserved in retried commit request"
1091            );
1092            assert_eq!(
1093                req.max_commit_delay.as_ref(),
1094                Some(&expected_delay),
1095                "Expected max_commit_delay to be preserved in retried commit request"
1096            );
1097            assert!(
1098                req.mutations.is_empty(),
1099                "Expected mutations to be empty in retried CommitRequest"
1100            );
1101
1102            Ok(tonic::Response::new(v1::CommitResponse {
1103                commit_timestamp: Some(prost_types::Timestamp {
1104                    seconds: 1001,
1105                    nanos: 0,
1106                }),
1107                ..Default::default()
1108            }))
1109        });
1110
1111        let (db_client, _server) = setup_db_client(mock).await;
1112
1113        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1114            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
1115            .set_return_commit_stats(true)
1116            .set_max_commit_delay(Duration::new(0, 200_000_000).expect("valid duration"))
1117            .build(None)
1118            .await
1119            .expect("Failed to build transaction");
1120
1121        let count = tx
1122            .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1")
1123            .await?;
1124        assert_eq!(count, 1);
1125
1126        let timestamp = tx.commit().await?;
1127        assert_eq!(
1128            timestamp
1129                .commit_timestamp
1130                .as_ref()
1131                .expect("timestamp should be present")
1132                .seconds(),
1133            1001
1134        );
1135
1136        Ok(())
1137    }
1138
1139    #[tokio_test_no_panics]
1140    async fn read_write_transaction_commit_carries_commit_priority() -> anyhow::Result<()> {
1141        let mut mock = create_session_mock();
1142
1143        mock.expect_execute_sql().once().returning(move |_req| {
1144            let mut metadata = v1::ResultSetMetadata {
1145                row_type: Some(v1::StructType { fields: vec![] }),
1146                ..Default::default()
1147            };
1148            metadata.transaction = Some(v1::Transaction {
1149                id: vec![1, 2, 3],
1150                ..Default::default()
1151            });
1152
1153            Ok(tonic::Response::new(v1::ResultSet {
1154                metadata: Some(metadata),
1155                stats: Some(v1::ResultSetStats {
1156                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1157                    ..Default::default()
1158                }),
1159                ..Default::default()
1160            }))
1161        });
1162
1163        mock.expect_commit().once().returning(|req| {
1164            let req = req.into_inner();
1165            let request_options = req
1166                .request_options
1167                .expect("Expected request_options in CommitRequest");
1168            assert_eq!(
1169                request_options.priority,
1170                v1::request_options::Priority::Low as i32,
1171                "Expected priority to be Priority::Low in CommitRequest"
1172            );
1173
1174            Ok(tonic::Response::new(v1::CommitResponse {
1175                commit_timestamp: Some(prost_types::Timestamp {
1176                    seconds: 123456789,
1177                    nanos: 0,
1178                }),
1179                ..Default::default()
1180            }))
1181        });
1182
1183        let (db_client, _server) = setup_db_client(mock).await;
1184
1185        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1186            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
1187            .set_commit_priority(Priority::Low)
1188            .build(None)
1189            .await
1190            .expect("Failed to build transaction");
1191
1192        let count = tx
1193            .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1194            .await?;
1195        assert_eq!(count, 1);
1196
1197        let _ = tx.commit().await?;
1198
1199        Ok(())
1200    }
1201
1202    #[tokio_test_no_panics]
1203    async fn read_write_transaction_execute_update_explicit() {
1204        run_read_write_transaction_execute_update(BeginTransactionOption::ExplicitBegin).await;
1205    }
1206
1207    #[tokio_test_no_panics]
1208    async fn read_write_transaction_execute_update_inline() {
1209        run_read_write_transaction_execute_update(BeginTransactionOption::InlineBegin).await;
1210    }
1211
1212    async fn run_read_write_transaction_execute_update(
1213        begin_transaction_option: BeginTransactionOption,
1214    ) {
1215        let mut mock = create_session_mock();
1216
1217        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1218            mock.expect_begin_transaction().once().returning(|req| {
1219                let req = req.into_inner();
1220                assert_eq!(
1221                    req.session,
1222                    "projects/p/instances/i/databases/d/sessions/123"
1223                );
1224                Ok(tonic::Response::new(v1::Transaction {
1225                    id: vec![1, 2, 3],
1226                    ..Default::default()
1227                }))
1228            });
1229        }
1230
1231        mock.expect_execute_sql().once().returning(move |req| {
1232            let req = req.into_inner();
1233            assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1");
1234            assert_eq!(req.seqno, 1);
1235
1236            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1237                let transaction = req
1238                    .transaction
1239                    .as_ref()
1240                    .expect("transaction options required for inline begin");
1241                let selector = transaction.selector.as_ref().expect("selector required");
1242                assert!(matches!(
1243                    selector,
1244                    v1::transaction_selector::Selector::Begin(_)
1245                ));
1246            }
1247
1248            let mut metadata = v1::ResultSetMetadata {
1249                row_type: Some(v1::StructType { fields: vec![] }),
1250                ..Default::default()
1251            };
1252            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1253                metadata.transaction = Some(v1::Transaction {
1254                    id: vec![1, 2, 3],
1255                    ..Default::default()
1256                });
1257            }
1258
1259            Ok(tonic::Response::new(v1::ResultSet {
1260                metadata: Some(metadata),
1261                stats: Some(v1::ResultSetStats {
1262                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1263                    ..Default::default()
1264                }),
1265                ..Default::default()
1266            }))
1267        });
1268
1269        mock.expect_commit().once().returning(|req| {
1270            let req = req.into_inner();
1271            assert_eq!(
1272                req.session,
1273                "projects/p/instances/i/databases/d/sessions/123"
1274            );
1275            assert_eq!(
1276                req.transaction,
1277                Some(v1::commit_request::Transaction::TransactionId(vec![
1278                    1, 2, 3
1279                ]))
1280            );
1281            Ok(tonic::Response::new(v1::CommitResponse {
1282                commit_timestamp: Some(prost_types::Timestamp {
1283                    seconds: 123456789,
1284                    nanos: 0,
1285                }),
1286                ..Default::default()
1287            }))
1288        });
1289
1290        let (db_client, _server) = setup_db_client(mock).await;
1291
1292        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1293            .with_begin_transaction_option(begin_transaction_option)
1294            .build(None)
1295            .await
1296            .expect("Failed to build transaction");
1297        let count = tx
1298            .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1299            .await
1300            .expect("Failed to execute update");
1301        assert_eq!(count, 1);
1302
1303        let ts = tx.commit().await.expect("Failed to commit");
1304        assert_eq!(
1305            ts.commit_timestamp
1306                .expect("Commit timestamp should be present")
1307                .seconds(),
1308            123456789
1309        );
1310    }
1311
1312    #[tokio_test_no_panics]
1313    async fn read_write_transaction_execute_update_invalid_stats_explicit() -> anyhow::Result<()> {
1314        run_read_write_transaction_execute_update_invalid_stats(
1315            BeginTransactionOption::ExplicitBegin,
1316        )
1317        .await
1318    }
1319
1320    #[tokio_test_no_panics]
1321    async fn read_write_transaction_execute_update_invalid_stats_inline() -> anyhow::Result<()> {
1322        run_read_write_transaction_execute_update_invalid_stats(BeginTransactionOption::InlineBegin)
1323            .await
1324    }
1325
1326    async fn run_read_write_transaction_execute_update_invalid_stats(
1327        begin_transaction_option: BeginTransactionOption,
1328    ) -> anyhow::Result<()> {
1329        let mut mock = create_session_mock();
1330
1331        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1332            mock.expect_begin_transaction().once().returning(|_| {
1333                Ok(tonic::Response::new(v1::Transaction {
1334                    id: vec![1, 2, 3],
1335                    ..Default::default()
1336                }))
1337            });
1338        }
1339
1340        mock.expect_execute_sql().once().returning(move |req| {
1341            let req = req.into_inner();
1342            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1343                let transaction = req
1344                    .transaction
1345                    .as_ref()
1346                    .expect("transaction options required for inline begin");
1347                let selector = transaction.selector.as_ref().expect("selector required");
1348                assert!(matches!(
1349                    selector,
1350                    v1::transaction_selector::Selector::Begin(_)
1351                ));
1352            }
1353
1354            let mut metadata = v1::ResultSetMetadata {
1355                row_type: Some(v1::StructType { fields: vec![] }),
1356                ..Default::default()
1357            };
1358            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1359                metadata.transaction = Some(v1::Transaction {
1360                    id: vec![1, 2, 3],
1361                    ..Default::default()
1362                });
1363            }
1364
1365            Ok(tonic::Response::new(v1::ResultSet {
1366                metadata: Some(metadata),
1367                stats: Some(v1::ResultSetStats {
1368                    row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(1)),
1369                    ..Default::default()
1370                }),
1371                ..Default::default()
1372            }))
1373        });
1374
1375        let (db_client, _server) = setup_db_client(mock).await;
1376
1377        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1378            .with_begin_transaction_option(begin_transaction_option)
1379            .build(None)
1380            .await
1381            .expect("Failed to build transaction");
1382
1383        let result = tx
1384            .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1385            .await;
1386
1387        let err = result.expect_err("Expected an error for invalid row count stats");
1388        assert!(
1389            format!("{:?}", err).contains("invalid or missing row count type"),
1390            "Error did not contain expected message: {:?}",
1391            err
1392        );
1393        Ok(())
1394    }
1395
1396    #[tokio_test_no_panics]
1397    async fn read_write_transaction_rollback_explicit() -> anyhow::Result<()> {
1398        run_read_write_transaction_rollback(BeginTransactionOption::ExplicitBegin).await
1399    }
1400
1401    #[tokio_test_no_panics]
1402    async fn read_write_transaction_rollback_inline() -> anyhow::Result<()> {
1403        run_read_write_transaction_rollback(BeginTransactionOption::InlineBegin).await
1404    }
1405
1406    async fn run_read_write_transaction_rollback(
1407        begin_transaction_option: BeginTransactionOption,
1408    ) -> anyhow::Result<()> {
1409        let mut mock = create_session_mock();
1410        let transaction_id = vec![9, 9, 9];
1411
1412        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1413            let id = transaction_id.clone();
1414            mock.expect_begin_transaction().once().returning(move |_| {
1415                Ok(tonic::Response::new(v1::Transaction {
1416                    id: id.clone(),
1417                    ..Default::default()
1418                }))
1419            });
1420        } else {
1421            let id = transaction_id.clone();
1422            mock.expect_execute_sql().once().returning(move |req| {
1423                let req = req.into_inner();
1424                let transaction = req
1425                    .transaction
1426                    .as_ref()
1427                    .expect("transaction options required for inline begin");
1428                let selector = transaction.selector.as_ref().expect("selector required");
1429                assert!(matches!(
1430                    selector,
1431                    v1::transaction_selector::Selector::Begin(_)
1432                ));
1433
1434                Ok(tonic::Response::new(v1::ResultSet {
1435                    metadata: Some(v1::ResultSetMetadata {
1436                        transaction: Some(v1::Transaction {
1437                            id: id.clone(),
1438                            ..Default::default()
1439                        }),
1440                        ..Default::default()
1441                    }),
1442                    stats: Some(v1::ResultSetStats {
1443                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1444                        ..Default::default()
1445                    }),
1446                    ..Default::default()
1447                }))
1448            });
1449        }
1450
1451        let id = transaction_id.clone();
1452        mock.expect_rollback().once().returning(move |req| {
1453            let req = req.into_inner();
1454            assert_eq!(
1455                req.session,
1456                "projects/p/instances/i/databases/d/sessions/123"
1457            );
1458            assert_eq!(req.transaction_id, id);
1459            Ok(tonic::Response::new(()))
1460        });
1461
1462        let (db_client, _server) = setup_db_client(mock).await;
1463
1464        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1465            .with_begin_transaction_option(begin_transaction_option)
1466            .build(None)
1467            .await?;
1468
1469        if begin_transaction_option == BeginTransactionOption::InlineBegin {
1470            tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1471                .await
1472                .expect("Failed to execute update");
1473        }
1474
1475        tx.rollback().await?;
1476        Ok(())
1477    }
1478
1479    #[tokio_test_no_panics]
1480    async fn read_write_transaction_execute_batch_update_explicit() -> anyhow::Result<()> {
1481        let batch = BatchDml::builder()
1482            .add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1483            .add_statement("UPDATE Users SET Name = 'Bob' WHERE Id = 2")
1484            .build();
1485        run_read_write_transaction_execute_batch_update(
1486            BeginTransactionOption::ExplicitBegin,
1487            batch,
1488        )
1489        .await
1490    }
1491
1492    #[tokio_test_no_panics]
1493    async fn read_write_transaction_execute_batch_update_inline() -> anyhow::Result<()> {
1494        let batch = BatchDml::builder()
1495            .add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1496            .add_statement("UPDATE Users SET Name = 'Bob' WHERE Id = 2")
1497            .build();
1498        run_read_write_transaction_execute_batch_update(BeginTransactionOption::InlineBegin, batch)
1499            .await
1500    }
1501
1502    #[tokio_test_no_panics]
1503    async fn read_write_transaction_execute_batch_update_vec() -> anyhow::Result<()> {
1504        let statements = vec![
1505            "UPDATE Users SET Name = 'Alice' WHERE Id = 1",
1506            "UPDATE Users SET Name = 'Bob' WHERE Id = 2",
1507        ];
1508        run_read_write_transaction_execute_batch_update(
1509            BeginTransactionOption::InlineBegin,
1510            statements,
1511        )
1512        .await
1513    }
1514
1515    #[tokio_test_no_panics]
1516    async fn read_write_transaction_execute_batch_update_vec_statement() -> anyhow::Result<()> {
1517        let statement1 = Statement::builder("UPDATE Users SET Name = 'Alice' WHERE Id = 1").build();
1518        let statement2 = Statement::builder("UPDATE Users SET Name = 'Bob' WHERE Id = 2").build();
1519        let statements = vec![statement1, statement2];
1520        run_read_write_transaction_execute_batch_update(
1521            BeginTransactionOption::InlineBegin,
1522            statements,
1523        )
1524        .await
1525    }
1526
1527    async fn run_read_write_transaction_execute_batch_update(
1528        begin_transaction_option: BeginTransactionOption,
1529        batch: impl Into<BatchDml>,
1530    ) -> anyhow::Result<()> {
1531        let mut mock = create_session_mock();
1532
1533        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1534            mock.expect_begin_transaction().once().returning(|_| {
1535                Ok(tonic::Response::new(v1::Transaction {
1536                    id: vec![4, 5, 6],
1537                    ..Default::default()
1538                }))
1539            });
1540        }
1541
1542        mock.expect_execute_batch_dml()
1543            .once()
1544            .returning(move |req| {
1545                let req = req.into_inner();
1546                assert_eq!(req.statements.len(), 2);
1547                assert_eq!(
1548                    req.statements[0].sql,
1549                    "UPDATE Users SET Name = 'Alice' WHERE Id = 1"
1550                );
1551                assert_eq!(
1552                    req.statements[1].sql,
1553                    "UPDATE Users SET Name = 'Bob' WHERE Id = 2"
1554                );
1555
1556                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1557                    let selector = req
1558                        .transaction
1559                        .expect("missing transaction selector")
1560                        .selector
1561                        .expect("missing selector");
1562                    assert!(matches!(
1563                        selector,
1564                        v1::transaction_selector::Selector::Begin(_)
1565                    ));
1566                }
1567
1568                let mut metadata = v1::ResultSetMetadata {
1569                    ..Default::default()
1570                };
1571                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1572                    metadata.transaction = Some(v1::Transaction {
1573                        id: vec![4, 5, 6],
1574                        ..Default::default()
1575                    });
1576                }
1577
1578                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
1579                    result_sets: vec![
1580                        v1::ResultSet {
1581                            metadata: Some(metadata),
1582                            stats: Some(v1::ResultSetStats {
1583                                row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1584                                ..Default::default()
1585                            }),
1586                            ..Default::default()
1587                        },
1588                        v1::ResultSet {
1589                            stats: Some(v1::ResultSetStats {
1590                                row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1591                                ..Default::default()
1592                            }),
1593                            ..Default::default()
1594                        },
1595                    ],
1596                    status: Some(spanner_grpc_mock::google::rpc::Status {
1597                        code: 0,
1598                        message: "OK".into(),
1599                        details: vec![],
1600                    }),
1601                    ..Default::default()
1602                }))
1603            });
1604
1605        let (db_client, _server) = setup_db_client(mock).await;
1606
1607        let transaction = ReadWriteTransactionBuilder::new(db_client)
1608            .with_begin_transaction_option(begin_transaction_option)
1609            .build(None)
1610            .await?;
1611
1612        let counts = transaction.execute_batch_update(batch).await?;
1613
1614        assert_eq!(counts, vec![1, 1]);
1615        Ok(())
1616    }
1617
1618    #[tokio_test_no_panics]
1619    async fn read_write_transaction_execute_batch_update_partial_failure_explicit()
1620    -> anyhow::Result<()> {
1621        run_read_write_transaction_execute_batch_update_partial_failure(
1622            BeginTransactionOption::ExplicitBegin,
1623        )
1624        .await
1625    }
1626
1627    #[tokio_test_no_panics]
1628    async fn read_write_transaction_execute_batch_update_partial_failure_inline()
1629    -> anyhow::Result<()> {
1630        run_read_write_transaction_execute_batch_update_partial_failure(
1631            BeginTransactionOption::InlineBegin,
1632        )
1633        .await
1634    }
1635
1636    async fn run_read_write_transaction_execute_batch_update_partial_failure(
1637        begin_transaction_option: BeginTransactionOption,
1638    ) -> anyhow::Result<()> {
1639        let mut mock = create_session_mock();
1640
1641        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1642            mock.expect_begin_transaction().once().returning(|_| {
1643                Ok(tonic::Response::new(v1::Transaction {
1644                    id: vec![7, 8, 9],
1645                    ..Default::default()
1646                }))
1647            });
1648        }
1649
1650        mock.expect_execute_batch_dml()
1651            .once()
1652            .returning(move |req| {
1653                let req = req.into_inner();
1654                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1655                    let selector = req
1656                        .transaction
1657                        .expect("missing transaction selector")
1658                        .selector
1659                        .expect("missing selector");
1660                    assert!(matches!(
1661                        selector,
1662                        v1::transaction_selector::Selector::Begin(_)
1663                    ));
1664                }
1665
1666                let mut metadata = v1::ResultSetMetadata {
1667                    ..Default::default()
1668                };
1669                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1670                    metadata.transaction = Some(v1::Transaction {
1671                        id: vec![7, 8, 9],
1672                        ..Default::default()
1673                    });
1674                }
1675
1676                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
1677                    result_sets: vec![v1::ResultSet {
1678                        metadata: Some(metadata),
1679                        stats: Some(v1::ResultSetStats {
1680                            row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1681                            ..Default::default()
1682                        }),
1683                        ..Default::default()
1684                    }],
1685                    status: Some(spanner_grpc_mock::google::rpc::Status {
1686                        code: gaxi::grpc::tonic::Code::AlreadyExists as i32,
1687                        message: "row already exists".into(),
1688                        details: vec![],
1689                    }),
1690                    ..Default::default()
1691                }))
1692            });
1693
1694        let (db_client, _server) = setup_db_client(mock).await;
1695
1696        let tx = ReadWriteTransactionBuilder::new(db_client)
1697            .with_begin_transaction_option(begin_transaction_option)
1698            .build(None)
1699            .await?;
1700
1701        let batch = BatchDml::builder()
1702            .add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1703            .add_statement("INSERT INTO Users (Id) VALUES (2)");
1704
1705        let res = tx.execute_batch_update(batch.build()).await;
1706
1707        let err = res.expect_err("expected error");
1708        use std::error::Error;
1709        let batch_err = err
1710            .source()
1711            .and_then(|e| e.downcast_ref::<BatchUpdateError>())
1712            .expect("should be BatchUpdateError");
1713        assert_eq!(batch_err.update_counts, vec![1]);
1714        assert_eq!(
1715            batch_err.status.status().expect("status").code,
1716            (gaxi::grpc::tonic::Code::AlreadyExists as i32).into()
1717        );
1718        Ok(())
1719    }
1720
1721    #[tokio_test_no_panics]
1722    async fn read_write_transaction_execute_multiple_updates_explicit() -> anyhow::Result<()> {
1723        run_read_write_transaction_execute_multiple_updates(BeginTransactionOption::ExplicitBegin)
1724            .await
1725    }
1726
1727    #[tokio_test_no_panics]
1728    async fn read_write_transaction_execute_multiple_updates_inline() -> anyhow::Result<()> {
1729        run_read_write_transaction_execute_multiple_updates(BeginTransactionOption::InlineBegin)
1730            .await
1731    }
1732
1733    async fn run_read_write_transaction_execute_multiple_updates(
1734        begin_transaction_option: BeginTransactionOption,
1735    ) -> anyhow::Result<()> {
1736        let mut mock = create_session_mock();
1737
1738        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1739            mock.expect_begin_transaction().once().returning(|req| {
1740                let req = req.into_inner();
1741                assert_eq!(
1742                    req.session,
1743                    "projects/p/instances/i/databases/d/sessions/123"
1744                );
1745                Ok(tonic::Response::new(v1::Transaction {
1746                    id: vec![4, 5, 6],
1747                    ..Default::default()
1748                }))
1749            });
1750        }
1751
1752        let counter = Arc::new(AtomicI64::new(1));
1753        mock.expect_execute_sql().times(3).returning(move |req| {
1754            let req = req.into_inner();
1755            assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1");
1756            let c = counter.fetch_add(1, Ordering::SeqCst);
1757            assert_eq!(req.seqno, c);
1758
1759            let mut metadata = v1::ResultSetMetadata {
1760                ..Default::default()
1761            };
1762
1763            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1764                if c == 1 {
1765                    let selector = req
1766                        .transaction
1767                        .expect("missing transaction selector")
1768                        .selector
1769                        .expect("missing selector");
1770                    assert!(matches!(
1771                        selector,
1772                        v1::transaction_selector::Selector::Begin(_)
1773                    ));
1774                    metadata.transaction = Some(v1::Transaction {
1775                        id: vec![4, 5, 6],
1776                        ..Default::default()
1777                    });
1778                } else {
1779                    let selector = req
1780                        .transaction
1781                        .expect("missing transaction selector")
1782                        .selector
1783                        .expect("missing selector");
1784                    match selector {
1785                        v1::transaction_selector::Selector::Id(id) => {
1786                            assert_eq!(id, vec![4, 5, 6]);
1787                        }
1788                        _ => panic!("Expected Selector::Id"),
1789                    }
1790                }
1791            }
1792
1793            Ok(tonic::Response::new(v1::ResultSet {
1794                metadata: Some(metadata),
1795                stats: Some(v1::ResultSetStats {
1796                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1797                    ..Default::default()
1798                }),
1799                ..Default::default()
1800            }))
1801        });
1802
1803        let (db_client, _server) = setup_db_client(mock).await;
1804
1805        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1806            .with_begin_transaction_option(begin_transaction_option)
1807            .build(None)
1808            .await
1809            .expect("Failed to build transaction");
1810
1811        for i in 1..=3 {
1812            let count = tx
1813                .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1814                .await
1815                .unwrap_or_else(|_| panic!("Failed to execute update {}", i));
1816            assert_eq!(count, 1);
1817        }
1818        Ok(())
1819    }
1820
1821    #[tokio_test_no_panics]
1822    async fn read_write_transaction_execute_query() {
1823        use crate::statement::Statement;
1824        let mut mock = create_session_mock();
1825
1826        mock.expect_begin_transaction().once().returning(|req| {
1827            let req = req.into_inner();
1828            assert_eq!(
1829                req.session,
1830                "projects/p/instances/i/databases/d/sessions/123"
1831            );
1832            Ok(tonic::Response::new(v1::Transaction {
1833                id: vec![7, 8, 9],
1834                ..Default::default()
1835            }))
1836        });
1837
1838        mock.expect_execute_streaming_sql().once().returning(|req| {
1839            let req = req.into_inner();
1840            assert_eq!(req.sql, "SELECT 1");
1841            // Queries do not need to include a sequence number.
1842            assert_eq!(req.seqno, 0);
1843
1844            assert_eq!(
1845                req.transaction,
1846                Some(v1::TransactionSelector {
1847                    selector: Some(v1::transaction_selector::Selector::Id(vec![7, 8, 9]))
1848                })
1849            );
1850
1851            let prs = v1::PartialResultSet {
1852                metadata: Some(v1::ResultSetMetadata {
1853                    row_type: Some(v1::StructType { fields: vec![] }),
1854                    ..Default::default()
1855                }),
1856                ..Default::default()
1857            };
1858            Ok(tonic::Response::from(crate::result_set::tests::adapt([
1859                Ok(prs),
1860            ])))
1861        });
1862
1863        let (db_client, _server) = setup_db_client(mock).await;
1864
1865        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1866            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
1867            .build(None)
1868            .await
1869            .expect("Failed to build transaction");
1870
1871        let mut rs = tx
1872            .execute_query(Statement::builder("SELECT 1").build())
1873            .await
1874            .expect("Failed to execute query");
1875
1876        let result = rs.next().await;
1877        assert!(result.is_none(), "expected None, got empty stream");
1878    }
1879
1880    #[tokio_test_no_panics]
1881    async fn read_write_transaction_with_options() {
1882        let mut mock = create_session_mock();
1883
1884        mock.expect_begin_transaction().once().returning(|req| {
1885            let req = req.into_inner();
1886            assert_eq!(
1887                req.session,
1888                "projects/p/instances/i/databases/d/sessions/123"
1889            );
1890
1891            let options = req.options.expect("missing transaction options");
1892            let mode = options.mode.expect("missing mode");
1893            match mode {
1894                v1::transaction_options::Mode::ReadWrite(rw) => {
1895                    assert_eq!(
1896                        rw.read_lock_mode,
1897                        v1::transaction_options::read_write::ReadLockMode::Pessimistic as i32
1898                    );
1899                }
1900                _ => panic!("Expected ReadWrite transaction mode"),
1901            }
1902            // Ensure isolation level is passed through
1903            assert_eq!(
1904                options.isolation_level,
1905                v1::transaction_options::IsolationLevel::Serializable as i32
1906            );
1907
1908            Ok(tonic::Response::new(v1::Transaction {
1909                id: vec![9, 9, 9],
1910                ..Default::default()
1911            }))
1912        });
1913
1914        let (db_client, _server) = setup_db_client(mock).await;
1915
1916        let _tx = ReadWriteTransactionBuilder::new(db_client.clone())
1917            .set_isolation_level(IsolationLevel::Serializable)
1918            .set_read_lock_mode(ReadLockMode::Pessimistic)
1919            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
1920            .build(None)
1921            .await
1922            .expect("Failed to build transaction");
1923    }
1924
1925    #[tokio_test_no_panics]
1926    async fn read_write_transaction_with_exclude_txn_from_change_streams() {
1927        let mut mock = create_session_mock();
1928
1929        mock.expect_begin_transaction().once().returning(|req| {
1930            let req = req.into_inner();
1931            let options = req.options.expect("missing transaction options");
1932            assert!(options.exclude_txn_from_change_streams);
1933
1934            Ok(tonic::Response::new(v1::Transaction {
1935                id: vec![9, 9, 9],
1936                ..Default::default()
1937            }))
1938        });
1939
1940        let (db_client, _server) = setup_db_client(mock).await;
1941
1942        let _tx = ReadWriteTransactionBuilder::new(db_client.clone())
1943            .set_exclude_txn_from_change_streams(true)
1944            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
1945            .build(None)
1946            .await
1947            .expect("Failed to build transaction");
1948    }
1949
1950    #[tokio_test_no_panics]
1951    async fn read_write_transaction_tracks_highest_precommit_token() {
1952        let mut mock = create_session_mock();
1953
1954        mock.expect_begin_transaction().once().returning(|_| {
1955            Ok(tonic::Response::new(v1::Transaction {
1956                id: vec![4, 2],
1957                ..Default::default()
1958            }))
1959        });
1960
1961        // 3 sequential updates returning tokens [seq 2, seq 5, seq 3]
1962        let tokens_iter = vec![2, 5, 3].into_iter();
1963        let counter_mutex = Mutex::new(tokens_iter);
1964
1965        mock.expect_execute_sql().times(3).returning(move |_req| {
1966            let seq = counter_mutex
1967                .lock()
1968                .expect("Failed to lock mutex")
1969                .next()
1970                .expect("Failed to get next token");
1971            Ok(tonic::Response::new(v1::ResultSet {
1972                stats: Some(v1::ResultSetStats {
1973                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1974                    ..Default::default()
1975                }),
1976                precommit_token: Some(v1::MultiplexedSessionPrecommitToken {
1977                    precommit_token: vec![seq as u8],
1978                    seq_num: seq,
1979                }),
1980                ..Default::default()
1981            }))
1982        });
1983
1984        // Commit should only use the highest token (seq 5)
1985        mock.expect_commit().once().returning(|req| {
1986            let req = req.into_inner();
1987            assert_eq!(
1988                req.precommit_token,
1989                Some(v1::MultiplexedSessionPrecommitToken {
1990                    precommit_token: vec![5],
1991                    seq_num: 5,
1992                })
1993            );
1994            Ok(tonic::Response::new(v1::CommitResponse {
1995                commit_timestamp: Some(prost_types::Timestamp {
1996                    seconds: 12345,
1997                    nanos: 0,
1998                }),
1999                ..Default::default()
2000            }))
2001        });
2002
2003        let (db_client, _server) = setup_db_client(mock).await;
2004        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
2005            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2006            .build(None)
2007            .await
2008            .expect("Failed to build transaction");
2009
2010        for _ in 0..3 {
2011            tx.execute_update("UPDATE Y")
2012                .await
2013                .expect("Failed to execute update");
2014        }
2015        let ts = tx.commit().await.expect("Failed to commit transaction");
2016        assert_eq!(
2017            ts.commit_timestamp
2018                .expect("Commit timestamp should be present")
2019                .seconds(),
2020            12345
2021        );
2022    }
2023
2024    #[tokio_test_no_panics]
2025    async fn read_write_transaction_commit_retry_exactly_once_explicit() -> anyhow::Result<()> {
2026        run_read_write_transaction_commit_retry_exactly_once(BeginTransactionOption::ExplicitBegin)
2027            .await
2028    }
2029
2030    #[tokio_test_no_panics]
2031    async fn read_write_transaction_commit_retry_exactly_once_inline() -> anyhow::Result<()> {
2032        run_read_write_transaction_commit_retry_exactly_once(BeginTransactionOption::InlineBegin)
2033            .await
2034    }
2035
2036    async fn run_read_write_transaction_commit_retry_exactly_once(
2037        begin_transaction_option: BeginTransactionOption,
2038    ) -> anyhow::Result<()> {
2039        let mut mock = create_session_mock();
2040
2041        let transaction_id = vec![7, 7];
2042
2043        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
2044            let id = transaction_id.clone();
2045            mock.expect_begin_transaction().once().returning(move |_| {
2046                Ok(tonic::Response::new(v1::Transaction {
2047                    id: id.clone(),
2048                    ..Default::default()
2049                }))
2050            });
2051        } else {
2052            let id = transaction_id.clone();
2053            mock.expect_execute_sql().once().returning(move |req| {
2054                let req = req.into_inner();
2055                let transaction = req
2056                    .transaction
2057                    .as_ref()
2058                    .expect("transaction options required for inline begin");
2059                let selector = transaction.selector.as_ref().expect("selector required");
2060                assert!(matches!(
2061                    selector,
2062                    v1::transaction_selector::Selector::Begin(_)
2063                ));
2064
2065                Ok(tonic::Response::new(v1::ResultSet {
2066                    metadata: Some(v1::ResultSetMetadata {
2067                        transaction: Some(v1::Transaction {
2068                            id: id.clone(),
2069                            ..Default::default()
2070                        }),
2071                        ..Default::default()
2072                    }),
2073                    stats: Some(v1::ResultSetStats {
2074                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2075                        ..Default::default()
2076                    }),
2077                    ..Default::default()
2078                }))
2079            });
2080        }
2081
2082        let mut seq = mockall::Sequence::new();
2083
2084        // Initial commit returns a retry token (seq 2)
2085        mock.expect_commit()
2086            .once()
2087            .in_sequence(&mut seq)
2088            .returning(|_| {
2089                Ok(tonic::Response::new(v1::CommitResponse {
2090                    commit_timestamp: Some(prost_types::Timestamp {
2091                        seconds: 1000,
2092                        nanos: 0,
2093                    }),
2094                    multiplexed_session_retry: Some(
2095                        v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
2096                            v1::MultiplexedSessionPrecommitToken {
2097                                precommit_token: vec![2],
2098                                seq_num: 2,
2099                            },
2100                        ),
2101                    ),
2102                    ..Default::default()
2103                }))
2104            });
2105
2106        // Retry commit returns another retry token (seq 3).
2107        // The library should not retry multiple times.
2108        mock.expect_commit()
2109            .once()
2110            .in_sequence(&mut seq)
2111            .returning(|req| {
2112                let req = req.into_inner();
2113                assert_eq!(
2114                    req.precommit_token
2115                        .as_ref()
2116                        .expect("Missing precommit token in retry req")
2117                        .seq_num,
2118                    2
2119                );
2120
2121                Ok(tonic::Response::new(v1::CommitResponse {
2122                    commit_timestamp: Some(prost_types::Timestamp {
2123                        seconds: 9999,
2124                        nanos: 0,
2125                    }),
2126                    multiplexed_session_retry: Some(
2127                        v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
2128                            v1::MultiplexedSessionPrecommitToken {
2129                                precommit_token: vec![3],
2130                                seq_num: 3,
2131                            },
2132                        ),
2133                    ),
2134                    ..Default::default()
2135                }))
2136            });
2137
2138        let (db_client, _server) = setup_db_client(mock).await;
2139        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
2140            .with_begin_transaction_option(begin_transaction_option)
2141            .build(None)
2142            .await?;
2143
2144        if begin_transaction_option == BeginTransactionOption::InlineBegin {
2145            tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
2146                .await?;
2147        }
2148
2149        let ts = tx.commit().await.expect("Failed to commit transaction");
2150        assert_eq!(
2151            ts.commit_timestamp
2152                .as_ref()
2153                .expect("timestamp should be present")
2154                .seconds(),
2155            9999
2156        );
2157        Ok(())
2158    }
2159
2160    #[tokio_test_no_panics]
2161    async fn read_write_transaction_commit_with_max_commit_delay_explicit() -> anyhow::Result<()> {
2162        run_read_write_transaction_commit_with_max_commit_delay(
2163            BeginTransactionOption::ExplicitBegin,
2164        )
2165        .await
2166    }
2167
2168    #[tokio_test_no_panics]
2169    async fn read_write_transaction_commit_with_max_commit_delay_inline() -> anyhow::Result<()> {
2170        run_read_write_transaction_commit_with_max_commit_delay(BeginTransactionOption::InlineBegin)
2171            .await
2172    }
2173
2174    async fn run_read_write_transaction_commit_with_max_commit_delay(
2175        begin_transaction_option: BeginTransactionOption,
2176    ) -> anyhow::Result<()> {
2177        let mut mock = create_session_mock();
2178
2179        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
2180            mock.expect_begin_transaction().once().returning(|_| {
2181                Ok(tonic::Response::new(v1::Transaction {
2182                    id: vec![1, 2, 3],
2183                    ..Default::default()
2184                }))
2185            });
2186        } else {
2187            mock.expect_execute_sql().once().returning(|req| {
2188                let req = req.into_inner();
2189                let transaction = req
2190                    .transaction
2191                    .as_ref()
2192                    .expect("transaction options required for inline begin");
2193                let selector = transaction.selector.as_ref().expect("selector required");
2194                assert!(matches!(
2195                    selector,
2196                    v1::transaction_selector::Selector::Begin(_)
2197                ));
2198
2199                Ok(tonic::Response::new(v1::ResultSet {
2200                    metadata: Some(v1::ResultSetMetadata {
2201                        transaction: Some(v1::Transaction {
2202                            id: vec![1, 2, 3],
2203                            ..Default::default()
2204                        }),
2205                        ..Default::default()
2206                    }),
2207                    stats: Some(v1::ResultSetStats {
2208                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2209                        ..Default::default()
2210                    }),
2211                    ..Default::default()
2212                }))
2213            });
2214        }
2215
2216        mock.expect_commit().once().returning(|req| {
2217            let req = req.into_inner();
2218            assert_eq!(
2219                req.max_commit_delay,
2220                Some(::prost_types::Duration {
2221                    seconds: 0,
2222                    nanos: 200_000_000, // 200ms
2223                })
2224            );
2225            Ok(tonic::Response::new(v1::CommitResponse {
2226                commit_timestamp: Some(prost_types::Timestamp {
2227                    seconds: 123456789,
2228                    nanos: 0,
2229                }),
2230                ..Default::default()
2231            }))
2232        });
2233
2234        let (db_client, _server) = setup_db_client(mock).await;
2235
2236        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
2237            .set_max_commit_delay(Duration::new(0, 200_000_000).unwrap())
2238            .with_begin_transaction_option(begin_transaction_option)
2239            .build(None)
2240            .await
2241            .expect("Failed to build transaction");
2242
2243        if begin_transaction_option == BeginTransactionOption::InlineBegin {
2244            tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
2245                .await?;
2246        }
2247
2248        let ts = tx.commit().await.expect("Failed to commit");
2249        assert_eq!(
2250            ts.commit_timestamp
2251                .expect("Commit timestamp should be present")
2252                .seconds(),
2253            123456789
2254        );
2255        Ok(())
2256    }
2257
2258    #[tokio_test_no_panics]
2259    async fn read_write_transaction_execute_update_fallback() {
2260        let mut mock = create_session_mock();
2261
2262        // 1. First DML attempt fails!
2263        mock.expect_execute_sql().once().returning(|req| {
2264            let req = req.into_inner();
2265            assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1");
2266
2267            let selector = req
2268                .transaction
2269                .expect("missing transaction selector")
2270                .selector
2271                .expect("missing selector");
2272            match selector {
2273                v1::transaction_selector::Selector::Begin(_) => {}
2274                _ => panic!("Expected Selector::Begin"),
2275            }
2276
2277            Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
2278        });
2279
2280        // 2. Client falls back to explicit BeginTransaction!
2281        mock.expect_begin_transaction().once().returning(|_| {
2282            Ok(tonic::Response::new(v1::Transaction {
2283                id: vec![7, 8, 9],
2284                ..Default::default()
2285            }))
2286        });
2287
2288        // 3. Client retries DML with new ID!
2289        mock.expect_execute_sql().once().returning(|req| {
2290            let req = req.into_inner();
2291            assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1");
2292
2293            let selector = req
2294                .transaction
2295                .expect("missing transaction selector")
2296                .selector
2297                .expect("missing selector");
2298            match selector {
2299                v1::transaction_selector::Selector::Id(id) => {
2300                    assert_eq!(id, vec![7, 8, 9]);
2301                }
2302                _ => panic!("Expected Selector::Id"),
2303            }
2304
2305            Ok(tonic::Response::new(v1::ResultSet {
2306                stats: Some(v1::ResultSetStats {
2307                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2308                    ..Default::default()
2309                }),
2310                ..Default::default()
2311            }))
2312        });
2313
2314        let (db_client, _server) = setup_db_client(mock).await;
2315
2316        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
2317            .build(None)
2318            .await
2319            .expect("Failed to build transaction");
2320
2321        let count = tx
2322            .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
2323            .await
2324            .expect("Failed to execute update after fallback");
2325        assert_eq!(count, 1);
2326    }
2327
2328    #[tokio_test_no_panics]
2329    async fn read_write_transaction_execute_batch_update_fallback() -> anyhow::Result<()> {
2330        let mut mock = create_session_mock();
2331
2332        // 1. First Batch DML attempt fails!
2333        mock.expect_execute_batch_dml().once().returning(|req| {
2334            let req = req.into_inner();
2335            let selector = req
2336                .transaction
2337                .expect("missing transaction selector")
2338                .selector
2339                .expect("missing selector");
2340            match selector {
2341                v1::transaction_selector::Selector::Begin(_) => {}
2342                _ => panic!("Expected Selector::Begin"),
2343            }
2344
2345            Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
2346        });
2347
2348        // 2. Client falls back to explicit BeginTransaction!
2349        mock.expect_begin_transaction().once().returning(|_| {
2350            Ok(tonic::Response::new(v1::Transaction {
2351                id: vec![4, 5, 6],
2352                ..Default::default()
2353            }))
2354        });
2355
2356        // 3. Client retries Batch DML with new ID!
2357        mock.expect_execute_batch_dml().once().returning(|req| {
2358            let req = req.into_inner();
2359            let selector = req
2360                .transaction
2361                .expect("missing transaction selector")
2362                .selector
2363                .expect("missing selector");
2364            match selector {
2365                v1::transaction_selector::Selector::Id(id) => {
2366                    assert_eq!(id, vec![4, 5, 6]);
2367                }
2368                _ => panic!("Expected Selector::Id"),
2369            }
2370
2371            Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
2372                result_sets: vec![v1::ResultSet {
2373                    stats: Some(v1::ResultSetStats {
2374                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2375                        ..Default::default()
2376                    }),
2377                    ..Default::default()
2378                }],
2379                status: Some(spanner_grpc_mock::google::rpc::Status {
2380                    code: 0,
2381                    message: "OK".into(),
2382                    details: vec![],
2383                }),
2384                ..Default::default()
2385            }))
2386        });
2387
2388        let (db_client, _server) = setup_db_client(mock).await;
2389
2390        let tx = ReadWriteTransactionBuilder::new(db_client)
2391            .build(None)
2392            .await?;
2393
2394        let batch =
2395            BatchDml::builder().add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1");
2396
2397        let counts = tx.execute_batch_update(batch.build()).await?;
2398
2399        assert_eq!(counts, vec![1]);
2400
2401        Ok(())
2402    }
2403
2404    #[tokio_test_no_panics]
2405    async fn leader_aware_routing_enabled_by_default() -> anyhow::Result<()> {
2406        let mut mock = create_session_mock();
2407        mock.expect_begin_transaction().once().returning(|req| {
2408            assert_eq!(
2409                req.metadata()
2410                    .get("x-goog-spanner-route-to-leader")
2411                    .expect("header required")
2412                    .to_str()
2413                    .unwrap(),
2414                "true"
2415            );
2416            Ok(tonic::Response::new(v1::Transaction {
2417                id: vec![1, 2, 3],
2418                ..Default::default()
2419            }))
2420        });
2421        mock.expect_execute_sql().once().returning(|req| {
2422            assert_eq!(
2423                req.metadata()
2424                    .get("x-goog-spanner-route-to-leader")
2425                    .expect("header required")
2426                    .to_str()
2427                    .unwrap(),
2428                "true"
2429            );
2430            Ok(tonic::Response::new(v1::ResultSet {
2431                stats: Some(v1::ResultSetStats {
2432                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2433                    ..Default::default()
2434                }),
2435                ..Default::default()
2436            }))
2437        });
2438        mock.expect_commit().once().returning(|req| {
2439            assert_eq!(
2440                req.metadata()
2441                    .get("x-goog-spanner-route-to-leader")
2442                    .expect("header required")
2443                    .to_str()
2444                    .unwrap(),
2445                "true"
2446            );
2447            Ok(tonic::Response::new(v1::CommitResponse {
2448                commit_timestamp: Some(prost_types::Timestamp {
2449                    seconds: 1,
2450                    nanos: 0,
2451                }),
2452                ..Default::default()
2453            }))
2454        });
2455
2456        let (db_client, _server) = setup_db_client(mock).await;
2457        let tx = ReadWriteTransactionBuilder::new(db_client)
2458            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2459            .build(None)
2460            .await?;
2461        let count = tx.execute_update("UPDATE Users SET active = true").await?;
2462        assert_eq!(count, 1);
2463        let _ = tx.commit().await?;
2464        Ok(())
2465    }
2466
2467    #[tokio_test_no_panics]
2468    async fn leader_aware_routing_disabled() -> anyhow::Result<()> {
2469        use crate::client::Spanner;
2470        use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
2471
2472        let mut mock = create_session_mock();
2473        mock.expect_begin_transaction().once().returning(|req| {
2474            assert!(
2475                req.metadata()
2476                    .get("x-goog-spanner-route-to-leader")
2477                    .is_none()
2478            );
2479            Ok(tonic::Response::new(v1::Transaction {
2480                id: vec![1, 2, 3],
2481                ..Default::default()
2482            }))
2483        });
2484        mock.expect_execute_sql().once().returning(|req| {
2485            assert!(
2486                req.metadata()
2487                    .get("x-goog-spanner-route-to-leader")
2488                    .is_none()
2489            );
2490            Ok(tonic::Response::new(v1::ResultSet {
2491                stats: Some(v1::ResultSetStats {
2492                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2493                    ..Default::default()
2494                }),
2495                ..Default::default()
2496            }))
2497        });
2498        mock.expect_commit().once().returning(|req| {
2499            assert!(
2500                req.metadata()
2501                    .get("x-goog-spanner-route-to-leader")
2502                    .is_none()
2503            );
2504            Ok(tonic::Response::new(v1::CommitResponse {
2505                commit_timestamp: Some(prost_types::Timestamp {
2506                    seconds: 1,
2507                    nanos: 0,
2508                }),
2509                ..Default::default()
2510            }))
2511        });
2512
2513        let (address, _server) = spanner_grpc_mock::start("0.0.0.0:0", mock).await.unwrap();
2514        let spanner = Spanner::builder()
2515            .with_endpoint(address)
2516            .with_credentials(Anonymous::new().build())
2517            .build()
2518            .await?;
2519        let db_client = spanner
2520            .database_client("projects/p/instances/i/databases/d")
2521            .with_leader_aware_routing(false)
2522            .build()
2523            .await?;
2524
2525        let tx = ReadWriteTransactionBuilder::new(db_client)
2526            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2527            .build(None)
2528            .await?;
2529        let count = tx.execute_update("UPDATE Users SET active = true").await?;
2530        assert_eq!(count, 1);
2531        let _ = tx.commit().await?;
2532        Ok(())
2533    }
2534
2535    #[tokio_test_no_panics]
2536    async fn leader_aware_routing_query_in_read_write() -> anyhow::Result<()> {
2537        let mut mock = create_session_mock();
2538        mock.expect_begin_transaction().once().returning(|req| {
2539            assert_eq!(
2540                req.metadata()
2541                    .get("x-goog-spanner-route-to-leader")
2542                    .expect("header required")
2543                    .to_str()
2544                    .unwrap(),
2545                "true"
2546            );
2547            Ok(tonic::Response::new(v1::Transaction {
2548                id: vec![1, 2, 3],
2549                ..Default::default()
2550            }))
2551        });
2552        mock.expect_execute_streaming_sql().once().returning(|req| {
2553            assert_eq!(
2554                req.metadata()
2555                    .get("x-goog-spanner-route-to-leader")
2556                    .expect("header required")
2557                    .to_str()
2558                    .unwrap(),
2559                "true"
2560            );
2561            let stream = adapt([Ok(v1::PartialResultSet {
2562                metadata: Some(v1::ResultSetMetadata {
2563                    row_type: Some(v1::StructType { fields: vec![] }),
2564                    ..Default::default()
2565                }),
2566                ..Default::default()
2567            })]);
2568            Ok(tonic::Response::from(stream))
2569        });
2570
2571        let (db_client, _server) = setup_db_client(mock).await;
2572        let tx = ReadWriteTransactionBuilder::new(db_client)
2573            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2574            .build(None)
2575            .await?;
2576        let _rs = tx
2577            .execute_query(Statement::builder("SELECT 1").build())
2578            .await?;
2579        Ok(())
2580    }
2581
2582    #[tokio_test_no_panics]
2583    async fn leader_aware_routing_merges_custom_headers() -> anyhow::Result<()> {
2584        let mut mock = create_session_mock();
2585        mock.expect_begin_transaction().once().returning(|req| {
2586            assert_eq!(
2587                req.metadata()
2588                    .get("x-goog-spanner-route-to-leader")
2589                    .expect("header required")
2590                    .to_str()
2591                    .unwrap(),
2592                "true"
2593            );
2594            Ok(tonic::Response::new(v1::Transaction {
2595                id: vec![1, 2, 3],
2596                ..Default::default()
2597            }))
2598        });
2599        mock.expect_execute_sql().once().returning(|req| {
2600            assert_eq!(
2601                req.metadata()
2602                    .get("x-goog-spanner-route-to-leader")
2603                    .expect("header required")
2604                    .to_str()
2605                    .unwrap(),
2606                "true"
2607            );
2608            assert_eq!(
2609                req.metadata()
2610                    .get("x-custom-user-header")
2611                    .expect("custom header required")
2612                    .to_str()
2613                    .unwrap(),
2614                "custom-value"
2615            );
2616            Ok(tonic::Response::new(v1::ResultSet {
2617                stats: Some(v1::ResultSetStats {
2618                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2619                    ..Default::default()
2620                }),
2621                ..Default::default()
2622            }))
2623        });
2624
2625        let (db_client, _server) = setup_db_client(mock).await;
2626        let tx = ReadWriteTransactionBuilder::new(db_client)
2627            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2628            .build(None)
2629            .await?;
2630
2631        let mut custom_headers = http::HeaderMap::new();
2632        custom_headers.insert(
2633            "x-custom-user-header",
2634            http::HeaderValue::from_static("custom-value"),
2635        );
2636
2637        let mut stmt = Statement::builder("UPDATE Users SET active = true").build();
2638        let opts = stmt.gax_options().clone().insert_extension(custom_headers);
2639        stmt = stmt.with_gax_options(opts);
2640
2641        let count = tx.execute_update(stmt).await?;
2642        assert_eq!(count, 1);
2643        Ok(())
2644    }
2645
2646    #[tokio_test_no_panics]
2647    async fn leader_aware_routing_implicit_begin_fallback() -> anyhow::Result<()> {
2648        let mut mock = create_session_mock();
2649
2650        // 1. Initial execute_sql attempts implicit begin and transiently fails.
2651        // It must include the LAR header because it is a modifying operation.
2652        mock.expect_execute_sql().once().returning(|req| {
2653            assert_eq!(
2654                req.metadata()
2655                    .get("x-goog-spanner-route-to-leader")
2656                    .expect("header required on initial execute")
2657                    .to_str()
2658                    .unwrap(),
2659                "true"
2660            );
2661            Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
2662        });
2663
2664        // 2. Client fallback mechanism invokes begin_explicitly_if_not_started.
2665        // This should also include the LAR header.
2666        mock.expect_begin_transaction().once().returning(|req| {
2667            assert_eq!(
2668                req.metadata()
2669                    .get("x-goog-spanner-route-to-leader")
2670                    .expect("header required on explicit begin fallback")
2671                    .to_str()
2672                    .unwrap(),
2673                "true"
2674            );
2675            Ok(tonic::Response::new(v1::Transaction {
2676                id: vec![42],
2677                ..Default::default()
2678            }))
2679        });
2680
2681        // 3. Retried execute_sql with fixed ID.
2682        mock.expect_execute_sql().once().returning(|req| {
2683            assert_eq!(
2684                req.metadata()
2685                    .get("x-goog-spanner-route-to-leader")
2686                    .expect("header required on retried execute")
2687                    .to_str()
2688                    .unwrap(),
2689                "true"
2690            );
2691            Ok(tonic::Response::new(v1::ResultSet {
2692                stats: Some(v1::ResultSetStats {
2693                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2694                    ..Default::default()
2695                }),
2696                ..Default::default()
2697            }))
2698        });
2699
2700        let (db_client, _server) = setup_db_client(mock).await;
2701        // Construct transaction using implicit begin (explicit_begin = false)
2702        let tx = ReadWriteTransactionBuilder::new(db_client)
2703            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2704            .build(None)
2705            .await?;
2706
2707        let count = tx.execute_update("UPDATE Users SET active = true").await?;
2708        assert_eq!(count, 1);
2709        Ok(())
2710    }
2711
2712    #[tokio_test_no_panics]
2713    async fn read_write_transaction_fallback_forwards_transaction_tag() -> anyhow::Result<()> {
2714        let mut mock = create_session_mock();
2715
2716        // 1. Initial execute_sql attempts inline begin and transiently fails.
2717        // The initial request includes the transaction tag.
2718        mock.expect_execute_sql().once().returning(|req| {
2719            let req = req.into_inner();
2720            assert_eq!(
2721                req.request_options
2722                    .as_ref()
2723                    .expect("Missing request_options on initial RPC")
2724                    .transaction_tag,
2725                "fallback-test-tag"
2726            );
2727            Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
2728        });
2729
2730        // 2. Client fallback mechanism invokes explicit begin.
2731        // This should include the transaction tag.
2732        mock.expect_begin_transaction().once().returning(|req| {
2733            let req = req.into_inner();
2734            assert_eq!(
2735                req.request_options
2736                    .as_ref()
2737                    .expect("Missing request_options on explicit begin fallback")
2738                    .transaction_tag,
2739                "fallback-test-tag"
2740            );
2741            Ok(tonic::Response::new(v1::Transaction {
2742                id: vec![7, 7, 7],
2743                ..Default::default()
2744            }))
2745        });
2746
2747        // 3. Retried execute_sql with the explicit transaction ID.
2748        mock.expect_execute_sql().once().returning(|req| {
2749            let req = req.into_inner();
2750            assert_eq!(
2751                req.request_options
2752                    .as_ref()
2753                    .expect("Missing request_options on retried RPC")
2754                    .transaction_tag,
2755                "fallback-test-tag"
2756            );
2757            Ok(tonic::Response::new(v1::ResultSet {
2758                stats: Some(v1::ResultSetStats {
2759                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2760                    ..Default::default()
2761                }),
2762                ..Default::default()
2763            }))
2764        });
2765
2766        let (db_client, _server) = setup_db_client(mock).await;
2767        let tx = ReadWriteTransactionBuilder::new(db_client)
2768            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2769            .set_transaction_tag("fallback-test-tag")
2770            .build(None)
2771            .await?;
2772
2773        let count = tx.execute_update("UPDATE Users SET active = true").await?;
2774        assert_eq!(count, 1);
2775        Ok(())
2776    }
2777
2778    #[tokio_test_no_panics]
2779    async fn read_write_transaction_mutation_only_inline_begin_commit() -> anyhow::Result<()> {
2780        let mut mock = create_session_mock();
2781
2782        // Since no statement was executed, commit will detect NotStarted and call begin_explicitly
2783        mock.expect_begin_transaction().once().returning(|req| {
2784            let req = req.into_inner();
2785            assert_eq!(
2786                req.session,
2787                "projects/p/instances/i/databases/d/sessions/123"
2788            );
2789            assert!(
2790                req.mutation_key.is_some(),
2791                "mutation_key should be populated when starting transaction at commit time"
2792            );
2793            let key = req
2794                .mutation_key
2795                .as_ref()
2796                .expect("mutation_key is populated");
2797            assert!(
2798                key.operation.is_some(),
2799                "mutation_key should have an operation"
2800            );
2801            Ok(tonic::Response::new(v1::Transaction {
2802                id: vec![7, 7, 7],
2803                ..Default::default()
2804            }))
2805        });
2806
2807        mock.expect_commit().once().returning(|req| {
2808            let req = req.into_inner();
2809            assert_eq!(
2810                req.session,
2811                "projects/p/instances/i/databases/d/sessions/123"
2812            );
2813            assert_eq!(
2814                req.transaction,
2815                Some(v1::commit_request::Transaction::TransactionId(vec![
2816                    7, 7, 7
2817                ]))
2818            );
2819            assert_eq!(req.mutations.len(), 1);
2820            Ok(tonic::Response::new(v1::CommitResponse {
2821                commit_timestamp: Some(prost_types::Timestamp {
2822                    seconds: 5000,
2823                    nanos: 0,
2824                }),
2825                ..Default::default()
2826            }))
2827        });
2828
2829        let (db_client, _server) = setup_db_client(mock).await;
2830        let tx = ReadWriteTransactionBuilder::new(db_client)
2831            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2832            .build(None)
2833            .await
2834            .expect("Transaction build should succeed");
2835
2836        let mutation = Mutation::new_insert_builder("Users")
2837            .set("UserId")
2838            .to(&1)
2839            .build();
2840        tx.buffer([mutation]).expect("Buffer should succeed");
2841
2842        let response = tx.commit().await.expect("Commit should succeed");
2843        assert_eq!(
2844            response
2845                .commit_timestamp
2846                .expect("timestamp present")
2847                .seconds(),
2848            5000
2849        );
2850        Ok(())
2851    }
2852
2853    #[tokio_test_no_panics]
2854    async fn transaction_runner_batch_dml_aborted_retry() -> anyhow::Result<()> {
2855        let mut mock = create_session_mock();
2856        let mut sequence = mockall::Sequence::new();
2857
2858        // 1. First attempt: Inline begin, execute_batch_dml returns OK with status Aborted.
2859        mock.expect_execute_batch_dml()
2860            .times(1)
2861            .in_sequence(&mut sequence)
2862            .returning(|req| {
2863                let req = req.into_inner();
2864                assert!(matches!(
2865                    req.transaction.unwrap().selector.unwrap(),
2866                    v1::transaction_selector::Selector::Begin(_)
2867                ));
2868                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
2869                    result_sets: vec![],
2870                    status: Some(spanner_grpc_mock::google::rpc::Status {
2871                        code: tonic::Code::Aborted as i32,
2872                        message: "concurrent lock abort".into(),
2873                        details: vec![],
2874                    }),
2875                    ..Default::default()
2876                }))
2877            });
2878
2879        // 2. TransactionRunner catches Aborted error and initiates attempt 2.
2880        mock.expect_execute_batch_dml()
2881            .times(1)
2882            .in_sequence(&mut sequence)
2883            .returning(|req| {
2884                let req = req.into_inner();
2885                assert!(matches!(
2886                    req.transaction.unwrap().selector.unwrap(),
2887                    v1::transaction_selector::Selector::Begin(_)
2888                ));
2889                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
2890                    result_sets: vec![v1::ResultSet {
2891                        metadata: Some(v1::ResultSetMetadata {
2892                            transaction: Some(v1::Transaction {
2893                                id: vec![9, 9, 9],
2894                                ..Default::default()
2895                            }),
2896                            ..Default::default()
2897                        }),
2898                        stats: Some(v1::ResultSetStats {
2899                            row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2900                            ..Default::default()
2901                        }),
2902                        ..Default::default()
2903                    }],
2904                    status: Some(spanner_grpc_mock::google::rpc::Status {
2905                        code: 0,
2906                        message: "OK".into(),
2907                        details: vec![],
2908                    }),
2909                    ..Default::default()
2910                }))
2911            });
2912
2913        mock.expect_commit().once().returning(|req| {
2914            let req = req.into_inner();
2915            assert_eq!(
2916                req.transaction,
2917                Some(v1::commit_request::Transaction::TransactionId(vec![
2918                    9, 9, 9
2919                ]))
2920            );
2921            Ok(tonic::Response::new(v1::CommitResponse {
2922                commit_timestamp: Some(prost_types::Timestamp {
2923                    seconds: 999,
2924                    nanos: 0,
2925                }),
2926                ..Default::default()
2927            }))
2928        });
2929
2930        let (db_client, _server) = setup_db_client(mock).await;
2931
2932        let runner = db_client
2933            .read_write_transaction()
2934            .with_retry_policy(
2935                BasicTransactionRetryPolicy::new()
2936                    .with_max_attempts(3)
2937                    .with_total_timeout(std::time::Duration::from_secs(5)),
2938            )
2939            .build()
2940            .await?;
2941
2942        runner
2943            .run(async |tx| {
2944                let batch = BatchDml::builder()
2945                    .add_statement("UPDATE Users SET active = true WHERE id = 1");
2946                tx.execute_batch_update(batch.build()).await?;
2947                Ok(())
2948            })
2949            .await?;
2950
2951        Ok(())
2952    }
2953
2954    #[tokio_test_no_panics]
2955    async fn read_write_transaction_first_dml_aborted_and_continue_success() -> anyhow::Result<()> {
2956        let mut mock = create_session_mock();
2957        let mut sequence = mockall::Sequence::new();
2958
2959        // 1. First statement (execute_sql) attempts inline begin and is aborted by Spanner
2960        mock.expect_execute_sql()
2961            .times(1)
2962            .in_sequence(&mut sequence)
2963            .returning(|req| {
2964                let req = req.into_inner();
2965                assert!(matches!(
2966                    req.transaction.unwrap().selector.unwrap(),
2967                    v1::transaction_selector::Selector::Begin(_)
2968                ));
2969                Err(tonic::Status::new(
2970                    tonic::Code::Aborted,
2971                    "concurrent lock abort",
2972                ))
2973            });
2974
2975        // 2. Second statement (execute_sql) sees NotStarted and attempts inline begin again
2976        mock.expect_execute_sql()
2977            .times(1)
2978            .in_sequence(&mut sequence)
2979            .returning(|req| {
2980                let req = req.into_inner();
2981                assert!(matches!(
2982                    req.transaction.unwrap().selector.unwrap(),
2983                    v1::transaction_selector::Selector::Begin(_)
2984                ));
2985                Ok(tonic::Response::new(v1::ResultSet {
2986                    metadata: Some(v1::ResultSetMetadata {
2987                        transaction: Some(v1::Transaction {
2988                            id: vec![9, 9, 9],
2989                            ..Default::default()
2990                        }),
2991                        ..Default::default()
2992                    }),
2993                    stats: Some(v1::ResultSetStats {
2994                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2995                        ..Default::default()
2996                    }),
2997                    ..Default::default()
2998                }))
2999            });
3000
3001        // 3. Commit called with the transaction ID returned in step 2
3002        mock.expect_commit().once().returning(|req| {
3003            let req = req.into_inner();
3004            assert_eq!(
3005                req.transaction,
3006                Some(v1::commit_request::Transaction::TransactionId(vec![
3007                    9, 9, 9
3008                ]))
3009            );
3010            Ok(tonic::Response::new(v1::CommitResponse {
3011                commit_timestamp: Some(prost_types::Timestamp {
3012                    seconds: 999,
3013                    nanos: 0,
3014                }),
3015                ..Default::default()
3016            }))
3017        });
3018
3019        let (db_client, _server) = setup_db_client(mock).await;
3020
3021        let runner = db_client
3022            .read_write_transaction()
3023            .with_retry_policy(
3024                BasicTransactionRetryPolicy::new()
3025                    .with_max_attempts(1)
3026                    .with_total_timeout(std::time::Duration::from_secs(5)),
3027            )
3028            .build()
3029            .await?;
3030
3031        runner
3032            .run(async |tx| {
3033                // 1. First statement fails with Aborted. We catch it and continue.
3034                let res = tx
3035                    .execute_update("UPDATE Users SET active = true WHERE id = 1")
3036                    .await;
3037                assert!(res.is_err(), "First statement must return error");
3038                assert!(is_aborted(&res.unwrap_err()), "Error must be Aborted");
3039
3040                // 2. Second statement continues. Without the fix, this would block/deadlock forever.
3041                let count = tx
3042                    .execute_update("UPDATE Users SET active = true WHERE id = 2")
3043                    .await?;
3044                assert_eq!(count, 1);
3045                Ok(())
3046            })
3047            .await?;
3048
3049        Ok(())
3050    }
3051
3052    #[tokio_test_no_panics]
3053    async fn read_write_transaction_first_batch_dml_aborted_and_continue_success()
3054    -> anyhow::Result<()> {
3055        let mut mock = create_session_mock();
3056        let mut sequence = mockall::Sequence::new();
3057
3058        // 1. First statement (execute_batch_dml) attempts inline begin and is aborted by Spanner
3059        mock.expect_execute_batch_dml()
3060            .times(1)
3061            .in_sequence(&mut sequence)
3062            .returning(|req| {
3063                let req = req.into_inner();
3064                assert!(matches!(
3065                    req.transaction.unwrap().selector.unwrap(),
3066                    v1::transaction_selector::Selector::Begin(_)
3067                ));
3068                Err(tonic::Status::new(
3069                    tonic::Code::Aborted,
3070                    "concurrent lock abort",
3071                ))
3072            });
3073
3074        // 2. Second statement (execute_sql) sees NotStarted and attempts inline begin again
3075        mock.expect_execute_sql()
3076            .times(1)
3077            .in_sequence(&mut sequence)
3078            .returning(|req| {
3079                let req = req.into_inner();
3080                assert!(matches!(
3081                    req.transaction.unwrap().selector.unwrap(),
3082                    v1::transaction_selector::Selector::Begin(_)
3083                ));
3084                Ok(tonic::Response::new(v1::ResultSet {
3085                    metadata: Some(v1::ResultSetMetadata {
3086                        transaction: Some(v1::Transaction {
3087                            id: vec![9, 9, 9],
3088                            ..Default::default()
3089                        }),
3090                        ..Default::default()
3091                    }),
3092                    stats: Some(v1::ResultSetStats {
3093                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
3094                        ..Default::default()
3095                    }),
3096                    ..Default::default()
3097                }))
3098            });
3099
3100        // 3. Commit called with the transaction ID returned in step 2
3101        mock.expect_commit().once().returning(|req| {
3102            let req = req.into_inner();
3103            assert_eq!(
3104                req.transaction,
3105                Some(v1::commit_request::Transaction::TransactionId(vec![
3106                    9, 9, 9
3107                ]))
3108            );
3109            Ok(tonic::Response::new(v1::CommitResponse {
3110                commit_timestamp: Some(prost_types::Timestamp {
3111                    seconds: 999,
3112                    nanos: 0,
3113                }),
3114                ..Default::default()
3115            }))
3116        });
3117
3118        let (db_client, _server) = setup_db_client(mock).await;
3119
3120        let runner = db_client
3121            .read_write_transaction()
3122            .with_retry_policy(
3123                BasicTransactionRetryPolicy::new()
3124                    .with_max_attempts(1)
3125                    .with_total_timeout(std::time::Duration::from_secs(5)),
3126            )
3127            .build()
3128            .await?;
3129
3130        runner
3131            .run(async |tx| {
3132                // 1. First statement (Batch DML) fails with Aborted. We catch it and continue.
3133                let batch = BatchDml::builder()
3134                    .add_statement("UPDATE Users SET active = true WHERE id = 1");
3135                let res = tx.execute_batch_update(batch.build()).await;
3136                assert!(res.is_err(), "First statement must return error");
3137                assert!(is_aborted(&res.unwrap_err()), "Error must be Aborted");
3138
3139                // 2. Second statement continues. Without the fix, this would block/deadlock forever.
3140                let count = tx
3141                    .execute_update("UPDATE Users SET active = true WHERE id = 2")
3142                    .await?;
3143                assert_eq!(count, 1);
3144                Ok(())
3145            })
3146            .await?;
3147
3148        Ok(())
3149    }
3150
3151    fn parse_grpc_timeout(metadata: &MetadataMap) -> Option<StdDuration> {
3152        let timeout_header = metadata.get("grpc-timeout")?.to_str().ok()?;
3153        let numeric_part: String = timeout_header
3154            .chars()
3155            .take_while(|c| c.is_ascii_digit())
3156            .collect();
3157        let value = numeric_part.parse::<u64>().ok()?;
3158        let unit = timeout_header.trim_start_matches(&numeric_part);
3159        let duration = match unit {
3160            "u" => StdDuration::from_micros(value),
3161            "m" => StdDuration::from_millis(value),
3162            "S" => StdDuration::from_secs(value),
3163            "M" => StdDuration::from_secs(value * 60),
3164            "H" => StdDuration::from_secs(value * 3600),
3165            _ => return None,
3166        };
3167        Some(duration)
3168    }
3169
3170    #[tokio_test_no_panics]
3171    async fn read_write_transaction_lazy_begin_fallback_never_retry() -> anyhow::Result<()> {
3172        let mut mock = create_session_mock();
3173        let mut sequence = mockall::Sequence::new();
3174
3175        // 1. First statement execution uses inline-begin and fails with Unavailable (transient error)
3176        mock.expect_execute_sql()
3177            .once()
3178            .in_sequence(&mut sequence)
3179            .withf(|req| {
3180                matches!(
3181                    req.get_ref()
3182                        .transaction
3183                        .as_ref()
3184                        .and_then(|t| t.selector.as_ref()),
3185                    Some(v1::transaction_selector::Selector::Begin(_))
3186                )
3187            })
3188            .returning(move |_req| Err(tonic::Status::unavailable("transient error")));
3189
3190        // 2. Fallback explicit BeginTransaction is executed exactly once and fails (because we configure NeverRetry)
3191        mock.expect_begin_transaction()
3192            .once()
3193            .in_sequence(&mut sequence)
3194            .returning(move |_req| Err(tonic::Status::unavailable("transient error")));
3195
3196        let (db_client, _server) = setup_db_client(mock).await;
3197
3198        let runner = db_client
3199            .read_write_transaction()
3200            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
3201            .with_begin_retry_policy(NeverRetry)
3202            .build()
3203            .await?;
3204
3205        let res = runner
3206            .run(async |tx| {
3207                let mut stmt_opts = crate::RequestOptions::default();
3208                stmt_opts.set_retry_policy(NeverRetry);
3209                let stmt = Statement::builder("UPDATE Users SET active = true WHERE id = 1")
3210                    .build()
3211                    .with_gax_options(stmt_opts);
3212                let _count = tx.execute_update(stmt).await?;
3213                Ok(())
3214            })
3215            .await;
3216
3217        assert!(
3218            res.is_err(),
3219            "Should fail immediately because NeverRetry aborted retries of explicit begin"
3220        );
3221        let err = res.unwrap_err();
3222        assert_eq!(err.status().map(|s| s.code), Some(Code::Unavailable));
3223
3224        Ok(())
3225    }
3226
3227    #[tokio_test_no_panics]
3228    async fn read_write_transaction_commit_under_deadline_delegates_to_custom_retry_policy()
3229    -> anyhow::Result<()> {
3230        let mut mock = create_session_mock();
3231        mock.expect_begin_transaction().once().returning(|_| {
3232            Ok(tonic::Response::new(v1::Transaction {
3233                id: vec![8, 8, 8],
3234                ..Default::default()
3235            }))
3236        });
3237
3238        // Commit fails with Unavailable. Since we use NeverRetry, it must fail immediately without retry.
3239        mock.expect_commit()
3240            .once()
3241            .returning(|_| Err(tonic::Status::unavailable("transient error")));
3242
3243        let (db_client, _server) = setup_db_client(mock).await;
3244
3245        let runner = db_client
3246            .read_write_transaction()
3247            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
3248            .with_commit_retry_policy(NeverRetry)
3249            .with_transaction_timeout(StdDuration::from_secs(5))
3250            .build()
3251            .await?;
3252
3253        let res = runner.run(async |_tx| Ok(())).await;
3254
3255        assert!(
3256            res.is_err(),
3257            "Should fail because NeverRetry aborted retries"
3258        );
3259        let err = res.unwrap_err();
3260        assert_eq!(
3261            err.status().map(|s| s.code),
3262            Some(Code::Unavailable),
3263            "Error code should be Unavailable"
3264        );
3265        Ok(())
3266    }
3267
3268    #[tokio_test_no_panics]
3269    async fn read_write_transaction_commit_timeout_combination() -> anyhow::Result<()> {
3270        let mut mock = create_session_mock();
3271        mock.expect_begin_transaction().once().returning(|_| {
3272            Ok(tonic::Response::new(v1::Transaction {
3273                id: vec![8, 8, 8],
3274                ..Default::default()
3275            }))
3276        });
3277
3278        // Assert that the commit attempt timeout of 2 seconds propagates as the gRPC timeout header metadata (approx 2000m/2000000u).
3279        mock.expect_commit()
3280            .once()
3281            .withf(|req| {
3282                let duration =
3283                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
3284                assert_eq!(
3285                    duration,
3286                    StdDuration::from_secs(2),
3287                    "Timeout duration should be exactly 2 seconds"
3288                );
3289                true
3290            })
3291            .returning(|_| {
3292                Ok(tonic::Response::new(v1::CommitResponse {
3293                    commit_timestamp: Some(Timestamp {
3294                        seconds: 999,
3295                        nanos: 0,
3296                    }),
3297                    ..Default::default()
3298                }))
3299            });
3300
3301        let (db_client, _server) = setup_db_client(mock).await;
3302
3303        let runner = db_client
3304            .read_write_transaction()
3305            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
3306            .with_commit_attempt_timeout(StdDuration::from_secs(2))
3307            .with_transaction_timeout(StdDuration::from_secs(10))
3308            .build()
3309            .await?;
3310
3311        let res = runner.run(async |_tx| Ok(())).await?;
3312
3313        assert!(res.commit_response.commit_timestamp.is_some());
3314        Ok(())
3315    }
3316
3317    #[tokio_test_no_panics]
3318    async fn read_write_transaction_fallback_begin_under_deadline() -> anyhow::Result<()> {
3319        let mut mock = create_session_mock();
3320        let mut sequence = mockall::Sequence::new();
3321
3322        // 1. First statement execution fails with Unavailable (transient error)
3323        mock.expect_execute_sql()
3324            .once()
3325            .in_sequence(&mut sequence)
3326            .withf(|req| {
3327                matches!(
3328                    req.get_ref()
3329                        .transaction
3330                        .as_ref()
3331                        .and_then(|t| t.selector.as_ref()),
3332                    Some(v1::transaction_selector::Selector::Begin(_))
3333                )
3334            })
3335            .returning(move |_req| Err(tonic::Status::unavailable("transient error")));
3336
3337        // 2. Fallback explicit BeginTransaction is executed and sets attempt timeout based on remaining transaction deadline (approx 5 seconds).
3338        mock.expect_begin_transaction()
3339            .once()
3340            .in_sequence(&mut sequence)
3341            .withf(|req| {
3342                let duration =
3343                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
3344                assert!(
3345                    duration >= StdDuration::from_millis(4000)
3346                        && duration <= StdDuration::from_millis(6000),
3347                    "Fallback begin timeout is wrong: {:?}",
3348                    duration
3349                );
3350                true
3351            })
3352            .returning(move |_req| {
3353                Ok(tonic::Response::new(v1::Transaction {
3354                    id: vec![42],
3355                    ..Default::default()
3356                }))
3357            });
3358
3359        // 3. Statement retry succeeds
3360        mock.expect_execute_sql()
3361            .once()
3362            .in_sequence(&mut sequence)
3363            .withf(|req| {
3364                matches!(
3365                    req.get_ref()
3366                        .transaction
3367                        .as_ref()
3368                        .and_then(|t| t.selector.as_ref()),
3369                    Some(v1::transaction_selector::Selector::Id(_))
3370                )
3371            })
3372            .returning(move |_req| {
3373                Ok(tonic::Response::new(v1::ResultSet {
3374                    metadata: Some(v1::ResultSetMetadata {
3375                        transaction: Some(v1::Transaction {
3376                            id: vec![42],
3377                            ..Default::default()
3378                        }),
3379                        ..Default::default()
3380                    }),
3381                    stats: Some(v1::ResultSetStats {
3382                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
3383                        ..Default::default()
3384                    }),
3385                    ..Default::default()
3386                }))
3387            });
3388
3389        // 4. Commit succeeds
3390        mock.expect_commit()
3391            .once()
3392            .in_sequence(&mut sequence)
3393            .returning(move |_req| {
3394                Ok(tonic::Response::new(v1::CommitResponse {
3395                    commit_timestamp: Some(Timestamp {
3396                        seconds: 1234,
3397                        nanos: 0,
3398                    }),
3399                    ..Default::default()
3400                }))
3401            });
3402
3403        let (db_client, _server) = setup_db_client(mock).await;
3404
3405        let runner = db_client
3406            .read_write_transaction()
3407            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
3408            .with_transaction_timeout(StdDuration::from_secs(5))
3409            .build()
3410            .await?;
3411
3412        let res = runner
3413            .run(async |tx| {
3414                let mut query_opts = crate::RequestOptions::default();
3415                query_opts.set_retry_policy(NeverRetry);
3416                let stmt = Statement::builder("UPDATE Users SET active = true WHERE id = 1")
3417                    .build()
3418                    .with_gax_options(query_opts);
3419                let count = tx.execute_update(stmt).await?;
3420                assert_eq!(count, 1);
3421                Ok(())
3422            })
3423            .await?;
3424
3425        assert!(res.commit_response.commit_timestamp.is_some());
3426        Ok(())
3427    }
3428
3429    #[tokio_test_no_panics]
3430    async fn read_write_transaction_commit_fallback_begin_under_deadline() -> anyhow::Result<()> {
3431        let mut mock = create_session_mock();
3432        let mut sequence = mockall::Sequence::new();
3433
3434        // 1. Transaction was never started (empty runner block), so commit falls back to explicit BeginTransaction.
3435        // Assert fallback explicit BeginTransaction sets timeout based on remaining transaction deadline (approx 5 seconds).
3436        mock.expect_begin_transaction()
3437            .once()
3438            .in_sequence(&mut sequence)
3439            .withf(|req| {
3440                let duration =
3441                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
3442                assert!(
3443                    duration >= StdDuration::from_millis(4000)
3444                        && duration <= StdDuration::from_millis(6000),
3445                    "Fallback begin timeout inside commit is wrong: {:?}",
3446                    duration
3447                );
3448                true
3449            })
3450            .returning(move |_req| {
3451                Ok(tonic::Response::new(v1::Transaction {
3452                    id: vec![42],
3453                    ..Default::default()
3454                }))
3455            });
3456
3457        // 2. Commit succeeds
3458        mock.expect_commit()
3459            .once()
3460            .in_sequence(&mut sequence)
3461            .returning(move |_req| {
3462                Ok(tonic::Response::new(v1::CommitResponse {
3463                    commit_timestamp: Some(Timestamp {
3464                        seconds: 5678,
3465                        nanos: 0,
3466                    }),
3467                    ..Default::default()
3468                }))
3469            });
3470
3471        let (db_client, _server) = setup_db_client(mock).await;
3472
3473        let runner = db_client
3474            .read_write_transaction()
3475            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
3476            .with_transaction_timeout(StdDuration::from_secs(5))
3477            .build()
3478            .await?;
3479
3480        let res = runner.run(async |_tx| Ok(())).await?;
3481
3482        assert!(res.commit_response.commit_timestamp.is_some());
3483        Ok(())
3484    }
3485
3486    #[test]
3487    fn test_amend_gax_options() {
3488        // Case 1: No Deadline, LAR disabled
3489        let mut options = RequestOptions::default();
3490        options.set_attempt_timeout(StdDuration::from_secs(4));
3491        amend_gax_options(false, None, &mut options);
3492        assert_eq!(*options.attempt_timeout(), Some(StdDuration::from_secs(4)));
3493        assert!(options.retry_policy().is_none());
3494
3495        // Case 2: No Deadline, LAR enabled
3496        let mut options = RequestOptions::default();
3497        amend_gax_options(true, None, &mut options);
3498        // Verify LAR extension is added
3499        let headers = options
3500            .get_extension::<HeaderMap>()
3501            .expect("HeaderMap extension missing");
3502        assert_eq!(
3503            headers
3504                .get("x-goog-spanner-route-to-leader")
3505                .unwrap()
3506                .to_str()
3507                .unwrap(),
3508            "true"
3509        );
3510
3511        // Case 3: Deadline present, no custom timeout.
3512        // Since Instant::now() is called inside amend_gax_options slightly after the test's
3513        // Instant::now() call, the remaining time will be slightly less than 5 seconds.
3514        // Therefore, we assert that it falls within a very close range.
3515        let mut options = RequestOptions::default();
3516        let deadline = Instant::now() + StdDuration::from_secs(5);
3517        amend_gax_options(false, Some(deadline), &mut options);
3518        let timeout = options.attempt_timeout().expect("attempt timeout missing");
3519        assert!(
3520            timeout >= StdDuration::from_millis(4500) && timeout <= StdDuration::from_millis(5500)
3521        );
3522        assert!(
3523            options.retry_policy().is_some(),
3524            "retry policy should be wrapped"
3525        );
3526
3527        // Case 4: Deadline present, custom timeout shorter than deadline.
3528        // Since custom timeout is 2s and remaining deadline is 10s, it does not depend
3529        // on Time/Instant and must be exactly 2s.
3530        let mut options = RequestOptions::default();
3531        options.set_attempt_timeout(StdDuration::from_secs(2));
3532        let deadline = Instant::now() + StdDuration::from_secs(10);
3533        amend_gax_options(false, Some(deadline), &mut options);
3534        assert_eq!(*options.attempt_timeout(), Some(StdDuration::from_secs(2)));
3535
3536        // Case 5: Deadline present, custom timeout longer than deadline.
3537        // The remaining deadline (approx 2 seconds) is shorter than custom timeout (10s).
3538        // Due to slight time passing, remaining will be slightly less than 2 seconds.
3539        let mut options = RequestOptions::default();
3540        options.set_attempt_timeout(StdDuration::from_secs(10));
3541        let deadline = Instant::now() + StdDuration::from_secs(2);
3542        amend_gax_options(false, Some(deadline), &mut options);
3543        let timeout = options.attempt_timeout().expect("attempt timeout missing");
3544        assert!(
3545            timeout >= StdDuration::from_millis(1500) && timeout <= StdDuration::from_millis(2500)
3546        );
3547    }
3548
3549    #[test]
3550    fn test_transaction_bounded_retry_policy_throttle_delegation() {
3551        #[derive(Debug)]
3552        struct ThrottleTestPolicy;
3553        impl RetryPolicy for ThrottleTestPolicy {
3554            fn on_error(&self, _state: &RetryState, error: GaxError) -> RetryResult {
3555                RetryResult::Continue(error)
3556            }
3557            fn on_throttle(&self, _state: &RetryState, error: GaxError) -> ThrottleResult {
3558                ThrottleResult::Exhausted(error)
3559            }
3560        }
3561
3562        let inner = Arc::new(ThrottleTestPolicy);
3563        let deadline = Instant::now() + StdDuration::from_secs(10);
3564        let bounded = TransactionBoundedRetryPolicy { inner, deadline };
3565
3566        let state = RetryState::new(true);
3567        let status = Status::default()
3568            .set_code(Code::Unavailable)
3569            .set_message("error");
3570        let error = GaxError::service(status);
3571
3572        let res = bounded.on_throttle(&state, error);
3573        assert!(matches!(res, ThrottleResult::Exhausted(_)));
3574    }
3575
3576    #[test]
3577    fn test_transaction_bounded_retry_policy_remaining_time_capping() {
3578        #[derive(Debug)]
3579        struct RemainingTimeTestPolicy {
3580            timeout: Option<StdDuration>,
3581        }
3582        impl RetryPolicy for RemainingTimeTestPolicy {
3583            fn on_error(&self, _state: &RetryState, error: GaxError) -> RetryResult {
3584                RetryResult::Continue(error)
3585            }
3586            fn remaining_time(&self, _state: &RetryState) -> Option<StdDuration> {
3587                self.timeout
3588            }
3589        }
3590
3591        let state = RetryState::new(true);
3592
3593        // Case A: Inner policy timeout (3s) is shorter than remaining transaction deadline (approx 10s)
3594        let inner = Arc::new(RemainingTimeTestPolicy {
3595            timeout: Some(StdDuration::from_secs(3)),
3596        });
3597        let deadline = Instant::now() + StdDuration::from_secs(10);
3598        let bounded = TransactionBoundedRetryPolicy { inner, deadline };
3599        let remaining = bounded.remaining_time(&state).expect("remaining time");
3600        assert!(
3601            remaining >= StdDuration::from_millis(2500)
3602                && remaining <= StdDuration::from_millis(3500)
3603        );
3604
3605        // Case B: Transaction deadline (approx 2s) is shorter than inner policy timeout (10s)
3606        let inner = Arc::new(RemainingTimeTestPolicy {
3607            timeout: Some(StdDuration::from_secs(10)),
3608        });
3609        let deadline = Instant::now() + StdDuration::from_secs(2);
3610        let bounded = TransactionBoundedRetryPolicy { inner, deadline };
3611        let remaining = bounded.remaining_time(&state).expect("remaining time");
3612        assert!(
3613            remaining >= StdDuration::from_millis(1500)
3614                && remaining <= StdDuration::from_millis(2500)
3615        );
3616
3617        // Case C: Inner policy timeout is None (returns transaction remaining approx 10s)
3618        let inner = Arc::new(RemainingTimeTestPolicy { timeout: None });
3619        let deadline = Instant::now() + StdDuration::from_secs(10);
3620        let bounded = TransactionBoundedRetryPolicy { inner, deadline };
3621        let remaining = bounded.remaining_time(&state).expect("remaining time");
3622        assert!(
3623            remaining >= StdDuration::from_millis(9500)
3624                && remaining <= StdDuration::from_millis(10500)
3625        );
3626    }
3627}