Skip to main content

google_cloud_spanner/
partitioned_dml_transaction.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::client::amend_request_options_for_lar;
16use crate::database_client::DatabaseClient;
17use crate::google::spanner::v1::result_set_stats::RowCount::RowCountLowerBound;
18use crate::model::transaction_options::PartitionedDml;
19use crate::model::{
20    BeginTransactionRequest, TransactionOptions, TransactionSelector, transaction_selector,
21};
22use crate::server_streaming::stream::PartialResultSetStream;
23use crate::statement::Statement;
24use crate::transaction_retry_policy::{
25    BasicTransactionRetryPolicy, TransactionRetryPolicy, retry_aborted,
26};
27use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
28
29/// A builder for [PartitionedDmlTransaction].
30///
31/// # Example
32/// ```
33/// # use google_cloud_spanner::client::Spanner;
34/// # use google_cloud_spanner::statement::Statement;
35/// # async fn build_transaction(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
36///     let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
37///     let transaction = db_client.partitioned_dml_transaction().build().await?;
38///     let statement = Statement::builder("UPDATE users SET active = true WHERE TRUE").build();
39///     let modified_rows = transaction.execute_update(statement).await?;
40/// #   Ok(())
41/// # }
42/// ```
43pub struct PartitionedDmlTransactionBuilder {
44    client: DatabaseClient,
45    retry_policy: Box<dyn TransactionRetryPolicy>,
46    exclude_txn_from_change_streams: bool,
47}
48
49impl PartitionedDmlTransactionBuilder {
50    pub(crate) fn new(client: DatabaseClient) -> Self {
51        Self {
52            client,
53            retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
54            exclude_txn_from_change_streams: false,
55        }
56    }
57
58    /// Sets whether to exclude the transaction from change streams.
59    ///
60    /// # Example
61    /// ```
62    /// # use google_cloud_spanner::client::Spanner;
63    /// # async fn build_transaction(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
64    ///     let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
65    ///     let transaction = db_client
66    ///         .partitioned_dml_transaction()
67    ///         .with_exclude_txn_from_change_streams(true)
68    ///         .build()
69    ///         .await?;
70    /// #   Ok(())
71    /// # }
72    /// ```
73    ///
74    /// When set to `true`, it prevents modifications from this transaction from being tracked in change streams.
75    /// Note that this only affects change streams that have been created with the DDL option `allow_txn_exclusion = true`.
76    /// If `allow_txn_exclusion` is not set or set to `false` for a change stream, updates made within this transaction
77    /// are recorded in that change stream regardless of this setting.
78    ///
79    /// When set to `false` or not specified, modifications from this transaction are recorded in all change streams
80    /// tracking columns modified by this transaction.
81    pub fn with_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
82        self.exclude_txn_from_change_streams = exclude;
83        self
84    }
85
86    /// Sets the retry policy for the transaction.
87    ///
88    /// # Example
89    /// ```
90    /// # use std::time::Duration;
91    /// # use google_cloud_spanner::client::Spanner;
92    /// # use google_cloud_spanner::transaction::BasicTransactionRetryPolicy;
93    /// # async fn build_transaction(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
94    ///     let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
95    ///     
96    ///     let retry_policy = BasicTransactionRetryPolicy::new()
97    ///         .with_max_attempts(5)
98    ///         .with_total_timeout(Duration::from_secs(60));
99    ///
100    ///     let transaction = db_client
101    ///         .partitioned_dml_transaction()
102    ///         .with_retry_policy(retry_policy)
103    ///         .build()
104    ///         .await?;
105    /// #   Ok(())
106    /// # }
107    /// ```
108    ///
109    /// The client will retry the entire transaction if it is aborted by Spanner.
110    /// This policy can be used to customize whether a transaction should be retried
111    /// or not. The default is to retry indefinitely until the transaction succeeds.
112    pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
113        self.retry_policy = Box::new(policy);
114        self
115    }
116
117    /// Builds the [PartitionedDmlTransaction].
118    pub async fn build(self) -> crate::Result<PartitionedDmlTransaction> {
119        Ok(PartitionedDmlTransaction {
120            client: self.client,
121            retry_policy: self.retry_policy,
122            exclude_txn_from_change_streams: self.exclude_txn_from_change_streams,
123        })
124    }
125}
126
127/// A Partitioned DML transaction.
128///
129/// Partitioned DML transactions are used to execute a single DML statement that may modify a large
130/// number of rows. The execution of the statement will automatically be partitioned into smaller
131/// transactions by Spanner, which may execute in parallel.
132///
133/// A Partitioned DML transaction cannot be committed or rolled back.
134///
135/// See also: <https://docs.cloud.google.com/spanner/docs/dml-partitioned>
136pub struct PartitionedDmlTransaction {
137    client: DatabaseClient,
138    retry_policy: Box<dyn TransactionRetryPolicy>,
139    exclude_txn_from_change_streams: bool,
140}
141
142impl PartitionedDmlTransaction {
143    /// Executes a Partitioned DML statement.
144    ///
145    /// # Example
146    /// ```
147    /// # use google_cloud_spanner::client::Spanner;
148    /// # use google_cloud_spanner::statement::Statement;
149    /// # async fn run(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> {
150    /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?;
151    /// let transaction = db_client.partitioned_dml_transaction().build().await?;
152    /// let statement = Statement::builder("UPDATE users SET active = true WHERE TRUE").build();
153    /// let modified_rows = transaction.execute_update(statement).await?;
154    /// # Ok(())
155    /// # }
156    /// ```
157    ///
158    /// # Return
159    ///
160    /// The number of rows that was at least modified by the statement. Note that the actual number
161    /// of rows that was modified may be higher than this number if the statement was retried or
162    /// split into multiple transactions by Spanner, and some of these (sub)transactions were
163    /// executed multiple times.
164    ///
165    /// See also: <https://docs.cloud.google.com/spanner/docs/dml-partitioned>
166    pub async fn execute_update<T: Into<Statement>>(self, statement: T) -> crate::Result<i64> {
167        let statement = statement.into();
168        let mut gax_options = statement.gax_options().clone();
169        self.amend_gax_options(&mut gax_options);
170
171        let session_name = self.client.session_name();
172        let transaction_options = TransactionOptions::default()
173            .set_partitioned_dml(PartitionedDml::default())
174            .set_exclude_txn_from_change_streams(self.exclude_txn_from_change_streams);
175        let begin_request = BeginTransactionRequest {
176            session: session_name.clone(),
177            options: Some(transaction_options),
178            ..Default::default()
179        };
180        let base_request = statement.into_request();
181        let channel_hint = self.client.spanner.next_channel_hint();
182        let client = self.client;
183        let is_emulator = client.is_emulator();
184
185        let action = || {
186            let begin_request = begin_request.clone();
187            let base_request = base_request.clone();
188            let session_name = session_name.clone();
189            let gax_options = gax_options.clone();
190            let client = client.clone();
191
192            async move {
193                let transaction = client
194                    .spanner
195                    .begin_transaction(begin_request, gax_options.clone(), channel_hint)
196                    .await?;
197
198                let execute_request =
199                    base_request
200                        .set_session(session_name)
201                        .set_transaction(TransactionSelector {
202                            selector: Some(transaction_selector::Selector::Id(
203                                transaction.id.clone(),
204                            )),
205                            ..Default::default()
206                        });
207
208                let stream_builder = client.spanner.execute_streaming_sql(
209                    execute_request,
210                    gax_options,
211                    channel_hint,
212                );
213                let stream = stream_builder.send().await?;
214
215                extract_lower_bound_update_count_from_stream(stream).await
216            }
217        };
218
219        retry_aborted(&*self.retry_policy, action, is_emulator).await
220    }
221
222    fn amend_gax_options(&self, options: &mut GaxRequestOptions) {
223        *options = amend_request_options_for_lar(
224            self.client.leader_aware_routing_enabled,
225            options.clone(),
226        );
227    }
228}
229
230/// Reads through the stream of `PartialResultSet` messages returned by the execution
231/// of a Partitioned DML statement and extracts the `row_count_lower_bound` from the
232/// query statistics. If the execution is successful but no lower bound is found,
233/// an internal error is returned.
234async fn extract_lower_bound_update_count_from_stream(
235    mut stream: PartialResultSetStream,
236) -> crate::Result<i64> {
237    let mut lower_bound: Option<i64> = None;
238    while let Some(prs) = stream.next_message().await.transpose()? {
239        if let Some(RowCountLowerBound(val)) = prs.stats.and_then(|s| s.row_count) {
240            lower_bound = Some(val);
241        }
242    }
243    lower_bound.ok_or_else(|| {
244        crate::Error::deser(crate::error::SpannerInternalError::new(
245            "ExecuteStreamingSql completed successfully but no row_count_lower_bound was returned",
246        ))
247    })
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
254    use crate::result_set::tests::adapt;
255    use crate::transaction_retry_policy::tests::create_aborted_status;
256    use gaxi::grpc::tonic;
257    use google_cloud_test_macros::tokio_test_no_panics;
258    use spanner_grpc_mock::google::spanner::v1;
259
260    #[test]
261    fn auto_traits() {
262        static_assertions::assert_impl_all!(PartitionedDmlTransactionBuilder: Send, Sync);
263        static_assertions::assert_impl_all!(PartitionedDmlTransaction: Send, Sync);
264    }
265
266    #[tokio_test_no_panics]
267    async fn execute_update_success() {
268        let mut mock = create_session_mock();
269
270        mock.expect_begin_transaction().once().returning(|req| {
271            let req = req.into_inner();
272            assert_eq!(
273                req.session,
274                "projects/p/instances/i/databases/d/sessions/123"
275            );
276            Ok(tonic::Response::new(v1::Transaction {
277                id: vec![0, 1, 2],
278                ..Default::default()
279            }))
280        });
281
282        mock.expect_execute_streaming_sql().once().returning(|req| {
283            let req = req.into_inner();
284            assert_eq!(req.sql, "UPDATE Users SET active = true");
285
286            let stream = adapt([Ok(v1::PartialResultSet {
287                stats: Some(v1::ResultSetStats {
288                    row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
289                    ..Default::default()
290                }),
291                ..Default::default()
292            })]);
293            Ok(tonic::Response::from(stream))
294        });
295
296        let (db_client, _server) = setup_db_client(mock).await;
297        let transaction = db_client
298            .partitioned_dml_transaction()
299            .build()
300            .await
301            .unwrap();
302        let statement = Statement::builder("UPDATE Users SET active = true").build();
303        let res: i64 = transaction.execute_update(statement).await.unwrap();
304        assert_eq!(res, 500);
305    }
306
307    #[tokio_test_no_panics]
308    async fn execute_update_with_exclude_txn_from_change_streams() {
309        let mut mock = create_session_mock();
310
311        mock.expect_begin_transaction().once().returning(|req| {
312            let req = req.into_inner();
313            let options = req.options.expect("missing transaction options");
314            assert!(options.exclude_txn_from_change_streams);
315
316            Ok(tonic::Response::new(v1::Transaction {
317                id: vec![0, 1, 2],
318                ..Default::default()
319            }))
320        });
321
322        mock.expect_execute_streaming_sql()
323            .once()
324            .returning(|_req| {
325                let stream = adapt([Ok(v1::PartialResultSet {
326                    stats: Some(v1::ResultSetStats {
327                        row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
328                        ..Default::default()
329                    }),
330                    ..Default::default()
331                })]);
332                Ok(tonic::Response::from(stream))
333            });
334
335        let (db_client, _server) = setup_db_client(mock).await;
336        let transaction = db_client
337            .partitioned_dml_transaction()
338            .with_exclude_txn_from_change_streams(true)
339            .build()
340            .await
341            .unwrap();
342        let statement = Statement::builder("UPDATE Users SET active = true").build();
343        let res: i64 = transaction.execute_update(statement).await.unwrap();
344        assert_eq!(res, 500);
345    }
346
347    #[tokio_test_no_panics]
348    async fn execute_update_with_aborted_retry() {
349        let mut mock = create_session_mock();
350
351        mock.expect_begin_transaction().times(2).returning(|_req| {
352            Ok(tonic::Response::new(v1::Transaction {
353                id: vec![0, 1, 2],
354                ..Default::default()
355            }))
356        });
357
358        let mut seq = mockall::Sequence::new();
359        mock.expect_execute_streaming_sql()
360            .times(1)
361            .in_sequence(&mut seq)
362            .returning(move |_req| {
363                // Return an error stream on first try
364                let stream = adapt([Err(create_aborted_status(std::time::Duration::from_nanos(
365                    1,
366                )))]);
367                Ok(tonic::Response::from(stream))
368            });
369        mock.expect_execute_streaming_sql()
370            .times(1)
371            .in_sequence(&mut seq)
372            .returning(move |_req| {
373                let stream = adapt([Ok(v1::PartialResultSet {
374                    stats: Some(v1::ResultSetStats {
375                        row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(100)),
376                        ..Default::default()
377                    }),
378                    ..Default::default()
379                })]);
380                Ok(tonic::Response::from(stream))
381            });
382
383        let (db_client, _server) = setup_db_client(mock).await;
384        let transaction = db_client
385            .partitioned_dml_transaction()
386            .build()
387            .await
388            .unwrap();
389        let res: i64 = transaction
390            .execute_update(Statement::builder("UPDATE Users SET active = true").build())
391            .await
392            .unwrap();
393        assert_eq!(res, 100);
394    }
395
396    #[tokio_test_no_panics]
397    async fn builder_with_retry_settings() {
398        let mock = create_session_mock();
399        let (db_client, _server) = setup_db_client(mock).await;
400
401        let policy = BasicTransactionRetryPolicy::new()
402            .with_max_attempts(10)
403            .with_total_timeout(std::time::Duration::from_secs(42));
404
405        let _transaction = db_client
406            .partitioned_dml_transaction()
407            .with_retry_policy(policy)
408            .build()
409            .await
410            .unwrap();
411    }
412
413    #[tokio_test_no_panics]
414    async fn execute_update_missing_lower_bound() {
415        let mut mock = create_session_mock();
416
417        mock.expect_begin_transaction().once().returning(|_req| {
418            Ok(tonic::Response::new(v1::Transaction {
419                id: vec![0, 1, 2],
420                ..Default::default()
421            }))
422        });
423
424        mock.expect_execute_streaming_sql()
425            .once()
426            .returning(|_req| {
427                let stream = adapt([Ok(v1::PartialResultSet {
428                    stats: Some(v1::ResultSetStats {
429                        // Provide a RowCountExact instead of RowCountLowerBound
430                        row_count: Some(v1::result_set_stats::RowCount::RowCountExact(100)),
431                        ..Default::default()
432                    }),
433                    ..Default::default()
434                })]);
435                Ok(tonic::Response::from(stream))
436            });
437
438        let (db_client, _server) = setup_db_client(mock).await;
439        let transaction = db_client
440            .partitioned_dml_transaction()
441            .build()
442            .await
443            .unwrap();
444
445        let statement = Statement::builder("UPDATE Users SET active = true").build();
446        let res = transaction.execute_update(statement).await;
447
448        assert!(res.is_err());
449        let err = res.unwrap_err();
450        assert!(err.is_deserialization());
451        assert!(
452            err.to_string()
453                .contains("no row_count_lower_bound was returned")
454        );
455    }
456
457    #[tokio_test_no_panics]
458    async fn leader_aware_routing_enabled_by_default() {
459        let mut mock = create_session_mock();
460        mock.expect_begin_transaction().once().returning(|req| {
461            assert_eq!(
462                req.metadata()
463                    .get("x-goog-spanner-route-to-leader")
464                    .expect("header required")
465                    .to_str()
466                    .unwrap(),
467                "true"
468            );
469            Ok(tonic::Response::new(v1::Transaction {
470                id: vec![0, 1, 2],
471                ..Default::default()
472            }))
473        });
474
475        mock.expect_execute_streaming_sql().once().returning(|req| {
476            assert_eq!(
477                req.metadata()
478                    .get("x-goog-spanner-route-to-leader")
479                    .expect("header required")
480                    .to_str()
481                    .unwrap(),
482                "true"
483            );
484            let stream = adapt([Ok(v1::PartialResultSet {
485                stats: Some(v1::ResultSetStats {
486                    row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
487                    ..Default::default()
488                }),
489                ..Default::default()
490            })]);
491            Ok(tonic::Response::from(stream))
492        });
493
494        let (db_client, _server) = setup_db_client(mock).await;
495        let transaction = db_client
496            .partitioned_dml_transaction()
497            .build()
498            .await
499            .unwrap();
500        let statement = Statement::builder("UPDATE Users SET active = true").build();
501        let res: i64 = transaction.execute_update(statement).await.unwrap();
502        assert_eq!(res, 500);
503    }
504}