use std::fmt;
use arrow_array::RecordBatch;
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
use crate::{BulkWriter, Error, Result, SchemaMapping, TableName, WriteOptions, WriteStats};
type CompatibleMssqlTransport = Compat<TcpStream>;
pub struct ConnectedMssqlClient {
client: tiberius::Client<CompatibleMssqlTransport>,
}
pub struct ConnectedBulkWriter<'client> {
writer: BulkWriter<'client, CompatibleMssqlTransport>,
}
impl fmt::Debug for ConnectedMssqlClient {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ConnectedMssqlClient")
.finish_non_exhaustive()
}
}
impl fmt::Debug for ConnectedBulkWriter<'_> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ConnectedBulkWriter")
.finish_non_exhaustive()
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct SqlExecutionOutcome {
pub rows_affected: Vec<u64>,
}
impl SqlExecutionOutcome {
pub fn total_rows_affected(&self) -> u64 {
self.rows_affected.iter().copied().sum()
}
}
impl ConnectedMssqlClient {
pub async fn table_exists(&mut self, table: &TableName) -> Result<bool> {
let query = table_exists_query(table);
let row = self
.client
.simple_query(query)
.await
.map_err(|source| Error::TableExistsQuery { source })?
.into_row()
.await
.map_err(|source| Error::TableExistsQuery { source })?
.ok_or_else(|| Error::TableExistsUnexpectedResult {
reason: "metadata query returned no rows".to_owned(),
})?;
row.try_get("exists")
.map_err(|source| Error::TableExistsQuery { source })?
.ok_or_else(|| Error::TableExistsUnexpectedResult {
reason: "metadata query returned NULL".to_owned(),
})
}
pub async fn execute_statement(&mut self, sql: &str) -> Result<SqlExecutionOutcome> {
let result = self
.client
.execute(sql, &[])
.await
.map_err(|source| Error::SqlExecution { source })?;
Ok(SqlExecutionOutcome {
rows_affected: result.rows_affected().to_vec(),
})
}
pub async fn bulk_writer(
&mut self,
table: TableName,
mappings: Vec<SchemaMapping>,
options: WriteOptions,
) -> Result<ConnectedBulkWriter<'_>> {
let writer = BulkWriter::new(&mut self.client, table, mappings, options).await?;
Ok(ConnectedBulkWriter { writer })
}
}
impl ConnectedBulkWriter<'_> {
pub async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteStats> {
self.writer.write_batch(batch).await
}
pub async fn finish(self) -> Result<WriteStats> {
self.writer.finish().await
}
}
pub async fn connect_mssql_client_from_ado_string(
connection_string: &str,
) -> Result<ConnectedMssqlClient> {
let config = tiberius::Config::from_ado_string(connection_string)
.map_err(|_source| Error::InvalidConnectionString)?;
let tcp = TcpStream::connect(config.get_addr())
.await
.map_err(|source| Error::ConnectionTcpConnect { source })?;
tcp.set_nodelay(true)
.map_err(|source| Error::ConnectionTcpConnect { source })?;
let client = tiberius::Client::connect(config, tcp.compat_write())
.await
.map_err(|source| Error::ConnectionClientSetup { source })?;
Ok(ConnectedMssqlClient { client })
}
fn table_exists_query(table: &TableName) -> String {
let mut conditions = vec![format!(
"t.name = {}",
sql_string_literal(table.table().as_str())
)];
if let Some(schema) = table.schema() {
conditions.push(format!("s.name = {}", sql_string_literal(schema.as_str())));
}
format!(
"SELECT CASE WHEN EXISTS (SELECT 1 FROM sys.tables AS t \
INNER JOIN sys.schemas AS s ON s.schema_id = t.schema_id \
WHERE {}) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END AS [exists]",
conditions.join(" AND ")
)
}
fn sql_string_literal(value: &str) -> String {
format!("N'{}'", value.replace('\'', "''"))
}
#[cfg(test)]
mod tests {
use crate::{Error, connect_mssql_client_from_ado_string};
#[test]
fn sql_execution_outcome_records_rows_affected_in_order() {
let outcome = crate::SqlExecutionOutcome {
rows_affected: vec![2, 3, 5],
};
assert_eq!(outcome.rows_affected, vec![2, 3, 5]);
assert_eq!(outcome.total_rows_affected(), 10);
}
#[test]
fn table_exists_query_filters_schema_and_table() -> crate::Result<()> {
let table = crate::TableName::new("tenant", "people")?;
let query = super::table_exists_query(&table);
assert!(query.contains("FROM sys.tables AS t"));
assert!(query.contains("INNER JOIN sys.schemas AS s"));
assert!(query.contains("t.name = N'people'"));
assert!(query.contains("s.name = N'tenant'"));
Ok(())
}
#[test]
fn table_exists_query_escapes_string_literals() -> crate::Result<()> {
let table = crate::TableName::new("tenant's", "people's")?;
let query = super::table_exists_query(&table);
assert!(query.contains("t.name = N'people''s'"));
assert!(query.contains("s.name = N'tenant''s'"));
Ok(())
}
#[test]
fn unqualified_table_exists_query_filters_only_table_name() -> crate::Result<()> {
let table = crate::TableName::unqualified("people")?;
let query = super::table_exists_query(&table);
assert!(query.contains("t.name = N'people'"));
assert!(!query.contains("s.name ="));
Ok(())
}
#[test]
fn connected_client_type_is_public_without_raw_client_signature() {
let type_name = std::any::type_name::<crate::ConnectedMssqlClient>();
assert!(type_name.contains("ConnectedMssqlClient"));
assert!(!type_name.contains("tiberius::Client"));
}
#[test]
fn connected_writer_type_is_public_without_raw_transport_signature() {
let type_name = std::any::type_name::<crate::ConnectedBulkWriter<'static>>();
assert!(type_name.contains("ConnectedBulkWriter"));
assert!(!type_name.contains("tiberius::Client"));
assert!(!type_name.contains("tokio::net::TcpStream"));
}
#[tokio::test]
async fn invalid_connection_string_error_is_redacted() -> crate::Result<()> {
let connection_string =
"Server=tcp:localhost,notaport;Password=secret-token-123;Access Token=token-456";
let result = connect_mssql_client_from_ado_string(connection_string).await;
let Err(error) = result else {
return Err(Error::InvalidConnectionString);
};
assert!(matches!(error, Error::InvalidConnectionString));
let display = error.to_string();
let debug = format!("{error:?}");
for secret in ["secret-token-123", "token-456", connection_string] {
assert!(!display.contains(secret));
assert!(!debug.contains(secret));
}
Ok(())
}
#[tokio::test]
async fn tcp_connect_error_is_structured_and_redacted() -> crate::Result<()> {
let connection_string =
"Server=tcp:127.0.0.1,1;User Id=sa;Password=secret-token-123;Encrypt=false";
let result = connect_mssql_client_from_ado_string(connection_string).await;
let Err(error) = result else {
return Err(Error::InvalidConnectionString);
};
assert!(matches!(error, Error::ConnectionTcpConnect { .. }));
let display = error.to_string();
let debug = format!("{error:?}");
for secret in ["secret-token-123", connection_string] {
assert!(!display.contains(secret));
assert!(!debug.contains(secret));
}
Ok(())
}
}