Skip to main content

google_cloud_spanner/
batch_dml.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::error::{BatchUpdateError, internal_error};
16use crate::model::result_set_stats::RowCount;
17use crate::model::{ExecuteBatchDmlResponse, RequestOptions};
18use crate::statement::Statement;
19use google_cloud_gax::backoff_policy::BackoffPolicyArg;
20use google_cloud_gax::error::rpc::Code;
21use google_cloud_gax::error::rpc::Status as RpcStatus;
22use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
23use google_cloud_gax::retry_policy::RetryPolicyArg;
24use std::time::Duration;
25
26/// A builder for [BatchDml].
27#[derive(Clone, Default, Debug)]
28pub struct BatchDmlBuilder {
29    statements: Vec<Statement>,
30    request_options: Option<RequestOptions>,
31    gax_options: GaxRequestOptions,
32}
33
34impl BatchDmlBuilder {
35    /// Creates a new empty BatchDmlBuilder.
36    pub fn new() -> Self {
37        BatchDmlBuilder::default()
38    }
39
40    /// Adds a statement to the batch.
41    pub fn add_statement(mut self, statement: impl Into<Statement>) -> Self {
42        self.statements.push(statement.into());
43        self
44    }
45
46    /// Sets the request tag for this batch.
47    ///
48    /// # Example
49    /// ```
50    /// # use google_cloud_spanner::statement::Statement;
51    /// # use google_cloud_spanner::batch::BatchDml;
52    /// let statement1 = Statement::builder("UPDATE users SET active = true WHERE id = 1").build();
53    /// let batch = BatchDml::builder()
54    ///     .add_statement(statement1)
55    ///     .set_request_tag("my-tag")
56    ///     .build();
57    /// ```
58    ///
59    /// See also: [Troubleshooting with tags](https://docs.cloud.google.com/spanner/docs/introspection/troubleshooting-with-tags)
60    pub fn set_request_tag(mut self, tag: impl Into<String>) -> Self {
61        self.request_options
62            .get_or_insert_with(RequestOptions::default)
63            .request_tag = tag.into();
64        self
65    }
66
67    /// Sets the per-attempt timeout for this batch DML request.
68    pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self {
69        self.gax_options.set_attempt_timeout(timeout);
70        self
71    }
72
73    /// Sets the retry policy for this batch DML request.
74    pub fn with_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
75        self.gax_options.set_retry_policy(policy);
76        self
77    }
78
79    /// Sets the backoff policy for this batch DML request.
80    pub fn with_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
81        self.gax_options.set_backoff_policy(policy);
82        self
83    }
84
85    /// Builds and returns the finalized BatchDml object.
86    pub fn build(self) -> BatchDml {
87        BatchDml {
88            statements: self.statements,
89            request_options: self.request_options,
90            gax_options: self.gax_options,
91        }
92    }
93}
94
95/// A batch of DML statements to be executed in a single round-trip to Spanner.
96#[derive(Clone, Debug)]
97pub struct BatchDml {
98    pub(crate) statements: Vec<Statement>,
99    pub(crate) request_options: Option<RequestOptions>,
100    pub(crate) gax_options: GaxRequestOptions,
101}
102
103impl BatchDml {
104    /// Creates a new builder for constructing a [`BatchDml`] request.
105    pub fn builder() -> BatchDmlBuilder {
106        BatchDmlBuilder::new()
107    }
108}
109
110impl From<BatchDmlBuilder> for BatchDml {
111    fn from(builder: BatchDmlBuilder) -> Self {
112        builder.build()
113    }
114}
115
116impl<T: Into<Statement>> From<Vec<T>> for BatchDml {
117    fn from(statements: Vec<T>) -> Self {
118        BatchDml {
119            statements: statements.into_iter().map(Into::into).collect(),
120            request_options: None,
121            gax_options: GaxRequestOptions::default(),
122        }
123    }
124}
125
126/// Processes an ExecuteBatchDmlResponse and returns the success counts, or an error.
127pub(crate) fn process_response(response: ExecuteBatchDmlResponse) -> crate::Result<Vec<i64>> {
128    let mut update_counts = Vec::with_capacity(response.result_sets.len());
129    for result_set in response.result_sets {
130        if let Some(stats) = result_set.stats {
131            let exact_count = match stats.row_count {
132                Some(RowCount::RowCountExact(c)) => c,
133                _ => {
134                    return Err(internal_error(
135                        "ExecuteBatchDml returned an invalid or missing row count type",
136                    ));
137                }
138            };
139            update_counts.push(exact_count);
140        }
141    }
142
143    // If a non-zero status is present, it halted the batch somewhere in the middle of the batch.
144    if let Some(status) = response.status.filter(|s| s.code != Code::Ok as i32) {
145        let grpc_status = RpcStatus::default()
146            .set_code(status.code)
147            .set_message(status.message);
148
149        // If the error code is Aborted, then we propagate a 'normal' service error.
150        // The TransactionRunner will then retry the transaction.
151        if status.code == Code::Aborted as i32 {
152            return Err(crate::Error::service(grpc_status));
153        }
154        return Err(BatchUpdateError::build_error(update_counts, grpc_status));
155    }
156
157    Ok(update_counts)
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::model::{ResultSet, ResultSetStats};
164    use google_cloud_rpc::model::Status;
165    use static_assertions::assert_impl_all;
166
167    #[test]
168    fn auto_traits() {
169        assert_impl_all!(BatchDml: Send, Sync, Clone, std::fmt::Debug);
170        assert_impl_all!(BatchDmlBuilder: Send, Sync, Clone, std::fmt::Debug);
171    }
172
173    #[test]
174    fn builder() {
175        let stmt1 = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
176        let stmt2 = Statement::builder("UPDATE t SET c = 2 WHERE id = 2").build();
177
178        let batch = BatchDml::builder()
179            .add_statement(stmt1)
180            .add_statement(stmt2)
181            .build();
182
183        assert_eq!(batch.statements.len(), 2);
184        assert_eq!(batch.statements[0].sql, "UPDATE t SET c = 1 WHERE id = 1");
185        assert_eq!(batch.statements[1].sql, "UPDATE t SET c = 2 WHERE id = 2");
186        assert!(
187            batch.request_options.is_none(),
188            "Unexpected request_options set: {:#?}",
189            batch.request_options
190        );
191    }
192
193    #[test]
194    fn builder_with_gax_options() {
195        use google_cloud_gax::backoff_policy::BackoffPolicy;
196        use google_cloud_gax::retry_policy::Aip194Strict;
197        use google_cloud_gax::retry_state::RetryState;
198        use std::time::Duration;
199
200        #[derive(Debug)]
201        struct DummyBackoff;
202        impl BackoffPolicy for DummyBackoff {
203            fn on_failure(&self, _state: &RetryState) -> Duration {
204                Duration::ZERO
205            }
206        }
207
208        let stmt = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
209
210        let batch = BatchDml::builder()
211            .add_statement(stmt)
212            .with_attempt_timeout(Duration::from_secs(5))
213            .with_retry_policy(Aip194Strict)
214            .with_backoff_policy(DummyBackoff)
215            .build();
216
217        assert_eq!(
218            *batch.gax_options.attempt_timeout(),
219            Some(Duration::from_secs(5))
220        );
221        assert!(batch.gax_options.retry_policy().is_some());
222        assert!(batch.gax_options.backoff_policy().is_some());
223    }
224
225    #[test]
226    fn builder_with_request_tag() {
227        let stmt = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
228
229        let batch = BatchDml::builder()
230            .add_statement(stmt)
231            .set_request_tag("tag1")
232            .build();
233
234        assert_eq!(batch.statements.len(), 1);
235        assert_eq!(
236            batch
237                .request_options
238                .expect("request options missing")
239                .request_tag,
240            "tag1"
241        );
242    }
243
244    #[test]
245    fn process_response_success() -> anyhow::Result<()> {
246        let stats1 = ResultSetStats {
247            row_count: Some(RowCount::RowCountExact(5)),
248            ..Default::default()
249        };
250        let stats2 = ResultSetStats {
251            row_count: Some(RowCount::RowCountExact(10)),
252            ..Default::default()
253        };
254
255        let rs1 = ResultSet {
256            stats: Some(stats1),
257            ..Default::default()
258        };
259        let rs2 = ResultSet {
260            stats: Some(stats2),
261            ..Default::default()
262        };
263
264        let response = ExecuteBatchDmlResponse {
265            result_sets: vec![rs1, rs2],
266            status: None,
267            ..Default::default()
268        };
269
270        let counts = process_response(response)?;
271        assert_eq!(counts, vec![5, 10]);
272        Ok(())
273    }
274
275    #[test]
276    fn process_response_grpc_error() {
277        let stats = ResultSetStats {
278            row_count: Some(RowCount::RowCountExact(3)),
279            ..Default::default()
280        };
281        let rs = ResultSet {
282            stats: Some(stats),
283            ..Default::default()
284        };
285
286        // Note: crate::model::Status is the common type for status embedded in generated grpc responses.
287        let err_status = Status::default()
288            .set_code(Code::InvalidArgument as i32)
289            .set_message("Bad query");
290
291        let response = ExecuteBatchDmlResponse {
292            result_sets: vec![rs],
293            status: Some(err_status),
294            ..Default::default()
295        };
296
297        let result = process_response(response);
298        let err = result.expect_err("should return error");
299        let batch_err = BatchUpdateError::extract(&err).expect("should extract BatchUpdateError");
300
301        assert_eq!(batch_err.update_counts, vec![3]);
302        assert_eq!(
303            batch_err.status.status().expect("status").code,
304            Code::InvalidArgument
305        );
306        assert_eq!(
307            batch_err.status.status().expect("status").message,
308            "Bad query"
309        );
310    }
311
312    #[test]
313    fn process_response_aborted() {
314        let stats = ResultSetStats {
315            row_count: Some(RowCount::RowCountExact(3)),
316            ..Default::default()
317        };
318        let rs = ResultSet {
319            stats: Some(stats),
320            ..Default::default()
321        };
322
323        let err_status = Status::default()
324            .set_code(Code::Aborted as i32)
325            .set_message("transaction aborted");
326
327        let response = ExecuteBatchDmlResponse {
328            result_sets: vec![rs],
329            status: Some(err_status),
330            ..Default::default()
331        };
332
333        let result = process_response(response);
334        let err = result.expect_err("should return error");
335        let batch_err = BatchUpdateError::extract(&err);
336        assert!(
337            batch_err.is_none(),
338            "Unexpected BatchUpdateError: {batch_err:?}"
339        );
340        assert_eq!(err.status().expect("status").code, Code::Aborted);
341        assert_eq!(err.status().expect("status").message, "transaction aborted");
342    }
343
344    #[test]
345    fn process_response_missing_stats() {
346        let rs = ResultSet {
347            stats: None,
348            ..Default::default()
349        };
350
351        let response = ExecuteBatchDmlResponse {
352            result_sets: vec![rs],
353            ..Default::default()
354        };
355
356        let result = process_response(response).expect("should return empty update counts");
357        assert!(result.is_empty());
358    }
359
360    #[test]
361    fn process_response_missing_row_count_type() {
362        let stats = ResultSetStats {
363            row_count: None,
364            ..Default::default()
365        };
366
367        let rs = ResultSet {
368            stats: Some(stats),
369            ..Default::default()
370        };
371
372        let response = ExecuteBatchDmlResponse {
373            result_sets: vec![rs],
374            ..Default::default()
375        };
376
377        let result = process_response(response);
378        let err = result.expect_err("should fail");
379        assert!(
380            err.to_string()
381                .contains("invalid or missing row count type")
382        );
383    }
384
385    #[test]
386    fn from_vector_of_strings() {
387        let statements = vec!["UPDATE table SET col = 1", "UPDATE table SET col = 2"];
388        let batch: BatchDml = statements.into();
389        assert_eq!(batch.statements.len(), 2);
390        assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
391        assert_eq!(batch.statements[1].sql, "UPDATE table SET col = 2");
392    }
393
394    #[test]
395    fn from_vector_of_statements() {
396        let statement1 = Statement::builder("UPDATE table SET col = 1").build();
397        let statement2 = Statement::builder("UPDATE table SET col = 2").build();
398        let statements = vec![statement1, statement2];
399        let batch: BatchDml = statements.into();
400        assert_eq!(batch.statements.len(), 2);
401        assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
402        assert_eq!(batch.statements[1].sql, "UPDATE table SET col = 2");
403    }
404
405    #[test]
406    fn from_builder() {
407        let builder = BatchDml::builder().add_statement("UPDATE table SET col = 1");
408        let batch: BatchDml = builder.into();
409        assert_eq!(batch.statements.len(), 1);
410        assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
411    }
412
413    #[test]
414    fn process_response_metadata_no_stats_grpc_error() {
415        let rs = ResultSet {
416            metadata: Some(crate::model::ResultSetMetadata {
417                transaction: Some(crate::model::Transaction {
418                    id: vec![7, 7, 7].into(),
419                    ..Default::default()
420                }),
421                ..Default::default()
422            }),
423            stats: None,
424            ..Default::default()
425        };
426
427        let err_status = Status::default()
428            .set_code(Code::InvalidArgument as i32)
429            .set_message("Table not found or syntax invalid");
430
431        let response = ExecuteBatchDmlResponse {
432            result_sets: vec![rs],
433            status: Some(err_status),
434            ..Default::default()
435        };
436
437        let result = process_response(response);
438        let err = result.expect_err("should return error");
439        let batch_err = BatchUpdateError::extract(&err)
440            .expect("should extract BatchUpdateError cleanly and not return internal error");
441
442        assert_eq!(
443            batch_err.update_counts,
444            Vec::<i64>::new(),
445            "Update counts should be completely empty"
446        );
447        assert_eq!(
448            batch_err.status.status().expect("status").code,
449            Code::InvalidArgument
450        );
451        assert_eq!(
452            batch_err.status.status().expect("status").message,
453            "Table not found or syntax invalid"
454        );
455    }
456}