arrow_tiberius/
connection.rs1use std::fmt;
4
5use arrow_array::RecordBatch;
6use tokio::net::TcpStream;
7use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
8
9use crate::{BulkWriter, Error, Result, SchemaMapping, TableName, WriteOptions, WriteStats};
10
11type CompatibleMssqlTransport = Compat<TcpStream>;
12
13pub struct ConnectedMssqlClient {
20 client: tiberius::Client<CompatibleMssqlTransport>,
21}
22
23pub struct ConnectedBulkWriter<'client> {
29 writer: BulkWriter<'client, CompatibleMssqlTransport>,
30}
31
32impl fmt::Debug for ConnectedMssqlClient {
33 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
34 formatter
35 .debug_struct("ConnectedMssqlClient")
36 .finish_non_exhaustive()
37 }
38}
39
40impl fmt::Debug for ConnectedBulkWriter<'_> {
41 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
42 formatter
43 .debug_struct("ConnectedBulkWriter")
44 .finish_non_exhaustive()
45 }
46}
47
48#[derive(Clone, Debug, Default, Eq, PartialEq)]
54pub struct SqlExecutionOutcome {
55 pub rows_affected: Vec<u64>,
57}
58
59impl SqlExecutionOutcome {
60 pub fn total_rows_affected(&self) -> u64 {
62 self.rows_affected.iter().copied().sum()
63 }
64}
65
66impl ConnectedMssqlClient {
67 pub async fn table_exists(&mut self, table: &TableName) -> Result<bool> {
74 let query = table_exists_query(table);
75 let row = self
76 .client
77 .simple_query(query)
78 .await
79 .map_err(|source| Error::TableExistsQuery { source })?
80 .into_row()
81 .await
82 .map_err(|source| Error::TableExistsQuery { source })?
83 .ok_or_else(|| Error::TableExistsUnexpectedResult {
84 reason: "metadata query returned no rows".to_owned(),
85 })?;
86
87 row.try_get("exists")
88 .map_err(|source| Error::TableExistsQuery { source })?
89 .ok_or_else(|| Error::TableExistsUnexpectedResult {
90 reason: "metadata query returned NULL".to_owned(),
91 })
92 }
93
94 pub async fn execute_statement(&mut self, sql: &str) -> Result<SqlExecutionOutcome> {
100 let result = self
101 .client
102 .execute(sql, &[])
103 .await
104 .map_err(|source| Error::SqlExecution { source })?;
105
106 Ok(SqlExecutionOutcome {
107 rows_affected: result.rows_affected().to_vec(),
108 })
109 }
110
111 pub async fn bulk_writer(
117 &mut self,
118 table: TableName,
119 mappings: Vec<SchemaMapping>,
120 options: WriteOptions,
121 ) -> Result<ConnectedBulkWriter<'_>> {
122 let writer = BulkWriter::new(&mut self.client, table, mappings, options).await?;
123
124 Ok(ConnectedBulkWriter { writer })
125 }
126}
127
128impl ConnectedBulkWriter<'_> {
129 pub async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteStats> {
131 self.writer.write_batch(batch).await
132 }
133
134 pub async fn finish(self) -> Result<WriteStats> {
136 self.writer.finish().await
137 }
138}
139
140pub async fn connect_mssql_client_from_ado_string(
148 connection_string: &str,
149) -> Result<ConnectedMssqlClient> {
150 let config = tiberius::Config::from_ado_string(connection_string)
151 .map_err(|_source| Error::InvalidConnectionString)?;
152 let tcp = TcpStream::connect(config.get_addr())
153 .await
154 .map_err(|source| Error::ConnectionTcpConnect { source })?;
155 tcp.set_nodelay(true)
156 .map_err(|source| Error::ConnectionTcpConnect { source })?;
157
158 let client = tiberius::Client::connect(config, tcp.compat_write())
159 .await
160 .map_err(|source| Error::ConnectionClientSetup { source })?;
161
162 Ok(ConnectedMssqlClient { client })
163}
164
165fn table_exists_query(table: &TableName) -> String {
166 let mut conditions = vec![format!(
167 "t.name = {}",
168 sql_string_literal(table.table().as_str())
169 )];
170 if let Some(schema) = table.schema() {
171 conditions.push(format!("s.name = {}", sql_string_literal(schema.as_str())));
172 }
173
174 format!(
175 "SELECT CASE WHEN EXISTS (SELECT 1 FROM sys.tables AS t \
176 INNER JOIN sys.schemas AS s ON s.schema_id = t.schema_id \
177 WHERE {}) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END AS [exists]",
178 conditions.join(" AND ")
179 )
180}
181
182fn sql_string_literal(value: &str) -> String {
183 format!("N'{}'", value.replace('\'', "''"))
184}
185
186#[cfg(test)]
187mod tests {
188 use crate::{Error, connect_mssql_client_from_ado_string};
189
190 #[test]
191 fn sql_execution_outcome_records_rows_affected_in_order() {
192 let outcome = crate::SqlExecutionOutcome {
193 rows_affected: vec![2, 3, 5],
194 };
195
196 assert_eq!(outcome.rows_affected, vec![2, 3, 5]);
197 assert_eq!(outcome.total_rows_affected(), 10);
198 }
199
200 #[test]
201 fn table_exists_query_filters_schema_and_table() -> crate::Result<()> {
202 let table = crate::TableName::new("tenant", "people")?;
203 let query = super::table_exists_query(&table);
204
205 assert!(query.contains("FROM sys.tables AS t"));
206 assert!(query.contains("INNER JOIN sys.schemas AS s"));
207 assert!(query.contains("t.name = N'people'"));
208 assert!(query.contains("s.name = N'tenant'"));
209 Ok(())
210 }
211
212 #[test]
213 fn table_exists_query_escapes_string_literals() -> crate::Result<()> {
214 let table = crate::TableName::new("tenant's", "people's")?;
215 let query = super::table_exists_query(&table);
216
217 assert!(query.contains("t.name = N'people''s'"));
218 assert!(query.contains("s.name = N'tenant''s'"));
219 Ok(())
220 }
221
222 #[test]
223 fn unqualified_table_exists_query_filters_only_table_name() -> crate::Result<()> {
224 let table = crate::TableName::unqualified("people")?;
225 let query = super::table_exists_query(&table);
226
227 assert!(query.contains("t.name = N'people'"));
228 assert!(!query.contains("s.name ="));
229 Ok(())
230 }
231
232 #[test]
233 fn connected_client_type_is_public_without_raw_client_signature() {
234 let type_name = std::any::type_name::<crate::ConnectedMssqlClient>();
235
236 assert!(type_name.contains("ConnectedMssqlClient"));
237 assert!(!type_name.contains("tiberius::Client"));
238 }
239
240 #[test]
241 fn connected_writer_type_is_public_without_raw_transport_signature() {
242 let type_name = std::any::type_name::<crate::ConnectedBulkWriter<'static>>();
243
244 assert!(type_name.contains("ConnectedBulkWriter"));
245 assert!(!type_name.contains("tiberius::Client"));
246 assert!(!type_name.contains("tokio::net::TcpStream"));
247 }
248
249 #[tokio::test]
250 async fn invalid_connection_string_error_is_redacted() -> crate::Result<()> {
251 let connection_string =
252 "Server=tcp:localhost,notaport;Password=secret-token-123;Access Token=token-456";
253 let result = connect_mssql_client_from_ado_string(connection_string).await;
254 let Err(error) = result else {
255 return Err(Error::InvalidConnectionString);
256 };
257
258 assert!(matches!(error, Error::InvalidConnectionString));
259 let display = error.to_string();
260 let debug = format!("{error:?}");
261
262 for secret in ["secret-token-123", "token-456", connection_string] {
263 assert!(!display.contains(secret));
264 assert!(!debug.contains(secret));
265 }
266
267 Ok(())
268 }
269
270 #[tokio::test]
271 async fn tcp_connect_error_is_structured_and_redacted() -> crate::Result<()> {
272 let connection_string =
273 "Server=tcp:127.0.0.1,1;User Id=sa;Password=secret-token-123;Encrypt=false";
274 let result = connect_mssql_client_from_ado_string(connection_string).await;
275 let Err(error) = result else {
276 return Err(Error::InvalidConnectionString);
277 };
278
279 assert!(matches!(error, Error::ConnectionTcpConnect { .. }));
280 let display = error.to_string();
281 let debug = format!("{error:?}");
282
283 for secret in ["secret-token-123", connection_string] {
284 assert!(!display.contains(secret));
285 assert!(!debug.contains(secret));
286 }
287
288 Ok(())
289 }
290}