Skip to main content

google_cloud_spanner/
partitioned_dml_transaction.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::client::amend_request_options_for_lar;
16use crate::database_client::DatabaseClient;
17use crate::google::spanner::v1::result_set_stats::RowCount::RowCountLowerBound;
18use crate::model::transaction_options::PartitionedDml;
19use crate::model::{
20    BeginTransactionRequest, TransactionOptions, TransactionSelector, transaction_selector,
21};
22use crate::server_streaming::stream::PartialResultSetStream;
23use crate::statement::Statement;
24use crate::transaction_retry_policy::{
25    BasicTransactionRetryPolicy, TransactionRetryPolicy, retry_aborted,
26};
27use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
28
29/// A builder for [PartitionedDmlTransaction].
30///
31/// # Example
32/// ```
33/// # use google_cloud_spanner::client::Spanner;
34/// # use google_cloud_spanner::statement::Statement;
35/// # async fn build_transaction(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
36///     let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
37///     let transaction = db_client.partitioned_dml_transaction().build().await?;
38///     let statement = Statement::builder("UPDATE users SET active = true WHERE TRUE").build();
39///     let modified_rows = transaction.execute_update(statement).await?;
40/// #   Ok(())
41/// # }
42/// ```
43pub struct PartitionedDmlTransactionBuilder {
44    client: DatabaseClient,
45    retry_policy: Box<dyn TransactionRetryPolicy>,
46    exclude_txn_from_change_streams: bool,
47}
48
49impl PartitionedDmlTransactionBuilder {
50    pub(crate) fn new(client: DatabaseClient) -> Self {
51        Self {
52            client,
53            retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
54            exclude_txn_from_change_streams: false,
55        }
56    }
57
58    /// Sets whether to exclude the transaction from change streams.
59    ///
60    /// # Example
61    /// ```
62    /// # use google_cloud_spanner::client::Spanner;
63    /// # async fn build_transaction(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
64    ///     let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
65    ///     let transaction = db_client
66    ///         .partitioned_dml_transaction()
67    ///         .with_exclude_txn_from_change_streams(true)
68    ///         .build()
69    ///         .await?;
70    /// #   Ok(())
71    /// # }
72    /// ```
73    ///
74    /// When set to `true`, it prevents modifications from this transaction from being tracked in change streams.
75    /// Note that this only affects change streams that have been created with the DDL option `allow_txn_exclusion = true`.
76    /// If `allow_txn_exclusion` is not set or set to `false` for a change stream, updates made within this transaction
77    /// are recorded in that change stream regardless of this setting.
78    ///
79    /// When set to `false` or not specified, modifications from this transaction are recorded in all change streams
80    /// tracking columns modified by this transaction.
81    pub fn with_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
82        self.exclude_txn_from_change_streams = exclude;
83        self
84    }
85
86    /// Sets the retry policy for the transaction.
87    ///
88    /// # Example
89    /// ```
90    /// # use std::time::Duration;
91    /// # use google_cloud_spanner::client::Spanner;
92    /// # use google_cloud_spanner::transaction::BasicTransactionRetryPolicy;
93    /// # async fn build_transaction(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
94    ///     let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
95    ///     
96    ///     let retry_policy = BasicTransactionRetryPolicy::new()
97    ///         .with_max_attempts(5)
98    ///         .with_total_timeout(Duration::from_secs(60));
99    ///
100    ///     let transaction = db_client
101    ///         .partitioned_dml_transaction()
102    ///         .with_retry_policy(retry_policy)
103    ///         .build()
104    ///         .await?;
105    /// #   Ok(())
106    /// # }
107    /// ```
108    ///
109    /// The client will retry the entire transaction if it is aborted by Spanner.
110    /// This policy can be used to customize whether a transaction should be retried
111    /// or not. The default is to retry indefinitely until the transaction succeeds.
112    pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
113        self.retry_policy = Box::new(policy);
114        self
115    }
116
117    /// Builds the [PartitionedDmlTransaction].
118    pub async fn build(self) -> crate::Result<PartitionedDmlTransaction> {
119        Ok(PartitionedDmlTransaction {
120            client: self.client,
121            retry_policy: self.retry_policy,
122            exclude_txn_from_change_streams: self.exclude_txn_from_change_streams,
123        })
124    }
125}
126
127/// A Partitioned DML transaction.
128///
129/// Partitioned DML transactions are used to execute a single DML statement that may modify a large
130/// number of rows. The execution of the statement will automatically be partitioned into smaller
131/// transactions by Spanner, which may execute in parallel.
132///
133/// A Partitioned DML transaction cannot be committed or rolled back.
134///
135/// See also: <https://docs.cloud.google.com/spanner/docs/dml-partitioned>
136pub struct PartitionedDmlTransaction {
137    client: DatabaseClient,
138    retry_policy: Box<dyn TransactionRetryPolicy>,
139    exclude_txn_from_change_streams: bool,
140}
141
142impl PartitionedDmlTransaction {
143    /// Executes a Partitioned DML statement.
144    ///
145    /// # Example
146    /// ```
147    /// # use google_cloud_spanner::client::Spanner;
148    /// # use google_cloud_spanner::statement::Statement;
149    /// # async fn run(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
150    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
151    /// let transaction = db_client.partitioned_dml_transaction().build().await?;
152    /// let statement = Statement::builder("UPDATE users SET active = true WHERE TRUE").build();
153    /// let modified_rows = transaction.execute_update(statement).await?;
154    /// # Ok(())
155    /// # }
156    /// ```
157    ///
158    /// # Return
159    ///
160    /// The number of rows that was at least modified by the statement. Note that the actual number
161    /// of rows that was modified may be higher than this number if the statement was retried or
162    /// split into multiple transactions by Spanner, and some of these (sub)transactions were
163    /// executed multiple times.
164    ///
165    /// See also: <https://docs.cloud.google.com/spanner/docs/dml-partitioned>
166    pub async fn execute_update<T: Into<Statement>>(self, statement: T) -> crate::Result<i64> {
167        let statement = statement.into();
168        let mut gax_options = statement.gax_options().clone();
169        self.amend_gax_options(&mut gax_options);
170
171        let session_name = self.client.session_name();
172        let transaction_options = TransactionOptions::default()
173            .set_partitioned_dml(PartitionedDml::default())
174            .set_exclude_txn_from_change_streams(self.exclude_txn_from_change_streams);
175        let begin_request = BeginTransactionRequest {
176            session: session_name.clone(),
177            options: Some(transaction_options),
178            ..Default::default()
179        };
180        let base_request = statement.into_request();
181        let channel_hint = self.client.spanner.next_channel_hint();
182        let client = self.client;
183
184        // Execute the statement and retry if the transaction is aborted by Spanner.
185        retry_aborted(&*self.retry_policy, || {
186            let begin_request = begin_request.clone();
187            let base_request = base_request.clone();
188            let session_name = session_name.clone();
189            let gax_options = gax_options.clone();
190            let client = client.clone();
191
192            async move {
193                let transaction = client
194                    .spanner
195                    .begin_transaction(begin_request, gax_options.clone(), channel_hint)
196                    .await?;
197
198                let execute_request =
199                    base_request
200                        .set_session(session_name)
201                        .set_transaction(TransactionSelector {
202                            selector: Some(transaction_selector::Selector::Id(
203                                transaction.id.clone(),
204                            )),
205                            ..Default::default()
206                        });
207
208                let stream_builder = client.spanner.execute_streaming_sql(
209                    execute_request,
210                    gax_options,
211                    channel_hint,
212                );
213                let stream = stream_builder.send().await?;
214
215                extract_lower_bound_update_count_from_stream(stream).await
216            }
217        })
218        .await
219    }
220
221    fn amend_gax_options(&self, options: &mut GaxRequestOptions) {
222        *options = amend_request_options_for_lar(
223            self.client.leader_aware_routing_enabled,
224            options.clone(),
225        );
226    }
227}
228
229/// Reads through the stream of `PartialResultSet` messages returned by the execution
230/// of a Partitioned DML statement and extracts the `row_count_lower_bound` from the
231/// query statistics. If the execution is successful but no lower bound is found,
232/// an internal error is returned.
233async fn extract_lower_bound_update_count_from_stream(
234    mut stream: PartialResultSetStream,
235) -> crate::Result<i64> {
236    let mut lower_bound: Option<i64> = None;
237    while let Some(prs) = stream.next_message().await.transpose()? {
238        if let Some(RowCountLowerBound(val)) = prs.stats.and_then(|s| s.row_count) {
239            lower_bound = Some(val);
240        }
241    }
242    lower_bound.ok_or_else(|| {
243        crate::Error::deser(crate::error::SpannerInternalError::new(
244            "ExecuteStreamingSql completed successfully but no row_count_lower_bound was returned",
245        ))
246    })
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
253    use crate::result_set::tests::adapt;
254    use crate::transaction_retry_policy::tests::create_aborted_status;
255    use gaxi::grpc::tonic;
256    use google_cloud_test_macros::tokio_test_no_panics;
257    use spanner_grpc_mock::google::spanner::v1;
258
259    #[test]
260    fn auto_traits() {
261        static_assertions::assert_impl_all!(PartitionedDmlTransactionBuilder: Send, Sync);
262        static_assertions::assert_impl_all!(PartitionedDmlTransaction: Send, Sync);
263    }
264
265    #[tokio_test_no_panics]
266    async fn execute_update_success() {
267        let mut mock = create_session_mock();
268
269        mock.expect_begin_transaction().once().returning(|req| {
270            let req = req.into_inner();
271            assert_eq!(
272                req.session,
273                "projects/p/instances/i/databases/d/sessions/123"
274            );
275            Ok(tonic::Response::new(v1::Transaction {
276                id: vec![0, 1, 2],
277                ..Default::default()
278            }))
279        });
280
281        mock.expect_execute_streaming_sql().once().returning(|req| {
282            let req = req.into_inner();
283            assert_eq!(req.sql, "UPDATE Users SET active = true");
284
285            let stream = adapt([Ok(v1::PartialResultSet {
286                stats: Some(v1::ResultSetStats {
287                    row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
288                    ..Default::default()
289                }),
290                ..Default::default()
291            })]);
292            Ok(tonic::Response::from(stream))
293        });
294
295        let (db_client, _server) = setup_db_client(mock).await;
296        let transaction = db_client
297            .partitioned_dml_transaction()
298            .build()
299            .await
300            .unwrap();
301        let statement = Statement::builder("UPDATE Users SET active = true").build();
302        let res: i64 = transaction.execute_update(statement).await.unwrap();
303        assert_eq!(res, 500);
304    }
305
306    #[tokio_test_no_panics]
307    async fn execute_update_with_exclude_txn_from_change_streams() {
308        let mut mock = create_session_mock();
309
310        mock.expect_begin_transaction().once().returning(|req| {
311            let req = req.into_inner();
312            let options = req.options.expect("missing transaction options");
313            assert!(options.exclude_txn_from_change_streams);
314
315            Ok(tonic::Response::new(v1::Transaction {
316                id: vec![0, 1, 2],
317                ..Default::default()
318            }))
319        });
320
321        mock.expect_execute_streaming_sql()
322            .once()
323            .returning(|_req| {
324                let stream = adapt([Ok(v1::PartialResultSet {
325                    stats: Some(v1::ResultSetStats {
326                        row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
327                        ..Default::default()
328                    }),
329                    ..Default::default()
330                })]);
331                Ok(tonic::Response::from(stream))
332            });
333
334        let (db_client, _server) = setup_db_client(mock).await;
335        let transaction = db_client
336            .partitioned_dml_transaction()
337            .with_exclude_txn_from_change_streams(true)
338            .build()
339            .await
340            .unwrap();
341        let statement = Statement::builder("UPDATE Users SET active = true").build();
342        let res: i64 = transaction.execute_update(statement).await.unwrap();
343        assert_eq!(res, 500);
344    }
345
346    #[tokio_test_no_panics]
347    async fn execute_update_with_aborted_retry() {
348        let mut mock = create_session_mock();
349
350        mock.expect_begin_transaction().times(2).returning(|_req| {
351            Ok(tonic::Response::new(v1::Transaction {
352                id: vec![0, 1, 2],
353                ..Default::default()
354            }))
355        });
356
357        let mut seq = mockall::Sequence::new();
358        mock.expect_execute_streaming_sql()
359            .times(1)
360            .in_sequence(&mut seq)
361            .returning(move |_req| {
362                // Return an error stream on first try
363                let stream = adapt([Err(create_aborted_status(std::time::Duration::from_nanos(
364                    1,
365                )))]);
366                Ok(tonic::Response::from(stream))
367            });
368        mock.expect_execute_streaming_sql()
369            .times(1)
370            .in_sequence(&mut seq)
371            .returning(move |_req| {
372                let stream = adapt([Ok(v1::PartialResultSet {
373                    stats: Some(v1::ResultSetStats {
374                        row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(100)),
375                        ..Default::default()
376                    }),
377                    ..Default::default()
378                })]);
379                Ok(tonic::Response::from(stream))
380            });
381
382        let (db_client, _server) = setup_db_client(mock).await;
383        let transaction = db_client
384            .partitioned_dml_transaction()
385            .build()
386            .await
387            .unwrap();
388        let res: i64 = transaction
389            .execute_update(Statement::builder("UPDATE Users SET active = true").build())
390            .await
391            .unwrap();
392        assert_eq!(res, 100);
393    }
394
395    #[tokio_test_no_panics]
396    async fn builder_with_retry_settings() {
397        let mock = create_session_mock();
398        let (db_client, _server) = setup_db_client(mock).await;
399
400        let policy = BasicTransactionRetryPolicy::new()
401            .with_max_attempts(10)
402            .with_total_timeout(std::time::Duration::from_secs(42));
403
404        let _transaction = db_client
405            .partitioned_dml_transaction()
406            .with_retry_policy(policy)
407            .build()
408            .await
409            .unwrap();
410    }
411
412    #[tokio_test_no_panics]
413    async fn execute_update_missing_lower_bound() {
414        let mut mock = create_session_mock();
415
416        mock.expect_begin_transaction().once().returning(|_req| {
417            Ok(tonic::Response::new(v1::Transaction {
418                id: vec![0, 1, 2],
419                ..Default::default()
420            }))
421        });
422
423        mock.expect_execute_streaming_sql()
424            .once()
425            .returning(|_req| {
426                let stream = adapt([Ok(v1::PartialResultSet {
427                    stats: Some(v1::ResultSetStats {
428                        // Provide a RowCountExact instead of RowCountLowerBound
429                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(100)),
430                        ..Default::default()
431                    }),
432                    ..Default::default()
433                })]);
434                Ok(tonic::Response::from(stream))
435            });
436
437        let (db_client, _server) = setup_db_client(mock).await;
438        let transaction = db_client
439            .partitioned_dml_transaction()
440            .build()
441            .await
442            .unwrap();
443
444        let statement = Statement::builder("UPDATE Users SET active = true").build();
445        let res = transaction.execute_update(statement).await;
446
447        assert!(res.is_err());
448        let err = res.unwrap_err();
449        assert!(err.is_deserialization());
450        assert!(
451            err.to_string()
452                .contains("no row_count_lower_bound was returned")
453        );
454    }
455
456    #[tokio_test_no_panics]
457    async fn leader_aware_routing_enabled_by_default() {
458        let mut mock = create_session_mock();
459        mock.expect_begin_transaction().once().returning(|req| {
460            assert_eq!(
461                req.metadata()
462                    .get("x-goog-spanner-route-to-leader")
463                    .expect("header required")
464                    .to_str()
465                    .unwrap(),
466                "true"
467            );
468            Ok(tonic::Response::new(v1::Transaction {
469                id: vec![0, 1, 2],
470                ..Default::default()
471            }))
472        });
473
474        mock.expect_execute_streaming_sql().once().returning(|req| {
475            assert_eq!(
476                req.metadata()
477                    .get("x-goog-spanner-route-to-leader")
478                    .expect("header required")
479                    .to_str()
480                    .unwrap(),
481                "true"
482            );
483            let stream = adapt([Ok(v1::PartialResultSet {
484                stats: Some(v1::ResultSetStats {
485                    row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
486                    ..Default::default()
487                }),
488                ..Default::default()
489            })]);
490            Ok(tonic::Response::from(stream))
491        });
492
493        let (db_client, _server) = setup_db_client(mock).await;
494        let transaction = db_client
495            .partitioned_dml_transaction()
496            .build()
497            .await
498            .unwrap();
499        let statement = Statement::builder("UPDATE Users SET active = true").build();
500        let res: i64 = transaction.execute_update(statement).await.unwrap();
501        assert_eq!(res, 500);
502    }
503}