use crate::error::{BatchUpdateError, internal_error};
use crate::model::result_set_stats::RowCount;
use crate::model::{ExecuteBatchDmlResponse, RequestOptions};
use crate::statement::Statement;
use google_cloud_gax::backoff_policy::BackoffPolicyArg;
use google_cloud_gax::error::rpc::Code;
use google_cloud_gax::error::rpc::Status as RpcStatus;
use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
use google_cloud_gax::retry_policy::RetryPolicyArg;
use std::time::Duration;
#[derive(Clone, Default, Debug)]
pub struct BatchDmlBuilder {
statements: Vec<Statement>,
request_options: Option<RequestOptions>,
gax_options: GaxRequestOptions,
}
impl BatchDmlBuilder {
pub fn new() -> Self {
BatchDmlBuilder::default()
}
pub fn add_statement(mut self, statement: impl Into<Statement>) -> Self {
self.statements.push(statement.into());
self
}
pub fn set_request_tag(mut self, tag: impl Into<String>) -> Self {
self.request_options
.get_or_insert_with(RequestOptions::default)
.request_tag = tag.into();
self
}
pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self {
self.gax_options.set_attempt_timeout(timeout);
self
}
pub fn with_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
self.gax_options.set_retry_policy(policy);
self
}
pub fn with_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
self.gax_options.set_backoff_policy(policy);
self
}
pub fn build(self) -> BatchDml {
BatchDml {
statements: self.statements,
request_options: self.request_options,
gax_options: self.gax_options,
}
}
}
#[derive(Clone, Debug)]
pub struct BatchDml {
pub(crate) statements: Vec<Statement>,
pub(crate) request_options: Option<RequestOptions>,
pub(crate) gax_options: GaxRequestOptions,
}
impl BatchDml {
pub fn builder() -> BatchDmlBuilder {
BatchDmlBuilder::new()
}
}
impl From<BatchDmlBuilder> for BatchDml {
fn from(builder: BatchDmlBuilder) -> Self {
builder.build()
}
}
impl<T: Into<Statement>> From<Vec<T>> for BatchDml {
fn from(statements: Vec<T>) -> Self {
BatchDml {
statements: statements.into_iter().map(Into::into).collect(),
request_options: None,
gax_options: GaxRequestOptions::default(),
}
}
}
pub(crate) fn process_response(response: ExecuteBatchDmlResponse) -> crate::Result<Vec<i64>> {
let mut update_counts = Vec::with_capacity(response.result_sets.len());
for result_set in response.result_sets {
if let Some(stats) = result_set.stats {
let exact_count = match stats.row_count {
Some(RowCount::RowCountExact(c)) => c,
_ => {
return Err(internal_error(
"ExecuteBatchDml returned an invalid or missing row count type",
));
}
};
update_counts.push(exact_count);
}
}
if let Some(status) = response.status.filter(|s| s.code != Code::Ok as i32) {
let grpc_status = RpcStatus::default()
.set_code(status.code)
.set_message(status.message);
if status.code == Code::Aborted as i32 {
return Err(crate::Error::service(grpc_status));
}
return Err(BatchUpdateError::build_error(update_counts, grpc_status));
}
Ok(update_counts)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{ResultSet, ResultSetStats};
use google_cloud_rpc::model::Status;
use static_assertions::assert_impl_all;
#[test]
fn auto_traits() {
assert_impl_all!(BatchDml: Send, Sync, Clone, std::fmt::Debug);
assert_impl_all!(BatchDmlBuilder: Send, Sync, Clone, std::fmt::Debug);
}
#[test]
fn builder() {
let stmt1 = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
let stmt2 = Statement::builder("UPDATE t SET c = 2 WHERE id = 2").build();
let batch = BatchDml::builder()
.add_statement(stmt1)
.add_statement(stmt2)
.build();
assert_eq!(batch.statements.len(), 2);
assert_eq!(batch.statements[0].sql, "UPDATE t SET c = 1 WHERE id = 1");
assert_eq!(batch.statements[1].sql, "UPDATE t SET c = 2 WHERE id = 2");
assert!(
batch.request_options.is_none(),
"Unexpected request_options set: {:#?}",
batch.request_options
);
}
#[test]
fn builder_with_gax_options() {
use google_cloud_gax::backoff_policy::BackoffPolicy;
use google_cloud_gax::retry_policy::Aip194Strict;
use google_cloud_gax::retry_state::RetryState;
use std::time::Duration;
#[derive(Debug)]
struct DummyBackoff;
impl BackoffPolicy for DummyBackoff {
fn on_failure(&self, _state: &RetryState) -> Duration {
Duration::ZERO
}
}
let stmt = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
let batch = BatchDml::builder()
.add_statement(stmt)
.with_attempt_timeout(Duration::from_secs(5))
.with_retry_policy(Aip194Strict)
.with_backoff_policy(DummyBackoff)
.build();
assert_eq!(
*batch.gax_options.attempt_timeout(),
Some(Duration::from_secs(5))
);
assert!(batch.gax_options.retry_policy().is_some());
assert!(batch.gax_options.backoff_policy().is_some());
}
#[test]
fn builder_with_request_tag() {
let stmt = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build();
let batch = BatchDml::builder()
.add_statement(stmt)
.set_request_tag("tag1")
.build();
assert_eq!(batch.statements.len(), 1);
assert_eq!(
batch
.request_options
.expect("request options missing")
.request_tag,
"tag1"
);
}
#[test]
fn process_response_success() -> anyhow::Result<()> {
let stats1 = ResultSetStats {
row_count: Some(RowCount::RowCountExact(5)),
..Default::default()
};
let stats2 = ResultSetStats {
row_count: Some(RowCount::RowCountExact(10)),
..Default::default()
};
let rs1 = ResultSet {
stats: Some(stats1),
..Default::default()
};
let rs2 = ResultSet {
stats: Some(stats2),
..Default::default()
};
let response = ExecuteBatchDmlResponse {
result_sets: vec![rs1, rs2],
status: None,
..Default::default()
};
let counts = process_response(response)?;
assert_eq!(counts, vec![5, 10]);
Ok(())
}
#[test]
fn process_response_grpc_error() {
let stats = ResultSetStats {
row_count: Some(RowCount::RowCountExact(3)),
..Default::default()
};
let rs = ResultSet {
stats: Some(stats),
..Default::default()
};
let err_status = Status::default()
.set_code(Code::InvalidArgument as i32)
.set_message("Bad query");
let response = ExecuteBatchDmlResponse {
result_sets: vec![rs],
status: Some(err_status),
..Default::default()
};
let result = process_response(response);
let err = result.expect_err("should return error");
let batch_err = BatchUpdateError::extract(&err).expect("should extract BatchUpdateError");
assert_eq!(batch_err.update_counts, vec![3]);
assert_eq!(
batch_err.status.status().expect("status").code,
Code::InvalidArgument
);
assert_eq!(
batch_err.status.status().expect("status").message,
"Bad query"
);
}
#[test]
fn process_response_aborted() {
let stats = ResultSetStats {
row_count: Some(RowCount::RowCountExact(3)),
..Default::default()
};
let rs = ResultSet {
stats: Some(stats),
..Default::default()
};
let err_status = Status::default()
.set_code(Code::Aborted as i32)
.set_message("transaction aborted");
let response = ExecuteBatchDmlResponse {
result_sets: vec![rs],
status: Some(err_status),
..Default::default()
};
let result = process_response(response);
let err = result.expect_err("should return error");
let batch_err = BatchUpdateError::extract(&err);
assert!(
batch_err.is_none(),
"Unexpected BatchUpdateError: {batch_err:?}"
);
assert_eq!(err.status().expect("status").code, Code::Aborted);
assert_eq!(err.status().expect("status").message, "transaction aborted");
}
#[test]
fn process_response_missing_stats() {
let rs = ResultSet {
stats: None,
..Default::default()
};
let response = ExecuteBatchDmlResponse {
result_sets: vec![rs],
..Default::default()
};
let result = process_response(response).expect("should return empty update counts");
assert!(result.is_empty());
}
#[test]
fn process_response_missing_row_count_type() {
let stats = ResultSetStats {
row_count: None,
..Default::default()
};
let rs = ResultSet {
stats: Some(stats),
..Default::default()
};
let response = ExecuteBatchDmlResponse {
result_sets: vec![rs],
..Default::default()
};
let result = process_response(response);
let err = result.expect_err("should fail");
assert!(
err.to_string()
.contains("invalid or missing row count type")
);
}
#[test]
fn from_vector_of_strings() {
let statements = vec!["UPDATE table SET col = 1", "UPDATE table SET col = 2"];
let batch: BatchDml = statements.into();
assert_eq!(batch.statements.len(), 2);
assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
assert_eq!(batch.statements[1].sql, "UPDATE table SET col = 2");
}
#[test]
fn from_vector_of_statements() {
let statement1 = Statement::builder("UPDATE table SET col = 1").build();
let statement2 = Statement::builder("UPDATE table SET col = 2").build();
let statements = vec![statement1, statement2];
let batch: BatchDml = statements.into();
assert_eq!(batch.statements.len(), 2);
assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
assert_eq!(batch.statements[1].sql, "UPDATE table SET col = 2");
}
#[test]
fn from_builder() {
let builder = BatchDml::builder().add_statement("UPDATE table SET col = 1");
let batch: BatchDml = builder.into();
assert_eq!(batch.statements.len(), 1);
assert_eq!(batch.statements[0].sql, "UPDATE table SET col = 1");
}
#[test]
fn process_response_metadata_no_stats_grpc_error() {
let rs = ResultSet {
metadata: Some(crate::model::ResultSetMetadata {
transaction: Some(crate::model::Transaction {
id: vec![7, 7, 7].into(),
..Default::default()
}),
..Default::default()
}),
stats: None,
..Default::default()
};
let err_status = Status::default()
.set_code(Code::InvalidArgument as i32)
.set_message("Table not found or syntax invalid");
let response = ExecuteBatchDmlResponse {
result_sets: vec![rs],
status: Some(err_status),
..Default::default()
};
let result = process_response(response);
let err = result.expect_err("should return error");
let batch_err = BatchUpdateError::extract(&err)
.expect("should extract BatchUpdateError cleanly and not return internal error");
assert_eq!(
batch_err.update_counts,
Vec::<i64>::new(),
"Update counts should be completely empty"
);
assert_eq!(
batch_err.status.status().expect("status").code,
Code::InvalidArgument
);
assert_eq!(
batch_err.status.status().expect("status").message,
"Table not found or syntax invalid"
);
}
}