google_cloud_spanner/
partitioned_dml_transaction.rs1use crate::client::amend_request_options_for_lar;
16use crate::database_client::DatabaseClient;
17use crate::google::spanner::v1::result_set_stats::RowCount::RowCountLowerBound;
18use crate::model::transaction_options::PartitionedDml;
19use crate::model::{
20 BeginTransactionRequest, TransactionOptions, TransactionSelector, transaction_selector,
21};
22use crate::server_streaming::stream::PartialResultSetStream;
23use crate::statement::Statement;
24use crate::transaction_retry_policy::{
25 BasicTransactionRetryPolicy, TransactionRetryPolicy, retry_aborted,
26};
27use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
28
29pub struct PartitionedDmlTransactionBuilder {
44 client: DatabaseClient,
45 retry_policy: Box<dyn TransactionRetryPolicy>,
46 exclude_txn_from_change_streams: bool,
47}
48
49impl PartitionedDmlTransactionBuilder {
50 pub(crate) fn new(client: DatabaseClient) -> Self {
51 Self {
52 client,
53 retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
54 exclude_txn_from_change_streams: false,
55 }
56 }
57
58 pub fn with_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
82 self.exclude_txn_from_change_streams = exclude;
83 self
84 }
85
86 pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
113 self.retry_policy = Box::new(policy);
114 self
115 }
116
117 pub async fn build(self) -> crate::Result<PartitionedDmlTransaction> {
119 Ok(PartitionedDmlTransaction {
120 client: self.client,
121 retry_policy: self.retry_policy,
122 exclude_txn_from_change_streams: self.exclude_txn_from_change_streams,
123 })
124 }
125}
126
127pub struct PartitionedDmlTransaction {
137 client: DatabaseClient,
138 retry_policy: Box<dyn TransactionRetryPolicy>,
139 exclude_txn_from_change_streams: bool,
140}
141
142impl PartitionedDmlTransaction {
143 pub async fn execute_update<T: Into<Statement>>(self, statement: T) -> crate::Result<i64> {
167 let statement = statement.into();
168 let mut gax_options = statement.gax_options().clone();
169 self.amend_gax_options(&mut gax_options);
170
171 let session_name = self.client.session_name();
172 let transaction_options = TransactionOptions::default()
173 .set_partitioned_dml(PartitionedDml::default())
174 .set_exclude_txn_from_change_streams(self.exclude_txn_from_change_streams);
175 let begin_request = BeginTransactionRequest {
176 session: session_name.clone(),
177 options: Some(transaction_options),
178 ..Default::default()
179 };
180 let base_request = statement.into_request();
181 let channel_hint = self.client.spanner.next_channel_hint();
182 let client = self.client;
183
184 retry_aborted(&*self.retry_policy, || {
186 let begin_request = begin_request.clone();
187 let base_request = base_request.clone();
188 let session_name = session_name.clone();
189 let gax_options = gax_options.clone();
190 let client = client.clone();
191
192 async move {
193 let transaction = client
194 .spanner
195 .begin_transaction(begin_request, gax_options.clone(), channel_hint)
196 .await?;
197
198 let execute_request =
199 base_request
200 .set_session(session_name)
201 .set_transaction(TransactionSelector {
202 selector: Some(transaction_selector::Selector::Id(
203 transaction.id.clone(),
204 )),
205 ..Default::default()
206 });
207
208 let stream_builder = client.spanner.execute_streaming_sql(
209 execute_request,
210 gax_options,
211 channel_hint,
212 );
213 let stream = stream_builder.send().await?;
214
215 extract_lower_bound_update_count_from_stream(stream).await
216 }
217 })
218 .await
219 }
220
221 fn amend_gax_options(&self, options: &mut GaxRequestOptions) {
222 *options = amend_request_options_for_lar(
223 self.client.leader_aware_routing_enabled,
224 options.clone(),
225 );
226 }
227}
228
229async fn extract_lower_bound_update_count_from_stream(
234 mut stream: PartialResultSetStream,
235) -> crate::Result<i64> {
236 let mut lower_bound: Option<i64> = None;
237 while let Some(prs) = stream.next_message().await.transpose()? {
238 if let Some(RowCountLowerBound(val)) = prs.stats.and_then(|s| s.row_count) {
239 lower_bound = Some(val);
240 }
241 }
242 lower_bound.ok_or_else(|| {
243 crate::Error::deser(crate::error::SpannerInternalError::new(
244 "ExecuteStreamingSql completed successfully but no row_count_lower_bound was returned",
245 ))
246 })
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
253 use crate::result_set::tests::adapt;
254 use crate::transaction_retry_policy::tests::create_aborted_status;
255 use gaxi::grpc::tonic;
256 use google_cloud_test_macros::tokio_test_no_panics;
257 use spanner_grpc_mock::google::spanner::v1;
258
259 #[test]
260 fn auto_traits() {
261 static_assertions::assert_impl_all!(PartitionedDmlTransactionBuilder: Send, Sync);
262 static_assertions::assert_impl_all!(PartitionedDmlTransaction: Send, Sync);
263 }
264
265 #[tokio_test_no_panics]
266 async fn execute_update_success() {
267 let mut mock = create_session_mock();
268
269 mock.expect_begin_transaction().once().returning(|req| {
270 let req = req.into_inner();
271 assert_eq!(
272 req.session,
273 "projects/p/instances/i/databases/d/sessions/123"
274 );
275 Ok(tonic::Response::new(v1::Transaction {
276 id: vec![0, 1, 2],
277 ..Default::default()
278 }))
279 });
280
281 mock.expect_execute_streaming_sql().once().returning(|req| {
282 let req = req.into_inner();
283 assert_eq!(req.sql, "UPDATE Users SET active = true");
284
285 let stream = adapt([Ok(v1::PartialResultSet {
286 stats: Some(v1::ResultSetStats {
287 row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
288 ..Default::default()
289 }),
290 ..Default::default()
291 })]);
292 Ok(tonic::Response::from(stream))
293 });
294
295 let (db_client, _server) = setup_db_client(mock).await;
296 let transaction = db_client
297 .partitioned_dml_transaction()
298 .build()
299 .await
300 .unwrap();
301 let statement = Statement::builder("UPDATE Users SET active = true").build();
302 let res: i64 = transaction.execute_update(statement).await.unwrap();
303 assert_eq!(res, 500);
304 }
305
306 #[tokio_test_no_panics]
307 async fn execute_update_with_exclude_txn_from_change_streams() {
308 let mut mock = create_session_mock();
309
310 mock.expect_begin_transaction().once().returning(|req| {
311 let req = req.into_inner();
312 let options = req.options.expect("missing transaction options");
313 assert!(options.exclude_txn_from_change_streams);
314
315 Ok(tonic::Response::new(v1::Transaction {
316 id: vec![0, 1, 2],
317 ..Default::default()
318 }))
319 });
320
321 mock.expect_execute_streaming_sql()
322 .once()
323 .returning(|_req| {
324 let stream = adapt([Ok(v1::PartialResultSet {
325 stats: Some(v1::ResultSetStats {
326 row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
327 ..Default::default()
328 }),
329 ..Default::default()
330 })]);
331 Ok(tonic::Response::from(stream))
332 });
333
334 let (db_client, _server) = setup_db_client(mock).await;
335 let transaction = db_client
336 .partitioned_dml_transaction()
337 .with_exclude_txn_from_change_streams(true)
338 .build()
339 .await
340 .unwrap();
341 let statement = Statement::builder("UPDATE Users SET active = true").build();
342 let res: i64 = transaction.execute_update(statement).await.unwrap();
343 assert_eq!(res, 500);
344 }
345
346 #[tokio_test_no_panics]
347 async fn execute_update_with_aborted_retry() {
348 let mut mock = create_session_mock();
349
350 mock.expect_begin_transaction().times(2).returning(|_req| {
351 Ok(tonic::Response::new(v1::Transaction {
352 id: vec![0, 1, 2],
353 ..Default::default()
354 }))
355 });
356
357 let mut seq = mockall::Sequence::new();
358 mock.expect_execute_streaming_sql()
359 .times(1)
360 .in_sequence(&mut seq)
361 .returning(move |_req| {
362 let stream = adapt([Err(create_aborted_status(std::time::Duration::from_nanos(
364 1,
365 )))]);
366 Ok(tonic::Response::from(stream))
367 });
368 mock.expect_execute_streaming_sql()
369 .times(1)
370 .in_sequence(&mut seq)
371 .returning(move |_req| {
372 let stream = adapt([Ok(v1::PartialResultSet {
373 stats: Some(v1::ResultSetStats {
374 row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(100)),
375 ..Default::default()
376 }),
377 ..Default::default()
378 })]);
379 Ok(tonic::Response::from(stream))
380 });
381
382 let (db_client, _server) = setup_db_client(mock).await;
383 let transaction = db_client
384 .partitioned_dml_transaction()
385 .build()
386 .await
387 .unwrap();
388 let res: i64 = transaction
389 .execute_update(Statement::builder("UPDATE Users SET active = true").build())
390 .await
391 .unwrap();
392 assert_eq!(res, 100);
393 }
394
395 #[tokio_test_no_panics]
396 async fn builder_with_retry_settings() {
397 let mock = create_session_mock();
398 let (db_client, _server) = setup_db_client(mock).await;
399
400 let policy = BasicTransactionRetryPolicy::new()
401 .with_max_attempts(10)
402 .with_total_timeout(std::time::Duration::from_secs(42));
403
404 let _transaction = db_client
405 .partitioned_dml_transaction()
406 .with_retry_policy(policy)
407 .build()
408 .await
409 .unwrap();
410 }
411
412 #[tokio_test_no_panics]
413 async fn execute_update_missing_lower_bound() {
414 let mut mock = create_session_mock();
415
416 mock.expect_begin_transaction().once().returning(|_req| {
417 Ok(tonic::Response::new(v1::Transaction {
418 id: vec![0, 1, 2],
419 ..Default::default()
420 }))
421 });
422
423 mock.expect_execute_streaming_sql()
424 .once()
425 .returning(|_req| {
426 let stream = adapt([Ok(v1::PartialResultSet {
427 stats: Some(v1::ResultSetStats {
428 row_count: Some(v1::result_set_stats::RowCount::RowCountExact(100)),
430 ..Default::default()
431 }),
432 ..Default::default()
433 })]);
434 Ok(tonic::Response::from(stream))
435 });
436
437 let (db_client, _server) = setup_db_client(mock).await;
438 let transaction = db_client
439 .partitioned_dml_transaction()
440 .build()
441 .await
442 .unwrap();
443
444 let statement = Statement::builder("UPDATE Users SET active = true").build();
445 let res = transaction.execute_update(statement).await;
446
447 assert!(res.is_err());
448 let err = res.unwrap_err();
449 assert!(err.is_deserialization());
450 assert!(
451 err.to_string()
452 .contains("no row_count_lower_bound was returned")
453 );
454 }
455
456 #[tokio_test_no_panics]
457 async fn leader_aware_routing_enabled_by_default() {
458 let mut mock = create_session_mock();
459 mock.expect_begin_transaction().once().returning(|req| {
460 assert_eq!(
461 req.metadata()
462 .get("x-goog-spanner-route-to-leader")
463 .expect("header required")
464 .to_str()
465 .unwrap(),
466 "true"
467 );
468 Ok(tonic::Response::new(v1::Transaction {
469 id: vec![0, 1, 2],
470 ..Default::default()
471 }))
472 });
473
474 mock.expect_execute_streaming_sql().once().returning(|req| {
475 assert_eq!(
476 req.metadata()
477 .get("x-goog-spanner-route-to-leader")
478 .expect("header required")
479 .to_str()
480 .unwrap(),
481 "true"
482 );
483 let stream = adapt([Ok(v1::PartialResultSet {
484 stats: Some(v1::ResultSetStats {
485 row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
486 ..Default::default()
487 }),
488 ..Default::default()
489 })]);
490 Ok(tonic::Response::from(stream))
491 });
492
493 let (db_client, _server) = setup_db_client(mock).await;
494 let transaction = db_client
495 .partitioned_dml_transaction()
496 .build()
497 .await
498 .unwrap();
499 let statement = Statement::builder("UPDATE Users SET active = true").build();
500 let res: i64 = transaction.execute_update(statement).await.unwrap();
501 assert_eq!(res, 500);
502 }
503}