google_cloud_spanner/
partitioned_dml_transaction.rs1use crate::client::amend_request_options_for_lar;
16use crate::database_client::DatabaseClient;
17use crate::google::spanner::v1::result_set_stats::RowCount::RowCountLowerBound;
18use crate::model::transaction_options::PartitionedDml;
19use crate::model::{
20 BeginTransactionRequest, TransactionOptions, TransactionSelector, transaction_selector,
21};
22use crate::server_streaming::stream::PartialResultSetStream;
23use crate::statement::Statement;
24use crate::transaction_retry_policy::{
25 BasicTransactionRetryPolicy, TransactionRetryPolicy, retry_aborted,
26};
27use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
28
29pub struct PartitionedDmlTransactionBuilder {
44 client: DatabaseClient,
45 retry_policy: Box<dyn TransactionRetryPolicy>,
46 exclude_txn_from_change_streams: bool,
47}
48
49impl PartitionedDmlTransactionBuilder {
50 pub(crate) fn new(client: DatabaseClient) -> Self {
51 Self {
52 client,
53 retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
54 exclude_txn_from_change_streams: false,
55 }
56 }
57
58 pub fn with_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
82 self.exclude_txn_from_change_streams = exclude;
83 self
84 }
85
86 pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
113 self.retry_policy = Box::new(policy);
114 self
115 }
116
117 pub async fn build(self) -> crate::Result<PartitionedDmlTransaction> {
119 Ok(PartitionedDmlTransaction {
120 client: self.client,
121 retry_policy: self.retry_policy,
122 exclude_txn_from_change_streams: self.exclude_txn_from_change_streams,
123 })
124 }
125}
126
127pub struct PartitionedDmlTransaction {
137 client: DatabaseClient,
138 retry_policy: Box<dyn TransactionRetryPolicy>,
139 exclude_txn_from_change_streams: bool,
140}
141
142impl PartitionedDmlTransaction {
143 pub async fn execute_update<T: Into<Statement>>(self, statement: T) -> crate::Result<i64> {
167 let statement = statement.into();
168 let mut gax_options = statement.gax_options().clone();
169 self.amend_gax_options(&mut gax_options);
170
171 let session_name = self.client.session_name();
172 let transaction_options = TransactionOptions::default()
173 .set_partitioned_dml(PartitionedDml::default())
174 .set_exclude_txn_from_change_streams(self.exclude_txn_from_change_streams);
175 let begin_request = BeginTransactionRequest {
176 session: session_name.clone(),
177 options: Some(transaction_options),
178 ..Default::default()
179 };
180 let base_request = statement.into_request();
181 let channel_hint = self.client.spanner.next_channel_hint();
182 let client = self.client;
183 let is_emulator = client.is_emulator();
184
185 let action = || {
186 let begin_request = begin_request.clone();
187 let base_request = base_request.clone();
188 let session_name = session_name.clone();
189 let gax_options = gax_options.clone();
190 let client = client.clone();
191
192 async move {
193 let transaction = client
194 .spanner
195 .begin_transaction(begin_request, gax_options.clone(), channel_hint)
196 .await?;
197
198 let execute_request =
199 base_request
200 .set_session(session_name)
201 .set_transaction(TransactionSelector {
202 selector: Some(transaction_selector::Selector::Id(
203 transaction.id.clone(),
204 )),
205 ..Default::default()
206 });
207
208 let stream_builder = client.spanner.execute_streaming_sql(
209 execute_request,
210 gax_options,
211 channel_hint,
212 );
213 let stream = stream_builder.send().await?;
214
215 extract_lower_bound_update_count_from_stream(stream).await
216 }
217 };
218
219 retry_aborted(&*self.retry_policy, action, is_emulator).await
220 }
221
222 fn amend_gax_options(&self, options: &mut GaxRequestOptions) {
223 *options = amend_request_options_for_lar(
224 self.client.leader_aware_routing_enabled,
225 options.clone(),
226 );
227 }
228}
229
230async fn extract_lower_bound_update_count_from_stream(
235 mut stream: PartialResultSetStream,
236) -> crate::Result<i64> {
237 let mut lower_bound: Option<i64> = None;
238 while let Some(prs) = stream.next_message().await.transpose()? {
239 if let Some(RowCountLowerBound(val)) = prs.stats.and_then(|s| s.row_count) {
240 lower_bound = Some(val);
241 }
242 }
243 lower_bound.ok_or_else(|| {
244 crate::Error::deser(crate::error::SpannerInternalError::new(
245 "ExecuteStreamingSql completed successfully but no row_count_lower_bound was returned",
246 ))
247 })
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
254 use crate::result_set::tests::adapt;
255 use crate::transaction_retry_policy::tests::create_aborted_status;
256 use gaxi::grpc::tonic;
257 use google_cloud_test_macros::tokio_test_no_panics;
258 use spanner_grpc_mock::google::spanner::v1;
259
260 #[test]
261 fn auto_traits() {
262 static_assertions::assert_impl_all!(PartitionedDmlTransactionBuilder: Send, Sync);
263 static_assertions::assert_impl_all!(PartitionedDmlTransaction: Send, Sync);
264 }
265
266 #[tokio_test_no_panics]
267 async fn execute_update_success() {
268 let mut mock = create_session_mock();
269
270 mock.expect_begin_transaction().once().returning(|req| {
271 let req = req.into_inner();
272 assert_eq!(
273 req.session,
274 "projects/p/instances/i/databases/d/sessions/123"
275 );
276 Ok(tonic::Response::new(v1::Transaction {
277 id: vec![0, 1, 2],
278 ..Default::default()
279 }))
280 });
281
282 mock.expect_execute_streaming_sql().once().returning(|req| {
283 let req = req.into_inner();
284 assert_eq!(req.sql, "UPDATE Users SET active = true");
285
286 let stream = adapt([Ok(v1::PartialResultSet {
287 stats: Some(v1::ResultSetStats {
288 row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
289 ..Default::default()
290 }),
291 ..Default::default()
292 })]);
293 Ok(tonic::Response::from(stream))
294 });
295
296 let (db_client, _server) = setup_db_client(mock).await;
297 let transaction = db_client
298 .partitioned_dml_transaction()
299 .build()
300 .await
301 .unwrap();
302 let statement = Statement::builder("UPDATE Users SET active = true").build();
303 let res: i64 = transaction.execute_update(statement).await.unwrap();
304 assert_eq!(res, 500);
305 }
306
307 #[tokio_test_no_panics]
308 async fn execute_update_with_exclude_txn_from_change_streams() {
309 let mut mock = create_session_mock();
310
311 mock.expect_begin_transaction().once().returning(|req| {
312 let req = req.into_inner();
313 let options = req.options.expect("missing transaction options");
314 assert!(options.exclude_txn_from_change_streams);
315
316 Ok(tonic::Response::new(v1::Transaction {
317 id: vec![0, 1, 2],
318 ..Default::default()
319 }))
320 });
321
322 mock.expect_execute_streaming_sql()
323 .once()
324 .returning(|_req| {
325 let stream = adapt([Ok(v1::PartialResultSet {
326 stats: Some(v1::ResultSetStats {
327 row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
328 ..Default::default()
329 }),
330 ..Default::default()
331 })]);
332 Ok(tonic::Response::from(stream))
333 });
334
335 let (db_client, _server) = setup_db_client(mock).await;
336 let transaction = db_client
337 .partitioned_dml_transaction()
338 .with_exclude_txn_from_change_streams(true)
339 .build()
340 .await
341 .unwrap();
342 let statement = Statement::builder("UPDATE Users SET active = true").build();
343 let res: i64 = transaction.execute_update(statement).await.unwrap();
344 assert_eq!(res, 500);
345 }
346
347 #[tokio_test_no_panics]
348 async fn execute_update_with_aborted_retry() {
349 let mut mock = create_session_mock();
350
351 mock.expect_begin_transaction().times(2).returning(|_req| {
352 Ok(tonic::Response::new(v1::Transaction {
353 id: vec![0, 1, 2],
354 ..Default::default()
355 }))
356 });
357
358 let mut seq = mockall::Sequence::new();
359 mock.expect_execute_streaming_sql()
360 .times(1)
361 .in_sequence(&mut seq)
362 .returning(move |_req| {
363 let stream = adapt([Err(create_aborted_status(std::time::Duration::from_nanos(
365 1,
366 )))]);
367 Ok(tonic::Response::from(stream))
368 });
369 mock.expect_execute_streaming_sql()
370 .times(1)
371 .in_sequence(&mut seq)
372 .returning(move |_req| {
373 let stream = adapt([Ok(v1::PartialResultSet {
374 stats: Some(v1::ResultSetStats {
375 row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(100)),
376 ..Default::default()
377 }),
378 ..Default::default()
379 })]);
380 Ok(tonic::Response::from(stream))
381 });
382
383 let (db_client, _server) = setup_db_client(mock).await;
384 let transaction = db_client
385 .partitioned_dml_transaction()
386 .build()
387 .await
388 .unwrap();
389 let res: i64 = transaction
390 .execute_update(Statement::builder("UPDATE Users SET active = true").build())
391 .await
392 .unwrap();
393 assert_eq!(res, 100);
394 }
395
396 #[tokio_test_no_panics]
397 async fn builder_with_retry_settings() {
398 let mock = create_session_mock();
399 let (db_client, _server) = setup_db_client(mock).await;
400
401 let policy = BasicTransactionRetryPolicy::new()
402 .with_max_attempts(10)
403 .with_total_timeout(std::time::Duration::from_secs(42));
404
405 let _transaction = db_client
406 .partitioned_dml_transaction()
407 .with_retry_policy(policy)
408 .build()
409 .await
410 .unwrap();
411 }
412
413 #[tokio_test_no_panics]
414 async fn execute_update_missing_lower_bound() {
415 let mut mock = create_session_mock();
416
417 mock.expect_begin_transaction().once().returning(|_req| {
418 Ok(tonic::Response::new(v1::Transaction {
419 id: vec![0, 1, 2],
420 ..Default::default()
421 }))
422 });
423
424 mock.expect_execute_streaming_sql()
425 .once()
426 .returning(|_req| {
427 let stream = adapt([Ok(v1::PartialResultSet {
428 stats: Some(v1::ResultSetStats {
429 row_count: Some(v1::result_set_stats::RowCount::RowCountExact(100)),
431 ..Default::default()
432 }),
433 ..Default::default()
434 })]);
435 Ok(tonic::Response::from(stream))
436 });
437
438 let (db_client, _server) = setup_db_client(mock).await;
439 let transaction = db_client
440 .partitioned_dml_transaction()
441 .build()
442 .await
443 .unwrap();
444
445 let statement = Statement::builder("UPDATE Users SET active = true").build();
446 let res = transaction.execute_update(statement).await;
447
448 assert!(res.is_err());
449 let err = res.unwrap_err();
450 assert!(err.is_deserialization());
451 assert!(
452 err.to_string()
453 .contains("no row_count_lower_bound was returned")
454 );
455 }
456
457 #[tokio_test_no_panics]
458 async fn leader_aware_routing_enabled_by_default() {
459 let mut mock = create_session_mock();
460 mock.expect_begin_transaction().once().returning(|req| {
461 assert_eq!(
462 req.metadata()
463 .get("x-goog-spanner-route-to-leader")
464 .expect("header required")
465 .to_str()
466 .unwrap(),
467 "true"
468 );
469 Ok(tonic::Response::new(v1::Transaction {
470 id: vec![0, 1, 2],
471 ..Default::default()
472 }))
473 });
474
475 mock.expect_execute_streaming_sql().once().returning(|req| {
476 assert_eq!(
477 req.metadata()
478 .get("x-goog-spanner-route-to-leader")
479 .expect("header required")
480 .to_str()
481 .unwrap(),
482 "true"
483 );
484 let stream = adapt([Ok(v1::PartialResultSet {
485 stats: Some(v1::ResultSetStats {
486 row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
487 ..Default::default()
488 }),
489 ..Default::default()
490 })]);
491 Ok(tonic::Response::from(stream))
492 });
493
494 let (db_client, _server) = setup_db_client(mock).await;
495 let transaction = db_client
496 .partitioned_dml_transaction()
497 .build()
498 .await
499 .unwrap();
500 let statement = Statement::builder("UPDATE Users SET active = true").build();
501 let res: i64 = transaction.execute_update(statement).await.unwrap();
502 assert_eq!(res, 500);
503 }
504}