Skip to main content

google_cloud_spanner/
write_only_transaction.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::client::{DatabaseClient, amend_request_options_for_lar};
16use crate::model::request_options::Priority;
17use crate::model::transaction_options::ReadWrite;
18use crate::model::{
19    BeginTransactionRequest, CommitRequest, CommitResponse, MultiplexedSessionPrecommitToken,
20    Mutation as ProtoMutation, RequestOptions, TransactionOptions,
21};
22use crate::mutation::Mutation;
23use crate::transaction_retry_policy::{
24    BasicTransactionRetryPolicy, TransactionRetryPolicy, retry_aborted,
25};
26use bytes::Bytes;
27use google_cloud_gax::backoff_policy::BackoffPolicyArg;
28use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
29use google_cloud_gax::retry_policy::RetryPolicyArg;
30use std::sync::{Arc, Mutex};
31use wkt::Duration;
32
33/// A builder for [WriteOnlyTransaction].
34pub struct WriteOnlyTransactionBuilder {
35    client: DatabaseClient,
36    transaction_tag: Option<String>,
37    max_commit_delay: Option<Duration>,
38    retry_policy: Box<dyn TransactionRetryPolicy>,
39    exclude_txn_from_change_streams: bool,
40    return_commit_stats: bool,
41    commit_priority: Priority,
42    begin_gax_options: GaxRequestOptions,
43    commit_gax_options: GaxRequestOptions,
44}
45
46impl WriteOnlyTransactionBuilder {
47    pub(crate) fn new(client: DatabaseClient) -> Self {
48        Self {
49            client,
50            transaction_tag: None,
51            max_commit_delay: None,
52            retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
53            exclude_txn_from_change_streams: false,
54            return_commit_stats: false,
55            commit_priority: Priority::Unspecified,
56            begin_gax_options: GaxRequestOptions::default(),
57            commit_gax_options: GaxRequestOptions::default(),
58        }
59    }
60
61    /// Sets a transaction tag to be used for the transaction.
62    ///
63    /// # Example
64    /// ```
65    /// # use google_cloud_spanner::client::Spanner;
66    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
67    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
68    /// let transaction = db_client.write_only_transaction()
69    ///     .set_transaction_tag("my-tag")
70    ///     .build();
71    /// # Ok(())
72    /// # }
73    /// ```
74    ///
75    /// See also: [Troubleshooting with tags](https://docs.cloud.google.com/spanner/docs/introspection/troubleshooting-with-tags)
76    pub fn set_transaction_tag(mut self, tag: impl Into<String>) -> Self {
77        self.transaction_tag = Some(tag.into());
78        self
79    }
80
81    /// Sets the RPC priority to use for the commit of this transaction.
82    ///
83    /// # Example
84    /// ```
85    /// # use google_cloud_spanner::client::Spanner;
86    /// # use google_cloud_spanner::model::request_options::Priority;
87    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
88    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
89    /// let transaction = db_client.write_only_transaction()
90    ///     .set_commit_priority(Priority::Low)
91    ///     .build();
92    /// # Ok(())
93    /// # }
94    /// ```
95    pub fn set_commit_priority(mut self, priority: Priority) -> Self {
96        self.commit_priority = priority;
97        self
98    }
99
100    /// Sets the maximum commit delay for the transaction.
101    ///
102    /// # Example
103    /// ```
104    /// # use google_cloud_spanner::client::Spanner;
105    /// # use wkt::Duration;
106    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
107    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
108    /// let transaction = db_client.write_only_transaction()
109    ///     .set_max_commit_delay(Duration::try_from("0.1s").unwrap())
110    ///     .build();
111    /// # Ok(())
112    /// # }
113    /// ```
114    ///
115    /// This option allows you to specify the maximum amount of time Spanner can
116    /// adjust the commit timestamp of the transaction to allow for commit batching.
117    /// Increasing this value can increase throughput at the expense of latency.
118    /// The value must be between 0 and 500 milliseconds. If not set, or set to 0,
119    /// Spanner does not delay the commit.
120    pub fn set_max_commit_delay(mut self, delay: Duration) -> Self {
121        self.max_commit_delay = Some(delay);
122        self
123    }
124
125    /// Sets whether to exclude the transaction from change streams.
126    ///
127    /// # Example
128    /// ```
129    /// # use google_cloud_spanner::client::Spanner;
130    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
131    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
132    /// let transaction = db_client.write_only_transaction()
133    ///     .set_exclude_txn_from_change_streams(true)
134    ///     .build();
135    /// # Ok(())
136    /// # }
137    /// ```
138    ///
139    /// When set to `true`, it prevents modifications from this transaction from being tracked in change streams.
140    /// Note that this only affects change streams that have been created with the DDL option `allow_txn_exclusion = true`.
141    /// If `allow_txn_exclusion` is not set or set to `false` for a change stream, updates made within this transaction
142    /// are recorded in that change stream regardless of this setting.
143    ///
144    /// When set to `false` or not specified, modifications from this transaction are recorded in all change streams
145    /// tracking columns modified by this transaction.
146    pub fn set_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
147        self.exclude_txn_from_change_streams = exclude;
148        self
149    }
150
151    /// Sets whether to return commit stats for the transaction.
152    ///
153    /// # Example
154    /// ```
155    /// # use google_cloud_spanner::mutation::Mutation;
156    /// # use google_cloud_spanner::client::Spanner;
157    /// # async fn test_doc() -> Result<(), Box<dyn std::error::Error>> {
158    /// # let client = Spanner::builder().build().await?;
159    /// # let db = client.database_client("projects/p/instances/i/databases/d").build().await?;
160    /// let mutation = Mutation::new_insert_builder("Users")
161    ///     .set("UserId").to(&1)
162    ///     .build();
163    ///
164    /// let response = db.write_only_transaction()
165    ///     .set_return_commit_stats(true)
166    ///     .build()
167    ///     .write(vec![mutation])
168    ///     .await?;
169    ///
170    /// if let Some(stats) = response.commit_stats {
171    ///     println!("Mutation count: {}", stats.mutation_count);
172    /// }
173    /// # Ok(())
174    /// # }
175    /// ```
176    ///
177    /// See also: <https://docs.cloud.google.com/spanner/docs/commit-statistics>
178    pub fn set_return_commit_stats(mut self, return_stats: bool) -> Self {
179        self.return_commit_stats = return_stats;
180        self
181    }
182
183    /// Sets the retry policy for the transaction.
184    ///
185    /// # Example
186    /// ```
187    /// # use std::time::Duration;
188    /// # use google_cloud_spanner::client::Spanner;
189    /// # use google_cloud_spanner::transaction::BasicTransactionRetryPolicy;
190    /// # async fn build_tx(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
191    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
192    ///
193    /// let retry_policy = BasicTransactionRetryPolicy::new()
194    ///     .with_max_attempts(5)
195    ///     .with_total_timeout(Duration::from_secs(60));
196    ///
197    /// let transaction = db_client.write_only_transaction()
198    ///     .with_retry_policy(retry_policy)
199    ///     .build();
200    /// # Ok(())
201    /// # }
202    /// ```
203    ///
204    /// The client will retry the transaction if it is aborted by Spanner.
205    /// This policy can be used to customize whether a transaction should be retried
206    /// or not. The default is to retry indefinitely until the transaction succeeds.
207    pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
208        self.retry_policy = Box::new(policy);
209        self
210    }
211
212    /// Sets the per-attempt timeout for the BeginTransaction RPC.
213    ///
214    /// # Example
215    /// ```
216    /// # use google_cloud_spanner::client::Spanner;
217    /// # use std::time::Duration;
218    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
219    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
220    /// let transaction = db_client.write_only_transaction()
221    ///     .with_begin_attempt_timeout(Duration::from_secs(5))
222    ///     .build();
223    /// # Ok(())
224    /// # }
225    /// ```
226    pub fn with_begin_attempt_timeout(mut self, timeout: std::time::Duration) -> Self {
227        self.begin_gax_options.set_attempt_timeout(timeout);
228        self
229    }
230
231    /// Sets the retry policy for the BeginTransaction RPC.
232    ///
233    /// # Example
234    /// ```
235    /// # use google_cloud_spanner::client::Spanner;
236    /// # use google_cloud_gax::retry_policy::NeverRetry;
237    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
238    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
239    /// let transaction = db_client.write_only_transaction()
240    ///     .with_begin_retry_policy(NeverRetry)
241    ///     .build();
242    /// # Ok(())
243    /// # }
244    /// ```
245    pub fn with_begin_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
246        self.begin_gax_options.set_retry_policy(policy);
247        self
248    }
249
250    /// Sets the backoff policy for the BeginTransaction RPC.
251    ///
252    /// # Example
253    /// ```
254    /// # use google_cloud_spanner::client::Spanner;
255    /// # use google_cloud_gax::exponential_backoff::ExponentialBackoff;
256    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
257    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
258    /// let transaction = db_client.write_only_transaction()
259    ///     .with_begin_backoff_policy(ExponentialBackoff::default())
260    ///     .build();
261    /// # Ok(())
262    /// # }
263    /// ```
264    pub fn with_begin_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
265        self.begin_gax_options.set_backoff_policy(policy);
266        self
267    }
268
269    /// Sets the per-attempt timeout for the Commit RPC.
270    ///
271    /// # Example
272    /// ```
273    /// # use google_cloud_spanner::client::Spanner;
274    /// # use std::time::Duration;
275    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
276    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
277    /// let transaction = db_client.write_only_transaction()
278    ///     .with_commit_attempt_timeout(Duration::from_secs(5))
279    ///     .build();
280    /// # Ok(())
281    /// # }
282    /// ```
283    pub fn with_commit_attempt_timeout(mut self, timeout: std::time::Duration) -> Self {
284        self.commit_gax_options.set_attempt_timeout(timeout);
285        self
286    }
287
288    /// Sets the retry policy for the Commit RPC.
289    ///
290    /// # Example
291    /// ```
292    /// # use google_cloud_spanner::client::Spanner;
293    /// # use google_cloud_gax::retry_policy::NeverRetry;
294    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
295    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
296    /// let transaction = db_client.write_only_transaction()
297    ///     .with_commit_retry_policy(NeverRetry)
298    ///     .build();
299    /// # Ok(())
300    /// # }
301    /// ```
302    pub fn with_commit_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
303        self.commit_gax_options.set_retry_policy(policy);
304        self
305    }
306
307    /// Sets the backoff policy for the Commit RPC.
308    ///
309    /// # Example
310    /// ```
311    /// # use google_cloud_spanner::client::Spanner;
312    /// # use google_cloud_gax::exponential_backoff::ExponentialBackoff;
313    /// # async fn sample(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
314    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
315    /// let transaction = db_client.write_only_transaction()
316    ///     .with_commit_backoff_policy(ExponentialBackoff::default())
317    ///     .build();
318    /// # Ok(())
319    /// # }
320    /// ```
321    pub fn with_commit_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
322        self.commit_gax_options.set_backoff_policy(policy);
323        self
324    }
325
326    /// Builds the [WriteOnlyTransaction].
327    ///
328    /// # Example
329    /// ```
330    /// # use google_cloud_spanner::client::Spanner;
331    /// # async fn build(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
332    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
333    /// let transaction = db_client.write_only_transaction().build();
334    /// # Ok(())
335    /// # }
336    /// ```
337    pub fn build(self) -> WriteOnlyTransaction {
338        let session_name = self.client.session_name();
339        WriteOnlyTransaction {
340            session_name,
341            client: self.client,
342            transaction_tag: self.transaction_tag,
343            max_commit_delay: self.max_commit_delay,
344            retry_policy: self.retry_policy,
345            exclude_txn_from_change_streams: self.exclude_txn_from_change_streams,
346            return_commit_stats: self.return_commit_stats,
347            commit_priority: self.commit_priority,
348            begin_gax_options: self.begin_gax_options,
349            commit_gax_options: self.commit_gax_options,
350        }
351    }
352}
353
354/// A write-only transaction.
355///
356/// A write-only transaction can be used to execute blind writes.
357pub struct WriteOnlyTransaction {
358    pub(crate) session_name: String,
359    client: DatabaseClient,
360    transaction_tag: Option<String>,
361    max_commit_delay: Option<Duration>,
362    retry_policy: Box<dyn TransactionRetryPolicy>,
363    exclude_txn_from_change_streams: bool,
364    return_commit_stats: bool,
365    commit_priority: Priority,
366    begin_gax_options: GaxRequestOptions,
367    commit_gax_options: GaxRequestOptions,
368}
369
370impl WriteOnlyTransaction {
371    /// Writes a set of mutations atomically to Spanner.
372    ///
373    /// # Example
374    /// ```
375    /// # use google_cloud_spanner::mutation::Mutation;
376    /// # use google_cloud_spanner::client::Spanner;
377    /// # async fn test_doc() -> Result<(), Box<dyn std::error::Error>> {
378    /// let client = Spanner::builder().build().await?;
379    /// let db = client.database_client("projects/p/instances/i/databases/d").build().await?;
380    ///
381    /// let mutation = Mutation::new_insert_builder("Users")
382    ///     .set("UserId").to(&1)
383    ///     .set("UserName").to(&"Alice")
384    ///     .build();
385    ///
386    /// let response = db.write_only_transaction()
387    ///     .set_transaction_tag("my-tag")
388    ///     .build()
389    ///     .write(vec![mutation])
390    ///     .await?;
391    /// # Ok(())
392    /// # }
393    /// ```
394    ///
395    /// This method uses retries and replay protection internally, which means that the mutations
396    /// are applied exactly once on success, or not at all if an error is returned, regardless of
397    /// any failures in the underlying network. Note that if the call is cancelled or reaches
398    /// deadline, it is not possible to know whether the mutations were applied without performing
399    /// a subsequent database operation, but the mutations will have been applied at most once.
400    pub async fn write<I>(self, mutations: I) -> crate::Result<CommitResponse>
401    where
402        I: IntoIterator<Item = Mutation>,
403    {
404        let begin_gax_options = self.begin_gax_options();
405        let commit_gax_options = self.commit_gax_options();
406        let req_options = RequestOptions::default()
407            .set_transaction_tag(self.transaction_tag.unwrap_or_default())
408            .set_priority(self.commit_priority.clone());
409
410        let mutations_proto: Vec<_> = mutations.into_iter().map(|m| m.build_proto()).collect();
411        let mutation_key = Mutation::select_mutation_key(&mutations_proto);
412        let client = self.client;
413        let session_name = self.session_name.clone();
414        let previous_transaction_id = Arc::new(Mutex::new(Bytes::new()));
415        let channel_hint = client.spanner.next_channel_hint();
416
417        let max_commit_delay = self.max_commit_delay;
418        let return_commit_stats = self.return_commit_stats;
419        let is_emulator = client.is_emulator();
420
421        let action = || {
422            let client = client.clone();
423            let session_name = session_name.clone();
424            let req_options = req_options.clone();
425            let mutations_proto = mutations_proto.clone();
426            let mutation_key = mutation_key.clone();
427            let previous_transaction_id = previous_transaction_id.clone();
428            let begin_gax_options = begin_gax_options.clone();
429            let commit_gax_options = commit_gax_options.clone();
430
431            async move {
432                let previous_id: Bytes = previous_transaction_id.lock().unwrap().clone();
433
434                let begin_req = BeginTransactionRequest::default()
435                    .set_session(session_name.clone())
436                    .set_options(
437                        TransactionOptions::default()
438                            .set_read_write(Box::new(
439                                ReadWrite::default()
440                                    .set_multiplexed_session_previous_transaction_id(previous_id),
441                            ))
442                            .set_exclude_txn_from_change_streams(
443                                self.exclude_txn_from_change_streams,
444                            ),
445                    )
446                    .set_request_options(req_options.clone())
447                    .set_or_clear_mutation_key(mutation_key.clone());
448
449                let tx = client
450                    .spanner
451                    .begin_transaction(begin_req, begin_gax_options, channel_hint)
452                    .await?;
453                *previous_transaction_id.lock().unwrap() = tx.id.clone();
454
455                let commit_req = create_commit_request(
456                    session_name.clone(),
457                    tx.id.clone(),
458                    mutations_proto,
459                    tx.precommit_token,
460                    Some(req_options.clone()),
461                    max_commit_delay,
462                    return_commit_stats,
463                );
464
465                let response = client
466                    .spanner
467                    .commit(commit_req, commit_gax_options.clone(), channel_hint)
468                    .await?;
469
470                // If a commit_response with a precommit_token is returned, then we need to
471                // retry the commit with the new precommit_token and without any mutations.
472                if let Some(new_token) = response.precommit_token().map(|b| *b.clone()) {
473                    let retry_commit_req = create_commit_request(
474                        session_name.clone(),
475                        tx.id,
476                        Vec::new(),
477                        Some(new_token),
478                        Some(req_options),
479                        max_commit_delay,
480                        return_commit_stats,
481                    );
482                    client
483                        .spanner
484                        .commit(retry_commit_req, commit_gax_options, channel_hint)
485                        .await
486                } else {
487                    Ok(response)
488                }
489            }
490        };
491
492        retry_aborted(&*self.retry_policy, action, is_emulator).await
493    }
494
495    /// Writes a set of mutations at least once using a single Commit RPC.
496    ///
497    /// # Example
498    /// ```
499    /// # use google_cloud_spanner::mutation::Mutation;
500    /// # use google_cloud_spanner::client::Spanner;
501    /// # async fn test_doc() -> Result<(), Box<dyn std::error::Error>> {
502    /// let client = Spanner::builder().build().await?;
503    /// let db = client.database_client("projects/p/instances/i/databases/d").build().await?;
504    ///
505    /// let mutation = Mutation::new_insert_or_update_builder("Users")
506    ///     .set("UserId").to(&1)
507    ///     .set("UserName").to(&"Alice")
508    ///     .build();
509    ///
510    /// let response = db.write_only_transaction()
511    ///     .set_transaction_tag("my-tag")
512    ///     .build()
513    ///     .write_at_least_once(vec![mutation])
514    ///     .await?;
515    /// # Ok(())
516    /// # }
517    /// ```
518    ///
519    /// Since this method does not feature replay protection, it may attempt to apply the provided
520    /// mutations more than once. If the mutations are not idempotent, this may lead to a failure
521    /// being reported even if the mutation was applied successfully the first time. For example,
522    /// an insert may fail with an `AlreadyExists` error even though the row did not exist before
523    /// this method was called. For this reason, most users of the library will prefer to use write
524    /// transactions with replay protection instead.
525    /// However, `write_at_least_once` requires only a single RPC, whereas replay-protected
526    /// writes require two RPCs. Thus, this method may be appropriate for latency sensitive
527    /// and/or high throughput blind writing.
528    pub async fn write_at_least_once<I>(self, mutations: I) -> crate::Result<CommitResponse>
529    where
530        I: IntoIterator<Item = Mutation>,
531    {
532        let commit_gax_options = self.commit_gax_options();
533        let single_use = TransactionOptions::new()
534            .set_read_write(Box::new(ReadWrite::new()))
535            .set_exclude_txn_from_change_streams(self.exclude_txn_from_change_streams);
536        let req_options = RequestOptions::default()
537            .set_transaction_tag(self.transaction_tag.unwrap_or_default())
538            .set_priority(self.commit_priority.clone());
539        let request = CommitRequest::new()
540            .set_session(self.session_name.clone())
541            .set_mutations(mutations.into_iter().map(|m| m.build_proto()))
542            .set_single_use_transaction(Box::new(single_use))
543            .set_request_options(req_options)
544            .set_or_clear_max_commit_delay(self.max_commit_delay)
545            .set_return_commit_stats(self.return_commit_stats);
546        let client = self.client;
547        let channel_hint = client.spanner.next_channel_hint();
548        let is_emulator = client.is_emulator();
549
550        let action = || {
551            let client = client.clone();
552            let request = request.clone();
553            let commit_gax_options = commit_gax_options.clone();
554
555            async move {
556                client
557                    .spanner
558                    .commit(request, commit_gax_options, channel_hint)
559                    .await
560            }
561        };
562
563        retry_aborted(&*self.retry_policy, action, is_emulator).await
564    }
565
566    fn begin_gax_options(&self) -> GaxRequestOptions {
567        amend_request_options_for_lar(
568            self.client.leader_aware_routing_enabled,
569            self.begin_gax_options.clone(),
570        )
571    }
572
573    fn commit_gax_options(&self) -> GaxRequestOptions {
574        amend_request_options_for_lar(
575            self.client.leader_aware_routing_enabled,
576            self.commit_gax_options.clone(),
577        )
578    }
579}
580
581pub(crate) fn create_commit_request(
582    session_name: String,
583    transaction_id: bytes::Bytes,
584    mutations: Vec<ProtoMutation>,
585    precommit_token: Option<MultiplexedSessionPrecommitToken>,
586    request_options: Option<RequestOptions>,
587    max_commit_delay: Option<Duration>,
588    return_commit_stats: bool,
589) -> CommitRequest {
590    CommitRequest::default()
591        .set_session(session_name)
592        .set_transaction_id(transaction_id)
593        .set_mutations(mutations)
594        .set_or_clear_precommit_token(precommit_token)
595        .set_or_clear_request_options(request_options)
596        .set_or_clear_max_commit_delay(max_commit_delay)
597        .set_return_commit_stats(return_commit_stats)
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603    use crate::client::Spanner;
604    use crate::transaction_retry_policy::tests::create_aborted_status;
605    use gaxi::grpc::tonic::Response;
606    use google_cloud_gax::exponential_backoff::ExponentialBackoff;
607    use google_cloud_gax::retry_policy::NeverRetry;
608    use google_cloud_test_macros::tokio_test_no_panics;
609    use prost_types::Duration as ProstDuration;
610    use prost_types::Timestamp;
611    use spanner_grpc_mock::google::spanner::v1::CommitResponse;
612    use spanner_grpc_mock::google::spanner::v1::Session;
613    use spanner_grpc_mock::google::spanner::v1::Transaction;
614    use spanner_grpc_mock::google::spanner::v1::commit_response::CommitStats;
615    use spanner_grpc_mock::google::spanner::v1::transaction_options::Mode;
616    use std::time::Duration as StdDuration;
617    use wkt::Duration;
618
619    pub(crate) async fn setup_db_client(
620        mock: spanner_grpc_mock::MockSpanner,
621    ) -> (DatabaseClient, tokio::task::JoinHandle<()>) {
622        use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
623        let (address, server) = spanner_grpc_mock::start("0.0.0.0:0", mock)
624            .await
625            .expect("Failed to start mock server");
626        let spanner = Spanner::builder()
627            .with_endpoint(address)
628            .with_credentials(Anonymous::new().build())
629            .build()
630            .await
631            .expect("Failed to build client");
632
633        let db_client = spanner
634            .database_client("projects/p/instances/i/databases/d")
635            .build()
636            .await
637            .expect("Failed to create DatabaseClient");
638
639        (db_client, server)
640    }
641
642    #[tokio_test_no_panics]
643    async fn write_at_least_once() {
644        let mut mock = spanner_grpc_mock::MockSpanner::new();
645        mock.expect_create_session().returning(|_| {
646            Ok(gaxi::grpc::tonic::Response::new(
647                spanner_grpc_mock::google::spanner::v1::Session {
648                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
649                    ..Default::default()
650                },
651            ))
652        });
653
654        mock.expect_commit().once().returning(|req| {
655            let req = req.into_inner();
656            assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
657
658            // Validate the custom request options contain the transaction tag and priority
659            assert!(req.request_options.is_some());
660            let req_opts = req.request_options.as_ref().expect("request_options should be present");
661            assert_eq!(req_opts.transaction_tag, "my_tag");
662            assert_eq!(Priority::from(req_opts.priority), Priority::High);
663
664            assert!(req.mutations.len() == 1);
665
666            // Validate it's a single-use transaction configured correctly
667            match req.transaction {
668                Some(spanner_grpc_mock::google::spanner::v1::commit_request::Transaction::SingleUseTransaction(opts)) => {
669                    assert!(opts.mode.is_some());
670                }
671                _ => panic!("Expected SingleUseTransaction"),
672            }
673
674            Ok(gaxi::grpc::tonic::Response::new(
675                spanner_grpc_mock::google::spanner::v1::CommitResponse {
676                    commit_timestamp: Some(prost_types::Timestamp {
677                        seconds: 1234,
678                        nanos: 0,
679                    }),
680                    ..Default::default()
681                },
682            ))
683        });
684
685        let (db_client, _server) = setup_db_client(mock).await;
686
687        let mutation = Mutation::new_insert_or_update_builder("Users")
688            .set("UserId")
689            .to(&1)
690            .build();
691
692        let res = db_client
693            .write_only_transaction()
694            .set_transaction_tag("my_tag")
695            .set_commit_priority(Priority::High)
696            .build()
697            .write_at_least_once(vec![mutation])
698            .await;
699
700        assert!(res.is_ok());
701        let res = res.expect("write_at_least_once should succeed");
702        assert!(res.commit_timestamp.is_some());
703        assert_eq!(
704            res.commit_timestamp
705                .expect("commit_timestamp should be present")
706                .seconds(),
707            1234
708        );
709    }
710
711    #[tokio_test_no_panics]
712    async fn write() {
713        let mut mock = spanner_grpc_mock::MockSpanner::new();
714        mock.expect_create_session().returning(|_| {
715            Ok(gaxi::grpc::tonic::Response::new(
716                spanner_grpc_mock::google::spanner::v1::Session {
717                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
718                    ..Default::default()
719                },
720            ))
721        });
722
723        mock.expect_begin_transaction().once().returning(|req| {
724            let req = req.into_inner();
725            assert_eq!(
726                req.session,
727                "projects/p/instances/i/databases/d/sessions/123"
728            );
729            assert!(req.options.is_some());
730            assert!(req.mutation_key.is_some());
731
732            Ok(gaxi::grpc::tonic::Response::new(
733                spanner_grpc_mock::google::spanner::v1::Transaction {
734                    id: vec![42],
735                    precommit_token: Some(
736                        spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken {
737                            precommit_token: vec![1, 2, 3],
738                            seq_num: 1,
739                        },
740                    ),
741                    ..Default::default()
742                },
743            ))
744        });
745
746        mock.expect_commit().once().returning(|req| {
747            let req = req.into_inner();
748            assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
749            assert_eq!(
750                req.precommit_token.expect("precommit_token required").precommit_token,
751                vec![1, 2, 3]
752            );
753
754            // Validate that we pass down the transaction ID from BeginTransaction.
755            match req.transaction {
756                Some(spanner_grpc_mock::google::spanner::v1::commit_request::Transaction::TransactionId(tid)) => {
757                    assert_eq!(tid, vec![42]);
758                }
759                _ => panic!("Expected TransactionId"),
760            }
761
762            Ok(gaxi::grpc::tonic::Response::new(
763                spanner_grpc_mock::google::spanner::v1::CommitResponse {
764                    commit_timestamp: Some(prost_types::Timestamp {
765                        seconds: 5678,
766                        nanos: 0,
767                    }),
768                    ..Default::default()
769                },
770            ))
771        });
772
773        let (db_client, _server) = setup_db_client(mock).await;
774
775        let mutation = Mutation::new_insert_or_update_builder("Users")
776            .set("UserId")
777            .to(&1)
778            .build();
779
780        let res = db_client
781            .write_only_transaction()
782            .build()
783            .write(vec![mutation])
784            .await;
785
786        assert!(res.is_ok());
787        let res = res.expect("write should succeed");
788        assert!(res.commit_timestamp.is_some());
789        assert_eq!(
790            res.commit_timestamp
791                .expect("commit_timestamp should be present")
792                .seconds(),
793            5678
794        );
795    }
796
797    #[tokio_test_no_panics]
798    async fn write_at_least_once_with_commit_stats() -> anyhow::Result<()> {
799        let mut mock = spanner_grpc_mock::MockSpanner::new();
800        mock.expect_create_session().returning(|_| {
801            Ok(Response::new(Session {
802                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
803                ..Default::default()
804            }))
805        });
806
807        mock.expect_commit().once().returning(|req| {
808            let req = req.into_inner();
809            assert!(req.return_commit_stats);
810
811            Ok(Response::new(CommitResponse {
812                commit_timestamp: Some(prost_types::Timestamp {
813                    seconds: 1234,
814                    nanos: 0,
815                }),
816                commit_stats: Some(CommitStats { mutation_count: 5 }),
817                ..Default::default()
818            }))
819        });
820
821        let (db_client, _server) = setup_db_client(mock).await;
822
823        let mutation = Mutation::new_insert_or_update_builder("Users")
824            .set("UserId")
825            .to(&1)
826            .build();
827
828        let res = db_client
829            .write_only_transaction()
830            .set_return_commit_stats(true)
831            .build()
832            .write_at_least_once(vec![mutation])
833            .await?;
834
835        let stats = res.commit_stats.expect("Commit stats should be present");
836        assert_eq!(stats.mutation_count, 5);
837        Ok(())
838    }
839
840    #[tokio_test_no_panics]
841    async fn write_with_commit_stats() -> anyhow::Result<()> {
842        let mut mock = spanner_grpc_mock::MockSpanner::new();
843        mock.expect_create_session().returning(|_| {
844            Ok(Response::new(Session {
845                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
846                ..Default::default()
847            }))
848        });
849
850        mock.expect_begin_transaction().once().returning(|_| {
851            Ok(Response::new(Transaction {
852                id: vec![42],
853                ..Default::default()
854            }))
855        });
856
857        mock.expect_commit().once().returning(|req| {
858            let req = req.into_inner();
859            assert!(req.return_commit_stats);
860
861            Ok(Response::new(CommitResponse {
862                commit_timestamp: Some(prost_types::Timestamp {
863                    seconds: 5678,
864                    nanos: 0,
865                }),
866                commit_stats: Some(CommitStats { mutation_count: 10 }),
867                ..Default::default()
868            }))
869        });
870
871        let (db_client, _server) = setup_db_client(mock).await;
872
873        let mutation = Mutation::new_insert_or_update_builder("Users")
874            .set("UserId")
875            .to(&1)
876            .build();
877
878        let res = db_client
879            .write_only_transaction()
880            .set_return_commit_stats(true)
881            .build()
882            .write(vec![mutation])
883            .await?;
884
885        let stats = res.commit_stats.expect("Commit stats should be present");
886        assert_eq!(stats.mutation_count, 10);
887        Ok(())
888    }
889
890    #[tokio_test_no_panics]
891    async fn write_at_least_once_with_exclude_txn_from_change_streams() {
892        let mut mock = spanner_grpc_mock::MockSpanner::new();
893        mock.expect_create_session().returning(|_| {
894            Ok(gaxi::grpc::tonic::Response::new(
895                spanner_grpc_mock::google::spanner::v1::Session {
896                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
897                    ..Default::default()
898                },
899            ))
900        });
901
902        mock.expect_commit().once().returning(|req| {
903            let req = req.into_inner();
904            match req.transaction {
905                Some(spanner_grpc_mock::google::spanner::v1::commit_request::Transaction::SingleUseTransaction(opts)) => {
906                    assert!(opts.exclude_txn_from_change_streams);
907                }
908                _ => panic!("Expected SingleUseTransaction"),
909            }
910
911            Ok(gaxi::grpc::tonic::Response::new(
912                spanner_grpc_mock::google::spanner::v1::CommitResponse {
913                    commit_timestamp: Some(prost_types::Timestamp {
914                        seconds: 1234,
915                        nanos: 0,
916                    }),
917                    ..Default::default()
918                },
919            ))
920        });
921
922        let (db_client, _server) = setup_db_client(mock).await;
923
924        let mutation = Mutation::new_insert_or_update_builder("Users")
925            .set("UserId")
926            .to(&1)
927            .build();
928
929        let res = db_client
930            .write_only_transaction()
931            .set_exclude_txn_from_change_streams(true)
932            .build()
933            .write_at_least_once(vec![mutation])
934            .await;
935
936        assert!(res.is_ok());
937    }
938
939    #[tokio_test_no_panics]
940    async fn write_with_exclude_txn_from_change_streams() {
941        let mut mock = spanner_grpc_mock::MockSpanner::new();
942        mock.expect_create_session().returning(|_| {
943            Ok(gaxi::grpc::tonic::Response::new(
944                spanner_grpc_mock::google::spanner::v1::Session {
945                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
946                    ..Default::default()
947                },
948            ))
949        });
950
951        mock.expect_begin_transaction().once().returning(|req| {
952            let req = req.into_inner();
953            let options = req.options.expect("Missing transaction options");
954            assert!(options.exclude_txn_from_change_streams);
955
956            Ok(gaxi::grpc::tonic::Response::new(
957                spanner_grpc_mock::google::spanner::v1::Transaction {
958                    id: vec![42],
959                    ..Default::default()
960                },
961            ))
962        });
963
964        mock.expect_commit().once().returning(|_req| {
965            Ok(gaxi::grpc::tonic::Response::new(
966                spanner_grpc_mock::google::spanner::v1::CommitResponse {
967                    commit_timestamp: Some(prost_types::Timestamp {
968                        seconds: 5678,
969                        nanos: 0,
970                    }),
971                    ..Default::default()
972                },
973            ))
974        });
975
976        let (db_client, _server) = setup_db_client(mock).await;
977
978        let mutation = Mutation::new_insert_or_update_builder("Users")
979            .set("UserId")
980            .to(&1)
981            .build();
982
983        let res = db_client
984            .write_only_transaction()
985            .set_exclude_txn_from_change_streams(true)
986            .build()
987            .write(vec![mutation])
988            .await;
989
990        assert!(res.is_ok());
991    }
992
993    #[tokio_test_no_panics]
994    async fn write_with_commit_retry() {
995        let mut mock = spanner_grpc_mock::MockSpanner::new();
996        mock.expect_create_session().returning(|_| {
997            Ok(gaxi::grpc::tonic::Response::new(
998                spanner_grpc_mock::google::spanner::v1::Session {
999                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1000                    ..Default::default()
1001                },
1002            ))
1003        });
1004
1005        mock.expect_begin_transaction().once().returning(|req| {
1006            let req = req.into_inner();
1007            assert!(req.mutation_key.is_some());
1008
1009            Ok(gaxi::grpc::tonic::Response::new(
1010                spanner_grpc_mock::google::spanner::v1::Transaction {
1011                    id: vec![42],
1012                    ..Default::default()
1013                },
1014            ))
1015        });
1016
1017        let commit_call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
1018        mock.expect_commit().times(2).returning(move |req| {
1019            let count = commit_call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1020            let req = req.into_inner();
1021            assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
1022
1023            if count == 0 {
1024                assert!(!req.mutations.is_empty());
1025                Ok(gaxi::grpc::tonic::Response::new(
1026                    spanner_grpc_mock::google::spanner::v1::CommitResponse {
1027                        multiplexed_session_retry: Some(
1028                            spanner_grpc_mock::google::spanner::v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
1029                                spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken {
1030                                    precommit_token: vec![4, 5, 6],
1031                                    seq_num: 2,
1032                                }
1033                            )
1034                        ),
1035                        ..Default::default()
1036                    },
1037                ))
1038            } else {
1039                assert!(req.mutations.is_empty());
1040                assert_eq!(
1041                    req.precommit_token.expect("precommit_token required").precommit_token,
1042                    vec![4, 5, 6]
1043                );
1044                Ok(gaxi::grpc::tonic::Response::new(
1045                    spanner_grpc_mock::google::spanner::v1::CommitResponse {
1046                        commit_timestamp: Some(prost_types::Timestamp {
1047                            seconds: 9999,
1048                            nanos: 0,
1049                        }),
1050                        ..Default::default()
1051                    },
1052                ))
1053            }
1054        });
1055
1056        let (db_client, _server) = setup_db_client(mock).await;
1057
1058        let mutation = Mutation::new_insert_or_update_builder("Users")
1059            .set("UserId")
1060            .to(&1)
1061            .build();
1062
1063        let res = db_client
1064            .write_only_transaction()
1065            .build()
1066            .write(vec![mutation])
1067            .await;
1068
1069        assert!(res.is_ok());
1070        let res = res.expect("write should succeed");
1071        assert!(res.commit_timestamp.is_some());
1072        assert_eq!(
1073            res.commit_timestamp
1074                .expect("commit_timestamp should be present")
1075                .seconds(),
1076            9999
1077        );
1078    }
1079
1080    #[tokio_test_no_panics]
1081    async fn write_with_commit_retry_preserves_options() -> anyhow::Result<()> {
1082        let mut mock = spanner_grpc_mock::MockSpanner::new();
1083        mock.expect_create_session().returning(|_| {
1084            Ok(gaxi::grpc::tonic::Response::new(
1085                spanner_grpc_mock::google::spanner::v1::Session {
1086                    name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1087                    ..Default::default()
1088                },
1089            ))
1090        });
1091
1092        mock.expect_begin_transaction().once().returning(|req| {
1093            let req = req.into_inner();
1094            assert!(req.mutation_key.is_some());
1095
1096            Ok(gaxi::grpc::tonic::Response::new(
1097                spanner_grpc_mock::google::spanner::v1::Transaction {
1098                    id: vec![42],
1099                    ..Default::default()
1100                },
1101            ))
1102        });
1103
1104        let expected_delay = prost_types::Duration {
1105            seconds: 0,
1106            nanos: 200_000_000,
1107        };
1108
1109        let expected_delay_clone = expected_delay;
1110        let commit_call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
1111        mock.expect_commit().times(2).returning(move |req| {
1112            let count = commit_call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1113            let req = req.into_inner();
1114            assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123");
1115
1116            // Verify options are present in both attempts
1117            assert!(req.return_commit_stats, "Expected return_commit_stats to be true");
1118            assert_eq!(req.max_commit_delay.as_ref(), Some(&expected_delay_clone), "Expected max_commit_delay to be 200ms");
1119
1120            if count == 0 {
1121                assert!(!req.mutations.is_empty());
1122                Ok(gaxi::grpc::tonic::Response::new(
1123                    spanner_grpc_mock::google::spanner::v1::CommitResponse {
1124                        multiplexed_session_retry: Some(
1125                            spanner_grpc_mock::google::spanner::v1::commit_response::MultiplexedSessionRetry::PrecommitToken(
1126                                spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken {
1127                                    precommit_token: vec![4, 5, 6],
1128                                    seq_num: 2,
1129                                }
1130                            )
1131                        ),
1132                        ..Default::default()
1133                    },
1134                ))
1135            } else {
1136                assert!(req.mutations.is_empty(), "Expected mutations to be empty on retry");
1137                assert_eq!(
1138                    req.precommit_token.expect("precommit_token required").precommit_token,
1139                    vec![4, 5, 6]
1140                );
1141                Ok(gaxi::grpc::tonic::Response::new(
1142                    spanner_grpc_mock::google::spanner::v1::CommitResponse {
1143                        commit_timestamp: Some(prost_types::Timestamp {
1144                            seconds: 9999,
1145                            nanos: 0,
1146                        }),
1147                        commit_stats: Some(CommitStats { mutation_count: 12 }),
1148                        ..Default::default()
1149                    },
1150                ))
1151            }
1152        });
1153
1154        let (db_client, _server) = setup_db_client(mock).await;
1155
1156        let mutation = Mutation::new_insert_or_update_builder("Users")
1157            .set("UserId")
1158            .to(&1)
1159            .build();
1160
1161        let res = db_client
1162            .write_only_transaction()
1163            .set_return_commit_stats(true)
1164            .set_max_commit_delay(Duration::new(0, 200_000_000).expect("valid duration"))
1165            .build()
1166            .write(vec![mutation])
1167            .await?;
1168
1169        let stats = res.commit_stats.expect("Expected commit stats in response");
1170        assert_eq!(stats.mutation_count, 12);
1171        assert_eq!(
1172            res.commit_timestamp
1173                .expect("timestamp should be present")
1174                .seconds(),
1175            9999
1176        );
1177
1178        Ok(())
1179    }
1180
1181    #[tokio_test_no_panics]
1182    async fn write_with_commit_aborted_retry() -> anyhow::Result<()> {
1183        let mut mock = spanner_grpc_mock::MockSpanner::new();
1184        mock.expect_create_session().returning(|_| {
1185            Ok(Response::new(Session {
1186                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1187                ..Default::default()
1188            }))
1189        });
1190
1191        let mut seq = mockall::Sequence::new();
1192
1193        mock.expect_begin_transaction()
1194            .once()
1195            .in_sequence(&mut seq)
1196            .returning(move |req| {
1197                let req = req.into_inner();
1198                assert!(req.mutation_key.is_some());
1199
1200                Ok(Response::new(Transaction {
1201                    id: vec![42],
1202                    ..Default::default()
1203                }))
1204            });
1205
1206        mock.expect_commit()
1207            .once()
1208            .in_sequence(&mut seq)
1209            .returning(move |_req| Err(create_aborted_status(std::time::Duration::from_nanos(1))));
1210
1211        mock.expect_begin_transaction()
1212            .once()
1213            .in_sequence(&mut seq)
1214            .returning(move |req| {
1215                let req = req.into_inner();
1216                assert!(req.mutation_key.is_some());
1217
1218                let options = req.options.as_ref().expect("options required on retry");
1219                let read_write = options.mode.as_ref().expect("mode required on retry");
1220                match read_write {
1221                    Mode::ReadWrite(rw) => {
1222                        assert_eq!(rw.multiplexed_session_previous_transaction_id, vec![42], "previous_transaction_id should be set to the ID of the aborted transaction");
1223                    }
1224                    _ => panic!("Expected ReadWrite mode"),
1225                }
1226
1227                Ok(Response::new(Transaction {
1228                    id: vec![42],
1229                    ..Default::default()
1230                }))
1231            });
1232
1233        mock.expect_commit()
1234            .once()
1235            .in_sequence(&mut seq)
1236            .returning(move |_req| {
1237                Ok(Response::new(CommitResponse {
1238                    commit_timestamp: Some(Timestamp {
1239                        seconds: 8888,
1240                        nanos: 0,
1241                    }),
1242                    ..Default::default()
1243                }))
1244            });
1245
1246        let (db_client, _server) = setup_db_client(mock).await;
1247
1248        let mutation = Mutation::new_insert_or_update_builder("Users")
1249            .set("UserId")
1250            .to(&1)
1251            .build();
1252
1253        let res = db_client
1254            .write_only_transaction()
1255            .build()
1256            .write(vec![mutation])
1257            .await;
1258
1259        let res = res.expect("write should succeed");
1260        assert_eq!(
1261            res.commit_timestamp
1262                .expect("commit_timestamp should be present")
1263                .seconds(),
1264            8888,
1265            "expected commit timestamp to match"
1266        );
1267        Ok(())
1268    }
1269
1270    #[tokio_test_no_panics]
1271    async fn write_at_least_once_with_max_commit_delay() {
1272        let mut mock = spanner_grpc_mock::MockSpanner::new();
1273        mock.expect_create_session().returning(|_| {
1274            Ok(Response::new(Session {
1275                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1276                ..Default::default()
1277            }))
1278        });
1279
1280        mock.expect_commit().once().returning(|req| {
1281            let req = req.into_inner();
1282            assert_eq!(
1283                req.session,
1284                "projects/p/instances/i/databases/d/sessions/123"
1285            );
1286            assert_eq!(
1287                req.max_commit_delay,
1288                Some(ProstDuration {
1289                    seconds: 0,
1290                    nanos: 100_000_000, // 100ms
1291                })
1292            );
1293
1294            Ok(Response::new(CommitResponse {
1295                commit_timestamp: Some(Timestamp {
1296                    seconds: 1234,
1297                    nanos: 0,
1298                }),
1299                ..Default::default()
1300            }))
1301        });
1302
1303        let (db_client, _server) = setup_db_client(mock).await;
1304
1305        let mutation = Mutation::new_insert_or_update_builder("Users")
1306            .set("UserId")
1307            .to(&1)
1308            .build();
1309
1310        let res = db_client
1311            .write_only_transaction()
1312            .set_max_commit_delay(Duration::try_from("0.1s").unwrap())
1313            .build()
1314            .write_at_least_once(vec![mutation])
1315            .await;
1316
1317        assert!(res.is_ok());
1318    }
1319
1320    #[tokio_test_no_panics]
1321    async fn leader_aware_routing_enabled_by_default() {
1322        let mut mock = spanner_grpc_mock::MockSpanner::new();
1323        mock.expect_create_session().returning(|_| {
1324            Ok(Response::new(Session {
1325                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1326                ..Default::default()
1327            }))
1328        });
1329
1330        mock.expect_commit().once().returning(|req| {
1331            assert_eq!(
1332                req.metadata()
1333                    .get("x-goog-spanner-route-to-leader")
1334                    .expect("header required")
1335                    .to_str()
1336                    .unwrap(),
1337                "true"
1338            );
1339            Ok(Response::new(CommitResponse {
1340                commit_timestamp: Some(Timestamp {
1341                    seconds: 1234,
1342                    nanos: 0,
1343                }),
1344                ..Default::default()
1345            }))
1346        });
1347
1348        let (db_client, _server) = setup_db_client(mock).await;
1349        let mutation = Mutation::new_insert_or_update_builder("Users")
1350            .set("UserId")
1351            .to(&1)
1352            .build();
1353        let res = db_client
1354            .write_only_transaction()
1355            .build()
1356            .write_at_least_once(vec![mutation])
1357            .await;
1358        assert!(res.is_ok());
1359    }
1360
1361    #[tokio_test_no_panics]
1362    async fn write_only_transaction_builder_sets_gax_options() -> anyhow::Result<()> {
1363        let mut mock = spanner_grpc_mock::MockSpanner::new();
1364        mock.expect_create_session().returning(|_| {
1365            Ok(Response::new(Session {
1366                name: "session".to_string(),
1367                ..Default::default()
1368            }))
1369        });
1370        let (db_client, _server) = setup_db_client(mock).await;
1371
1372        let builder = db_client
1373            .write_only_transaction()
1374            .with_begin_attempt_timeout(StdDuration::from_secs(5))
1375            .with_begin_retry_policy(NeverRetry)
1376            .with_begin_backoff_policy(ExponentialBackoff::default())
1377            .with_commit_attempt_timeout(StdDuration::from_secs(10))
1378            .with_commit_retry_policy(NeverRetry)
1379            .with_commit_backoff_policy(ExponentialBackoff::default());
1380
1381        let begin_gax = &builder.begin_gax_options;
1382        assert_eq!(
1383            *begin_gax.attempt_timeout(),
1384            Some(StdDuration::from_secs(5))
1385        );
1386        assert!(begin_gax.retry_policy().is_some());
1387        assert!(begin_gax.backoff_policy().is_some());
1388
1389        let commit_gax = &builder.commit_gax_options;
1390        assert_eq!(
1391            *commit_gax.attempt_timeout(),
1392            Some(StdDuration::from_secs(10))
1393        );
1394        assert!(commit_gax.retry_policy().is_some());
1395        assert!(commit_gax.backoff_policy().is_some());
1396
1397        Ok(())
1398    }
1399
1400    fn parse_grpc_timeout(metadata: &gaxi::grpc::tonic::MetadataMap) -> Option<StdDuration> {
1401        let timeout_header = metadata.get("grpc-timeout")?.to_str().ok()?;
1402        let numeric_part: String = timeout_header
1403            .chars()
1404            .take_while(|c| c.is_ascii_digit())
1405            .collect();
1406        let value = numeric_part.parse::<u64>().ok()?;
1407        let unit = timeout_header.trim_start_matches(&numeric_part);
1408        let duration = match unit {
1409            "u" => StdDuration::from_micros(value),
1410            "m" => StdDuration::from_millis(value),
1411            "S" => StdDuration::from_secs(value),
1412            "M" => StdDuration::from_secs(value * 60),
1413            "H" => StdDuration::from_secs(value * 3600),
1414            _ => return None,
1415        };
1416        Some(duration)
1417    }
1418
1419    #[tokio_test_no_panics]
1420    async fn write_only_transaction_with_custom_options() -> anyhow::Result<()> {
1421        let mut mock = spanner_grpc_mock::MockSpanner::new();
1422        mock.expect_create_session().returning(|_| {
1423            Ok(Response::new(Session {
1424                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1425                ..Default::default()
1426            }))
1427        });
1428
1429        let mut seq = mockall::Sequence::new();
1430
1431        mock.expect_begin_transaction()
1432            .once()
1433            .in_sequence(&mut seq)
1434            .withf(|req| {
1435                let duration =
1436                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
1437                assert_eq!(duration, StdDuration::from_secs(5));
1438                true
1439            })
1440            .returning(|_| {
1441                Ok(Response::new(Transaction {
1442                    id: vec![42],
1443                    ..Default::default()
1444                }))
1445            });
1446
1447        mock.expect_commit()
1448            .once()
1449            .in_sequence(&mut seq)
1450            .withf(|req| {
1451                let duration =
1452                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
1453                assert_eq!(duration, StdDuration::from_secs(10));
1454                true
1455            })
1456            .returning(|_| {
1457                Ok(Response::new(CommitResponse {
1458                    commit_timestamp: Some(Timestamp {
1459                        seconds: 8888,
1460                        nanos: 0,
1461                    }),
1462                    ..Default::default()
1463                }))
1464            });
1465
1466        let (db_client, _server) = setup_db_client(mock).await;
1467
1468        let mutation = Mutation::new_insert_or_update_builder("Users")
1469            .set("UserId")
1470            .to(&1)
1471            .build();
1472
1473        let res = db_client
1474            .write_only_transaction()
1475            .with_begin_attempt_timeout(StdDuration::from_secs(5))
1476            .with_commit_attempt_timeout(StdDuration::from_secs(10))
1477            .build()
1478            .write(vec![mutation])
1479            .await?;
1480
1481        assert_eq!(
1482            res.commit_timestamp
1483                .expect("commit_timestamp should be present")
1484                .seconds(),
1485            8888
1486        );
1487        Ok(())
1488    }
1489
1490    #[tokio_test_no_panics]
1491    async fn write_at_least_once_with_custom_commit_options() -> anyhow::Result<()> {
1492        let mut mock = spanner_grpc_mock::MockSpanner::new();
1493        mock.expect_create_session().returning(|_| {
1494            Ok(Response::new(Session {
1495                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1496                ..Default::default()
1497            }))
1498        });
1499
1500        mock.expect_begin_transaction().never();
1501
1502        mock.expect_commit()
1503            .once()
1504            .withf(|req| {
1505                let duration =
1506                    parse_grpc_timeout(req.metadata()).expect("valid grpc-timeout header");
1507                assert_eq!(duration, StdDuration::from_secs(7));
1508                true
1509            })
1510            .returning(|_| {
1511                Ok(Response::new(CommitResponse {
1512                    commit_timestamp: Some(Timestamp {
1513                        seconds: 7777,
1514                        nanos: 0,
1515                    }),
1516                    ..Default::default()
1517                }))
1518            });
1519
1520        let (db_client, _server) = setup_db_client(mock).await;
1521
1522        let mutation = Mutation::new_insert_or_update_builder("Users")
1523            .set("UserId")
1524            .to(&1)
1525            .build();
1526
1527        let res = db_client
1528            .write_only_transaction()
1529            .with_commit_attempt_timeout(StdDuration::from_secs(7))
1530            .build()
1531            .write_at_least_once(vec![mutation])
1532            .await?;
1533
1534        assert_eq!(
1535            res.commit_timestamp
1536                .expect("commit_timestamp should be present")
1537                .seconds(),
1538            7777
1539        );
1540        Ok(())
1541    }
1542}