1use crate::error::{BatchUpdateError, internal_error};
16use crate::model::result_set_stats::RowCount;
17use crate::model::{ExecuteBatchDmlResponse, RequestOptions};
18use crate::statement::Statement;
19use google_cloud_gax::backoff_policy::BackoffPolicyArg;
20use google_cloud_gax::error::rpc::Code;
21use google_cloud_gax::error::rpc::Status as RpcStatus;
22use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
23use google_cloud_gax::retry_policy::RetryPolicyArg;
24use std::time::Duration;
25
26#[derive(Clone, Default, Debug)]
28pub struct BatchDmlBuilder {
29 statements: Vec<Statement>,
30 request_options: Option<RequestOptions>,
31 gax_options: GaxRequestOptions,
32}
33
34impl BatchDmlBuilder {
35 pub fn new() -> Self {
37 BatchDmlBuilder::default()
38 }
39
40 pub fn add_statement(mut self, statement: impl Into<Statement>) -> Self {
42 self.statements.push(statement.into());
43 self
44 }
45
46 pub fn set_request_tag(mut self, tag: impl Into<String>) -> Self {
61 self.request_options
62 .get_or_insert_with(RequestOptions::default)
63 .request_tag = tag.into();
64 self
65 }
66
67 pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self {
69 self.gax_options.set_attempt_timeout(timeout);
70 self
71 }
72
73 pub fn with_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
75 self.gax_options.set_retry_policy(policy);
76 self
77 }
78
79 pub fn with_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
81 self.gax_options.set_backoff_policy(policy);
82 self
83 }
84
85 pub fn build(self) -> BatchDml {
87 BatchDml {
88 statements: self.statements,
89 request_options: self.request_options,
90 gax_options: self.gax_options,
91 }
92 }
93}
94
95#[derive(Clone, Debug)]
97pub struct BatchDml {
98 pub(crate) statements: Vec<Statement>,
99 pub(crate) request_options: Option<RequestOptions>,
100 pub(crate) gax_options: GaxRequestOptions,
101}
102
103impl BatchDml {
104 pub fn builder() -> BatchDmlBuilder {
106 BatchDmlBuilder::new()
107 }
108}
109
110impl From<BatchDmlBuilder> for BatchDml {
111 fn from(builder: BatchDmlBuilder) -> Self {
112 builder.build()
113 }
114}
115
116impl<T: Into<Statement>> From<Vec<T>> for BatchDml {
117 fn from(statements: Vec<T>) -> Self {
118 BatchDml {
119 statements: statements.into_iter().map(Into::into).collect(),
120 request_options: None,
121 gax_options: GaxRequestOptions::default(),
122 }
123 }
124}
125
126pub(crate) fn process_response(response: ExecuteBatchDmlResponse) -> crate::Result<Vec<i64>> {
128 let mut update_counts = Vec::with_capacity(response.result_sets.len());
129 for result_set in response.result_sets {
130 if let Some(stats) = result_set.stats {
131 let exact_count = match stats.row_count {
132 Some(RowCount::RowCountExact(c)) => c,
133 _ => {
134 return Err(internal_error(
135 "ExecuteBatchDml returned an invalid or missing row count type",
136 ));
137 }
138 };
139 update_counts.push(exact_count);
140 }
141 }
142
143 if let Some(status) = response.status.filter(|s| s.code != Code::Ok as i32) {
145 let grpc_status = RpcStatus::default()
146 .set_code(status.code)
147 .set_message(status.message);
148
149 if status.code == Code::Aborted as i32 {
152 return Err(crate::Error::service(grpc_status));
153 }
154 return Err(BatchUpdateError::build_error(update_counts, grpc_status));
155 }
156
157 Ok(update_counts)
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::model::{ResultSet, ResultSetStats};
164 use google_cloud_rpc::model::Status;
165 use static_assertions::assert_impl_all;
166
167 #[test]
168 fn auto_traits() {
169 assert_impl_all!(BatchDml: Send, Sync, Clone, std::fmt::Debug);
170 assert_impl_all!(BatchDmlBuilder: Send, Sync, Clone, std::fmt::Debug);
171 }
172
173 #[test]
174 fn builder() {
175 let stmt1 = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
176 let stmt2 = Statement::builder("UPDATE t SET c = 2 WHERE id = 2").build();
177
178 let batch = BatchDml::builder()
179 .add_statement(stmt1)
180 .add_statement(stmt2)
181 .build();
182
183 assert_eq!(batch.statements.len(), 2);
184 assert_eq!(batch.statements[0].sql, "UPDATE t SET c = 1 WHERE id = 1");
185 assert_eq!(batch.statements[1].sql, "UPDATE t SET c = 2 WHERE id = 2");
186 assert!(
187 batch.request_options.is_none(),
188 "Unexpected request_options set: {:#?}",
189 batch.request_options
190 );
191 }
192
193 #[test]
194 fn builder_with_gax_options() {
195 use google_cloud_gax::backoff_policy::BackoffPolicy;
196 use google_cloud_gax::retry_policy::Aip194Strict;
197 use google_cloud_gax::retry_state::RetryState;
198 use std::time::Duration;
199
200 #[derive(Debug)]
201 struct DummyBackoff;
202 impl BackoffPolicy for DummyBackoff {
203 fn on_failure(&self, _state: &RetryState) -> Duration {
204 Duration::ZERO
205 }
206 }
207
208 let stmt = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
209
210 let batch = BatchDml::builder()
211 .add_statement(stmt)
212 .with_attempt_timeout(Duration::from_secs(5))
213 .with_retry_policy(Aip194Strict)
214 .with_backoff_policy(DummyBackoff)
215 .build();
216
217 assert_eq!(
218 *batch.gax_options.attempt_timeout(),
219 Some(Duration::from_secs(5))
220 );
221 assert!(batch.gax_options.retry_policy().is_some());
222 assert!(batch.gax_options.backoff_policy().is_some());
223 }
224
225 #[test]
226 fn builder_with_request_tag() {
227 let stmt = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
228
229 let batch = BatchDml::builder()
230 .add_statement(stmt)
231 .set_request_tag("tag1")
232 .build();
233
234 assert_eq!(batch.statements.len(), 1);
235 assert_eq!(
236 batch
237 .request_options
238 .expect("request options missing")
239 .request_tag,
240 "tag1"
241 );
242 }
243
244 #[test]
245 fn process_response_success() -> anyhow::Result<()> {
246 let stats1 = ResultSetStats {
247 row_count: Some(RowCount::RowCountExact(5)),
248 ..Default::default()
249 };
250 let stats2 = ResultSetStats {
251 row_count: Some(RowCount::RowCountExact(10)),
252 ..Default::default()
253 };
254
255 let rs1 = ResultSet {
256 stats: Some(stats1),
257 ..Default::default()
258 };
259 let rs2 = ResultSet {
260 stats: Some(stats2),
261 ..Default::default()
262 };
263
264 let response = ExecuteBatchDmlResponse {
265 result_sets: vec![rs1, rs2],
266 status: None,
267 ..Default::default()
268 };
269
270 let counts = process_response(response)?;
271 assert_eq!(counts, vec![5, 10]);
272 Ok(())
273 }
274
275 #[test]
276 fn process_response_grpc_error() {
277 let stats = ResultSetStats {
278 row_count: Some(RowCount::RowCountExact(3)),
279 ..Default::default()
280 };
281 let rs = ResultSet {
282 stats: Some(stats),
283 ..Default::default()
284 };
285
286 let err_status = Status::default()
288 .set_code(Code::InvalidArgument as i32)
289 .set_message("Bad query");
290
291 let response = ExecuteBatchDmlResponse {
292 result_sets: vec![rs],
293 status: Some(err_status),
294 ..Default::default()
295 };
296
297 let result = process_response(response);
298 let err = result.expect_err("should return error");
299 let batch_err = BatchUpdateError::extract(&err).expect("should extract BatchUpdateError");
300
301 assert_eq!(batch_err.update_counts, vec![3]);
302 assert_eq!(
303 batch_err.status.status().expect("status").code,
304 Code::InvalidArgument
305 );
306 assert_eq!(
307 batch_err.status.status().expect("status").message,
308 "Bad query"
309 );
310 }
311
312 #[test]
313 fn process_response_aborted() {
314 let stats = ResultSetStats {
315 row_count: Some(RowCount::RowCountExact(3)),
316 ..Default::default()
317 };
318 let rs = ResultSet {
319 stats: Some(stats),
320 ..Default::default()
321 };
322
323 let err_status = Status::default()
324 .set_code(Code::Aborted as i32)
325 .set_message("transaction aborted");
326
327 let response = ExecuteBatchDmlResponse {
328 result_sets: vec![rs],
329 status: Some(err_status),
330 ..Default::default()
331 };
332
333 let result = process_response(response);
334 let err = result.expect_err("should return error");
335 let batch_err = BatchUpdateError::extract(&err);
336 assert!(
337 batch_err.is_none(),
338 "Unexpected BatchUpdateError: {batch_err:?}"
339 );
340 assert_eq!(err.status().expect("status").code, Code::Aborted);
341 assert_eq!(err.status().expect("status").message, "transaction aborted");
342 }
343
344 #[test]
345 fn process_response_missing_stats() {
346 let rs = ResultSet {
347 stats: None,
348 ..Default::default()
349 };
350
351 let response = ExecuteBatchDmlResponse {
352 result_sets: vec![rs],
353 ..Default::default()
354 };
355
356 let result = process_response(response).expect("should return empty update counts");
357 assert!(result.is_empty());
358 }
359
360 #[test]
361 fn process_response_missing_row_count_type() {
362 let stats = ResultSetStats {
363 row_count: None,
364 ..Default::default()
365 };
366
367 let rs = ResultSet {
368 stats: Some(stats),
369 ..Default::default()
370 };
371
372 let response = ExecuteBatchDmlResponse {
373 result_sets: vec![rs],
374 ..Default::default()
375 };
376
377 let result = process_response(response);
378 let err = result.expect_err("should fail");
379 assert!(
380 err.to_string()
381 .contains("invalid or missing row count type")
382 );
383 }
384
385 #[test]
386 fn from_vector_of_strings() {
387 let statements = vec!["UPDATE table SET col = 1", "UPDATE table SET col = 2"];
388 let batch: BatchDml = statements.into();
389 assert_eq!(batch.statements.len(), 2);
390 assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
391 assert_eq!(batch.statements[1].sql, "UPDATE table SET col = 2");
392 }
393
394 #[test]
395 fn from_vector_of_statements() {
396 let statement1 = Statement::builder("UPDATE table SET col = 1").build();
397 let statement2 = Statement::builder("UPDATE table SET col = 2").build();
398 let statements = vec![statement1, statement2];
399 let batch: BatchDml = statements.into();
400 assert_eq!(batch.statements.len(), 2);
401 assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
402 assert_eq!(batch.statements[1].sql, "UPDATE table SET col = 2");
403 }
404
405 #[test]
406 fn from_builder() {
407 let builder = BatchDml::builder().add_statement("UPDATE table SET col = 1");
408 let batch: BatchDml = builder.into();
409 assert_eq!(batch.statements.len(), 1);
410 assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
411 }
412
413 #[test]
414 fn process_response_metadata_no_stats_grpc_error() {
415 let rs = ResultSet {
416 metadata: Some(crate::model::ResultSetMetadata {
417 transaction: Some(crate::model::Transaction {
418 id: vec![7, 7, 7].into(),
419 ..Default::default()
420 }),
421 ..Default::default()
422 }),
423 stats: None,
424 ..Default::default()
425 };
426
427 let err_status = Status::default()
428 .set_code(Code::InvalidArgument as i32)
429 .set_message("Table not found or syntax invalid");
430
431 let response = ExecuteBatchDmlResponse {
432 result_sets: vec![rs],
433 status: Some(err_status),
434 ..Default::default()
435 };
436
437 let result = process_response(response);
438 let err = result.expect_err("should return error");
439 let batch_err = BatchUpdateError::extract(&err)
440 .expect("should extract BatchUpdateError cleanly and not return internal error");
441
442 assert_eq!(
443 batch_err.update_counts,
444 Vec::<i64>::new(),
445 "Update counts should be completely empty"
446 );
447 assert_eq!(
448 batch_err.status.status().expect("status").code,
449 Code::InvalidArgument
450 );
451 assert_eq!(
452 batch_err.status.status().expect("status").message,
453 "Table not found or syntax invalid"
454 );
455 }
456}