use crate::client::amend_request_options_for_lar;
use crate::database_client::DatabaseClient;
use crate::google::spanner::v1::result_set_stats::RowCount::RowCountLowerBound;
use crate::model::transaction_options::PartitionedDml;
use crate::model::{
BeginTransactionRequest, TransactionOptions, TransactionSelector, transaction_selector,
};
use crate::server_streaming::stream::PartialResultSetStream;
use crate::statement::Statement;
use crate::transaction_retry_policy::{
BasicTransactionRetryPolicy, TransactionRetryPolicy, retry_aborted,
};
use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
pub struct PartitionedDmlTransactionBuilder {
client: DatabaseClient,
retry_policy: Box<dyn TransactionRetryPolicy>,
exclude_txn_from_change_streams: bool,
}
impl PartitionedDmlTransactionBuilder {
pub(crate) fn new(client: DatabaseClient) -> Self {
Self {
client,
retry_policy: Box::new(BasicTransactionRetryPolicy::default()),
exclude_txn_from_change_streams: false,
}
}
pub fn with_exclude_txn_from_change_streams(mut self, exclude: bool) -> Self {
self.exclude_txn_from_change_streams = exclude;
self
}
pub fn with_retry_policy<P: TransactionRetryPolicy + 'static>(mut self, policy: P) -> Self {
self.retry_policy = Box::new(policy);
self
}
pub async fn build(self) -> crate::Result<PartitionedDmlTransaction> {
Ok(PartitionedDmlTransaction {
client: self.client,
retry_policy: self.retry_policy,
exclude_txn_from_change_streams: self.exclude_txn_from_change_streams,
})
}
}
pub struct PartitionedDmlTransaction {
client: DatabaseClient,
retry_policy: Box<dyn TransactionRetryPolicy>,
exclude_txn_from_change_streams: bool,
}
impl PartitionedDmlTransaction {
pub async fn execute_update<T: Into<Statement>>(self, statement: T) -> crate::Result<i64> {
let statement = statement.into();
let mut gax_options = statement.gax_options().clone();
self.amend_gax_options(&mut gax_options);
let session_name = self.client.session_name();
let transaction_options = TransactionOptions::default()
.set_partitioned_dml(PartitionedDml::default())
.set_exclude_txn_from_change_streams(self.exclude_txn_from_change_streams);
let begin_request = BeginTransactionRequest {
session: session_name.clone(),
options: Some(transaction_options),
..Default::default()
};
let base_request = statement.into_request();
let channel_hint = self.client.spanner.next_channel_hint();
let client = self.client;
retry_aborted(&*self.retry_policy, || {
let begin_request = begin_request.clone();
let base_request = base_request.clone();
let session_name = session_name.clone();
let gax_options = gax_options.clone();
let client = client.clone();
async move {
let transaction = client
.spanner
.begin_transaction(begin_request, gax_options.clone(), channel_hint)
.await?;
let execute_request =
base_request
.set_session(session_name)
.set_transaction(TransactionSelector {
selector: Some(transaction_selector::Selector::Id(
transaction.id.clone(),
)),
..Default::default()
});
let stream_builder = client.spanner.execute_streaming_sql(
execute_request,
gax_options,
channel_hint,
);
let stream = stream_builder.send().await?;
extract_lower_bound_update_count_from_stream(stream).await
}
})
.await
}
fn amend_gax_options(&self, options: &mut GaxRequestOptions) {
*options = amend_request_options_for_lar(
self.client.leader_aware_routing_enabled,
options.clone(),
);
}
}
async fn extract_lower_bound_update_count_from_stream(
mut stream: PartialResultSetStream,
) -> crate::Result<i64> {
let mut lower_bound: Option<i64> = None;
while let Some(prs) = stream.next_message().await.transpose()? {
if let Some(RowCountLowerBound(val)) = prs.stats.and_then(|s| s.row_count) {
lower_bound = Some(val);
}
}
lower_bound.ok_or_else(|| {
crate::Error::deser(crate::error::SpannerInternalError::new(
"ExecuteStreamingSql completed successfully but no row_count_lower_bound was returned",
))
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
use crate::result_set::tests::adapt;
use crate::transaction_retry_policy::tests::create_aborted_status;
use gaxi::grpc::tonic;
use google_cloud_test_macros::tokio_test_no_panics;
use spanner_grpc_mock::google::spanner::v1;
#[test]
fn auto_traits() {
static_assertions::assert_impl_all!(PartitionedDmlTransactionBuilder: Send, Sync);
static_assertions::assert_impl_all!(PartitionedDmlTransaction: Send, Sync);
}
#[tokio_test_no_panics]
async fn execute_update_success() {
let mut mock = create_session_mock();
mock.expect_begin_transaction().once().returning(|req| {
let req = req.into_inner();
assert_eq!(
req.session,
"projects/p/instances/i/databases/d/sessions/123"
);
Ok(tonic::Response::new(v1::Transaction {
id: vec![0, 1, 2],
..Default::default()
}))
});
mock.expect_execute_streaming_sql().once().returning(|req| {
let req = req.into_inner();
assert_eq!(req.sql, "UPDATE Users SET active = true");
let stream = adapt([Ok(v1::PartialResultSet {
stats: Some(v1::ResultSetStats {
row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
..Default::default()
}),
..Default::default()
})]);
Ok(tonic::Response::from(stream))
});
let (db_client, _server) = setup_db_client(mock).await;
let transaction = db_client
.partitioned_dml_transaction()
.build()
.await
.unwrap();
let statement = Statement::builder("UPDATE Users SET active = true").build();
let res: i64 = transaction.execute_update(statement).await.unwrap();
assert_eq!(res, 500);
}
#[tokio_test_no_panics]
async fn execute_update_with_exclude_txn_from_change_streams() {
let mut mock = create_session_mock();
mock.expect_begin_transaction().once().returning(|req| {
let req = req.into_inner();
let options = req.options.expect("missing transaction options");
assert!(options.exclude_txn_from_change_streams);
Ok(tonic::Response::new(v1::Transaction {
id: vec![0, 1, 2],
..Default::default()
}))
});
mock.expect_execute_streaming_sql()
.once()
.returning(|_req| {
let stream = adapt([Ok(v1::PartialResultSet {
stats: Some(v1::ResultSetStats {
row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
..Default::default()
}),
..Default::default()
})]);
Ok(tonic::Response::from(stream))
});
let (db_client, _server) = setup_db_client(mock).await;
let transaction = db_client
.partitioned_dml_transaction()
.with_exclude_txn_from_change_streams(true)
.build()
.await
.unwrap();
let statement = Statement::builder("UPDATE Users SET active = true").build();
let res: i64 = transaction.execute_update(statement).await.unwrap();
assert_eq!(res, 500);
}
#[tokio_test_no_panics]
async fn execute_update_with_aborted_retry() {
let mut mock = create_session_mock();
mock.expect_begin_transaction().times(2).returning(|_req| {
Ok(tonic::Response::new(v1::Transaction {
id: vec![0, 1, 2],
..Default::default()
}))
});
let mut seq = mockall::Sequence::new();
mock.expect_execute_streaming_sql()
.times(1)
.in_sequence(&mut seq)
.returning(move |_req| {
let stream = adapt([Err(create_aborted_status(std::time::Duration::from_nanos(
1,
)))]);
Ok(tonic::Response::from(stream))
});
mock.expect_execute_streaming_sql()
.times(1)
.in_sequence(&mut seq)
.returning(move |_req| {
let stream = adapt([Ok(v1::PartialResultSet {
stats: Some(v1::ResultSetStats {
row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(100)),
..Default::default()
}),
..Default::default()
})]);
Ok(tonic::Response::from(stream))
});
let (db_client, _server) = setup_db_client(mock).await;
let transaction = db_client
.partitioned_dml_transaction()
.build()
.await
.unwrap();
let res: i64 = transaction
.execute_update(Statement::builder("UPDATE Users SET active = true").build())
.await
.unwrap();
assert_eq!(res, 100);
}
#[tokio_test_no_panics]
async fn builder_with_retry_settings() {
let mock = create_session_mock();
let (db_client, _server) = setup_db_client(mock).await;
let policy = BasicTransactionRetryPolicy::new()
.with_max_attempts(10)
.with_total_timeout(std::time::Duration::from_secs(42));
let _transaction = db_client
.partitioned_dml_transaction()
.with_retry_policy(policy)
.build()
.await
.unwrap();
}
#[tokio_test_no_panics]
async fn execute_update_missing_lower_bound() {
let mut mock = create_session_mock();
mock.expect_begin_transaction().once().returning(|_req| {
Ok(tonic::Response::new(v1::Transaction {
id: vec![0, 1, 2],
..Default::default()
}))
});
mock.expect_execute_streaming_sql()
.once()
.returning(|_req| {
let stream = adapt([Ok(v1::PartialResultSet {
stats: Some(v1::ResultSetStats {
row_count: Some(v1::result_set_stats::RowCount::RowCountExact(100)),
..Default::default()
}),
..Default::default()
})]);
Ok(tonic::Response::from(stream))
});
let (db_client, _server) = setup_db_client(mock).await;
let transaction = db_client
.partitioned_dml_transaction()
.build()
.await
.unwrap();
let statement = Statement::builder("UPDATE Users SET active = true").build();
let res = transaction.execute_update(statement).await;
assert!(res.is_err());
let err = res.unwrap_err();
assert!(err.is_deserialization());
assert!(
err.to_string()
.contains("no row_count_lower_bound was returned")
);
}
#[tokio_test_no_panics]
async fn leader_aware_routing_enabled_by_default() {
let mut mock = create_session_mock();
mock.expect_begin_transaction().once().returning(|req| {
assert_eq!(
req.metadata()
.get("x-goog-spanner-route-to-leader")
.expect("header required")
.to_str()
.unwrap(),
"true"
);
Ok(tonic::Response::new(v1::Transaction {
id: vec![0, 1, 2],
..Default::default()
}))
});
mock.expect_execute_streaming_sql().once().returning(|req| {
assert_eq!(
req.metadata()
.get("x-goog-spanner-route-to-leader")
.expect("header required")
.to_str()
.unwrap(),
"true"
);
let stream = adapt([Ok(v1::PartialResultSet {
stats: Some(v1::ResultSetStats {
row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(500)),
..Default::default()
}),
..Default::default()
})]);
Ok(tonic::Response::from(stream))
});
let (db_client, _server) = setup_db_client(mock).await;
let transaction = db_client
.partitioned_dml_transaction()
.build()
.await
.unwrap();
let statement = Statement::builder("UPDATE Users SET active = true").build();
let res: i64 = transaction.execute_update(statement).await.unwrap();
assert_eq!(res, 500);
}
}