Skip to main content

google_cloud_spanner/
read_write_transaction.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::Error;
16use crate::RequestOptions;
17use crate::batch::BatchDml;
18use crate::client::amend_request_options_for_lar;
19use crate::database_client::DatabaseClient;
20use crate::error::internal_error;
21use crate::model::CommitRequest;
22use crate::model::CommitResponse;
23use crate::model::ExecuteBatchDmlRequest;
24use crate::model::ExecuteBatchDmlResponse;
25use crate::model::Mutation as ProtoMutation;
26use crate::model::ResultSet as ProtoResultSet;
27use crate::model::RollbackRequest;
28use crate::model::TransactionOptions;
29use crate::model::TransactionSelector;
30use crate::model::execute_batch_dml_request::Statement as ExecuteBatchDmlStatement;
31use crate::model::request_options::Priority;
32use crate::model::result_set_stats::RowCount;
33use crate::model::transaction_options::IsolationLevel;
34use crate::model::transaction_options::Mode;
35use crate::model::transaction_options::ReadWrite;
36use crate::model::transaction_options::read_write::ReadLockMode;
37use crate::model::transaction_selector::Selector;
38use crate::mutation::Mutation;
39use crate::precommit::PrecommitTokenTracker;
40use crate::read_only_transaction::{
41    BeginTransactionOption, ReadContext, ReadContextTransactionSelector, TransactionState,
42};
43use crate::result_set::ResultSet;
44use crate::statement::Statement;
45use crate::transaction_retry_policy::is_aborted;
46use crate::write_only_transaction::create_commit_request;
47use google_cloud_gax::error::Error as GaxError;
48use google_cloud_gax::error::rpc::{Code, Status};
49use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
50use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicy};
51use google_cloud_gax::retry_result::RetryResult;
52use google_cloud_gax::retry_state::RetryState;
53use google_cloud_gax::throttle_result::ThrottleResult;
54use std::cmp::min;
55use std::mem::take;
56use std::sync::Arc;
57use std::sync::Mutex;
58use std::sync::atomic::{AtomicI64, Ordering};
59use std::time::Duration as StdDuration;
60use tokio::time::Instant;
61use wkt::Duration;
62
63/// A builder for [ReadWriteTransaction].
64#[derive(Clone, Debug)]
65pub(crate) struct ReadWriteTransactionBuilder {
66    client: DatabaseClient,
67    options: TransactionOptions,
68    transaction_tag: Option<String>,
69    max_commit_delay: Option<Duration>,
70    pub(crate) session_name: String,
71    return_commit_stats: bool,
72    commit_priority: Priority,
73    begin_transaction_option: BeginTransactionOption,
74    begin_gax_options: Option<crate::RequestOptions>,
75    commit_gax_options: Option<crate::RequestOptions>,
76}
77
78impl ReadWriteTransactionBuilder {
79    pub(crate) fn new(client: DatabaseClient) -> Self {
80        let session_name = client.session_name();
81        Self {
82            client,
83            options: TransactionOptions::default().set_read_write(ReadWrite::default()),
84            transaction_tag: None,
85            max_commit_delay: None,
86            session_name,
87            return_commit_stats: false,
88            commit_priority: Priority::Unspecified,
89            begin_transaction_option: BeginTransactionOption::InlineBegin,
90            begin_gax_options: None,
91            commit_gax_options: None,
92        }
93    }
94
95    pub(crate) fn set_isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
96        self.options = self.options.set_isolation_level(isolation_level);
97        self
98    }
99
100    pub(crate) fn set_read_lock_mode(mut self, read_lock_mode: ReadLockMode) -> Self {
101        if let Some(Mode::ReadWrite(rw)) = self.options.mode.take() {
102            self.options = self
103                .options
104                .set_read_write(rw.set_read_lock_mode(read_lock_mode));
105        }
106        self
107    }
108
109    pub(crate) fn set_previous_transaction_id(mut self, id: Option<bytes::Bytes>) -> Self {
110        if let Some(id) = id
111            && let Some(Mode::ReadWrite(rw)) = self.options.mode.take()
112        {
113            self.options = self
114                .options
115                .set_read_write(rw.set_multiplexed_session_previous_transaction_id(id));
116        }
117        self
118    }
119
120    pub(crate) fn set_transaction_tag(mut self, tag: impl Into<String>) -> Self {
121        self.transaction_tag = Some(tag.into());
122        self
123    }
124
125    pub(crate) fn set_commit_priority(mut self, priority: Priority) -> Self {
126        self.commit_priority = priority;
127        self
128    }
129
130    pub(crate) fn set_max_commit_delay(mut self, delay: Duration) -> Self {
131        self.max_commit_delay = Some(delay);
132        self
133    }
134
135    pub(crate) fn set_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
136        self.options = self.options.set_exclude_txn_from_change_streams(exclude);
137        self
138    }
139
140    pub(crate) fn set_return_commit_stats(mut self, return_stats: bool) -> Self {
141        self.return_commit_stats = return_stats;
142        self
143    }
144
145    pub fn with_begin_transaction_option(mut self, option: BeginTransactionOption) -> Self {
146        self.begin_transaction_option = option;
147        self
148    }
149
150    pub(crate) fn with_begin_transaction_request_options(
151        mut self,
152        options: Option<crate::RequestOptions>,
153    ) -> Self {
154        self.begin_gax_options = options;
155        self
156    }
157
158    pub(crate) fn with_commit_request_options(
159        mut self,
160        options: Option<crate::RequestOptions>,
161    ) -> Self {
162        self.commit_gax_options = options;
163        self
164    }
165
166    async fn begin(
167        &self,
168        session_name: String,
169        channel_hint: usize,
170        request_options: crate::RequestOptions,
171    ) -> crate::Result<ReadContextTransactionSelector> {
172        let response = crate::read_only_transaction::execute_begin_transaction(
173            &self.client,
174            session_name,
175            self.options.clone(),
176            self.transaction_tag.clone(),
177            channel_hint,
178            request_options,
179            None,
180        )
181        .await?;
182
183        Ok(ReadContextTransactionSelector::Fixed(
184            TransactionSelector::default().set_id(response.id),
185            None,
186        ))
187    }
188
189    pub(crate) async fn build(
190        &self,
191        deadline: Option<Instant>,
192    ) -> crate::Result<ReadWriteTransaction> {
193        let session_name = self.session_name.clone();
194        let channel_hint = self.client.spanner.next_channel_hint();
195        let transaction_selector = match self.begin_transaction_option {
196            BeginTransactionOption::ExplicitBegin => {
197                let mut options = self.begin_gax_options.clone().unwrap_or_default();
198                amend_gax_options(
199                    self.client.leader_aware_routing_enabled,
200                    deadline,
201                    &mut options,
202                );
203
204                self.begin(session_name.clone(), channel_hint, options)
205                    .await?
206            }
207            BeginTransactionOption::InlineBegin => ReadContextTransactionSelector::Lazy(Arc::new(
208                Mutex::new(TransactionState::NotStarted(self.options.clone())),
209            )),
210        };
211
212        Ok(ReadWriteTransaction {
213            context: ReadContext {
214                session_name,
215                client: self.client.clone(),
216                transaction_selector,
217                precommit_token_tracker: PrecommitTokenTracker::new(),
218                transaction_tag: self.transaction_tag.clone(),
219                channel_hint,
220                begin_transaction_request_options: None,
221            },
222            seqno: Arc::new(AtomicI64::new(1)),
223            max_commit_delay: self.max_commit_delay,
224            return_commit_stats: self.return_commit_stats,
225            deadline,
226            commit_priority: self.commit_priority.clone(),
227            mutations: Arc::new(Mutex::new(Vec::new())),
228            begin_gax_options: self.begin_gax_options.clone(),
229            commit_gax_options: self.commit_gax_options.clone(),
230        })
231    }
232}
233
234trait CheckServiceError {
235    fn check_service_error(&self) -> Option<Error>;
236}
237
238impl CheckServiceError for ProtoResultSet {
239    fn check_service_error(&self) -> Option<Error> {
240        None
241    }
242}
243
244/// Normalizes responses from `ExecuteBatchDml`.
245/// If Spanner encounters an error during inline transaction initialization (such as a missing table),
246/// it returns an `Ok(ExecuteBatchDmlResponse)` containing the error status but with empty `result_sets`.
247/// This implementation evaluates that payload so fallback handlers can recover.
248impl CheckServiceError for ExecuteBatchDmlResponse {
249    fn check_service_error(&self) -> Option<Error> {
250        if self.result_sets.is_empty()
251            && let Some(status) = &self.status
252            && status.code != Code::Ok as i32
253        {
254            let rpc_status = Status::default()
255                .set_code(status.code)
256                .set_message(status.message.clone());
257            return Some(Error::service(rpc_status));
258        }
259        None
260    }
261}
262
263/// A scope-bound guard that manages the state of a lazy transaction start attempt.
264///
265/// If the first statement in a transaction is executed using an inline `BeginTransaction` option,
266/// the transaction selector is transitioned to the `Starting` state.
267/// If that initial statement execution fails, or if the transaction ID is not successfully returned,
268/// we must reset the starting state back to `NotStarted` and unlock any concurrent threads waiting
269/// for this transaction to start.
270///
271/// This struct implements the RAII pattern:
272/// - It is initialized with `active = true` when the statement is starting the transaction.
273/// - If the transaction successfully starts and yields a valid ID, the guard is `disarm()`ed.
274/// - If the scope exits early due to an error (e.g., aborted error, protocol error, etc.), the guard
275///   is dropped, and its `Drop` implementation automatically calls `maybe_reset_starting()` to
276///   restore the selector state and notify waiters.
277struct LazyTransactionStartGuard {
278    selector: ReadContextTransactionSelector,
279    active: bool,
280}
281
282impl LazyTransactionStartGuard {
283    fn new(selector: ReadContextTransactionSelector, active: bool) -> Self {
284        Self { selector, active }
285    }
286
287    fn disarm(&mut self) {
288        self.active = false;
289    }
290}
291
292impl Drop for LazyTransactionStartGuard {
293    fn drop(&mut self) {
294        if self.active {
295            self.selector.maybe_reset_starting();
296        }
297    }
298}
299
300/// Helper macro to execute a DML or BatchDML RPC with retry logic if the
301/// request included a BeginTransaction option.
302macro_rules! execute_with_retry {
303    ($self:expr, $request:ident, $gax_options:expr, $rpc_method:ident, $extract_id:expr) => {{
304        let is_starting = matches!(
305            $request
306                .transaction
307                .as_ref()
308                .and_then(|t| t.selector.as_ref()),
309            Some(Selector::Begin(_))
310        );
311
312        let mut guard =
313            LazyTransactionStartGuard::new($self.context.transaction_selector.clone(), is_starting);
314
315        let response_result = $self
316            .context
317            .client
318            .spanner
319            .$rpc_method(
320                $request.clone(),
321                $gax_options.clone(),
322                $self.context.channel_hint,
323            )
324            .await;
325
326        let service_error = response_result
327            .as_ref()
328            .ok()
329            .and_then(|res| res.check_service_error());
330        let err_ref = response_result.as_ref().err().or(service_error.as_ref());
331
332        let response = match err_ref {
333            None => {
334                let response = response_result?;
335                if is_starting {
336                    let id = $extract_id(&response).ok_or_else(|| {
337                        internal_error("Transaction ID was not returned by Spanner")
338                    })?;
339                    $self.context.transaction_selector.update(id, None)?;
340                    guard.disarm();
341                }
342                response
343            }
344            Some(error) => {
345                if !is_starting {
346                    response_result?
347                } else if is_aborted(error) {
348                    response_result?
349                } else {
350                    $self.begin_explicitly_if_not_started(true, None).await?;
351
352                    $request.transaction =
353                        Some($self.context.transaction_selector.selector().await?);
354
355                    let res = $self
356                        .context
357                        .client
358                        .spanner
359                        .$rpc_method($request.clone(), $gax_options, $self.context.channel_hint)
360                        .await?;
361
362                    guard.disarm();
363                    res
364                }
365            }
366        };
367
368        response
369    }};
370}
371
372/// A read-write transaction.
373#[derive(Clone, Debug)]
374pub struct ReadWriteTransaction {
375    pub(crate) context: ReadContext,
376    pub(crate) deadline: Option<Instant>,
377    seqno: Arc<AtomicI64>,
378    max_commit_delay: Option<Duration>,
379    return_commit_stats: bool,
380    commit_priority: Priority,
381    mutations: Arc<Mutex<Vec<ProtoMutation>>>,
382    begin_gax_options: Option<crate::RequestOptions>,
383    commit_gax_options: Option<crate::RequestOptions>,
384}
385
386impl ReadWriteTransaction {
387    /// Buffers one or more mutations to be applied when the transaction commits.
388    ///
389    /// # Example
390    /// ```
391    /// # use google_cloud_spanner::client::Spanner;
392    /// # use google_cloud_spanner::mutation::Mutation;
393    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
394    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
395    /// let runner = db_client.read_write_transaction().build().await?;
396    /// runner.run(async |tx| {
397    ///     let mutation = Mutation::new_insert_builder("users")
398    ///         .set("id").to(&1)
399    ///         .build();
400    ///     tx.buffer([mutation])?;
401    ///     Ok(())
402    /// }).await?;
403    /// # Ok(())
404    /// # }
405    /// ```
406    pub fn buffer<I>(&self, mutations: I) -> crate::Result<()>
407    where
408        I: IntoIterator<Item = Mutation>,
409    {
410        let mut guard = self
411            .mutations
412            .lock()
413            .map_err(|_| crate::error::internal_error("mutations mutex poisoned"))?;
414        for mutation in mutations {
415            guard.push(mutation.build_proto());
416        }
417        Ok(())
418    }
419
420    /// Executes a query using this transaction.
421    pub async fn execute_query<T: Into<Statement>>(
422        &self,
423        statement: T,
424    ) -> crate::Result<ResultSet> {
425        let stmt = statement.into();
426        let mut gax_options = stmt.gax_options().clone();
427        self.amend_gax_options(&mut gax_options);
428        let stmt = stmt.with_gax_options(gax_options);
429        self.context.execute_query(stmt).await
430    }
431
432    /// Reads rows from the database using key lookups and scans, as a simple key/value style alternative to execute_query.
433    pub async fn execute_read<T: Into<crate::read::ReadRequest>>(
434        &self,
435        read: T,
436    ) -> crate::Result<ResultSet> {
437        let mut req = read.into();
438        self.amend_gax_options(&mut req.gax_options);
439        self.context.execute_read(req).await
440    }
441
442    /// Executes an update using this transaction.
443    pub async fn execute_update<T: Into<Statement>>(&self, statement: T) -> crate::Result<i64> {
444        let statement = statement.into();
445        let mut gax_options = statement.gax_options().clone();
446        self.amend_gax_options(&mut gax_options);
447        let seqno = self.seqno.fetch_add(1, Ordering::SeqCst);
448        let mut request = statement
449            .into_request()
450            .set_session(self.context.session_name.clone())
451            .set_transaction(self.context.transaction_selector.selector().await?)
452            .set_seqno(seqno);
453        request.request_options = self.context.amend_request_options(request.request_options);
454
455        let response = execute_with_retry!(
456            self,
457            request,
458            gax_options,
459            execute_sql,
460            |response: &crate::model::ResultSet| {
461                response
462                    .metadata
463                    .as_ref()
464                    .and_then(|md| md.transaction.as_ref())
465                    .map(|t| t.id.clone())
466            }
467        );
468
469        self.context
470            .precommit_token_tracker
471            .update(response.precommit_token);
472
473        let stats = response
474            .stats
475            .ok_or_else(|| internal_error("No stats returned"))?;
476        match stats.row_count {
477            Some(RowCount::RowCountExact(c)) => Ok(c),
478            _ => Err(internal_error(
479                "ExecuteSql returned an invalid or missing row count type for a read/write transaction",
480            )),
481        }
482    }
483
484    /// Executes a batch of DML statements using this transaction.
485    ///
486    /// # Example
487    /// ```
488    /// # use google_cloud_spanner::client::Spanner;
489    /// # use google_cloud_spanner::statement::Statement;
490    /// # use google_cloud_spanner::batch::BatchDml;
491    /// # async fn build(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
492    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
493    /// let runner = db_client.read_write_transaction().build().await?;
494    /// let result = runner.run(async |transaction| {
495    ///     let statement1 = Statement::builder("UPDATE users SET active = true WHERE id = @id")
496    ///         .add_param("id", &1)
497    ///         .build();
498    ///     let statement2 = Statement::builder("UPDATE users SET active = true WHERE id = @id")
499    ///         .add_param("id", &2)
500    ///         .build();
501    ///     let batch = BatchDml::builder()
502    ///         .add_statement(statement1)
503    ///         .add_statement(statement2)
504    ///         .build();
505    ///     let update_counts = transaction.execute_batch_update(batch).await?;
506    ///     Ok(())
507    /// }).await?;
508    /// # Ok(())
509    /// # }
510    /// ```
511    ///
512    /// If a `BatchDml` request fails halfway through execution, `execute_batch_update` will return a
513    /// `BatchUpdateError` indicating exactly which statements succeeded (and their respective update counts)
514    /// before the batch execution failed.
515    ///
516    /// # Error Handling Example
517    /// ```
518    /// # use google_cloud_spanner::client::Spanner;
519    /// # use google_cloud_spanner::statement::Statement;
520    /// # use google_cloud_spanner::batch::BatchDml;
521    /// # use google_cloud_spanner::error::BatchUpdateError;
522    /// # async fn build(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
523    /// # let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
524    /// # let runner = db_client.read_write_transaction().build().await?;
525    /// # let result = runner.run(async |transaction| {
526    /// let statement1 = Statement::builder("UPDATE users SET active = true WHERE id = 1").build();
527    /// let statement2 = Statement::builder("UPDATE non_existent_table SET active = true WHERE id = 2").build();
528    ///
529    /// let batch = BatchDml::builder()
530    ///     .add_statement(statement1)
531    ///     .add_statement(statement2)
532    ///     .build();
533    ///
534    /// match transaction.execute_batch_update(batch).await {
535    ///     Ok(update_counts) => {
536    ///         println!("All statements succeeded. Update counts: {:?}", update_counts);
537    ///     }
538    ///     Err(e) => {
539    ///         if let Some(batch_error) = BatchUpdateError::extract(&e) {
540    ///             println!("Batch execution failed. Successful update counts: {:?}", batch_error.update_counts);
541    ///         } else {
542    ///             println!("RPC failed or internal error occurred: {}", e);
543    ///         }
544    ///     }
545    /// }
546    /// # Ok(())
547    /// # }).await?;
548    /// # Ok(())
549    /// # }
550    /// ```
551    pub async fn execute_batch_update<T: Into<BatchDml>>(
552        &self,
553        batch: T,
554    ) -> crate::Result<Vec<i64>> {
555        let mut batch = batch.into();
556        self.amend_gax_options(&mut batch.gax_options);
557        let seqno = self.seqno.fetch_add(1, Ordering::SeqCst);
558
559        let statements: Vec<ExecuteBatchDmlStatement> = batch
560            .statements
561            .into_iter()
562            .map(|stmt: crate::statement::Statement| stmt.into_batch_statement())
563            .collect();
564
565        let mut request = ExecuteBatchDmlRequest::default()
566            .set_session(self.context.session_name.clone())
567            .set_transaction(self.context.transaction_selector.selector().await?)
568            .set_seqno(seqno)
569            .set_statements(statements)
570            .set_or_clear_request_options(
571                self.context.amend_request_options(batch.request_options),
572            );
573
574        let response = execute_with_retry!(
575            self,
576            request,
577            batch.gax_options,
578            execute_batch_dml,
579            |response: &crate::model::ExecuteBatchDmlResponse| {
580                response
581                    .result_sets
582                    .first()
583                    .and_then(|rs| rs.metadata.as_ref())
584                    .and_then(|md| md.transaction.as_ref())
585                    .map(|t| t.id.clone())
586            }
587        );
588
589        self.context
590            .precommit_token_tracker
591            .update(response.precommit_token.clone());
592        crate::batch_dml::process_response(response)
593    }
594
595    pub(crate) async fn begin_explicitly_if_not_started(
596        &self,
597        is_stream_fallback: bool,
598        mutation_key: Option<crate::model::Mutation>,
599    ) -> crate::Result<bool> {
600        let mut begin_options = self.begin_gax_options.clone().unwrap_or_default();
601        self.amend_gax_options(&mut begin_options);
602        self.context
603            .begin_explicitly_if_not_started(begin_options, is_stream_fallback, mutation_key)
604            .await
605    }
606
607    pub(crate) fn is_starting(&self) -> crate::Result<bool> {
608        self.context.transaction_selector.is_starting()
609    }
610
611    fn commit_request_options(&self) -> Option<crate::model::RequestOptions> {
612        let mut options = self.context.amend_request_options(None);
613        if self.commit_priority != Priority::Unspecified {
614            options
615                .get_or_insert_with(crate::model::RequestOptions::default)
616                .priority = self.commit_priority.clone();
617        }
618        options
619    }
620
621    fn build_commit_request(
622        &self,
623        transaction_id: bytes::Bytes,
624        mutations: Vec<ProtoMutation>,
625        precommit_token: Option<crate::model::MultiplexedSessionPrecommitToken>,
626    ) -> CommitRequest {
627        create_commit_request(
628            self.context.session_name.clone(),
629            transaction_id,
630            mutations,
631            precommit_token,
632            self.commit_request_options(),
633            self.max_commit_delay,
634            self.return_commit_stats,
635        )
636    }
637
638    /// Commits the transaction.
639    pub(crate) async fn commit(self) -> crate::Result<CommitResponse> {
640        let mutations = take(&mut *self.mutations.lock().unwrap());
641        let mut id = self.context.transaction_selector.get_id_no_wait()?;
642        if id.is_none() {
643            if self.is_starting()? {
644                return Err(crate::error::internal_error(
645                    "Commit called while an asynchronous statement is still starting the transaction",
646                ));
647            }
648            let mutation_key = Mutation::select_mutation_key(&mutations);
649            if self
650                .begin_explicitly_if_not_started(false, mutation_key)
651                .await?
652            {
653                id = self.context.transaction_selector.get_id_no_wait()?;
654            }
655        }
656        let transaction_id = id.ok_or_else(|| internal_error("Transaction ID is missing"))?;
657        let precommit_token = self.context.precommit_token_tracker.get();
658
659        let request = self.build_commit_request(transaction_id.clone(), mutations, precommit_token);
660
661        let mut gax_options = self.commit_gax_options.clone().unwrap_or_default();
662        self.amend_gax_options(&mut gax_options);
663
664        let response = self
665            .context
666            .client
667            .spanner
668            .commit(request, gax_options, self.context.channel_hint)
669            .await?;
670
671        let response =
672            if let Some(new_precommit_token) = response.precommit_token().map(|b| (*b).clone()) {
673                let retry_commit_req = self.build_commit_request(
674                    transaction_id,
675                    Vec::new(), // mutations are never re-sent in retry requests
676                    Some(*new_precommit_token),
677                );
678
679                let mut gax_options = self.commit_gax_options.clone().unwrap_or_default();
680                self.amend_gax_options(&mut gax_options);
681
682                self.context
683                    .client
684                    .spanner
685                    .commit(retry_commit_req, gax_options, self.context.channel_hint)
686                    .await?
687            } else {
688                response
689            };
690
691        Ok(response)
692    }
693
694    /// Rolls back the transaction.
695    pub(crate) async fn rollback(self) -> crate::Result<()> {
696        let Some(transaction_id) = self.context.transaction_selector.get_id_no_wait()? else {
697            return Ok(());
698        };
699
700        let request = RollbackRequest::default()
701            .set_session(self.context.session_name.clone())
702            .set_transaction_id(transaction_id);
703
704        let mut gax_options = RequestOptions::default();
705        self.amend_gax_options(&mut gax_options);
706
707        self.context
708            .client
709            .spanner
710            .rollback(request, gax_options, self.context.channel_hint)
711            .await?;
712
713        Ok(())
714    }
715
716    fn amend_gax_options(&self, options: &mut GaxRequestOptions) {
717        amend_gax_options(
718            self.context.client.leader_aware_routing_enabled,
719            self.deadline,
720            options,
721        );
722    }
723}
724
725pub(crate) fn amend_gax_options(
726    leader_aware_routing_enabled: bool,
727    deadline: Option<Instant>,
728    options: &mut GaxRequestOptions,
729) {
730    if let Some(deadline) = deadline {
731        let remaining = deadline.saturating_duration_since(Instant::now());
732        let attempt_timeout = match options.attempt_timeout() {
733            Some(custom_timeout) => std::cmp::min(*custom_timeout, remaining),
734            None => remaining,
735        };
736        options.set_attempt_timeout(attempt_timeout);
737
738        let inner_policy = options
739            .retry_policy()
740            .clone()
741            .unwrap_or_else(|| Arc::new(Aip194Strict));
742        let bounded_policy = TransactionBoundedRetryPolicy {
743            inner: inner_policy,
744            deadline,
745        };
746        options.set_retry_policy(bounded_policy);
747    }
748    *options = amend_request_options_for_lar(leader_aware_routing_enabled, take(options));
749}
750
751/// A retry policy that wraps another policy and bounds the total execution time
752/// by a specific transaction deadline.
753///
754/// This policy delegates `on_error` to the inner policy but overrides `remaining_time`
755/// to ensure that it never exceeds the time left until the transaction deadline.
756#[derive(Debug)]
757struct TransactionBoundedRetryPolicy {
758    inner: Arc<dyn RetryPolicy>,
759    deadline: Instant,
760}
761
762impl RetryPolicy for TransactionBoundedRetryPolicy {
763    fn on_error(&self, state: &RetryState, error: GaxError) -> RetryResult {
764        self.inner.on_error(state, error)
765    }
766
767    fn on_throttle(&self, state: &RetryState, error: GaxError) -> ThrottleResult {
768        self.inner.on_throttle(state, error)
769    }
770
771    fn remaining_time(&self, state: &RetryState) -> Option<StdDuration> {
772        let remaining = self.deadline.saturating_duration_since(Instant::now());
773        let attempt_timeout = self
774            .inner
775            .remaining_time(state)
776            .map(|inner| min(remaining, inner))
777            .unwrap_or(remaining);
778        Some(attempt_timeout)
779    }
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785    use crate::error::BatchUpdateError;
786    use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
787    use crate::result_set::tests::adapt;
788    use crate::transaction_retry_policy::BasicTransactionRetryPolicy;
789    use gaxi::grpc::tonic;
790    use gaxi::grpc::tonic::MetadataMap;
791    use google_cloud_gax::options::internal::RequestOptionsExt as _;
792    use google_cloud_gax::retry_policy::NeverRetry;
793    use google_cloud_gax::retry_result::RetryResult;
794    use google_cloud_gax::retry_state::RetryState;
795    use google_cloud_test_macros::tokio_test_no_panics;
796    use http::HeaderMap;
797    use prost_types::Timestamp;
798    use spanner_grpc_mock::google::spanner::v1;
799    use std::fmt::Debug;
800    use std::sync::Mutex;
801    use std::time::Duration as StdDuration;
802
803    #[test]
804    fn auto_traits() {
805        static_assertions::assert_impl_all!(ReadWriteTransactionBuilder: Send, Sync, Clone, Debug);
806        static_assertions::assert_impl_all!(ReadWriteTransaction: Send, Sync, Debug);
807    }
808
809    #[tokio_test_no_panics]
810    async fn read_write_transaction_commit_retry_explicit() -> anyhow::Result<()> {
811        run_read_write_transaction_commit_retry(BeginTransactionOption::ExplicitBegin).await
812    }
813
814    #[tokio_test_no_panics]
815    async fn read_write_transaction_commit_retry_inline() -> anyhow::Result<()> {
816        run_read_write_transaction_commit_retry(BeginTransactionOption::InlineBegin).await
817    }
818
819    async fn run_read_write_transaction_commit_retry(
820        begin_transaction_option: BeginTransactionOption,
821    ) -> anyhow::Result<()> {
822        let mut mock = create_session_mock();
823        let remotes = Arc::new(Mutex::new(Vec::new()));
824
825        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
826            let remotes_clone = remotes.clone();
827            mock.expect_begin_transaction()
828                .once()
829                .returning(move |req| {
830                    remotes_clone
831                        .lock()
832                        .unwrap()
833                        .push(req.remote_addr().expect("remote_addr should be available"));
834                    let req = req.into_inner();
835                    assert_eq!(
836                        req.session,
837                        "projects/p/instances/i/databases/d/sessions/123"
838                    );
839                    Ok(tonic::Response::new(v1::Transaction {
840                        id: vec![0, 0, 7],
841                        ..Default::default()
842                    }))
843                });
844        }
845
846        // execute_update returns a precommit token.
847        let remotes_clone = remotes.clone();
848        mock.expect_execute_sql().once().returning(move |req| {
849            remotes_clone
850                .lock()
851                .unwrap()
852                .push(req.remote_addr().expect("remote_addr should be available"));
853            let req = req.into_inner();
854            assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1");
855
856            if begin_transaction_option == BeginTransactionOption::InlineBegin {
857                let transaction = req
858                    .transaction
859                    .as_ref()
860                    .expect("transaction options required for inline begin");
861                let selector = transaction.selector.as_ref().expect("selector required");
862                assert!(matches!(
863                    selector,
864                    v1::transaction_selector::Selector::Begin(_)
865                ));
866            }
867
868            let mut metadata = v1::ResultSetMetadata {
869                row_type: Some(v1::StructType { fields: vec![] }),
870                ..Default::default()
871            };
872            if begin_transaction_option == BeginTransactionOption::InlineBegin {
873                metadata.transaction = Some(v1::Transaction {
874                    id: vec![0, 0, 7],
875                    ..Default::default()
876                });
877            }
878
879            Ok(tonic::Response::new(v1::ResultSet {
880                metadata: Some(metadata),
881                stats: Some(v1::ResultSetStats {
882                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
883                    ..Default::default()
884                }),
885                precommit_token: Some(v1::MultiplexedSessionPrecommitToken {
886                    precommit_token: vec![101],
887                    seq_num: 1,
888                }),
889                ..Default::default()
890            }))
891        });
892
893        // Simulate that commit returns a precommit token in the response.
894        // This would normally not happen, but we test it here to verify
895        // that the commit is retried.
896        let remotes_clone = remotes.clone();
897        mock.expect_commit().once().returning(move |req| {
898            remotes_clone
899                .lock()
900                .unwrap()
901                .push(req.remote_addr().expect("remote_addr should be available"));
902            let req = req.into_inner();
903            assert_eq!(
904                req.precommit_token,
905                Some(v1::MultiplexedSessionPrecommitToken {
906                    precommit_token: vec![101],
907                    seq_num: 1,
908                })
909            );
910            Ok(tonic::Response::new(v1::CommitResponse {
911                commit_timestamp: Some(prost_types::Timestamp {
912                    seconds: 1000,
913                    nanos: 0,
914                }),
915                multiplexed_session_retry: Some(
916                    v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
917                        v1::MultiplexedSessionPrecommitToken {
918                            precommit_token: vec![202],
919                            seq_num: 2,
920                        },
921                    ),
922                ),
923                ..Default::default()
924            }))
925        });
926
927        // Second commit retry is automatically issued with the new token
928        let remotes_clone = remotes.clone();
929        mock.expect_commit().once().returning(move |req| {
930            remotes_clone
931                .lock()
932                .unwrap()
933                .push(req.remote_addr().expect("remote_addr should be available"));
934            let req = req.into_inner();
935            assert_eq!(
936                req.precommit_token,
937                Some(v1::MultiplexedSessionPrecommitToken {
938                    precommit_token: vec![202],
939                    seq_num: 2,
940                })
941            );
942            assert!(
943                req.mutations.is_empty(),
944                "Expected mutations to be empty in retried CommitRequest"
945            );
946            Ok(tonic::Response::new(v1::CommitResponse {
947                commit_timestamp: Some(prost_types::Timestamp {
948                    seconds: 1001,
949                    nanos: 0,
950                }),
951                ..Default::default()
952            }))
953        });
954
955        let (db_client, _server) = setup_db_client(mock).await;
956
957        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
958            .with_begin_transaction_option(begin_transaction_option)
959            .build(None)
960            .await
961            .expect("Failed to build transaction");
962
963        let count = tx
964            .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1")
965            .await?;
966        assert_eq!(count, 1);
967
968        let timestamp = tx.commit().await?;
969        assert_eq!(
970            timestamp
971                .commit_timestamp
972                .as_ref()
973                .expect("timestamp should be present")
974                .seconds(),
975            1001
976        );
977
978        // Verify that all RPCs used the same channel (same remote address)
979        let remotes = remotes.lock().unwrap();
980        let expected_rpcs = if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
981            4
982        } else {
983            3
984        };
985        assert_eq!(
986            remotes.len(),
987            expected_rpcs,
988            "Expected exactly {} RPCs",
989            expected_rpcs
990        );
991        let first = remotes[0];
992        for addr in remotes.iter() {
993            assert_eq!(*addr, first, "All RPCs should use the same gRPC channel");
994        }
995
996        Ok(())
997    }
998
999    #[tokio_test_no_panics]
1000    async fn read_write_transaction_commit_retry_preserves_options() -> anyhow::Result<()> {
1001        let mut mock = create_session_mock();
1002
1003        // execute_update returns a precommit token.
1004        mock.expect_execute_sql().once().returning(move |req| {
1005            let req = req.into_inner();
1006            assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1");
1007
1008            let mut metadata = v1::ResultSetMetadata {
1009                row_type: Some(v1::StructType { fields: vec![] }),
1010                ..Default::default()
1011            };
1012            metadata.transaction = Some(v1::Transaction {
1013                id: vec![0, 0, 7],
1014                ..Default::default()
1015            });
1016
1017            Ok(tonic::Response::new(v1::ResultSet {
1018                metadata: Some(metadata),
1019                stats: Some(v1::ResultSetStats {
1020                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1021                    ..Default::default()
1022                }),
1023                precommit_token: Some(v1::MultiplexedSessionPrecommitToken {
1024                    precommit_token: vec![101],
1025                    seq_num: 1,
1026                }),
1027                ..Default::default()
1028            }))
1029        });
1030
1031        // Simulate that commit returns a precommit token in the response.
1032        // This would normally not happen, but we test it here to verify
1033        // that the commit is retried and the options are preserved.
1034        let expected_delay = prost_types::Duration {
1035            seconds: 0,
1036            nanos: 200_000_000,
1037        };
1038
1039        let expected_delay_clone = expected_delay;
1040        mock.expect_commit().once().returning(move |req| {
1041            let req = req.into_inner();
1042            assert_eq!(
1043                req.precommit_token,
1044                Some(v1::MultiplexedSessionPrecommitToken {
1045                    precommit_token: vec![101],
1046                    seq_num: 1,
1047                })
1048            );
1049            // Assert original options are present
1050            assert!(
1051                req.return_commit_stats,
1052                "Expected return_commit_stats to be true in first commit"
1053            );
1054            assert_eq!(
1055                req.max_commit_delay.as_ref(),
1056                Some(&expected_delay_clone),
1057                "Expected max_commit_delay to be set in first commit"
1058            );
1059
1060            Ok(tonic::Response::new(v1::CommitResponse {
1061                commit_timestamp: Some(prost_types::Timestamp {
1062                    seconds: 1000,
1063                    nanos: 0,
1064                }),
1065                multiplexed_session_retry: Some(
1066                    v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
1067                        v1::MultiplexedSessionPrecommitToken {
1068                            precommit_token: vec![202],
1069                            seq_num: 2,
1070                        },
1071                    ),
1072                ),
1073                ..Default::default()
1074            }))
1075        });
1076
1077        // Second commit retry is automatically issued with the new token and MUST preserve original options
1078        mock.expect_commit().once().returning(move |req| {
1079            let req = req.into_inner();
1080            assert_eq!(
1081                req.precommit_token,
1082                Some(v1::MultiplexedSessionPrecommitToken {
1083                    precommit_token: vec![202],
1084                    seq_num: 2,
1085                })
1086            );
1087            assert!(
1088                req.return_commit_stats,
1089                "Expected return_commit_stats to be preserved in retried commit request"
1090            );
1091            assert_eq!(
1092                req.max_commit_delay.as_ref(),
1093                Some(&expected_delay),
1094                "Expected max_commit_delay to be preserved in retried commit request"
1095            );
1096            assert!(
1097                req.mutations.is_empty(),
1098                "Expected mutations to be empty in retried CommitRequest"
1099            );
1100
1101            Ok(tonic::Response::new(v1::CommitResponse {
1102                commit_timestamp: Some(prost_types::Timestamp {
1103                    seconds: 1001,
1104                    nanos: 0,
1105                }),
1106                ..Default::default()
1107            }))
1108        });
1109
1110        let (db_client, _server) = setup_db_client(mock).await;
1111
1112        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1113            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
1114            .set_return_commit_stats(true)
1115            .set_max_commit_delay(Duration::new(0, 200_000_000).expect("valid duration"))
1116            .build(None)
1117            .await
1118            .expect("Failed to build transaction");
1119
1120        let count = tx
1121            .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1")
1122            .await?;
1123        assert_eq!(count, 1);
1124
1125        let timestamp = tx.commit().await?;
1126        assert_eq!(
1127            timestamp
1128                .commit_timestamp
1129                .as_ref()
1130                .expect("timestamp should be present")
1131                .seconds(),
1132            1001
1133        );
1134
1135        Ok(())
1136    }
1137
1138    #[tokio_test_no_panics]
1139    async fn read_write_transaction_commit_carries_commit_priority() -> anyhow::Result<()> {
1140        let mut mock = create_session_mock();
1141
1142        mock.expect_execute_sql().once().returning(move |_req| {
1143            let mut metadata = v1::ResultSetMetadata {
1144                row_type: Some(v1::StructType { fields: vec![] }),
1145                ..Default::default()
1146            };
1147            metadata.transaction = Some(v1::Transaction {
1148                id: vec![1, 2, 3],
1149                ..Default::default()
1150            });
1151
1152            Ok(tonic::Response::new(v1::ResultSet {
1153                metadata: Some(metadata),
1154                stats: Some(v1::ResultSetStats {
1155                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1156                    ..Default::default()
1157                }),
1158                ..Default::default()
1159            }))
1160        });
1161
1162        mock.expect_commit().once().returning(|req| {
1163            let req = req.into_inner();
1164            let request_options = req
1165                .request_options
1166                .expect("Expected request_options in CommitRequest");
1167            assert_eq!(
1168                request_options.priority,
1169                v1::request_options::Priority::Low as i32,
1170                "Expected priority to be Priority::Low in CommitRequest"
1171            );
1172
1173            Ok(tonic::Response::new(v1::CommitResponse {
1174                commit_timestamp: Some(prost_types::Timestamp {
1175                    seconds: 123456789,
1176                    nanos: 0,
1177                }),
1178                ..Default::default()
1179            }))
1180        });
1181
1182        let (db_client, _server) = setup_db_client(mock).await;
1183
1184        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1185            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
1186            .set_commit_priority(Priority::Low)
1187            .build(None)
1188            .await
1189            .expect("Failed to build transaction");
1190
1191        let count = tx
1192            .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1193            .await?;
1194        assert_eq!(count, 1);
1195
1196        let _ = tx.commit().await?;
1197
1198        Ok(())
1199    }
1200
1201    #[tokio_test_no_panics]
1202    async fn read_write_transaction_execute_update_explicit() {
1203        run_read_write_transaction_execute_update(BeginTransactionOption::ExplicitBegin).await;
1204    }
1205
1206    #[tokio_test_no_panics]
1207    async fn read_write_transaction_execute_update_inline() {
1208        run_read_write_transaction_execute_update(BeginTransactionOption::InlineBegin).await;
1209    }
1210
1211    async fn run_read_write_transaction_execute_update(
1212        begin_transaction_option: BeginTransactionOption,
1213    ) {
1214        let mut mock = create_session_mock();
1215
1216        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1217            mock.expect_begin_transaction().once().returning(|req| {
1218                let req = req.into_inner();
1219                assert_eq!(
1220                    req.session,
1221                    "projects/p/instances/i/databases/d/sessions/123"
1222                );
1223                Ok(tonic::Response::new(v1::Transaction {
1224                    id: vec![1, 2, 3],
1225                    ..Default::default()
1226                }))
1227            });
1228        }
1229
1230        mock.expect_execute_sql().once().returning(move |req| {
1231            let req = req.into_inner();
1232            assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1");
1233            assert_eq!(req.seqno, 1);
1234
1235            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1236                let transaction = req
1237                    .transaction
1238                    .as_ref()
1239                    .expect("transaction options required for inline begin");
1240                let selector = transaction.selector.as_ref().expect("selector required");
1241                assert!(matches!(
1242                    selector,
1243                    v1::transaction_selector::Selector::Begin(_)
1244                ));
1245            }
1246
1247            let mut metadata = v1::ResultSetMetadata {
1248                row_type: Some(v1::StructType { fields: vec![] }),
1249                ..Default::default()
1250            };
1251            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1252                metadata.transaction = Some(v1::Transaction {
1253                    id: vec![1, 2, 3],
1254                    ..Default::default()
1255                });
1256            }
1257
1258            Ok(tonic::Response::new(v1::ResultSet {
1259                metadata: Some(metadata),
1260                stats: Some(v1::ResultSetStats {
1261                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1262                    ..Default::default()
1263                }),
1264                ..Default::default()
1265            }))
1266        });
1267
1268        mock.expect_commit().once().returning(|req| {
1269            let req = req.into_inner();
1270            assert_eq!(
1271                req.session,
1272                "projects/p/instances/i/databases/d/sessions/123"
1273            );
1274            assert_eq!(
1275                req.transaction,
1276                Some(v1::commit_request::Transaction::TransactionId(vec![
1277                    1, 2, 3
1278                ]))
1279            );
1280            Ok(tonic::Response::new(v1::CommitResponse {
1281                commit_timestamp: Some(prost_types::Timestamp {
1282                    seconds: 123456789,
1283                    nanos: 0,
1284                }),
1285                ..Default::default()
1286            }))
1287        });
1288
1289        let (db_client, _server) = setup_db_client(mock).await;
1290
1291        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1292            .with_begin_transaction_option(begin_transaction_option)
1293            .build(None)
1294            .await
1295            .expect("Failed to build transaction");
1296        let count = tx
1297            .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1298            .await
1299            .expect("Failed to execute update");
1300        assert_eq!(count, 1);
1301
1302        let ts = tx.commit().await.expect("Failed to commit");
1303        assert_eq!(
1304            ts.commit_timestamp
1305                .expect("Commit timestamp should be present")
1306                .seconds(),
1307            123456789
1308        );
1309    }
1310
1311    #[tokio_test_no_panics]
1312    async fn read_write_transaction_execute_update_invalid_stats_explicit() -> anyhow::Result<()> {
1313        run_read_write_transaction_execute_update_invalid_stats(
1314            BeginTransactionOption::ExplicitBegin,
1315        )
1316        .await
1317    }
1318
1319    #[tokio_test_no_panics]
1320    async fn read_write_transaction_execute_update_invalid_stats_inline() -> anyhow::Result<()> {
1321        run_read_write_transaction_execute_update_invalid_stats(BeginTransactionOption::InlineBegin)
1322            .await
1323    }
1324
1325    async fn run_read_write_transaction_execute_update_invalid_stats(
1326        begin_transaction_option: BeginTransactionOption,
1327    ) -> anyhow::Result<()> {
1328        let mut mock = create_session_mock();
1329
1330        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1331            mock.expect_begin_transaction().once().returning(|_| {
1332                Ok(tonic::Response::new(v1::Transaction {
1333                    id: vec![1, 2, 3],
1334                    ..Default::default()
1335                }))
1336            });
1337        }
1338
1339        mock.expect_execute_sql().once().returning(move |req| {
1340            let req = req.into_inner();
1341            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1342                let transaction = req
1343                    .transaction
1344                    .as_ref()
1345                    .expect("transaction options required for inline begin");
1346                let selector = transaction.selector.as_ref().expect("selector required");
1347                assert!(matches!(
1348                    selector,
1349                    v1::transaction_selector::Selector::Begin(_)
1350                ));
1351            }
1352
1353            let mut metadata = v1::ResultSetMetadata {
1354                row_type: Some(v1::StructType { fields: vec![] }),
1355                ..Default::default()
1356            };
1357            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1358                metadata.transaction = Some(v1::Transaction {
1359                    id: vec![1, 2, 3],
1360                    ..Default::default()
1361                });
1362            }
1363
1364            Ok(tonic::Response::new(v1::ResultSet {
1365                metadata: Some(metadata),
1366                stats: Some(v1::ResultSetStats {
1367                    row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(1)),
1368                    ..Default::default()
1369                }),
1370                ..Default::default()
1371            }))
1372        });
1373
1374        let (db_client, _server) = setup_db_client(mock).await;
1375
1376        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1377            .with_begin_transaction_option(begin_transaction_option)
1378            .build(None)
1379            .await
1380            .expect("Failed to build transaction");
1381
1382        let result = tx
1383            .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1384            .await;
1385
1386        let err = result.expect_err("Expected an error for invalid row count stats");
1387        assert!(
1388            format!("{:?}", err).contains("invalid or missing row count type"),
1389            "Error did not contain expected message: {:?}",
1390            err
1391        );
1392        Ok(())
1393    }
1394
1395    #[tokio_test_no_panics]
1396    async fn read_write_transaction_rollback_explicit() -> anyhow::Result<()> {
1397        run_read_write_transaction_rollback(BeginTransactionOption::ExplicitBegin).await
1398    }
1399
1400    #[tokio_test_no_panics]
1401    async fn read_write_transaction_rollback_inline() -> anyhow::Result<()> {
1402        run_read_write_transaction_rollback(BeginTransactionOption::InlineBegin).await
1403    }
1404
1405    async fn run_read_write_transaction_rollback(
1406        begin_transaction_option: BeginTransactionOption,
1407    ) -> anyhow::Result<()> {
1408        let mut mock = create_session_mock();
1409        let transaction_id = vec![9, 9, 9];
1410
1411        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1412            let id = transaction_id.clone();
1413            mock.expect_begin_transaction().once().returning(move |_| {
1414                Ok(tonic::Response::new(v1::Transaction {
1415                    id: id.clone(),
1416                    ..Default::default()
1417                }))
1418            });
1419        } else {
1420            let id = transaction_id.clone();
1421            mock.expect_execute_sql().once().returning(move |req| {
1422                let req = req.into_inner();
1423                let transaction = req
1424                    .transaction
1425                    .as_ref()
1426                    .expect("transaction options required for inline begin");
1427                let selector = transaction.selector.as_ref().expect("selector required");
1428                assert!(matches!(
1429                    selector,
1430                    v1::transaction_selector::Selector::Begin(_)
1431                ));
1432
1433                Ok(tonic::Response::new(v1::ResultSet {
1434                    metadata: Some(v1::ResultSetMetadata {
1435                        transaction: Some(v1::Transaction {
1436                            id: id.clone(),
1437                            ..Default::default()
1438                        }),
1439                        ..Default::default()
1440                    }),
1441                    stats: Some(v1::ResultSetStats {
1442                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1443                        ..Default::default()
1444                    }),
1445                    ..Default::default()
1446                }))
1447            });
1448        }
1449
1450        let id = transaction_id.clone();
1451        mock.expect_rollback().once().returning(move |req| {
1452            let req = req.into_inner();
1453            assert_eq!(
1454                req.session,
1455                "projects/p/instances/i/databases/d/sessions/123"
1456            );
1457            assert_eq!(req.transaction_id, id);
1458            Ok(tonic::Response::new(()))
1459        });
1460
1461        let (db_client, _server) = setup_db_client(mock).await;
1462
1463        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1464            .with_begin_transaction_option(begin_transaction_option)
1465            .build(None)
1466            .await?;
1467
1468        if begin_transaction_option == BeginTransactionOption::InlineBegin {
1469            tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1470                .await
1471                .expect("Failed to execute update");
1472        }
1473
1474        tx.rollback().await?;
1475        Ok(())
1476    }
1477
1478    #[tokio_test_no_panics]
1479    async fn read_write_transaction_execute_batch_update_explicit() -> anyhow::Result<()> {
1480        let batch = BatchDml::builder()
1481            .add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1482            .add_statement("UPDATE Users SET Name = 'Bob' WHERE Id = 2")
1483            .build();
1484        run_read_write_transaction_execute_batch_update(
1485            BeginTransactionOption::ExplicitBegin,
1486            batch,
1487        )
1488        .await
1489    }
1490
1491    #[tokio_test_no_panics]
1492    async fn read_write_transaction_execute_batch_update_inline() -> anyhow::Result<()> {
1493        let batch = BatchDml::builder()
1494            .add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1495            .add_statement("UPDATE Users SET Name = 'Bob' WHERE Id = 2")
1496            .build();
1497        run_read_write_transaction_execute_batch_update(BeginTransactionOption::InlineBegin, batch)
1498            .await
1499    }
1500
1501    #[tokio_test_no_panics]
1502    async fn read_write_transaction_execute_batch_update_vec() -> anyhow::Result<()> {
1503        let statements = vec![
1504            "UPDATE Users SET Name = 'Alice' WHERE Id = 1",
1505            "UPDATE Users SET Name = 'Bob' WHERE Id = 2",
1506        ];
1507        run_read_write_transaction_execute_batch_update(
1508            BeginTransactionOption::InlineBegin,
1509            statements,
1510        )
1511        .await
1512    }
1513
1514    #[tokio_test_no_panics]
1515    async fn read_write_transaction_execute_batch_update_vec_statement() -> anyhow::Result<()> {
1516        let statement1 = Statement::builder("UPDATE Users SET Name = 'Alice' WHERE Id = 1").build();
1517        let statement2 = Statement::builder("UPDATE Users SET Name = 'Bob' WHERE Id = 2").build();
1518        let statements = vec![statement1, statement2];
1519        run_read_write_transaction_execute_batch_update(
1520            BeginTransactionOption::InlineBegin,
1521            statements,
1522        )
1523        .await
1524    }
1525
1526    async fn run_read_write_transaction_execute_batch_update(
1527        begin_transaction_option: BeginTransactionOption,
1528        batch: impl Into<BatchDml>,
1529    ) -> anyhow::Result<()> {
1530        let mut mock = create_session_mock();
1531
1532        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1533            mock.expect_begin_transaction().once().returning(|_| {
1534                Ok(tonic::Response::new(v1::Transaction {
1535                    id: vec![4, 5, 6],
1536                    ..Default::default()
1537                }))
1538            });
1539        }
1540
1541        mock.expect_execute_batch_dml()
1542            .once()
1543            .returning(move |req| {
1544                let req = req.into_inner();
1545                assert_eq!(req.statements.len(), 2);
1546                assert_eq!(
1547                    req.statements[0].sql,
1548                    "UPDATE Users SET Name = 'Alice' WHERE Id = 1"
1549                );
1550                assert_eq!(
1551                    req.statements[1].sql,
1552                    "UPDATE Users SET Name = 'Bob' WHERE Id = 2"
1553                );
1554
1555                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1556                    let selector = req
1557                        .transaction
1558                        .expect("missing transaction selector")
1559                        .selector
1560                        .expect("missing selector");
1561                    assert!(matches!(
1562                        selector,
1563                        v1::transaction_selector::Selector::Begin(_)
1564                    ));
1565                }
1566
1567                let mut metadata = v1::ResultSetMetadata {
1568                    ..Default::default()
1569                };
1570                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1571                    metadata.transaction = Some(v1::Transaction {
1572                        id: vec![4, 5, 6],
1573                        ..Default::default()
1574                    });
1575                }
1576
1577                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
1578                    result_sets: vec![
1579                        v1::ResultSet {
1580                            metadata: Some(metadata),
1581                            stats: Some(v1::ResultSetStats {
1582                                row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1583                                ..Default::default()
1584                            }),
1585                            ..Default::default()
1586                        },
1587                        v1::ResultSet {
1588                            stats: Some(v1::ResultSetStats {
1589                                row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1590                                ..Default::default()
1591                            }),
1592                            ..Default::default()
1593                        },
1594                    ],
1595                    status: Some(spanner_grpc_mock::google::rpc::Status {
1596                        code: 0,
1597                        message: "OK".into(),
1598                        details: vec![],
1599                    }),
1600                    ..Default::default()
1601                }))
1602            });
1603
1604        let (db_client, _server) = setup_db_client(mock).await;
1605
1606        let transaction = ReadWriteTransactionBuilder::new(db_client)
1607            .with_begin_transaction_option(begin_transaction_option)
1608            .build(None)
1609            .await?;
1610
1611        let counts = transaction.execute_batch_update(batch).await?;
1612
1613        assert_eq!(counts, vec![1, 1]);
1614        Ok(())
1615    }
1616
1617    #[tokio_test_no_panics]
1618    async fn read_write_transaction_execute_batch_update_partial_failure_explicit()
1619    -> anyhow::Result<()> {
1620        run_read_write_transaction_execute_batch_update_partial_failure(
1621            BeginTransactionOption::ExplicitBegin,
1622        )
1623        .await
1624    }
1625
1626    #[tokio_test_no_panics]
1627    async fn read_write_transaction_execute_batch_update_partial_failure_inline()
1628    -> anyhow::Result<()> {
1629        run_read_write_transaction_execute_batch_update_partial_failure(
1630            BeginTransactionOption::InlineBegin,
1631        )
1632        .await
1633    }
1634
1635    async fn run_read_write_transaction_execute_batch_update_partial_failure(
1636        begin_transaction_option: BeginTransactionOption,
1637    ) -> anyhow::Result<()> {
1638        let mut mock = create_session_mock();
1639
1640        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1641            mock.expect_begin_transaction().once().returning(|_| {
1642                Ok(tonic::Response::new(v1::Transaction {
1643                    id: vec![7, 8, 9],
1644                    ..Default::default()
1645                }))
1646            });
1647        }
1648
1649        mock.expect_execute_batch_dml()
1650            .once()
1651            .returning(move |req| {
1652                let req = req.into_inner();
1653                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1654                    let selector = req
1655                        .transaction
1656                        .expect("missing transaction selector")
1657                        .selector
1658                        .expect("missing selector");
1659                    assert!(matches!(
1660                        selector,
1661                        v1::transaction_selector::Selector::Begin(_)
1662                    ));
1663                }
1664
1665                let mut metadata = v1::ResultSetMetadata {
1666                    ..Default::default()
1667                };
1668                if begin_transaction_option == BeginTransactionOption::InlineBegin {
1669                    metadata.transaction = Some(v1::Transaction {
1670                        id: vec![7, 8, 9],
1671                        ..Default::default()
1672                    });
1673                }
1674
1675                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
1676                    result_sets: vec![v1::ResultSet {
1677                        metadata: Some(metadata),
1678                        stats: Some(v1::ResultSetStats {
1679                            row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1680                            ..Default::default()
1681                        }),
1682                        ..Default::default()
1683                    }],
1684                    status: Some(spanner_grpc_mock::google::rpc::Status {
1685                        code: gaxi::grpc::tonic::Code::AlreadyExists as i32,
1686                        message: "row already exists".into(),
1687                        details: vec![],
1688                    }),
1689                    ..Default::default()
1690                }))
1691            });
1692
1693        let (db_client, _server) = setup_db_client(mock).await;
1694
1695        let tx = ReadWriteTransactionBuilder::new(db_client)
1696            .with_begin_transaction_option(begin_transaction_option)
1697            .build(None)
1698            .await?;
1699
1700        let batch = BatchDml::builder()
1701            .add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1702            .add_statement("INSERT INTO Users (Id) VALUES (2)");
1703
1704        let res = tx.execute_batch_update(batch.build()).await;
1705
1706        let err = res.expect_err("expected error");
1707        use std::error::Error;
1708        let batch_err = err
1709            .source()
1710            .and_then(|e| e.downcast_ref::<BatchUpdateError>())
1711            .expect("should be BatchUpdateError");
1712        assert_eq!(batch_err.update_counts, vec![1]);
1713        assert_eq!(
1714            batch_err.status.status().expect("status").code,
1715            (gaxi::grpc::tonic::Code::AlreadyExists as i32).into()
1716        );
1717        Ok(())
1718    }
1719
1720    #[tokio_test_no_panics]
1721    async fn read_write_transaction_execute_multiple_updates_explicit() -> anyhow::Result<()> {
1722        run_read_write_transaction_execute_multiple_updates(BeginTransactionOption::ExplicitBegin)
1723            .await
1724    }
1725
1726    #[tokio_test_no_panics]
1727    async fn read_write_transaction_execute_multiple_updates_inline() -> anyhow::Result<()> {
1728        run_read_write_transaction_execute_multiple_updates(BeginTransactionOption::InlineBegin)
1729            .await
1730    }
1731
1732    async fn run_read_write_transaction_execute_multiple_updates(
1733        begin_transaction_option: BeginTransactionOption,
1734    ) -> anyhow::Result<()> {
1735        let mut mock = create_session_mock();
1736
1737        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
1738            mock.expect_begin_transaction().once().returning(|req| {
1739                let req = req.into_inner();
1740                assert_eq!(
1741                    req.session,
1742                    "projects/p/instances/i/databases/d/sessions/123"
1743                );
1744                Ok(tonic::Response::new(v1::Transaction {
1745                    id: vec![4, 5, 6],
1746                    ..Default::default()
1747                }))
1748            });
1749        }
1750
1751        let counter = Arc::new(AtomicI64::new(1));
1752        mock.expect_execute_sql().times(3).returning(move |req| {
1753            let req = req.into_inner();
1754            assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1");
1755            let c = counter.fetch_add(1, Ordering::SeqCst);
1756            assert_eq!(req.seqno, c);
1757
1758            let mut metadata = v1::ResultSetMetadata {
1759                ..Default::default()
1760            };
1761
1762            if begin_transaction_option == BeginTransactionOption::InlineBegin {
1763                if c == 1 {
1764                    let selector = req
1765                        .transaction
1766                        .expect("missing transaction selector")
1767                        .selector
1768                        .expect("missing selector");
1769                    assert!(matches!(
1770                        selector,
1771                        v1::transaction_selector::Selector::Begin(_)
1772                    ));
1773                    metadata.transaction = Some(v1::Transaction {
1774                        id: vec![4, 5, 6],
1775                        ..Default::default()
1776                    });
1777                } else {
1778                    let selector = req
1779                        .transaction
1780                        .expect("missing transaction selector")
1781                        .selector
1782                        .expect("missing selector");
1783                    match selector {
1784                        v1::transaction_selector::Selector::Id(id) => {
1785                            assert_eq!(id, vec![4, 5, 6]);
1786                        }
1787                        _ => panic!("Expected Selector::Id"),
1788                    }
1789                }
1790            }
1791
1792            Ok(tonic::Response::new(v1::ResultSet {
1793                metadata: Some(metadata),
1794                stats: Some(v1::ResultSetStats {
1795                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1796                    ..Default::default()
1797                }),
1798                ..Default::default()
1799            }))
1800        });
1801
1802        let (db_client, _server) = setup_db_client(mock).await;
1803
1804        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1805            .with_begin_transaction_option(begin_transaction_option)
1806            .build(None)
1807            .await
1808            .expect("Failed to build transaction");
1809
1810        for i in 1..=3 {
1811            let count = tx
1812                .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
1813                .await
1814                .unwrap_or_else(|_| panic!("Failed to execute update {}", i));
1815            assert_eq!(count, 1);
1816        }
1817        Ok(())
1818    }
1819
1820    #[tokio_test_no_panics]
1821    async fn read_write_transaction_execute_query() {
1822        use crate::statement::Statement;
1823        let mut mock = create_session_mock();
1824
1825        mock.expect_begin_transaction().once().returning(|req| {
1826            let req = req.into_inner();
1827            assert_eq!(
1828                req.session,
1829                "projects/p/instances/i/databases/d/sessions/123"
1830            );
1831            Ok(tonic::Response::new(v1::Transaction {
1832                id: vec![7, 8, 9],
1833                ..Default::default()
1834            }))
1835        });
1836
1837        mock.expect_execute_streaming_sql().once().returning(|req| {
1838            let req = req.into_inner();
1839            assert_eq!(req.sql, "SELECT 1");
1840            // Queries do not need to include a sequence number.
1841            assert_eq!(req.seqno, 0);
1842
1843            assert_eq!(
1844                req.transaction,
1845                Some(v1::TransactionSelector {
1846                    selector: Some(v1::transaction_selector::Selector::Id(vec![7, 8, 9]))
1847                })
1848            );
1849
1850            let prs = v1::PartialResultSet {
1851                metadata: Some(v1::ResultSetMetadata {
1852                    row_type: Some(v1::StructType { fields: vec![] }),
1853                    ..Default::default()
1854                }),
1855                ..Default::default()
1856            };
1857            Ok(tonic::Response::from(crate::result_set::tests::adapt([
1858                Ok(prs),
1859            ])))
1860        });
1861
1862        let (db_client, _server) = setup_db_client(mock).await;
1863
1864        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
1865            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
1866            .build(None)
1867            .await
1868            .expect("Failed to build transaction");
1869
1870        let mut rs = tx
1871            .execute_query(Statement::builder("SELECT 1").build())
1872            .await
1873            .expect("Failed to execute query");
1874
1875        let result = rs.next().await;
1876        assert!(result.is_none(), "expected None, got empty stream");
1877    }
1878
1879    #[tokio_test_no_panics]
1880    async fn read_write_transaction_with_options() {
1881        let mut mock = create_session_mock();
1882
1883        mock.expect_begin_transaction().once().returning(|req| {
1884            let req = req.into_inner();
1885            assert_eq!(
1886                req.session,
1887                "projects/p/instances/i/databases/d/sessions/123"
1888            );
1889
1890            let options = req.options.expect("missing transaction options");
1891            let mode = options.mode.expect("missing mode");
1892            match mode {
1893                v1::transaction_options::Mode::ReadWrite(rw) => {
1894                    assert_eq!(
1895                        rw.read_lock_mode,
1896                        v1::transaction_options::read_write::ReadLockMode::Pessimistic as i32
1897                    );
1898                }
1899                _ => panic!("Expected ReadWrite transaction mode"),
1900            }
1901            // Ensure isolation level is passed through
1902            assert_eq!(
1903                options.isolation_level,
1904                v1::transaction_options::IsolationLevel::Serializable as i32
1905            );
1906
1907            Ok(tonic::Response::new(v1::Transaction {
1908                id: vec![9, 9, 9],
1909                ..Default::default()
1910            }))
1911        });
1912
1913        let (db_client, _server) = setup_db_client(mock).await;
1914
1915        let _tx = ReadWriteTransactionBuilder::new(db_client.clone())
1916            .set_isolation_level(IsolationLevel::Serializable)
1917            .set_read_lock_mode(ReadLockMode::Pessimistic)
1918            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
1919            .build(None)
1920            .await
1921            .expect("Failed to build transaction");
1922    }
1923
1924    #[tokio_test_no_panics]
1925    async fn read_write_transaction_with_exclude_txn_from_change_streams() {
1926        let mut mock = create_session_mock();
1927
1928        mock.expect_begin_transaction().once().returning(|req| {
1929            let req = req.into_inner();
1930            let options = req.options.expect("missing transaction options");
1931            assert!(options.exclude_txn_from_change_streams);
1932
1933            Ok(tonic::Response::new(v1::Transaction {
1934                id: vec![9, 9, 9],
1935                ..Default::default()
1936            }))
1937        });
1938
1939        let (db_client, _server) = setup_db_client(mock).await;
1940
1941        let _tx = ReadWriteTransactionBuilder::new(db_client.clone())
1942            .set_exclude_txn_from_change_streams(true)
1943            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
1944            .build(None)
1945            .await
1946            .expect("Failed to build transaction");
1947    }
1948
1949    #[tokio_test_no_panics]
1950    async fn read_write_transaction_tracks_highest_precommit_token() {
1951        let mut mock = create_session_mock();
1952
1953        mock.expect_begin_transaction().once().returning(|_| {
1954            Ok(tonic::Response::new(v1::Transaction {
1955                id: vec![4, 2],
1956                ..Default::default()
1957            }))
1958        });
1959
1960        // 3 sequential updates returning tokens [seq 2, seq 5, seq 3]
1961        let tokens_iter = vec![2, 5, 3].into_iter();
1962        let counter_mutex = Mutex::new(tokens_iter);
1963
1964        mock.expect_execute_sql().times(3).returning(move |_req| {
1965            let seq = counter_mutex
1966                .lock()
1967                .expect("Failed to lock mutex")
1968                .next()
1969                .expect("Failed to get next token");
1970            Ok(tonic::Response::new(v1::ResultSet {
1971                stats: Some(v1::ResultSetStats {
1972                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
1973                    ..Default::default()
1974                }),
1975                precommit_token: Some(v1::MultiplexedSessionPrecommitToken {
1976                    precommit_token: vec![seq as u8],
1977                    seq_num: seq,
1978                }),
1979                ..Default::default()
1980            }))
1981        });
1982
1983        // Commit should only use the highest token (seq 5)
1984        mock.expect_commit().once().returning(|req| {
1985            let req = req.into_inner();
1986            assert_eq!(
1987                req.precommit_token,
1988                Some(v1::MultiplexedSessionPrecommitToken {
1989                    precommit_token: vec![5],
1990                    seq_num: 5,
1991                })
1992            );
1993            Ok(tonic::Response::new(v1::CommitResponse {
1994                commit_timestamp: Some(prost_types::Timestamp {
1995                    seconds: 12345,
1996                    nanos: 0,
1997                }),
1998                ..Default::default()
1999            }))
2000        });
2001
2002        let (db_client, _server) = setup_db_client(mock).await;
2003        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
2004            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2005            .build(None)
2006            .await
2007            .expect("Failed to build transaction");
2008
2009        for _ in 0..3 {
2010            tx.execute_update("UPDATE Y")
2011                .await
2012                .expect("Failed to execute update");
2013        }
2014        let ts = tx.commit().await.expect("Failed to commit transaction");
2015        assert_eq!(
2016            ts.commit_timestamp
2017                .expect("Commit timestamp should be present")
2018                .seconds(),
2019            12345
2020        );
2021    }
2022
2023    #[tokio_test_no_panics]
2024    async fn read_write_transaction_commit_retry_exactly_once_explicit() -> anyhow::Result<()> {
2025        run_read_write_transaction_commit_retry_exactly_once(BeginTransactionOption::ExplicitBegin)
2026            .await
2027    }
2028
2029    #[tokio_test_no_panics]
2030    async fn read_write_transaction_commit_retry_exactly_once_inline() -> anyhow::Result<()> {
2031        run_read_write_transaction_commit_retry_exactly_once(BeginTransactionOption::InlineBegin)
2032            .await
2033    }
2034
2035    async fn run_read_write_transaction_commit_retry_exactly_once(
2036        begin_transaction_option: BeginTransactionOption,
2037    ) -> anyhow::Result<()> {
2038        let mut mock = create_session_mock();
2039
2040        let transaction_id = vec![7, 7];
2041
2042        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
2043            let id = transaction_id.clone();
2044            mock.expect_begin_transaction().once().returning(move |_| {
2045                Ok(tonic::Response::new(v1::Transaction {
2046                    id: id.clone(),
2047                    ..Default::default()
2048                }))
2049            });
2050        } else {
2051            let id = transaction_id.clone();
2052            mock.expect_execute_sql().once().returning(move |req| {
2053                let req = req.into_inner();
2054                let transaction = req
2055                    .transaction
2056                    .as_ref()
2057                    .expect("transaction options required for inline begin");
2058                let selector = transaction.selector.as_ref().expect("selector required");
2059                assert!(matches!(
2060                    selector,
2061                    v1::transaction_selector::Selector::Begin(_)
2062                ));
2063
2064                Ok(tonic::Response::new(v1::ResultSet {
2065                    metadata: Some(v1::ResultSetMetadata {
2066                        transaction: Some(v1::Transaction {
2067                            id: id.clone(),
2068                            ..Default::default()
2069                        }),
2070                        ..Default::default()
2071                    }),
2072                    stats: Some(v1::ResultSetStats {
2073                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2074                        ..Default::default()
2075                    }),
2076                    ..Default::default()
2077                }))
2078            });
2079        }
2080
2081        let mut seq = mockall::Sequence::new();
2082
2083        // Initial commit returns a retry token (seq 2)
2084        mock.expect_commit()
2085            .once()
2086            .in_sequence(&mut seq)
2087            .returning(|_| {
2088                Ok(tonic::Response::new(v1::CommitResponse {
2089                    commit_timestamp: Some(prost_types::Timestamp {
2090                        seconds: 1000,
2091                        nanos: 0,
2092                    }),
2093                    multiplexed_session_retry: Some(
2094                        v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
2095                            v1::MultiplexedSessionPrecommitToken {
2096                                precommit_token: vec![2],
2097                                seq_num: 2,
2098                            },
2099                        ),
2100                    ),
2101                    ..Default::default()
2102                }))
2103            });
2104
2105        // Retry commit returns another retry token (seq 3).
2106        // The library should not retry multiple times.
2107        mock.expect_commit()
2108            .once()
2109            .in_sequence(&mut seq)
2110            .returning(|req| {
2111                let req = req.into_inner();
2112                assert_eq!(
2113                    req.precommit_token
2114                        .as_ref()
2115                        .expect("Missing precommit token in retry req")
2116                        .seq_num,
2117                    2
2118                );
2119
2120                Ok(tonic::Response::new(v1::CommitResponse {
2121                    commit_timestamp: Some(prost_types::Timestamp {
2122                        seconds: 9999,
2123                        nanos: 0,
2124                    }),
2125                    multiplexed_session_retry: Some(
2126                        v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
2127                            v1::MultiplexedSessionPrecommitToken {
2128                                precommit_token: vec![3],
2129                                seq_num: 3,
2130                            },
2131                        ),
2132                    ),
2133                    ..Default::default()
2134                }))
2135            });
2136
2137        let (db_client, _server) = setup_db_client(mock).await;
2138        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
2139            .with_begin_transaction_option(begin_transaction_option)
2140            .build(None)
2141            .await?;
2142
2143        if begin_transaction_option == BeginTransactionOption::InlineBegin {
2144            tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
2145                .await?;
2146        }
2147
2148        let ts = tx.commit().await.expect("Failed to commit transaction");
2149        assert_eq!(
2150            ts.commit_timestamp
2151                .as_ref()
2152                .expect("timestamp should be present")
2153                .seconds(),
2154            9999
2155        );
2156        Ok(())
2157    }
2158
2159    #[tokio_test_no_panics]
2160    async fn read_write_transaction_commit_with_max_commit_delay_explicit() -> anyhow::Result<()> {
2161        run_read_write_transaction_commit_with_max_commit_delay(
2162            BeginTransactionOption::ExplicitBegin,
2163        )
2164        .await
2165    }
2166
2167    #[tokio_test_no_panics]
2168    async fn read_write_transaction_commit_with_max_commit_delay_inline() -> anyhow::Result<()> {
2169        run_read_write_transaction_commit_with_max_commit_delay(BeginTransactionOption::InlineBegin)
2170            .await
2171    }
2172
2173    async fn run_read_write_transaction_commit_with_max_commit_delay(
2174        begin_transaction_option: BeginTransactionOption,
2175    ) -> anyhow::Result<()> {
2176        let mut mock = create_session_mock();
2177
2178        if begin_transaction_option == BeginTransactionOption::ExplicitBegin {
2179            mock.expect_begin_transaction().once().returning(|_| {
2180                Ok(tonic::Response::new(v1::Transaction {
2181                    id: vec![1, 2, 3],
2182                    ..Default::default()
2183                }))
2184            });
2185        } else {
2186            mock.expect_execute_sql().once().returning(|req| {
2187                let req = req.into_inner();
2188                let transaction = req
2189                    .transaction
2190                    .as_ref()
2191                    .expect("transaction options required for inline begin");
2192                let selector = transaction.selector.as_ref().expect("selector required");
2193                assert!(matches!(
2194                    selector,
2195                    v1::transaction_selector::Selector::Begin(_)
2196                ));
2197
2198                Ok(tonic::Response::new(v1::ResultSet {
2199                    metadata: Some(v1::ResultSetMetadata {
2200                        transaction: Some(v1::Transaction {
2201                            id: vec![1, 2, 3],
2202                            ..Default::default()
2203                        }),
2204                        ..Default::default()
2205                    }),
2206                    stats: Some(v1::ResultSetStats {
2207                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2208                        ..Default::default()
2209                    }),
2210                    ..Default::default()
2211                }))
2212            });
2213        }
2214
2215        mock.expect_commit().once().returning(|req| {
2216            let req = req.into_inner();
2217            assert_eq!(
2218                req.max_commit_delay,
2219                Some(::prost_types::Duration {
2220                    seconds: 0,
2221                    nanos: 200_000_000, // 200ms
2222                })
2223            );
2224            Ok(tonic::Response::new(v1::CommitResponse {
2225                commit_timestamp: Some(prost_types::Timestamp {
2226                    seconds: 123456789,
2227                    nanos: 0,
2228                }),
2229                ..Default::default()
2230            }))
2231        });
2232
2233        let (db_client, _server) = setup_db_client(mock).await;
2234
2235        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
2236            .set_max_commit_delay(Duration::new(0, 200_000_000).unwrap())
2237            .with_begin_transaction_option(begin_transaction_option)
2238            .build(None)
2239            .await
2240            .expect("Failed to build transaction");
2241
2242        if begin_transaction_option == BeginTransactionOption::InlineBegin {
2243            tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
2244                .await?;
2245        }
2246
2247        let ts = tx.commit().await.expect("Failed to commit");
2248        assert_eq!(
2249            ts.commit_timestamp
2250                .expect("Commit timestamp should be present")
2251                .seconds(),
2252            123456789
2253        );
2254        Ok(())
2255    }
2256
2257    #[tokio_test_no_panics]
2258    async fn read_write_transaction_execute_update_fallback() {
2259        let mut mock = create_session_mock();
2260
2261        // 1. First DML attempt fails!
2262        mock.expect_execute_sql().once().returning(|req| {
2263            let req = req.into_inner();
2264            assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1");
2265
2266            let selector = req
2267                .transaction
2268                .expect("missing transaction selector")
2269                .selector
2270                .expect("missing selector");
2271            match selector {
2272                v1::transaction_selector::Selector::Begin(_) => {}
2273                _ => panic!("Expected Selector::Begin"),
2274            }
2275
2276            Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
2277        });
2278
2279        // 2. Client falls back to explicit BeginTransaction!
2280        mock.expect_begin_transaction().once().returning(|_| {
2281            Ok(tonic::Response::new(v1::Transaction {
2282                id: vec![7, 8, 9],
2283                ..Default::default()
2284            }))
2285        });
2286
2287        // 3. Client retries DML with new ID!
2288        mock.expect_execute_sql().once().returning(|req| {
2289            let req = req.into_inner();
2290            assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1");
2291
2292            let selector = req
2293                .transaction
2294                .expect("missing transaction selector")
2295                .selector
2296                .expect("missing selector");
2297            match selector {
2298                v1::transaction_selector::Selector::Id(id) => {
2299                    assert_eq!(id, vec![7, 8, 9]);
2300                }
2301                _ => panic!("Expected Selector::Id"),
2302            }
2303
2304            Ok(tonic::Response::new(v1::ResultSet {
2305                stats: Some(v1::ResultSetStats {
2306                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2307                    ..Default::default()
2308                }),
2309                ..Default::default()
2310            }))
2311        });
2312
2313        let (db_client, _server) = setup_db_client(mock).await;
2314
2315        let tx = ReadWriteTransactionBuilder::new(db_client.clone())
2316            .build(None)
2317            .await
2318            .expect("Failed to build transaction");
2319
2320        let count = tx
2321            .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1")
2322            .await
2323            .expect("Failed to execute update after fallback");
2324        assert_eq!(count, 1);
2325    }
2326
2327    #[tokio_test_no_panics]
2328    async fn read_write_transaction_execute_batch_update_fallback() -> anyhow::Result<()> {
2329        let mut mock = create_session_mock();
2330
2331        // 1. First Batch DML attempt fails!
2332        mock.expect_execute_batch_dml().once().returning(|req| {
2333            let req = req.into_inner();
2334            let selector = req
2335                .transaction
2336                .expect("missing transaction selector")
2337                .selector
2338                .expect("missing selector");
2339            match selector {
2340                v1::transaction_selector::Selector::Begin(_) => {}
2341                _ => panic!("Expected Selector::Begin"),
2342            }
2343
2344            Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
2345        });
2346
2347        // 2. Client falls back to explicit BeginTransaction!
2348        mock.expect_begin_transaction().once().returning(|_| {
2349            Ok(tonic::Response::new(v1::Transaction {
2350                id: vec![4, 5, 6],
2351                ..Default::default()
2352            }))
2353        });
2354
2355        // 3. Client retries Batch DML with new ID!
2356        mock.expect_execute_batch_dml().once().returning(|req| {
2357            let req = req.into_inner();
2358            let selector = req
2359                .transaction
2360                .expect("missing transaction selector")
2361                .selector
2362                .expect("missing selector");
2363            match selector {
2364                v1::transaction_selector::Selector::Id(id) => {
2365                    assert_eq!(id, vec![4, 5, 6]);
2366                }
2367                _ => panic!("Expected Selector::Id"),
2368            }
2369
2370            Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
2371                result_sets: vec![v1::ResultSet {
2372                    stats: Some(v1::ResultSetStats {
2373                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2374                        ..Default::default()
2375                    }),
2376                    ..Default::default()
2377                }],
2378                status: Some(spanner_grpc_mock::google::rpc::Status {
2379                    code: 0,
2380                    message: "OK".into(),
2381                    details: vec![],
2382                }),
2383                ..Default::default()
2384            }))
2385        });
2386
2387        let (db_client, _server) = setup_db_client(mock).await;
2388
2389        let tx = ReadWriteTransactionBuilder::new(db_client)
2390            .build(None)
2391            .await?;
2392
2393        let batch =
2394            BatchDml::builder().add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1");
2395
2396        let counts = tx.execute_batch_update(batch.build()).await?;
2397
2398        assert_eq!(counts, vec![1]);
2399
2400        Ok(())
2401    }
2402
2403    #[tokio_test_no_panics]
2404    async fn leader_aware_routing_enabled_by_default() -> anyhow::Result<()> {
2405        let mut mock = create_session_mock();
2406        mock.expect_begin_transaction().once().returning(|req| {
2407            assert_eq!(
2408                req.metadata()
2409                    .get("x-goog-spanner-route-to-leader")
2410                    .expect("header required")
2411                    .to_str()
2412                    .unwrap(),
2413                "true"
2414            );
2415            Ok(tonic::Response::new(v1::Transaction {
2416                id: vec![1, 2, 3],
2417                ..Default::default()
2418            }))
2419        });
2420        mock.expect_execute_sql().once().returning(|req| {
2421            assert_eq!(
2422                req.metadata()
2423                    .get("x-goog-spanner-route-to-leader")
2424                    .expect("header required")
2425                    .to_str()
2426                    .unwrap(),
2427                "true"
2428            );
2429            Ok(tonic::Response::new(v1::ResultSet {
2430                stats: Some(v1::ResultSetStats {
2431                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2432                    ..Default::default()
2433                }),
2434                ..Default::default()
2435            }))
2436        });
2437        mock.expect_commit().once().returning(|req| {
2438            assert_eq!(
2439                req.metadata()
2440                    .get("x-goog-spanner-route-to-leader")
2441                    .expect("header required")
2442                    .to_str()
2443                    .unwrap(),
2444                "true"
2445            );
2446            Ok(tonic::Response::new(v1::CommitResponse {
2447                commit_timestamp: Some(prost_types::Timestamp {
2448                    seconds: 1,
2449                    nanos: 0,
2450                }),
2451                ..Default::default()
2452            }))
2453        });
2454
2455        let (db_client, _server) = setup_db_client(mock).await;
2456        let tx = ReadWriteTransactionBuilder::new(db_client)
2457            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2458            .build(None)
2459            .await?;
2460        let count = tx.execute_update("UPDATE Users SET active = true").await?;
2461        assert_eq!(count, 1);
2462        let _ = tx.commit().await?;
2463        Ok(())
2464    }
2465
2466    #[tokio_test_no_panics]
2467    async fn leader_aware_routing_disabled() -> anyhow::Result<()> {
2468        use crate::client::Spanner;
2469        use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
2470
2471        let mut mock = create_session_mock();
2472        mock.expect_begin_transaction().once().returning(|req| {
2473            assert!(
2474                req.metadata()
2475                    .get("x-goog-spanner-route-to-leader")
2476                    .is_none()
2477            );
2478            Ok(tonic::Response::new(v1::Transaction {
2479                id: vec![1, 2, 3],
2480                ..Default::default()
2481            }))
2482        });
2483        mock.expect_execute_sql().once().returning(|req| {
2484            assert!(
2485                req.metadata()
2486                    .get("x-goog-spanner-route-to-leader")
2487                    .is_none()
2488            );
2489            Ok(tonic::Response::new(v1::ResultSet {
2490                stats: Some(v1::ResultSetStats {
2491                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2492                    ..Default::default()
2493                }),
2494                ..Default::default()
2495            }))
2496        });
2497        mock.expect_commit().once().returning(|req| {
2498            assert!(
2499                req.metadata()
2500                    .get("x-goog-spanner-route-to-leader")
2501                    .is_none()
2502            );
2503            Ok(tonic::Response::new(v1::CommitResponse {
2504                commit_timestamp: Some(prost_types::Timestamp {
2505                    seconds: 1,
2506                    nanos: 0,
2507                }),
2508                ..Default::default()
2509            }))
2510        });
2511
2512        let (address, _server) = spanner_grpc_mock::start("0.0.0.0:0", mock).await.unwrap();
2513        let spanner = Spanner::builder()
2514            .with_endpoint(address)
2515            .with_credentials(Anonymous::new().build())
2516            .build()
2517            .await?;
2518        let db_client = spanner
2519            .database_client("projects/p/instances/i/databases/d")
2520            .with_leader_aware_routing(false)
2521            .build()
2522            .await?;
2523
2524        let tx = ReadWriteTransactionBuilder::new(db_client)
2525            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2526            .build(None)
2527            .await?;
2528        let count = tx.execute_update("UPDATE Users SET active = true").await?;
2529        assert_eq!(count, 1);
2530        let _ = tx.commit().await?;
2531        Ok(())
2532    }
2533
2534    #[tokio_test_no_panics]
2535    async fn leader_aware_routing_query_in_read_write() -> anyhow::Result<()> {
2536        let mut mock = create_session_mock();
2537        mock.expect_begin_transaction().once().returning(|req| {
2538            assert_eq!(
2539                req.metadata()
2540                    .get("x-goog-spanner-route-to-leader")
2541                    .expect("header required")
2542                    .to_str()
2543                    .unwrap(),
2544                "true"
2545            );
2546            Ok(tonic::Response::new(v1::Transaction {
2547                id: vec![1, 2, 3],
2548                ..Default::default()
2549            }))
2550        });
2551        mock.expect_execute_streaming_sql().once().returning(|req| {
2552            assert_eq!(
2553                req.metadata()
2554                    .get("x-goog-spanner-route-to-leader")
2555                    .expect("header required")
2556                    .to_str()
2557                    .unwrap(),
2558                "true"
2559            );
2560            let stream = adapt([Ok(v1::PartialResultSet {
2561                metadata: Some(v1::ResultSetMetadata {
2562                    row_type: Some(v1::StructType { fields: vec![] }),
2563                    ..Default::default()
2564                }),
2565                ..Default::default()
2566            })]);
2567            Ok(tonic::Response::from(stream))
2568        });
2569
2570        let (db_client, _server) = setup_db_client(mock).await;
2571        let tx = ReadWriteTransactionBuilder::new(db_client)
2572            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2573            .build(None)
2574            .await?;
2575        let _rs = tx
2576            .execute_query(Statement::builder("SELECT 1").build())
2577            .await?;
2578        Ok(())
2579    }
2580
2581    #[tokio_test_no_panics]
2582    async fn leader_aware_routing_merges_custom_headers() -> anyhow::Result<()> {
2583        let mut mock = create_session_mock();
2584        mock.expect_begin_transaction().once().returning(|req| {
2585            assert_eq!(
2586                req.metadata()
2587                    .get("x-goog-spanner-route-to-leader")
2588                    .expect("header required")
2589                    .to_str()
2590                    .unwrap(),
2591                "true"
2592            );
2593            Ok(tonic::Response::new(v1::Transaction {
2594                id: vec![1, 2, 3],
2595                ..Default::default()
2596            }))
2597        });
2598        mock.expect_execute_sql().once().returning(|req| {
2599            assert_eq!(
2600                req.metadata()
2601                    .get("x-goog-spanner-route-to-leader")
2602                    .expect("header required")
2603                    .to_str()
2604                    .unwrap(),
2605                "true"
2606            );
2607            assert_eq!(
2608                req.metadata()
2609                    .get("x-custom-user-header")
2610                    .expect("custom header required")
2611                    .to_str()
2612                    .unwrap(),
2613                "custom-value"
2614            );
2615            Ok(tonic::Response::new(v1::ResultSet {
2616                stats: Some(v1::ResultSetStats {
2617                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2618                    ..Default::default()
2619                }),
2620                ..Default::default()
2621            }))
2622        });
2623
2624        let (db_client, _server) = setup_db_client(mock).await;
2625        let tx = ReadWriteTransactionBuilder::new(db_client)
2626            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
2627            .build(None)
2628            .await?;
2629
2630        let mut custom_headers = http::HeaderMap::new();
2631        custom_headers.insert(
2632            "x-custom-user-header",
2633            http::HeaderValue::from_static("custom-value"),
2634        );
2635
2636        let mut stmt = Statement::builder("UPDATE Users SET active = true").build();
2637        let opts = stmt.gax_options().clone().insert_extension(custom_headers);
2638        stmt = stmt.with_gax_options(opts);
2639
2640        let count = tx.execute_update(stmt).await?;
2641        assert_eq!(count, 1);
2642        Ok(())
2643    }
2644
2645    #[tokio_test_no_panics]
2646    async fn leader_aware_routing_implicit_begin_fallback() -> anyhow::Result<()> {
2647        let mut mock = create_session_mock();
2648
2649        // 1. Initial execute_sql attempts implicit begin and transiently fails.
2650        // It must include the LAR header because it is a modifying operation.
2651        mock.expect_execute_sql().once().returning(|req| {
2652            assert_eq!(
2653                req.metadata()
2654                    .get("x-goog-spanner-route-to-leader")
2655                    .expect("header required on initial execute")
2656                    .to_str()
2657                    .unwrap(),
2658                "true"
2659            );
2660            Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
2661        });
2662
2663        // 2. Client fallback mechanism invokes begin_explicitly_if_not_started.
2664        // This should also include the LAR header.
2665        mock.expect_begin_transaction().once().returning(|req| {
2666            assert_eq!(
2667                req.metadata()
2668                    .get("x-goog-spanner-route-to-leader")
2669                    .expect("header required on explicit begin fallback")
2670                    .to_str()
2671                    .unwrap(),
2672                "true"
2673            );
2674            Ok(tonic::Response::new(v1::Transaction {
2675                id: vec![42],
2676                ..Default::default()
2677            }))
2678        });
2679
2680        // 3. Retried execute_sql with fixed ID.
2681        mock.expect_execute_sql().once().returning(|req| {
2682            assert_eq!(
2683                req.metadata()
2684                    .get("x-goog-spanner-route-to-leader")
2685                    .expect("header required on retried execute")
2686                    .to_str()
2687                    .unwrap(),
2688                "true"
2689            );
2690            Ok(tonic::Response::new(v1::ResultSet {
2691                stats: Some(v1::ResultSetStats {
2692                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2693                    ..Default::default()
2694                }),
2695                ..Default::default()
2696            }))
2697        });
2698
2699        let (db_client, _server) = setup_db_client(mock).await;
2700        // Construct transaction using implicit begin (explicit_begin = false)
2701        let tx = ReadWriteTransactionBuilder::new(db_client)
2702            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2703            .build(None)
2704            .await?;
2705
2706        let count = tx.execute_update("UPDATE Users SET active = true").await?;
2707        assert_eq!(count, 1);
2708        Ok(())
2709    }
2710
2711    #[tokio_test_no_panics]
2712    async fn read_write_transaction_fallback_forwards_transaction_tag() -> anyhow::Result<()> {
2713        let mut mock = create_session_mock();
2714
2715        // 1. Initial execute_sql attempts inline begin and transiently fails.
2716        // The initial request includes the transaction tag.
2717        mock.expect_execute_sql().once().returning(|req| {
2718            let req = req.into_inner();
2719            assert_eq!(
2720                req.request_options
2721                    .as_ref()
2722                    .expect("Missing request_options on initial RPC")
2723                    .transaction_tag,
2724                "fallback-test-tag"
2725            );
2726            Err(tonic::Status::new(tonic::Code::Internal, "internal error"))
2727        });
2728
2729        // 2. Client fallback mechanism invokes explicit begin.
2730        // This should include the transaction tag.
2731        mock.expect_begin_transaction().once().returning(|req| {
2732            let req = req.into_inner();
2733            assert_eq!(
2734                req.request_options
2735                    .as_ref()
2736                    .expect("Missing request_options on explicit begin fallback")
2737                    .transaction_tag,
2738                "fallback-test-tag"
2739            );
2740            Ok(tonic::Response::new(v1::Transaction {
2741                id: vec![7, 7, 7],
2742                ..Default::default()
2743            }))
2744        });
2745
2746        // 3. Retried execute_sql with the explicit transaction ID.
2747        mock.expect_execute_sql().once().returning(|req| {
2748            let req = req.into_inner();
2749            assert_eq!(
2750                req.request_options
2751                    .as_ref()
2752                    .expect("Missing request_options on retried RPC")
2753                    .transaction_tag,
2754                "fallback-test-tag"
2755            );
2756            Ok(tonic::Response::new(v1::ResultSet {
2757                stats: Some(v1::ResultSetStats {
2758                    row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2759                    ..Default::default()
2760                }),
2761                ..Default::default()
2762            }))
2763        });
2764
2765        let (db_client, _server) = setup_db_client(mock).await;
2766        let tx = ReadWriteTransactionBuilder::new(db_client)
2767            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2768            .set_transaction_tag("fallback-test-tag")
2769            .build(None)
2770            .await?;
2771
2772        let count = tx.execute_update("UPDATE Users SET active = true").await?;
2773        assert_eq!(count, 1);
2774        Ok(())
2775    }
2776
2777    #[tokio_test_no_panics]
2778    async fn read_write_transaction_mutation_only_inline_begin_commit() -> anyhow::Result<()> {
2779        let mut mock = create_session_mock();
2780
2781        // Since no statement was executed, commit will detect NotStarted and call begin_explicitly
2782        mock.expect_begin_transaction().once().returning(|req| {
2783            let req = req.into_inner();
2784            assert_eq!(
2785                req.session,
2786                "projects/p/instances/i/databases/d/sessions/123"
2787            );
2788            assert!(
2789                req.mutation_key.is_some(),
2790                "mutation_key should be populated when starting transaction at commit time"
2791            );
2792            let key = req
2793                .mutation_key
2794                .as_ref()
2795                .expect("mutation_key is populated");
2796            assert!(
2797                key.operation.is_some(),
2798                "mutation_key should have an operation"
2799            );
2800            Ok(tonic::Response::new(v1::Transaction {
2801                id: vec![7, 7, 7],
2802                ..Default::default()
2803            }))
2804        });
2805
2806        mock.expect_commit().once().returning(|req| {
2807            let req = req.into_inner();
2808            assert_eq!(
2809                req.session,
2810                "projects/p/instances/i/databases/d/sessions/123"
2811            );
2812            assert_eq!(
2813                req.transaction,
2814                Some(v1::commit_request::Transaction::TransactionId(vec![
2815                    7, 7, 7
2816                ]))
2817            );
2818            assert_eq!(req.mutations.len(), 1);
2819            Ok(tonic::Response::new(v1::CommitResponse {
2820                commit_timestamp: Some(prost_types::Timestamp {
2821                    seconds: 5000,
2822                    nanos: 0,
2823                }),
2824                ..Default::default()
2825            }))
2826        });
2827
2828        let (db_client, _server) = setup_db_client(mock).await;
2829        let tx = ReadWriteTransactionBuilder::new(db_client)
2830            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
2831            .build(None)
2832            .await
2833            .expect("Transaction build should succeed");
2834
2835        let mutation = Mutation::new_insert_builder("Users")
2836            .set("UserId")
2837            .to(&1)
2838            .build();
2839        tx.buffer([mutation]).expect("Buffer should succeed");
2840
2841        let response = tx.commit().await.expect("Commit should succeed");
2842        assert_eq!(
2843            response
2844                .commit_timestamp
2845                .expect("timestamp present")
2846                .seconds(),
2847            5000
2848        );
2849        Ok(())
2850    }
2851
2852    #[tokio_test_no_panics]
2853    async fn transaction_runner_batch_dml_aborted_retry() -> anyhow::Result<()> {
2854        let mut mock = create_session_mock();
2855        let mut sequence = mockall::Sequence::new();
2856
2857        // 1. First attempt: Inline begin, execute_batch_dml returns OK with status Aborted.
2858        mock.expect_execute_batch_dml()
2859            .times(1)
2860            .in_sequence(&mut sequence)
2861            .returning(|req| {
2862                let req = req.into_inner();
2863                assert!(matches!(
2864                    req.transaction.unwrap().selector.unwrap(),
2865                    v1::transaction_selector::Selector::Begin(_)
2866                ));
2867                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
2868                    result_sets: vec![],
2869                    status: Some(spanner_grpc_mock::google::rpc::Status {
2870                        code: tonic::Code::Aborted as i32,
2871                        message: "concurrent lock abort".into(),
2872                        details: vec![],
2873                    }),
2874                    ..Default::default()
2875                }))
2876            });
2877
2878        // 2. TransactionRunner catches Aborted error and initiates attempt 2.
2879        mock.expect_execute_batch_dml()
2880            .times(1)
2881            .in_sequence(&mut sequence)
2882            .returning(|req| {
2883                let req = req.into_inner();
2884                assert!(matches!(
2885                    req.transaction.unwrap().selector.unwrap(),
2886                    v1::transaction_selector::Selector::Begin(_)
2887                ));
2888                Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse {
2889                    result_sets: vec![v1::ResultSet {
2890                        metadata: Some(v1::ResultSetMetadata {
2891                            transaction: Some(v1::Transaction {
2892                                id: vec![9, 9, 9],
2893                                ..Default::default()
2894                            }),
2895                            ..Default::default()
2896                        }),
2897                        stats: Some(v1::ResultSetStats {
2898                            row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2899                            ..Default::default()
2900                        }),
2901                        ..Default::default()
2902                    }],
2903                    status: Some(spanner_grpc_mock::google::rpc::Status {
2904                        code: 0,
2905                        message: "OK".into(),
2906                        details: vec![],
2907                    }),
2908                    ..Default::default()
2909                }))
2910            });
2911
2912        mock.expect_commit().once().returning(|req| {
2913            let req = req.into_inner();
2914            assert_eq!(
2915                req.transaction,
2916                Some(v1::commit_request::Transaction::TransactionId(vec![
2917                    9, 9, 9
2918                ]))
2919            );
2920            Ok(tonic::Response::new(v1::CommitResponse {
2921                commit_timestamp: Some(prost_types::Timestamp {
2922                    seconds: 999,
2923                    nanos: 0,
2924                }),
2925                ..Default::default()
2926            }))
2927        });
2928
2929        let (db_client, _server) = setup_db_client(mock).await;
2930
2931        let runner = db_client
2932            .read_write_transaction()
2933            .with_retry_policy(
2934                BasicTransactionRetryPolicy::new()
2935                    .with_max_attempts(3)
2936                    .with_total_timeout(std::time::Duration::from_secs(5)),
2937            )
2938            .build()
2939            .await?;
2940
2941        runner
2942            .run(async |tx| {
2943                let batch = BatchDml::builder()
2944                    .add_statement("UPDATE Users SET active = true WHERE id = 1");
2945                tx.execute_batch_update(batch.build()).await?;
2946                Ok(())
2947            })
2948            .await?;
2949
2950        Ok(())
2951    }
2952
2953    #[tokio_test_no_panics]
2954    async fn read_write_transaction_first_dml_aborted_and_continue_success() -> anyhow::Result<()> {
2955        let mut mock = create_session_mock();
2956        let mut sequence = mockall::Sequence::new();
2957
2958        // 1. First statement (execute_sql) attempts inline begin and is aborted by Spanner
2959        mock.expect_execute_sql()
2960            .times(1)
2961            .in_sequence(&mut sequence)
2962            .returning(|req| {
2963                let req = req.into_inner();
2964                assert!(matches!(
2965                    req.transaction.unwrap().selector.unwrap(),
2966                    v1::transaction_selector::Selector::Begin(_)
2967                ));
2968                Err(tonic::Status::new(
2969                    tonic::Code::Aborted,
2970                    "concurrent lock abort",
2971                ))
2972            });
2973
2974        // 2. Second statement (execute_sql) sees NotStarted and attempts inline begin again
2975        mock.expect_execute_sql()
2976            .times(1)
2977            .in_sequence(&mut sequence)
2978            .returning(|req| {
2979                let req = req.into_inner();
2980                assert!(matches!(
2981                    req.transaction.unwrap().selector.unwrap(),
2982                    v1::transaction_selector::Selector::Begin(_)
2983                ));
2984                Ok(tonic::Response::new(v1::ResultSet {
2985                    metadata: Some(v1::ResultSetMetadata {
2986                        transaction: Some(v1::Transaction {
2987                            id: vec![9, 9, 9],
2988                            ..Default::default()
2989                        }),
2990                        ..Default::default()
2991                    }),
2992                    stats: Some(v1::ResultSetStats {
2993                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
2994                        ..Default::default()
2995                    }),
2996                    ..Default::default()
2997                }))
2998            });
2999
3000        // 3. Commit called with the transaction ID returned in step 2
3001        mock.expect_commit().once().returning(|req| {
3002            let req = req.into_inner();
3003            assert_eq!(
3004                req.transaction,
3005                Some(v1::commit_request::Transaction::TransactionId(vec![
3006                    9, 9, 9
3007                ]))
3008            );
3009            Ok(tonic::Response::new(v1::CommitResponse {
3010                commit_timestamp: Some(prost_types::Timestamp {
3011                    seconds: 999,
3012                    nanos: 0,
3013                }),
3014                ..Default::default()
3015            }))
3016        });
3017
3018        let (db_client, _server) = setup_db_client(mock).await;
3019
3020        let runner = db_client
3021            .read_write_transaction()
3022            .with_retry_policy(
3023                BasicTransactionRetryPolicy::new()
3024                    .with_max_attempts(1)
3025                    .with_total_timeout(std::time::Duration::from_secs(5)),
3026            )
3027            .build()
3028            .await?;
3029
3030        runner
3031            .run(async |tx| {
3032                // 1. First statement fails with Aborted. We catch it and continue.
3033                let res = tx
3034                    .execute_update("UPDATE Users SET active = true WHERE id = 1")
3035                    .await;
3036                assert!(res.is_err(), "First statement must return error");
3037                assert!(is_aborted(&res.unwrap_err()), "Error must be Aborted");
3038
3039                // 2. Second statement continues. Without the fix, this would block/deadlock forever.
3040                let count = tx
3041                    .execute_update("UPDATE Users SET active = true WHERE id = 2")
3042                    .await?;
3043                assert_eq!(count, 1);
3044                Ok(())
3045            })
3046            .await?;
3047
3048        Ok(())
3049    }
3050
3051    #[tokio_test_no_panics]
3052    async fn read_write_transaction_first_batch_dml_aborted_and_continue_success()
3053    -> anyhow::Result<()> {
3054        let mut mock = create_session_mock();
3055        let mut sequence = mockall::Sequence::new();
3056
3057        // 1. First statement (execute_batch_dml) attempts inline begin and is aborted by Spanner
3058        mock.expect_execute_batch_dml()
3059            .times(1)
3060            .in_sequence(&mut sequence)
3061            .returning(|req| {
3062                let req = req.into_inner();
3063                assert!(matches!(
3064                    req.transaction.unwrap().selector.unwrap(),
3065                    v1::transaction_selector::Selector::Begin(_)
3066                ));
3067                Err(tonic::Status::new(
3068                    tonic::Code::Aborted,
3069                    "concurrent lock abort",
3070                ))
3071            });
3072
3073        // 2. Second statement (execute_sql) sees NotStarted and attempts inline begin again
3074        mock.expect_execute_sql()
3075            .times(1)
3076            .in_sequence(&mut sequence)
3077            .returning(|req| {
3078                let req = req.into_inner();
3079                assert!(matches!(
3080                    req.transaction.unwrap().selector.unwrap(),
3081                    v1::transaction_selector::Selector::Begin(_)
3082                ));
3083                Ok(tonic::Response::new(v1::ResultSet {
3084                    metadata: Some(v1::ResultSetMetadata {
3085                        transaction: Some(v1::Transaction {
3086                            id: vec![9, 9, 9],
3087                            ..Default::default()
3088                        }),
3089                        ..Default::default()
3090                    }),
3091                    stats: Some(v1::ResultSetStats {
3092                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
3093                        ..Default::default()
3094                    }),
3095                    ..Default::default()
3096                }))
3097            });
3098
3099        // 3. Commit called with the transaction ID returned in step 2
3100        mock.expect_commit().once().returning(|req| {
3101            let req = req.into_inner();
3102            assert_eq!(
3103                req.transaction,
3104                Some(v1::commit_request::Transaction::TransactionId(vec![
3105                    9, 9, 9
3106                ]))
3107            );
3108            Ok(tonic::Response::new(v1::CommitResponse {
3109                commit_timestamp: Some(prost_types::Timestamp {
3110                    seconds: 999,
3111                    nanos: 0,
3112                }),
3113                ..Default::default()
3114            }))
3115        });
3116
3117        let (db_client, _server) = setup_db_client(mock).await;
3118
3119        let runner = db_client
3120            .read_write_transaction()
3121            .with_retry_policy(
3122                BasicTransactionRetryPolicy::new()
3123                    .with_max_attempts(1)
3124                    .with_total_timeout(std::time::Duration::from_secs(5)),
3125            )
3126            .build()
3127            .await?;
3128
3129        runner
3130            .run(async |tx| {
3131                // 1. First statement (Batch DML) fails with Aborted. We catch it and continue.
3132                let batch = BatchDml::builder()
3133                    .add_statement("UPDATE Users SET active = true WHERE id = 1");
3134                let res = tx.execute_batch_update(batch.build()).await;
3135                assert!(res.is_err(), "First statement must return error");
3136                assert!(is_aborted(&res.unwrap_err()), "Error must be Aborted");
3137
3138                // 2. Second statement continues. Without the fix, this would block/deadlock forever.
3139                let count = tx
3140                    .execute_update("UPDATE Users SET active = true WHERE id = 2")
3141                    .await?;
3142                assert_eq!(count, 1);
3143                Ok(())
3144            })
3145            .await?;
3146
3147        Ok(())
3148    }
3149
3150    fn parse_grpc_timeout(metadata: &MetadataMap) -> Option<StdDuration> {
3151        let timeout_header = metadata.get("grpc-timeout")?.to_str().ok()?;
3152        let numeric_part: String = timeout_header
3153            .chars()
3154            .take_while(|c| c.is_ascii_digit())
3155            .collect();
3156        let value = numeric_part.parse::<u64>().ok()?;
3157        let unit = timeout_header.trim_start_matches(&numeric_part);
3158        let duration = match unit {
3159            "u" => StdDuration::from_micros(value),
3160            "m" => StdDuration::from_millis(value),
3161            "S" => StdDuration::from_secs(value),
3162            "M" => StdDuration::from_secs(value * 60),
3163            "H" => StdDuration::from_secs(value * 3600),
3164            _ => return None,
3165        };
3166        Some(duration)
3167    }
3168
3169    #[tokio_test_no_panics]
3170    async fn read_write_transaction_lazy_begin_fallback_never_retry() -> anyhow::Result<()> {
3171        let mut mock = create_session_mock();
3172        let mut sequence = mockall::Sequence::new();
3173
3174        // 1. First statement execution uses inline-begin and fails with Unavailable (transient error)
3175        mock.expect_execute_sql()
3176            .once()
3177            .in_sequence(&mut sequence)
3178            .withf(|req| {
3179                matches!(
3180                    req.get_ref()
3181                        .transaction
3182                        .as_ref()
3183                        .and_then(|t| t.selector.as_ref()),
3184                    Some(v1::transaction_selector::Selector::Begin(_))
3185                )
3186            })
3187            .returning(move |_req| Err(tonic::Status::unavailable("transient error")));
3188
3189        // 2. Fallback explicit BeginTransaction is executed exactly once and fails (because we configure NeverRetry)
3190        mock.expect_begin_transaction()
3191            .once()
3192            .in_sequence(&mut sequence)
3193            .returning(move |_req| Err(tonic::Status::unavailable("transient error")));
3194
3195        let (db_client, _server) = setup_db_client(mock).await;
3196
3197        let runner = db_client
3198            .read_write_transaction()
3199            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
3200            .with_begin_retry_policy(NeverRetry)
3201            .build()
3202            .await?;
3203
3204        let res = runner
3205            .run(async |tx| {
3206                let mut stmt_opts = crate::RequestOptions::default();
3207                stmt_opts.set_retry_policy(NeverRetry);
3208                let stmt = Statement::builder("UPDATE Users SET active = true WHERE id = 1")
3209                    .build()
3210                    .with_gax_options(stmt_opts);
3211                let _count = tx.execute_update(stmt).await?;
3212                Ok(())
3213            })
3214            .await;
3215
3216        assert!(
3217            res.is_err(),
3218            "Should fail immediately because NeverRetry aborted retries of explicit begin"
3219        );
3220        let err = res.unwrap_err();
3221        assert_eq!(err.status().map(|s| s.code), Some(Code::Unavailable));
3222
3223        Ok(())
3224    }
3225
3226    #[tokio_test_no_panics]
3227    async fn read_write_transaction_commit_under_deadline_delegates_to_custom_retry_policy()
3228    -> anyhow::Result<()> {
3229        let mut mock = create_session_mock();
3230        mock.expect_begin_transaction().once().returning(|_| {
3231            Ok(tonic::Response::new(v1::Transaction {
3232                id: vec![8, 8, 8],
3233                ..Default::default()
3234            }))
3235        });
3236
3237        // Commit fails with Unavailable. Since we use NeverRetry, it must fail immediately without retry.
3238        mock.expect_commit()
3239            .once()
3240            .returning(|_| Err(tonic::Status::unavailable("transient error")));
3241
3242        let (db_client, _server) = setup_db_client(mock).await;
3243
3244        let runner = db_client
3245            .read_write_transaction()
3246            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
3247            .with_commit_retry_policy(NeverRetry)
3248            .with_transaction_timeout(StdDuration::from_secs(5))
3249            .build()
3250            .await?;
3251
3252        let res = runner.run(async |_tx| Ok(())).await;
3253
3254        assert!(
3255            res.is_err(),
3256            "Should fail because NeverRetry aborted retries"
3257        );
3258        let err = res.unwrap_err();
3259        assert_eq!(
3260            err.status().map(|s| s.code),
3261            Some(Code::Unavailable),
3262            "Error code should be Unavailable"
3263        );
3264        Ok(())
3265    }
3266
3267    #[tokio_test_no_panics]
3268    async fn read_write_transaction_commit_timeout_combination() -> anyhow::Result<()> {
3269        let mut mock = create_session_mock();
3270        mock.expect_begin_transaction().once().returning(|_| {
3271            Ok(tonic::Response::new(v1::Transaction {
3272                id: vec![8, 8, 8],
3273                ..Default::default()
3274            }))
3275        });
3276
3277        // Assert that the commit attempt timeout of 2 seconds propagates as the gRPC timeout header metadata (approx 2000m/2000000u).
3278        mock.expect_commit()
3279            .once()
3280            .withf(|req| {
3281                let duration =
3282                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
3283                assert_eq!(
3284                    duration,
3285                    StdDuration::from_secs(2),
3286                    "Timeout duration should be exactly 2 seconds"
3287                );
3288                true
3289            })
3290            .returning(|_| {
3291                Ok(tonic::Response::new(v1::CommitResponse {
3292                    commit_timestamp: Some(Timestamp {
3293                        seconds: 999,
3294                        nanos: 0,
3295                    }),
3296                    ..Default::default()
3297                }))
3298            });
3299
3300        let (db_client, _server) = setup_db_client(mock).await;
3301
3302        let runner = db_client
3303            .read_write_transaction()
3304            .with_begin_transaction_option(BeginTransactionOption::ExplicitBegin)
3305            .with_commit_attempt_timeout(StdDuration::from_secs(2))
3306            .with_transaction_timeout(StdDuration::from_secs(10))
3307            .build()
3308            .await?;
3309
3310        let res = runner.run(async |_tx| Ok(())).await?;
3311
3312        assert!(res.commit_response.commit_timestamp.is_some());
3313        Ok(())
3314    }
3315
3316    #[tokio_test_no_panics]
3317    async fn read_write_transaction_fallback_begin_under_deadline() -> anyhow::Result<()> {
3318        let mut mock = create_session_mock();
3319        let mut sequence = mockall::Sequence::new();
3320
3321        // 1. First statement execution fails with Unavailable (transient error)
3322        mock.expect_execute_sql()
3323            .once()
3324            .in_sequence(&mut sequence)
3325            .withf(|req| {
3326                matches!(
3327                    req.get_ref()
3328                        .transaction
3329                        .as_ref()
3330                        .and_then(|t| t.selector.as_ref()),
3331                    Some(v1::transaction_selector::Selector::Begin(_))
3332                )
3333            })
3334            .returning(move |_req| Err(tonic::Status::unavailable("transient error")));
3335
3336        // 2. Fallback explicit BeginTransaction is executed and sets attempt timeout based on remaining transaction deadline (approx 5 seconds).
3337        mock.expect_begin_transaction()
3338            .once()
3339            .in_sequence(&mut sequence)
3340            .withf(|req| {
3341                let duration =
3342                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
3343                assert!(
3344                    duration >= StdDuration::from_millis(4000)
3345                        && duration <= StdDuration::from_millis(6000),
3346                    "Fallback begin timeout is wrong: {:?}",
3347                    duration
3348                );
3349                true
3350            })
3351            .returning(move |_req| {
3352                Ok(tonic::Response::new(v1::Transaction {
3353                    id: vec![42],
3354                    ..Default::default()
3355                }))
3356            });
3357
3358        // 3. Statement retry succeeds
3359        mock.expect_execute_sql()
3360            .once()
3361            .in_sequence(&mut sequence)
3362            .withf(|req| {
3363                matches!(
3364                    req.get_ref()
3365                        .transaction
3366                        .as_ref()
3367                        .and_then(|t| t.selector.as_ref()),
3368                    Some(v1::transaction_selector::Selector::Id(_))
3369                )
3370            })
3371            .returning(move |_req| {
3372                Ok(tonic::Response::new(v1::ResultSet {
3373                    metadata: Some(v1::ResultSetMetadata {
3374                        transaction: Some(v1::Transaction {
3375                            id: vec![42],
3376                            ..Default::default()
3377                        }),
3378                        ..Default::default()
3379                    }),
3380                    stats: Some(v1::ResultSetStats {
3381                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)),
3382                        ..Default::default()
3383                    }),
3384                    ..Default::default()
3385                }))
3386            });
3387
3388        // 4. Commit succeeds
3389        mock.expect_commit()
3390            .once()
3391            .in_sequence(&mut sequence)
3392            .returning(move |_req| {
3393                Ok(tonic::Response::new(v1::CommitResponse {
3394                    commit_timestamp: Some(Timestamp {
3395                        seconds: 1234,
3396                        nanos: 0,
3397                    }),
3398                    ..Default::default()
3399                }))
3400            });
3401
3402        let (db_client, _server) = setup_db_client(mock).await;
3403
3404        let runner = db_client
3405            .read_write_transaction()
3406            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
3407            .with_transaction_timeout(StdDuration::from_secs(5))
3408            .build()
3409            .await?;
3410
3411        let res = runner
3412            .run(async |tx| {
3413                let mut query_opts = crate::RequestOptions::default();
3414                query_opts.set_retry_policy(NeverRetry);
3415                let stmt = Statement::builder("UPDATE Users SET active = true WHERE id = 1")
3416                    .build()
3417                    .with_gax_options(query_opts);
3418                let count = tx.execute_update(stmt).await?;
3419                assert_eq!(count, 1);
3420                Ok(())
3421            })
3422            .await?;
3423
3424        assert!(res.commit_response.commit_timestamp.is_some());
3425        Ok(())
3426    }
3427
3428    #[tokio_test_no_panics]
3429    async fn read_write_transaction_commit_fallback_begin_under_deadline() -> anyhow::Result<()> {
3430        let mut mock = create_session_mock();
3431        let mut sequence = mockall::Sequence::new();
3432
3433        // 1. Transaction was never started (empty runner block), so commit falls back to explicit BeginTransaction.
3434        // Assert fallback explicit BeginTransaction sets timeout based on remaining transaction deadline (approx 5 seconds).
3435        mock.expect_begin_transaction()
3436            .once()
3437            .in_sequence(&mut sequence)
3438            .withf(|req| {
3439                let duration =
3440                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
3441                assert!(
3442                    duration >= StdDuration::from_millis(4000)
3443                        && duration <= StdDuration::from_millis(6000),
3444                    "Fallback begin timeout inside commit is wrong: {:?}",
3445                    duration
3446                );
3447                true
3448            })
3449            .returning(move |_req| {
3450                Ok(tonic::Response::new(v1::Transaction {
3451                    id: vec![42],
3452                    ..Default::default()
3453                }))
3454            });
3455
3456        // 2. Commit succeeds
3457        mock.expect_commit()
3458            .once()
3459            .in_sequence(&mut sequence)
3460            .returning(move |_req| {
3461                Ok(tonic::Response::new(v1::CommitResponse {
3462                    commit_timestamp: Some(Timestamp {
3463                        seconds: 5678,
3464                        nanos: 0,
3465                    }),
3466                    ..Default::default()
3467                }))
3468            });
3469
3470        let (db_client, _server) = setup_db_client(mock).await;
3471
3472        let runner = db_client
3473            .read_write_transaction()
3474            .with_begin_transaction_option(BeginTransactionOption::InlineBegin)
3475            .with_transaction_timeout(StdDuration::from_secs(5))
3476            .build()
3477            .await?;
3478
3479        let res = runner.run(async |_tx| Ok(())).await?;
3480
3481        assert!(res.commit_response.commit_timestamp.is_some());
3482        Ok(())
3483    }
3484
3485    #[test]
3486    fn test_amend_gax_options() {
3487        // Case 1: No Deadline, LAR disabled
3488        let mut options = RequestOptions::default();
3489        options.set_attempt_timeout(StdDuration::from_secs(4));
3490        amend_gax_options(false, None, &mut options);
3491        assert_eq!(*options.attempt_timeout(), Some(StdDuration::from_secs(4)));
3492        assert!(options.retry_policy().is_none());
3493
3494        // Case 2: No Deadline, LAR enabled
3495        let mut options = RequestOptions::default();
3496        amend_gax_options(true, None, &mut options);
3497        // Verify LAR extension is added
3498        let headers = options
3499            .get_extension::<HeaderMap>()
3500            .expect("HeaderMap extension missing");
3501        assert_eq!(
3502            headers
3503                .get("x-goog-spanner-route-to-leader")
3504                .unwrap()
3505                .to_str()
3506                .unwrap(),
3507            "true"
3508        );
3509
3510        // Case 3: Deadline present, no custom timeout.
3511        // Since Instant::now() is called inside amend_gax_options slightly after the test's
3512        // Instant::now() call, the remaining time will be slightly less than 5 seconds.
3513        // Therefore, we assert that it falls within a very close range.
3514        let mut options = RequestOptions::default();
3515        let deadline = Instant::now() + StdDuration::from_secs(5);
3516        amend_gax_options(false, Some(deadline), &mut options);
3517        let timeout = options.attempt_timeout().expect("attempt timeout missing");
3518        assert!(
3519            timeout >= StdDuration::from_millis(4500) && timeout <= StdDuration::from_millis(5500)
3520        );
3521        assert!(
3522            options.retry_policy().is_some(),
3523            "retry policy should be wrapped"
3524        );
3525
3526        // Case 4: Deadline present, custom timeout shorter than deadline.
3527        // Since custom timeout is 2s and remaining deadline is 10s, it does not depend
3528        // on Time/Instant and must be exactly 2s.
3529        let mut options = RequestOptions::default();
3530        options.set_attempt_timeout(StdDuration::from_secs(2));
3531        let deadline = Instant::now() + StdDuration::from_secs(10);
3532        amend_gax_options(false, Some(deadline), &mut options);
3533        assert_eq!(*options.attempt_timeout(), Some(StdDuration::from_secs(2)));
3534
3535        // Case 5: Deadline present, custom timeout longer than deadline.
3536        // The remaining deadline (approx 2 seconds) is shorter than custom timeout (10s).
3537        // Due to slight time passing, remaining will be slightly less than 2 seconds.
3538        let mut options = RequestOptions::default();
3539        options.set_attempt_timeout(StdDuration::from_secs(10));
3540        let deadline = Instant::now() + StdDuration::from_secs(2);
3541        amend_gax_options(false, Some(deadline), &mut options);
3542        let timeout = options.attempt_timeout().expect("attempt timeout missing");
3543        assert!(
3544            timeout >= StdDuration::from_millis(1500) && timeout <= StdDuration::from_millis(2500)
3545        );
3546    }
3547
3548    #[test]
3549    fn test_transaction_bounded_retry_policy_throttle_delegation() {
3550        #[derive(Debug)]
3551        struct ThrottleTestPolicy;
3552        impl RetryPolicy for ThrottleTestPolicy {
3553            fn on_error(&self, _state: &RetryState, error: GaxError) -> RetryResult {
3554                RetryResult::Continue(error)
3555            }
3556            fn on_throttle(&self, _state: &RetryState, error: GaxError) -> ThrottleResult {
3557                ThrottleResult::Exhausted(error)
3558            }
3559        }
3560
3561        let inner = Arc::new(ThrottleTestPolicy);
3562        let deadline = Instant::now() + StdDuration::from_secs(10);
3563        let bounded = TransactionBoundedRetryPolicy { inner, deadline };
3564
3565        let state = RetryState::new(true);
3566        let status = Status::default()
3567            .set_code(Code::Unavailable)
3568            .set_message("error");
3569        let error = GaxError::service(status);
3570
3571        let res = bounded.on_throttle(&state, error);
3572        assert!(matches!(res, ThrottleResult::Exhausted(_)));
3573    }
3574
3575    #[test]
3576    fn test_transaction_bounded_retry_policy_remaining_time_capping() {
3577        #[derive(Debug)]
3578        struct RemainingTimeTestPolicy {
3579            timeout: Option<StdDuration>,
3580        }
3581        impl RetryPolicy for RemainingTimeTestPolicy {
3582            fn on_error(&self, _state: &RetryState, error: GaxError) -> RetryResult {
3583                RetryResult::Continue(error)
3584            }
3585            fn remaining_time(&self, _state: &RetryState) -> Option<StdDuration> {
3586                self.timeout
3587            }
3588        }
3589
3590        let state = RetryState::new(true);
3591
3592        // Case A: Inner policy timeout (3s) is shorter than remaining transaction deadline (approx 10s)
3593        let inner = Arc::new(RemainingTimeTestPolicy {
3594            timeout: Some(StdDuration::from_secs(3)),
3595        });
3596        let deadline = Instant::now() + StdDuration::from_secs(10);
3597        let bounded = TransactionBoundedRetryPolicy { inner, deadline };
3598        let remaining = bounded.remaining_time(&state).expect("remaining time");
3599        assert!(
3600            remaining >= StdDuration::from_millis(2500)
3601                && remaining <= StdDuration::from_millis(3500)
3602        );
3603
3604        // Case B: Transaction deadline (approx 2s) is shorter than inner policy timeout (10s)
3605        let inner = Arc::new(RemainingTimeTestPolicy {
3606            timeout: Some(StdDuration::from_secs(10)),
3607        });
3608        let deadline = Instant::now() + StdDuration::from_secs(2);
3609        let bounded = TransactionBoundedRetryPolicy { inner, deadline };
3610        let remaining = bounded.remaining_time(&state).expect("remaining time");
3611        assert!(
3612            remaining >= StdDuration::from_millis(1500)
3613                && remaining <= StdDuration::from_millis(2500)
3614        );
3615
3616        // Case C: Inner policy timeout is None (returns transaction remaining approx 10s)
3617        let inner = Arc::new(RemainingTimeTestPolicy { timeout: None });
3618        let deadline = Instant::now() + StdDuration::from_secs(10);
3619        let bounded = TransactionBoundedRetryPolicy { inner, deadline };
3620        let remaining = bounded.remaining_time(&state).expect("remaining time");
3621        assert!(
3622            remaining >= StdDuration::from_millis(9500)
3623                && remaining <= StdDuration::from_millis(10500)
3624        );
3625    }
3626}