Skip to main content

google_cloud_spanner/
write_only_transaction.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::client::{DatabaseClient, amend_request_options_for_lar};
16use crate::model::request_options::Priority;
17use crate::model::transaction_options::ReadWrite;
18use crate::model::{
19    BeginTransactionRequest, CommitRequest, CommitResponse, MultiplexedSessionPrecommitToken,
20    Mutation as ProtoMutation, RequestOptions, TransactionOptions,
21};
22use crate::mutation::Mutation;
23use crate::transaction_retry_policy::{
24    BasicTransactionRetryPolicy, TransactionRetryPolicy, retry_aborted,
25};
26use bytes::Bytes;
27use google_cloud_gax::backoff_policy::BackoffPolicyArg;
28use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
29use google_cloud_gax::retry_policy::RetryPolicyArg;
30use std::sync::{Arc, Mutex};
31use wkt::Duration;
32
33/// A builder for [WriteOnlyTransaction].
34pub struct WriteOnlyTransactionBuilder {
35    client: DatabaseClient,
36    transaction_tag: Option<String>,
37    max_commit_delay: Option<Duration>,
38    retry_policy: Box<dyn TransactionRetryPolicy>,
39    exclude_txn_from_change_streams: bool,
40    return_commit_stats: bool,
41    commit_priority: Priority,
42    begin_gax_options: GaxRequestOptions,
43    commit_gax_options: GaxRequestOptions,
44}
45
46impl WriteOnlyTransactionBuilder {
47    pub(crate) fn new(client: DatabaseClient) -> Self {
48        Self {
49            client,
50            transaction_tag: None,
51            max_commit_delay: None,
52            retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
53            exclude_txn_from_change_streams: false,
54            return_commit_stats: false,
55            commit_priority: Priority::Unspecified,
56            begin_gax_options: GaxRequestOptions::default(),
57            commit_gax_options: GaxRequestOptions::default(),
58        }
59    }
60
61    /// Sets a transaction tag to be used for the transaction.
62    ///
63    /// # Example
64    /// ```
65    /// # use google_cloud_spanner::client::Spanner;
66    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
67    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
68    /// let transaction = db_client.write_only_transaction()
69    ///     .set_transaction_tag("my-tag")
70    ///     .build();
71    /// # Ok(())
72    /// # }
73    /// ```
74    ///
75    /// See also: [Troubleshooting with tags](https://docs.cloud.google.com/spanner/docs/introspection/troubleshooting-with-tags)
76    pub fn set_transaction_tag(mut self, tag: impl Into<String>) -> Self {
77        self.transaction_tag = Some(tag.into());
78        self
79    }
80
81    /// Sets the RPC priority to use for the commit of this transaction.
82    ///
83    /// # Example
84    /// ```
85    /// # use google_cloud_spanner::client::Spanner;
86    /// # use google_cloud_spanner::model::request_options::Priority;
87    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
88    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
89    /// let transaction = db_client.write_only_transaction()
90    ///     .set_commit_priority(Priority::Low)
91    ///     .build();
92    /// # Ok(())
93    /// # }
94    /// ```
95    pub fn set_commit_priority(mut self, priority: Priority) -> Self {
96        self.commit_priority = priority;
97        self
98    }
99
100    /// Sets the maximum commit delay for the transaction.
101    ///
102    /// # Example
103    /// ```
104    /// # use google_cloud_spanner::client::Spanner;
105    /// # use wkt::Duration;
106    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
107    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
108    /// let transaction = db_client.write_only_transaction()
109    ///     .set_max_commit_delay(Duration::try_from("0.1s").unwrap())
110    ///     .build();
111    /// # Ok(())
112    /// # }
113    /// ```
114    ///
115    /// This option allows you to specify the maximum amount of time Spanner can
116    /// adjust the commit timestamp of the transaction to allow for commit batching.
117    /// Increasing this value can increase throughput at the expense of latency.
118    /// The value must be between 0 and 500 milliseconds. If not set, or set to 0,
119    /// Spanner does not delay the commit.
120    pub fn set_max_commit_delay(mut self, delay: Duration) -> Self {
121        self.max_commit_delay = Some(delay);
122        self
123    }
124
125    /// Sets whether to exclude the transaction from change streams.
126    ///
127    /// # Example
128    /// ```
129    /// # use google_cloud_spanner::client::Spanner;
130    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
131    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
132    /// let transaction = db_client.write_only_transaction()
133    ///     .set_exclude_txn_from_change_streams(true)
134    ///     .build();
135    /// # Ok(())
136    /// # }
137    /// ```
138    ///
139    /// When set to `true`, it prevents modifications from this transaction from being tracked in change streams.
140    /// Note that this only affects change streams that have been created with the DDL option `allow_txn_exclusion = true`.
141    /// If `allow_txn_exclusion` is not set or set to `false` for a change stream, updates made within this transaction
142    /// are recorded in that change stream regardless of this setting.
143    ///
144    /// When set to `false` or not specified, modifications from this transaction are recorded in all change streams
145    /// tracking columns modified by this transaction.
146    pub fn set_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
147        self.exclude_txn_from_change_streams = exclude;
148        self
149    }
150
151    /// Sets whether to return commit stats for the transaction.
152    ///
153    /// # Example
154    /// ```
155    /// # use google_cloud_spanner::mutation::Mutation;
156    /// # use google_cloud_spanner::client::Spanner;
157    /// # async fn test_doc() -> Result<(), Box<dyn std::error::Error>> {
158    /// # let client = Spanner::builder().build().await?;
159    /// # let db = client.database_client("projects/p/instances/i/databases/d").build().await?;
160    /// let mutation = Mutation::new_insert_builder("Users")
161    ///     .set("UserId").to(&1)
162    ///     .build();
163    ///
164    /// let response = db.write_only_transaction()
165    ///     .set_return_commit_stats(true)
166    ///     .build()
167    ///     .write(vec![mutation])
168    ///     .await?;
169    ///
170    /// if let Some(stats) = response.commit_stats {
171    ///     println!("Mutation count: {}", stats.mutation_count);
172    /// }
173    /// # Ok(())
174    /// # }
175    /// ```
176    ///
177    /// See also: <https://docs.cloud.google.com/spanner/docs/commit-statistics>
178    pub fn set_return_commit_stats(mut self, return_stats: bool) -> Self {
179        self.return_commit_stats = return_stats;
180        self
181    }
182
183    /// Sets the retry policy for the transaction.
184    ///
185    /// # Example
186    /// ```
187    /// # use std::time::Duration;
188    /// # use google_cloud_spanner::client::Spanner;
189    /// # use google_cloud_spanner::transaction::BasicTransactionRetryPolicy;
190    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
191    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
192    ///
193    /// let retry_policy = BasicTransactionRetryPolicy::new()
194    ///     .with_max_attempts(5)
195    ///     .with_total_timeout(Duration::from_secs(60));
196    ///
197    /// let transaction = db_client.write_only_transaction()
198    ///     .with_retry_policy(retry_policy)
199    ///     .build();
200    /// # Ok(())
201    /// # }
202    /// ```
203    ///
204    /// The client will retry the transaction if it is aborted by Spanner.
205    /// This policy can be used to customize whether a transaction should be retried
206    /// or not. The default is to retry indefinitely until the transaction succeeds.
207    pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
208        self.retry_policy = Box::new(policy);
209        self
210    }
211
212    /// Sets the per-attempt timeout for the BeginTransaction RPC.
213    ///
214    /// # Example
215    /// ```
216    /// # use google_cloud_spanner::client::Spanner;
217    /// # use std::time::Duration;
218    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
219    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
220    /// let transaction = db_client.write_only_transaction()
221    ///     .with_begin_attempt_timeout(Duration::from_secs(5))
222    ///     .build();
223    /// # Ok(())
224    /// # }
225    /// ```
226    pub fn with_begin_attempt_timeout(mut self, timeout: std::time::Duration) -> Self {
227        self.begin_gax_options.set_attempt_timeout(timeout);
228        self
229    }
230
231    /// Sets the retry policy for the BeginTransaction RPC.
232    ///
233    /// # Example
234    /// ```
235    /// # use google_cloud_spanner::client::Spanner;
236    /// # use google_cloud_gax::retry_policy::NeverRetry;
237    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
238    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
239    /// let transaction = db_client.write_only_transaction()
240    ///     .with_begin_retry_policy(NeverRetry)
241    ///     .build();
242    /// # Ok(())
243    /// # }
244    /// ```
245    pub fn with_begin_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
246        self.begin_gax_options.set_retry_policy(policy);
247        self
248    }
249
250    /// Sets the backoff policy for the BeginTransaction RPC.
251    ///
252    /// # Example
253    /// ```
254    /// # use google_cloud_spanner::client::Spanner;
255    /// # use google_cloud_gax::exponential_backoff::ExponentialBackoff;
256    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
257    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
258    /// let transaction = db_client.write_only_transaction()
259    ///     .with_begin_backoff_policy(ExponentialBackoff::default())
260    ///     .build();
261    /// # Ok(())
262    /// # }
263    /// ```
264    pub fn with_begin_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
265        self.begin_gax_options.set_backoff_policy(policy);
266        self
267    }
268
269    /// Sets the per-attempt timeout for the Commit RPC.
270    ///
271    /// # Example
272    /// ```
273    /// # use google_cloud_spanner::client::Spanner;
274    /// # use std::time::Duration;
275    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
276    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
277    /// let transaction = db_client.write_only_transaction()
278    ///     .with_commit_attempt_timeout(Duration::from_secs(5))
279    ///     .build();
280    /// # Ok(())
281    /// # }
282    /// ```
283    pub fn with_commit_attempt_timeout(mut self, timeout: std::time::Duration) -> Self {
284        self.commit_gax_options.set_attempt_timeout(timeout);
285        self
286    }
287
288    /// Sets the retry policy for the Commit RPC.
289    ///
290    /// # Example
291    /// ```
292    /// # use google_cloud_spanner::client::Spanner;
293    /// # use google_cloud_gax::retry_policy::NeverRetry;
294    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
295    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
296    /// let transaction = db_client.write_only_transaction()
297    ///     .with_commit_retry_policy(NeverRetry)
298    ///     .build();
299    /// # Ok(())
300    /// # }
301    /// ```
302    pub fn with_commit_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
303        self.commit_gax_options.set_retry_policy(policy);
304        self
305    }
306
307    /// Sets the backoff policy for the Commit RPC.
308    ///
309    /// # Example
310    /// ```
311    /// # use google_cloud_spanner::client::Spanner;
312    /// # use google_cloud_gax::exponential_backoff::ExponentialBackoff;
313    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
314    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
315    /// let transaction = db_client.write_only_transaction()
316    ///     .with_commit_backoff_policy(ExponentialBackoff::default())
317    ///     .build();
318    /// # Ok(())
319    /// # }
320    /// ```
321    pub fn with_commit_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
322        self.commit_gax_options.set_backoff_policy(policy);
323        self
324    }
325
326    /// Builds the [WriteOnlyTransaction].
327    ///
328    /// # Example
329    /// ```
330    /// # use google_cloud_spanner::client::Spanner;
331    /// # async fn build(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
332    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
333    /// let transaction = db_client.write_only_transaction().build();
334    /// # Ok(())
335    /// # }
336    /// ```
337    pub fn build(self) -> WriteOnlyTransaction {
338        let session_name = self.client.session_name();
339        WriteOnlyTransaction {
340            session_name,
341            client: self.client,
342            transaction_tag: self.transaction_tag,
343            max_commit_delay: self.max_commit_delay,
344            retry_policy: self.retry_policy,
345            exclude_txn_from_change_streams: self.exclude_txn_from_change_streams,
346            return_commit_stats: self.return_commit_stats,
347            commit_priority: self.commit_priority,
348            begin_gax_options: self.begin_gax_options,
349            commit_gax_options: self.commit_gax_options,
350        }
351    }
352}
353
354/// A write-only transaction.
355///
356/// A write-only transaction can be used to execute blind writes.
357pub struct WriteOnlyTransaction {
358    pub(crate) session_name: String,
359    client: DatabaseClient,
360    transaction_tag: Option<String>,
361    max_commit_delay: Option<Duration>,
362    retry_policy: Box<dyn TransactionRetryPolicy>,
363    exclude_txn_from_change_streams: bool,
364    return_commit_stats: bool,
365    commit_priority: Priority,
366    begin_gax_options: GaxRequestOptions,
367    commit_gax_options: GaxRequestOptions,
368}
369
370impl WriteOnlyTransaction {
371    /// Writes a set of mutations atomically to Spanner.
372    ///
373    /// # Example
374    /// ```
375    /// # use google_cloud_spanner::mutation::Mutation;
376    /// # use google_cloud_spanner::client::Spanner;
377    /// # async fn test_doc() -> Result<(), Box<dyn std::error::Error>> {
378    /// let client = Spanner::builder().build().await?;
379    /// let db = client.database_client("projects/p/instances/i/databases/d").build().await?;
380    ///
381    /// let mutation = Mutation::new_insert_builder("Users")
382    ///     .set("UserId").to(&1)
383    ///     .set("UserName").to(&"Alice")
384    ///     .build();
385    ///
386    /// let response = db.write_only_transaction()
387    ///     .set_transaction_tag("my-tag")
388    ///     .build()
389    ///     .write(vec![mutation])
390    ///     .await?;
391    /// # Ok(())
392    /// # }
393    /// ```
394    ///
395    /// This method uses retries and replay protection internally, which means that the mutations
396    /// are applied exactly once on success, or not at all if an error is returned, regardless of
397    /// any failures in the underlying network. Note that if the call is cancelled or reaches
398    /// deadline, it is not possible to know whether the mutations were applied without performing
399    /// a subsequent database operation, but the mutations will have been applied at most once.
400    pub async fn write<I>(self, mutations: I) -> crate::Result<CommitResponse>
401    where
402        I: IntoIterator<Item = Mutation>,
403    {
404        let begin_gax_options = self.begin_gax_options();
405        let commit_gax_options = self.commit_gax_options();
406        let req_options = RequestOptions::default()
407            .set_transaction_tag(self.transaction_tag.unwrap_or_default())
408            .set_priority(self.commit_priority.clone());
409
410        let mutations_proto: Vec<_> = mutations.into_iter().map(|m| m.build_proto()).collect();
411        let mutation_key = Mutation::select_mutation_key(&mutations_proto);
412        let client = self.client;
413        let session_name = self.session_name.clone();
414        let previous_transaction_id = Arc::new(Mutex::new(Bytes::new()));
415        let channel_hint = client.spanner.next_channel_hint();
416
417        let max_commit_delay = self.max_commit_delay;
418        let return_commit_stats = self.return_commit_stats;
419
420        retry_aborted(&*self.retry_policy, || {
421            let client = client.clone();
422            let session_name = session_name.clone();
423            let req_options = req_options.clone();
424            let mutations_proto = mutations_proto.clone();
425            let mutation_key = mutation_key.clone();
426            let previous_transaction_id = previous_transaction_id.clone();
427            let begin_gax_options = begin_gax_options.clone();
428            let commit_gax_options = commit_gax_options.clone();
429
430            async move {
431                let previous_id: Bytes = previous_transaction_id.lock().unwrap().clone();
432
433                let begin_req = BeginTransactionRequest::default()
434                    .set_session(session_name.clone())
435                    .set_options(
436                        TransactionOptions::default()
437                            .set_read_write(Box::new(
438                                ReadWrite::default()
439                                    .set_multiplexed_session_previous_transaction_id(previous_id),
440                            ))
441                            .set_exclude_txn_from_change_streams(
442                                self.exclude_txn_from_change_streams,
443                            ),
444                    )
445                    .set_request_options(req_options.clone())
446                    .set_or_clear_mutation_key(mutation_key.clone());
447
448                let tx = client
449                    .spanner
450                    .begin_transaction(begin_req, begin_gax_options, channel_hint)
451                    .await?;
452                *previous_transaction_id.lock().unwrap() = tx.id.clone();
453
454                let commit_req = create_commit_request(
455                    session_name.clone(),
456                    tx.id.clone(),
457                    mutations_proto,
458                    tx.precommit_token,
459                    Some(req_options.clone()),
460                    max_commit_delay,
461                    return_commit_stats,
462                );
463
464                let response = client
465                    .spanner
466                    .commit(commit_req, commit_gax_options.clone(), channel_hint)
467                    .await?;
468
469                // If a commit_response with a precommit_token is returned, then we need to
470                // retry the commit with the new precommit_token and without any mutations.
471                if let Some(new_token) = response.precommit_token().map(|b| *b.clone()) {
472                    let retry_commit_req = create_commit_request(
473                        session_name.clone(),
474                        tx.id,
475                        Vec::new(),
476                        Some(new_token),
477                        Some(req_options),
478                        max_commit_delay,
479                        return_commit_stats,
480                    );
481                    client
482                        .spanner
483                        .commit(retry_commit_req, commit_gax_options, channel_hint)
484                        .await
485                } else {
486                    Ok(response)
487                }
488            }
489        })
490        .await
491    }
492
493    /// Writes a set of mutations at least once using a single Commit RPC.
494    ///
495    /// # Example
496    /// ```
497    /// # use google_cloud_spanner::mutation::Mutation;
498    /// # use google_cloud_spanner::client::Spanner;
499    /// # async fn test_doc() -> Result<(), Box<dyn std::error::Error>> {
500    /// let client = Spanner::builder().build().await?;
501    /// let db = client.database_client("projects/p/instances/i/databases/d").build().await?;
502    ///
503    /// let mutation = Mutation::new_insert_or_update_builder("Users")
504    ///     .set("UserId").to(&1)
505    ///     .set("UserName").to(&"Alice")
506    ///     .build();
507    ///
508    /// let response = db.write_only_transaction()
509    ///     .set_transaction_tag("my-tag")
510    ///     .build()
511    ///     .write_at_least_once(vec![mutation])
512    ///     .await?;
513    /// # Ok(())
514    /// # }
515    /// ```
516    ///
517    /// Since this method does not feature replay protection, it may attempt to apply the provided
518    /// mutations more than once. If the mutations are not idempotent, this may lead to a failure
519    /// being reported even if the mutation was applied successfully the first time. For example,
520    /// an insert may fail with an `AlreadyExists` error even though the row did not exist before
521    /// this method was called. For this reason, most users of the library will prefer to use write
522    /// transactions with replay protection instead.
523    /// However, `write_at_least_once` requires only a single RPC, whereas replay-protected
524    /// writes require two RPCs. Thus, this method may be appropriate for latency sensitive
525    /// and/or high throughput blind writing.
526    pub async fn write_at_least_once<I>(self, mutations: I) -> crate::Result<CommitResponse>
527    where
528        I: IntoIterator<Item = Mutation>,
529    {
530        let commit_gax_options = self.commit_gax_options();
531        let single_use = TransactionOptions::new()
532            .set_read_write(Box::new(ReadWrite::new()))
533            .set_exclude_txn_from_change_streams(self.exclude_txn_from_change_streams);
534        let req_options = RequestOptions::default()
535            .set_transaction_tag(self.transaction_tag.unwrap_or_default())
536            .set_priority(self.commit_priority.clone());
537        let request = CommitRequest::new()
538            .set_session(self.session_name.clone())
539            .set_mutations(mutations.into_iter().map(|m| m.build_proto()))
540            .set_single_use_transaction(Box::new(single_use))
541            .set_request_options(req_options)
542            .set_or_clear_max_commit_delay(self.max_commit_delay)
543            .set_return_commit_stats(self.return_commit_stats);
544        let client = self.client;
545        let channel_hint = client.spanner.next_channel_hint();
546
547        retry_aborted(&*self.retry_policy, || {
548            let client = client.clone();
549            let request = request.clone();
550            let commit_gax_options = commit_gax_options.clone();
551
552            async move {
553                client
554                    .spanner
555                    .commit(request, commit_gax_options, channel_hint)
556                    .await
557            }
558        })
559        .await
560    }
561
562    fn begin_gax_options(&self) -> GaxRequestOptions {
563        amend_request_options_for_lar(
564            self.client.leader_aware_routing_enabled,
565            self.begin_gax_options.clone(),
566        )
567    }
568
569    fn commit_gax_options(&self) -> GaxRequestOptions {
570        amend_request_options_for_lar(
571            self.client.leader_aware_routing_enabled,
572            self.commit_gax_options.clone(),
573        )
574    }
575}
576
577pub(crate) fn create_commit_request(
578    session_name: String,
579    transaction_id: bytes::Bytes,
580    mutations: Vec<ProtoMutation>,
581    precommit_token: Option<MultiplexedSessionPrecommitToken>,
582    request_options: Option<RequestOptions>,
583    max_commit_delay: Option<Duration>,
584    return_commit_stats: bool,
585) -> CommitRequest {
586    CommitRequest::default()
587        .set_session(session_name)
588        .set_transaction_id(transaction_id)
589        .set_mutations(mutations)
590        .set_or_clear_precommit_token(precommit_token)
591        .set_or_clear_request_options(request_options)
592        .set_or_clear_max_commit_delay(max_commit_delay)
593        .set_return_commit_stats(return_commit_stats)
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use crate::client::Spanner;
600    use crate::transaction_retry_policy::tests::create_aborted_status;
601    use gaxi::grpc::tonic::Response;
602    use google_cloud_gax::exponential_backoff::ExponentialBackoff;
603    use google_cloud_gax::retry_policy::NeverRetry;
604    use google_cloud_test_macros::tokio_test_no_panics;
605    use prost_types::Duration as ProstDuration;
606    use prost_types::Timestamp;
607    use spanner_grpc_mock::google::spanner::v1::CommitResponse;
608    use spanner_grpc_mock::google::spanner::v1::Session;
609    use spanner_grpc_mock::google::spanner::v1::Transaction;
610    use spanner_grpc_mock::google::spanner::v1::commit_response::CommitStats;
611    use spanner_grpc_mock::google::spanner::v1::transaction_options::Mode;
612    use std::time::Duration as StdDuration;
613    use wkt::Duration;
614
615    pub(crate) async fn setup_db_client(
616        mock: spanner_grpc_mock::MockSpanner,
617    ) -> (DatabaseClient, tokio::task::JoinHandle<()>) {
618        use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
619        let (address, server) = spanner_grpc_mock::start("0.0.0.0:0", mock)
620            .await
621            .expect("Failed to start mock server");
622        let spanner = Spanner::builder()
623            .with_endpoint(address)
624            .with_credentials(Anonymous::new().build())
625            .build()
626            .await
627            .expect("Failed to build client");
628
629        let db_client = spanner
630            .database_client("projects/p/instances/i/databases/d")
631            .build()
632            .await
633            .expect("Failed to create DatabaseClient");
634
635        (db_client, server)
636    }
637
638    #[tokio_test_no_panics]
639    async fn write_at_least_once() {
640        let mut mock = spanner_grpc_mock::MockSpanner::new();
641        mock.expect_create_session().returning(|_| {
642            Ok(gaxi::grpc::tonic::Response::new(
643                spanner_grpc_mock::google::spanner::v1::Session {
644                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
645                    ..Default::default()
646                },
647            ))
648        });
649
650        mock.expect_commit().once().returning(|req| {
651            let req = req.into_inner();
652            assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
653
654            // Validate the custom request options contain the transaction tag and priority
655            assert!(req.request_options.is_some());
656            let req_opts = req.request_options.as_ref().expect("request_options should be present");
657            assert_eq!(req_opts.transaction_tag, "my_tag");
658            assert_eq!(Priority::from(req_opts.priority), Priority::High);
659
660            assert!(req.mutations.len() == 1);
661
662            // Validate it's a single-use transaction configured correctly
663            match req.transaction {
664                Some(spanner_grpc_mock::google::spanner::v1::commit_request::Transaction::SingleUseTransaction(opts)) => {
665                    assert!(opts.mode.is_some());
666                }
667                _ => panic!("Expected SingleUseTransaction"),
668            }
669
670            Ok(gaxi::grpc::tonic::Response::new(
671                spanner_grpc_mock::google::spanner::v1::CommitResponse {
672                    commit_timestamp: Some(prost_types::Timestamp {
673                        seconds: 1234,
674                        nanos: 0,
675                    }),
676                    ..Default::default()
677                },
678            ))
679        });
680
681        let (db_client, _server) = setup_db_client(mock).await;
682
683        let mutation = Mutation::new_insert_or_update_builder("Users")
684            .set("UserId")
685            .to(&1)
686            .build();
687
688        let res = db_client
689            .write_only_transaction()
690            .set_transaction_tag("my_tag")
691            .set_commit_priority(Priority::High)
692            .build()
693            .write_at_least_once(vec![mutation])
694            .await;
695
696        assert!(res.is_ok());
697        let res = res.expect("write_at_least_once should succeed");
698        assert!(res.commit_timestamp.is_some());
699        assert_eq!(
700            res.commit_timestamp
701                .expect("commit_timestamp should be present")
702                .seconds(),
703            1234
704        );
705    }
706
707    #[tokio_test_no_panics]
708    async fn write() {
709        let mut mock = spanner_grpc_mock::MockSpanner::new();
710        mock.expect_create_session().returning(|_| {
711            Ok(gaxi::grpc::tonic::Response::new(
712                spanner_grpc_mock::google::spanner::v1::Session {
713                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
714                    ..Default::default()
715                },
716            ))
717        });
718
719        mock.expect_begin_transaction().once().returning(|req| {
720            let req = req.into_inner();
721            assert_eq!(
722                req.session,
723                "projects/p/instances/i/databases/d/sessions/123"
724            );
725            assert!(req.options.is_some());
726            assert!(req.mutation_key.is_some());
727
728            Ok(gaxi::grpc::tonic::Response::new(
729                spanner_grpc_mock::google::spanner::v1::Transaction {
730                    id: vec![42],
731                    precommit_token: Some(
732                        spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken {
733                            precommit_token: vec![1, 2, 3],
734                            seq_num: 1,
735                        },
736                    ),
737                    ..Default::default()
738                },
739            ))
740        });
741
742        mock.expect_commit().once().returning(|req| {
743            let req = req.into_inner();
744            assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
745            assert_eq!(
746                req.precommit_token.expect("precommit_token required").precommit_token,
747                vec![1, 2, 3]
748            );
749
750            // Validate that we pass down the transaction ID from BeginTransaction.
751            match req.transaction {
752                Some(spanner_grpc_mock::google::spanner::v1::commit_request::Transaction::TransactionId(tid)) => {
753                    assert_eq!(tid, vec![42]);
754                }
755                _ => panic!("Expected TransactionId"),
756            }
757
758            Ok(gaxi::grpc::tonic::Response::new(
759                spanner_grpc_mock::google::spanner::v1::CommitResponse {
760                    commit_timestamp: Some(prost_types::Timestamp {
761                        seconds: 5678,
762                        nanos: 0,
763                    }),
764                    ..Default::default()
765                },
766            ))
767        });
768
769        let (db_client, _server) = setup_db_client(mock).await;
770
771        let mutation = Mutation::new_insert_or_update_builder("Users")
772            .set("UserId")
773            .to(&1)
774            .build();
775
776        let res = db_client
777            .write_only_transaction()
778            .build()
779            .write(vec![mutation])
780            .await;
781
782        assert!(res.is_ok());
783        let res = res.expect("write should succeed");
784        assert!(res.commit_timestamp.is_some());
785        assert_eq!(
786            res.commit_timestamp
787                .expect("commit_timestamp should be present")
788                .seconds(),
789            5678
790        );
791    }
792
793    #[tokio_test_no_panics]
794    async fn write_at_least_once_with_commit_stats() -> anyhow::Result<()> {
795        let mut mock = spanner_grpc_mock::MockSpanner::new();
796        mock.expect_create_session().returning(|_| {
797            Ok(Response::new(Session {
798                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
799                ..Default::default()
800            }))
801        });
802
803        mock.expect_commit().once().returning(|req| {
804            let req = req.into_inner();
805            assert!(req.return_commit_stats);
806
807            Ok(Response::new(CommitResponse {
808                commit_timestamp: Some(prost_types::Timestamp {
809                    seconds: 1234,
810                    nanos: 0,
811                }),
812                commit_stats: Some(CommitStats { mutation_count: 5 }),
813                ..Default::default()
814            }))
815        });
816
817        let (db_client, _server) = setup_db_client(mock).await;
818
819        let mutation = Mutation::new_insert_or_update_builder("Users")
820            .set("UserId")
821            .to(&1)
822            .build();
823
824        let res = db_client
825            .write_only_transaction()
826            .set_return_commit_stats(true)
827            .build()
828            .write_at_least_once(vec![mutation])
829            .await?;
830
831        let stats = res.commit_stats.expect("Commit stats should be present");
832        assert_eq!(stats.mutation_count, 5);
833        Ok(())
834    }
835
836    #[tokio_test_no_panics]
837    async fn write_with_commit_stats() -> anyhow::Result<()> {
838        let mut mock = spanner_grpc_mock::MockSpanner::new();
839        mock.expect_create_session().returning(|_| {
840            Ok(Response::new(Session {
841                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
842                ..Default::default()
843            }))
844        });
845
846        mock.expect_begin_transaction().once().returning(|_| {
847            Ok(Response::new(Transaction {
848                id: vec![42],
849                ..Default::default()
850            }))
851        });
852
853        mock.expect_commit().once().returning(|req| {
854            let req = req.into_inner();
855            assert!(req.return_commit_stats);
856
857            Ok(Response::new(CommitResponse {
858                commit_timestamp: Some(prost_types::Timestamp {
859                    seconds: 5678,
860                    nanos: 0,
861                }),
862                commit_stats: Some(CommitStats { mutation_count: 10 }),
863                ..Default::default()
864            }))
865        });
866
867        let (db_client, _server) = setup_db_client(mock).await;
868
869        let mutation = Mutation::new_insert_or_update_builder("Users")
870            .set("UserId")
871            .to(&1)
872            .build();
873
874        let res = db_client
875            .write_only_transaction()
876            .set_return_commit_stats(true)
877            .build()
878            .write(vec![mutation])
879            .await?;
880
881        let stats = res.commit_stats.expect("Commit stats should be present");
882        assert_eq!(stats.mutation_count, 10);
883        Ok(())
884    }
885
886    #[tokio_test_no_panics]
887    async fn write_at_least_once_with_exclude_txn_from_change_streams() {
888        let mut mock = spanner_grpc_mock::MockSpanner::new();
889        mock.expect_create_session().returning(|_| {
890            Ok(gaxi::grpc::tonic::Response::new(
891                spanner_grpc_mock::google::spanner::v1::Session {
892                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
893                    ..Default::default()
894                },
895            ))
896        });
897
898        mock.expect_commit().once().returning(|req| {
899            let req = req.into_inner();
900            match req.transaction {
901                Some(spanner_grpc_mock::google::spanner::v1::commit_request::Transaction::SingleUseTransaction(opts)) => {
902                    assert!(opts.exclude_txn_from_change_streams);
903                }
904                _ => panic!("Expected SingleUseTransaction"),
905            }
906
907            Ok(gaxi::grpc::tonic::Response::new(
908                spanner_grpc_mock::google::spanner::v1::CommitResponse {
909                    commit_timestamp: Some(prost_types::Timestamp {
910                        seconds: 1234,
911                        nanos: 0,
912                    }),
913                    ..Default::default()
914                },
915            ))
916        });
917
918        let (db_client, _server) = setup_db_client(mock).await;
919
920        let mutation = Mutation::new_insert_or_update_builder("Users")
921            .set("UserId")
922            .to(&1)
923            .build();
924
925        let res = db_client
926            .write_only_transaction()
927            .set_exclude_txn_from_change_streams(true)
928            .build()
929            .write_at_least_once(vec![mutation])
930            .await;
931
932        assert!(res.is_ok());
933    }
934
935    #[tokio_test_no_panics]
936    async fn write_with_exclude_txn_from_change_streams() {
937        let mut mock = spanner_grpc_mock::MockSpanner::new();
938        mock.expect_create_session().returning(|_| {
939            Ok(gaxi::grpc::tonic::Response::new(
940                spanner_grpc_mock::google::spanner::v1::Session {
941                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
942                    ..Default::default()
943                },
944            ))
945        });
946
947        mock.expect_begin_transaction().once().returning(|req| {
948            let req = req.into_inner();
949            let options = req.options.expect("Missing transaction options");
950            assert!(options.exclude_txn_from_change_streams);
951
952            Ok(gaxi::grpc::tonic::Response::new(
953                spanner_grpc_mock::google::spanner::v1::Transaction {
954                    id: vec![42],
955                    ..Default::default()
956                },
957            ))
958        });
959
960        mock.expect_commit().once().returning(|_req| {
961            Ok(gaxi::grpc::tonic::Response::new(
962                spanner_grpc_mock::google::spanner::v1::CommitResponse {
963                    commit_timestamp: Some(prost_types::Timestamp {
964                        seconds: 5678,
965                        nanos: 0,
966                    }),
967                    ..Default::default()
968                },
969            ))
970        });
971
972        let (db_client, _server) = setup_db_client(mock).await;
973
974        let mutation = Mutation::new_insert_or_update_builder("Users")
975            .set("UserId")
976            .to(&1)
977            .build();
978
979        let res = db_client
980            .write_only_transaction()
981            .set_exclude_txn_from_change_streams(true)
982            .build()
983            .write(vec![mutation])
984            .await;
985
986        assert!(res.is_ok());
987    }
988
989    #[tokio_test_no_panics]
990    async fn write_with_commit_retry() {
991        let mut mock = spanner_grpc_mock::MockSpanner::new();
992        mock.expect_create_session().returning(|_| {
993            Ok(gaxi::grpc::tonic::Response::new(
994                spanner_grpc_mock::google::spanner::v1::Session {
995                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
996                    ..Default::default()
997                },
998            ))
999        });
1000
1001        mock.expect_begin_transaction().once().returning(|req| {
1002            let req = req.into_inner();
1003            assert!(req.mutation_key.is_some());
1004
1005            Ok(gaxi::grpc::tonic::Response::new(
1006                spanner_grpc_mock::google::spanner::v1::Transaction {
1007                    id: vec![42],
1008                    ..Default::default()
1009                },
1010            ))
1011        });
1012
1013        let commit_call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
1014        mock.expect_commit().times(2).returning(move |req| {
1015            let count = commit_call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1016            let req = req.into_inner();
1017            assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
1018
1019            if count == 0 {
1020                assert!(!req.mutations.is_empty());
1021                Ok(gaxi::grpc::tonic::Response::new(
1022                    spanner_grpc_mock::google::spanner::v1::CommitResponse {
1023                        multiplexed_session_retry: Some(
1024                            spanner_grpc_mock::google::spanner::v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
1025                                spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken {
1026                                    precommit_token: vec![4, 5, 6],
1027                                    seq_num: 2,
1028                                }
1029                            )
1030                        ),
1031                        ..Default::default()
1032                    },
1033                ))
1034            } else {
1035                assert!(req.mutations.is_empty());
1036                assert_eq!(
1037                    req.precommit_token.expect("precommit_token required").precommit_token,
1038                    vec![4, 5, 6]
1039                );
1040                Ok(gaxi::grpc::tonic::Response::new(
1041                    spanner_grpc_mock::google::spanner::v1::CommitResponse {
1042                        commit_timestamp: Some(prost_types::Timestamp {
1043                            seconds: 9999,
1044                            nanos: 0,
1045                        }),
1046                        ..Default::default()
1047                    },
1048                ))
1049            }
1050        });
1051
1052        let (db_client, _server) = setup_db_client(mock).await;
1053
1054        let mutation = Mutation::new_insert_or_update_builder("Users")
1055            .set("UserId")
1056            .to(&1)
1057            .build();
1058
1059        let res = db_client
1060            .write_only_transaction()
1061            .build()
1062            .write(vec![mutation])
1063            .await;
1064
1065        assert!(res.is_ok());
1066        let res = res.expect("write should succeed");
1067        assert!(res.commit_timestamp.is_some());
1068        assert_eq!(
1069            res.commit_timestamp
1070                .expect("commit_timestamp should be present")
1071                .seconds(),
1072            9999
1073        );
1074    }
1075
1076    #[tokio_test_no_panics]
1077    async fn write_with_commit_retry_preserves_options() -> anyhow::Result<()> {
1078        let mut mock = spanner_grpc_mock::MockSpanner::new();
1079        mock.expect_create_session().returning(|_| {
1080            Ok(gaxi::grpc::tonic::Response::new(
1081                spanner_grpc_mock::google::spanner::v1::Session {
1082                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1083                    ..Default::default()
1084                },
1085            ))
1086        });
1087
1088        mock.expect_begin_transaction().once().returning(|req| {
1089            let req = req.into_inner();
1090            assert!(req.mutation_key.is_some());
1091
1092            Ok(gaxi::grpc::tonic::Response::new(
1093                spanner_grpc_mock::google::spanner::v1::Transaction {
1094                    id: vec![42],
1095                    ..Default::default()
1096                },
1097            ))
1098        });
1099
1100        let expected_delay = prost_types::Duration {
1101            seconds: 0,
1102            nanos: 200_000_000,
1103        };
1104
1105        let expected_delay_clone = expected_delay;
1106        let commit_call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
1107        mock.expect_commit().times(2).returning(move |req| {
1108            let count = commit_call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1109            let req = req.into_inner();
1110            assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
1111
1112            // Verify options are present in both attempts
1113            assert!(req.return_commit_stats, "Expected return_commit_stats to be true");
1114            assert_eq!(req.max_commit_delay.as_ref(), Some(&expected_delay_clone), "Expected max_commit_delay to be 200ms");
1115
1116            if count == 0 {
1117                assert!(!req.mutations.is_empty());
1118                Ok(gaxi::grpc::tonic::Response::new(
1119                    spanner_grpc_mock::google::spanner::v1::CommitResponse {
1120                        multiplexed_session_retry: Some(
1121                            spanner_grpc_mock::google::spanner::v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
1122                                spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken {
1123                                    precommit_token: vec![4, 5, 6],
1124                                    seq_num: 2,
1125                                }
1126                            )
1127                        ),
1128                        ..Default::default()
1129                    },
1130                ))
1131            } else {
1132                assert!(req.mutations.is_empty(), "Expected mutations to be empty on retry");
1133                assert_eq!(
1134                    req.precommit_token.expect("precommit_token required").precommit_token,
1135                    vec![4, 5, 6]
1136                );
1137                Ok(gaxi::grpc::tonic::Response::new(
1138                    spanner_grpc_mock::google::spanner::v1::CommitResponse {
1139                        commit_timestamp: Some(prost_types::Timestamp {
1140                            seconds: 9999,
1141                            nanos: 0,
1142                        }),
1143                        commit_stats: Some(CommitStats { mutation_count: 12 }),
1144                        ..Default::default()
1145                    },
1146                ))
1147            }
1148        });
1149
1150        let (db_client, _server) = setup_db_client(mock).await;
1151
1152        let mutation = Mutation::new_insert_or_update_builder("Users")
1153            .set("UserId")
1154            .to(&1)
1155            .build();
1156
1157        let res = db_client
1158            .write_only_transaction()
1159            .set_return_commit_stats(true)
1160            .set_max_commit_delay(Duration::new(0, 200_000_000).expect("valid duration"))
1161            .build()
1162            .write(vec![mutation])
1163            .await?;
1164
1165        let stats = res.commit_stats.expect("Expected commit stats in response");
1166        assert_eq!(stats.mutation_count, 12);
1167        assert_eq!(
1168            res.commit_timestamp
1169                .expect("timestamp should be present")
1170                .seconds(),
1171            9999
1172        );
1173
1174        Ok(())
1175    }
1176
1177    #[tokio_test_no_panics]
1178    async fn write_with_commit_aborted_retry() -> anyhow::Result<()> {
1179        let mut mock = spanner_grpc_mock::MockSpanner::new();
1180        mock.expect_create_session().returning(|_| {
1181            Ok(Response::new(Session {
1182                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1183                ..Default::default()
1184            }))
1185        });
1186
1187        let mut seq = mockall::Sequence::new();
1188
1189        mock.expect_begin_transaction()
1190            .once()
1191            .in_sequence(&mut seq)
1192            .returning(move |req| {
1193                let req = req.into_inner();
1194                assert!(req.mutation_key.is_some());
1195
1196                Ok(Response::new(Transaction {
1197                    id: vec![42],
1198                    ..Default::default()
1199                }))
1200            });
1201
1202        mock.expect_commit()
1203            .once()
1204            .in_sequence(&mut seq)
1205            .returning(move |_req| Err(create_aborted_status(std::time::Duration::from_nanos(1))));
1206
1207        mock.expect_begin_transaction()
1208            .once()
1209            .in_sequence(&mut seq)
1210            .returning(move |req| {
1211                let req = req.into_inner();
1212                assert!(req.mutation_key.is_some());
1213
1214                let options = req.options.as_ref().expect("options required on retry");
1215                let read_write = options.mode.as_ref().expect("mode required on retry");
1216                match read_write {
1217                    Mode::ReadWrite(rw) => {
1218                        assert_eq!(rw.multiplexed_session_previous_transaction_id, vec![42], "previous_transaction_id should be set to the ID of the aborted transaction");
1219                    }
1220                    _ => panic!("Expected ReadWrite mode"),
1221                }
1222
1223                Ok(Response::new(Transaction {
1224                    id: vec![42],
1225                    ..Default::default()
1226                }))
1227            });
1228
1229        mock.expect_commit()
1230            .once()
1231            .in_sequence(&mut seq)
1232            .returning(move |_req| {
1233                Ok(Response::new(CommitResponse {
1234                    commit_timestamp: Some(Timestamp {
1235                        seconds: 8888,
1236                        nanos: 0,
1237                    }),
1238                    ..Default::default()
1239                }))
1240            });
1241
1242        let (db_client, _server) = setup_db_client(mock).await;
1243
1244        let mutation = Mutation::new_insert_or_update_builder("Users")
1245            .set("UserId")
1246            .to(&1)
1247            .build();
1248
1249        let res = db_client
1250            .write_only_transaction()
1251            .build()
1252            .write(vec![mutation])
1253            .await;
1254
1255        let res = res.expect("write should succeed");
1256        assert_eq!(
1257            res.commit_timestamp
1258                .expect("commit_timestamp should be present")
1259                .seconds(),
1260            8888,
1261            "expected commit timestamp to match"
1262        );
1263        Ok(())
1264    }
1265
1266    #[tokio_test_no_panics]
1267    async fn write_at_least_once_with_max_commit_delay() {
1268        let mut mock = spanner_grpc_mock::MockSpanner::new();
1269        mock.expect_create_session().returning(|_| {
1270            Ok(Response::new(Session {
1271                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1272                ..Default::default()
1273            }))
1274        });
1275
1276        mock.expect_commit().once().returning(|req| {
1277            let req = req.into_inner();
1278            assert_eq!(
1279                req.session,
1280                "projects/p/instances/i/databases/d/sessions/123"
1281            );
1282            assert_eq!(
1283                req.max_commit_delay,
1284                Some(ProstDuration {
1285                    seconds: 0,
1286                    nanos: 100_000_000, // 100ms
1287                })
1288            );
1289
1290            Ok(Response::new(CommitResponse {
1291                commit_timestamp: Some(Timestamp {
1292                    seconds: 1234,
1293                    nanos: 0,
1294                }),
1295                ..Default::default()
1296            }))
1297        });
1298
1299        let (db_client, _server) = setup_db_client(mock).await;
1300
1301        let mutation = Mutation::new_insert_or_update_builder("Users")
1302            .set("UserId")
1303            .to(&1)
1304            .build();
1305
1306        let res = db_client
1307            .write_only_transaction()
1308            .set_max_commit_delay(Duration::try_from("0.1s").unwrap())
1309            .build()
1310            .write_at_least_once(vec![mutation])
1311            .await;
1312
1313        assert!(res.is_ok());
1314    }
1315
1316    #[tokio_test_no_panics]
1317    async fn leader_aware_routing_enabled_by_default() {
1318        let mut mock = spanner_grpc_mock::MockSpanner::new();
1319        mock.expect_create_session().returning(|_| {
1320            Ok(Response::new(Session {
1321                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1322                ..Default::default()
1323            }))
1324        });
1325
1326        mock.expect_commit().once().returning(|req| {
1327            assert_eq!(
1328                req.metadata()
1329                    .get("x-goog-spanner-route-to-leader")
1330                    .expect("header required")
1331                    .to_str()
1332                    .unwrap(),
1333                "true"
1334            );
1335            Ok(Response::new(CommitResponse {
1336                commit_timestamp: Some(Timestamp {
1337                    seconds: 1234,
1338                    nanos: 0,
1339                }),
1340                ..Default::default()
1341            }))
1342        });
1343
1344        let (db_client, _server) = setup_db_client(mock).await;
1345        let mutation = Mutation::new_insert_or_update_builder("Users")
1346            .set("UserId")
1347            .to(&1)
1348            .build();
1349        let res = db_client
1350            .write_only_transaction()
1351            .build()
1352            .write_at_least_once(vec![mutation])
1353            .await;
1354        assert!(res.is_ok());
1355    }
1356
1357    #[tokio_test_no_panics]
1358    async fn write_only_transaction_builder_sets_gax_options() -> anyhow::Result<()> {
1359        let mut mock = spanner_grpc_mock::MockSpanner::new();
1360        mock.expect_create_session().returning(|_| {
1361            Ok(Response::new(Session {
1362                name: "session".to_string(),
1363                ..Default::default()
1364            }))
1365        });
1366        let (db_client, _server) = setup_db_client(mock).await;
1367
1368        let builder = db_client
1369            .write_only_transaction()
1370            .with_begin_attempt_timeout(StdDuration::from_secs(5))
1371            .with_begin_retry_policy(NeverRetry)
1372            .with_begin_backoff_policy(ExponentialBackoff::default())
1373            .with_commit_attempt_timeout(StdDuration::from_secs(10))
1374            .with_commit_retry_policy(NeverRetry)
1375            .with_commit_backoff_policy(ExponentialBackoff::default());
1376
1377        let begin_gax = &builder.begin_gax_options;
1378        assert_eq!(
1379            *begin_gax.attempt_timeout(),
1380            Some(StdDuration::from_secs(5))
1381        );
1382        assert!(begin_gax.retry_policy().is_some());
1383        assert!(begin_gax.backoff_policy().is_some());
1384
1385        let commit_gax = &builder.commit_gax_options;
1386        assert_eq!(
1387            *commit_gax.attempt_timeout(),
1388            Some(StdDuration::from_secs(10))
1389        );
1390        assert!(commit_gax.retry_policy().is_some());
1391        assert!(commit_gax.backoff_policy().is_some());
1392
1393        Ok(())
1394    }
1395
1396    fn parse_grpc_timeout(metadata: &gaxi::grpc::tonic::MetadataMap) -> Option<StdDuration> {
1397        let timeout_header = metadata.get("grpc-timeout")?.to_str().ok()?;
1398        let numeric_part: String = timeout_header
1399            .chars()
1400            .take_while(|c| c.is_ascii_digit())
1401            .collect();
1402        let value = numeric_part.parse::<u64>().ok()?;
1403        let unit = timeout_header.trim_start_matches(&numeric_part);
1404        let duration = match unit {
1405            "u" => StdDuration::from_micros(value),
1406            "m" => StdDuration::from_millis(value),
1407            "S" => StdDuration::from_secs(value),
1408            "M" => StdDuration::from_secs(value * 60),
1409            "H" => StdDuration::from_secs(value * 3600),
1410            _ => return None,
1411        };
1412        Some(duration)
1413    }
1414
1415    #[tokio_test_no_panics]
1416    async fn write_only_transaction_with_custom_options() -> anyhow::Result<()> {
1417        let mut mock = spanner_grpc_mock::MockSpanner::new();
1418        mock.expect_create_session().returning(|_| {
1419            Ok(Response::new(Session {
1420                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1421                ..Default::default()
1422            }))
1423        });
1424
1425        let mut seq = mockall::Sequence::new();
1426
1427        mock.expect_begin_transaction()
1428            .once()
1429            .in_sequence(&mut seq)
1430            .withf(|req| {
1431                let duration =
1432                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
1433                assert_eq!(duration, StdDuration::from_secs(5));
1434                true
1435            })
1436            .returning(|_| {
1437                Ok(Response::new(Transaction {
1438                    id: vec![42],
1439                    ..Default::default()
1440                }))
1441            });
1442
1443        mock.expect_commit()
1444            .once()
1445            .in_sequence(&mut seq)
1446            .withf(|req| {
1447                let duration =
1448                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
1449                assert_eq!(duration, StdDuration::from_secs(10));
1450                true
1451            })
1452            .returning(|_| {
1453                Ok(Response::new(CommitResponse {
1454                    commit_timestamp: Some(Timestamp {
1455                        seconds: 8888,
1456                        nanos: 0,
1457                    }),
1458                    ..Default::default()
1459                }))
1460            });
1461
1462        let (db_client, _server) = setup_db_client(mock).await;
1463
1464        let mutation = Mutation::new_insert_or_update_builder("Users")
1465            .set("UserId")
1466            .to(&1)
1467            .build();
1468
1469        let res = db_client
1470            .write_only_transaction()
1471            .with_begin_attempt_timeout(StdDuration::from_secs(5))
1472            .with_commit_attempt_timeout(StdDuration::from_secs(10))
1473            .build()
1474            .write(vec![mutation])
1475            .await?;
1476
1477        assert_eq!(
1478            res.commit_timestamp
1479                .expect("commit_timestamp should be present")
1480                .seconds(),
1481            8888
1482        );
1483        Ok(())
1484    }
1485
1486    #[tokio_test_no_panics]
1487    async fn write_at_least_once_with_custom_commit_options() -> anyhow::Result<()> {
1488        let mut mock = spanner_grpc_mock::MockSpanner::new();
1489        mock.expect_create_session().returning(|_| {
1490            Ok(Response::new(Session {
1491                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1492                ..Default::default()
1493            }))
1494        });
1495
1496        mock.expect_begin_transaction().never();
1497
1498        mock.expect_commit()
1499            .once()
1500            .withf(|req| {
1501                let duration =
1502                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
1503                assert_eq!(duration, StdDuration::from_secs(7));
1504                true
1505            })
1506            .returning(|_| {
1507                Ok(Response::new(CommitResponse {
1508                    commit_timestamp: Some(Timestamp {
1509                        seconds: 7777,
1510                        nanos: 0,
1511                    }),
1512                    ..Default::default()
1513                }))
1514            });
1515
1516        let (db_client, _server) = setup_db_client(mock).await;
1517
1518        let mutation = Mutation::new_insert_or_update_builder("Users")
1519            .set("UserId")
1520            .to(&1)
1521            .build();
1522
1523        let res = db_client
1524            .write_only_transaction()
1525            .with_commit_attempt_timeout(StdDuration::from_secs(7))
1526            .build()
1527            .write_at_least_once(vec![mutation])
1528            .await?;
1529
1530        assert_eq!(
1531            res.commit_timestamp
1532                .expect("commit_timestamp should be present")
1533                .seconds(),
1534            7777
1535        );
1536        Ok(())
1537    }
1538}