Skip to main content

arrow_tiberius/
connection.rs

1//! SQL Server connection helpers.
2
3use std::fmt;
4
5use arrow_array::RecordBatch;
6use tokio::net::TcpStream;
7use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
8
9use crate::{BulkWriter, Error, Result, SchemaMapping, TableName, WriteOptions, WriteStats};
10
11type CompatibleMssqlTransport = Compat<TcpStream>;
12
13/// Opaque SQL Server client constructed with this crate's compatible Tiberius dependency.
14///
15/// Use [`connect_mssql_client_from_ado_string`] to create this type. Its
16/// concrete Tiberius client and async transport types are intentionally hidden
17/// so downstream crates do not have to name or match `tiberius-raw-bulk`
18/// directly.
19pub struct ConnectedMssqlClient {
20    client: tiberius::Client<CompatibleMssqlTransport>,
21}
22
23/// Bulk writer created from a [`ConnectedMssqlClient`].
24///
25/// This wrapper keeps the compatible Tiberius client and transport types out of
26/// downstream signatures while exposing the same write and finish operations as
27/// [`BulkWriter`].
28pub struct ConnectedBulkWriter<'client> {
29    writer: BulkWriter<'client, CompatibleMssqlTransport>,
30}
31
32impl fmt::Debug for ConnectedMssqlClient {
33    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
34        formatter
35            .debug_struct("ConnectedMssqlClient")
36            .finish_non_exhaustive()
37    }
38}
39
40impl fmt::Debug for ConnectedBulkWriter<'_> {
41    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
42        formatter
43            .debug_struct("ConnectedBulkWriter")
44            .finish_non_exhaustive()
45    }
46}
47
48/// Metadata returned after executing SQL through a connected client.
49///
50/// This type is part of the narrow lifecycle SQL API. Statement execution is
51/// added separately from connection construction so connection setup can remain
52/// independently reviewable.
53#[derive(Clone, Debug, Default, Eq, PartialEq)]
54pub struct SqlExecutionOutcome {
55    /// Row counts reported by SQL Server DONE tokens, in server result order.
56    pub rows_affected: Vec<u64>,
57}
58
59impl SqlExecutionOutcome {
60    /// Returns the sum of all reported affected-row counts.
61    pub fn total_rows_affected(&self) -> u64 {
62        self.rows_affected.iter().copied().sum()
63    }
64}
65
66impl ConnectedMssqlClient {
67    /// Returns whether the target table exists in SQL Server metadata.
68    ///
69    /// This is a narrow metadata probe, not a generic query API. For
70    /// schema-qualified names it checks the exact schema and table. For
71    /// unqualified names it checks whether any table with that name exists in
72    /// the current database.
73    pub async fn table_exists(&mut self, table: &TableName) -> Result<bool> {
74        let query = table_exists_query(table);
75        let row = self
76            .client
77            .simple_query(query)
78            .await
79            .map_err(|source| Error::TableExistsQuery { source })?
80            .into_row()
81            .await
82            .map_err(|source| Error::TableExistsQuery { source })?
83            .ok_or_else(|| Error::TableExistsUnexpectedResult {
84                reason: "metadata query returned no rows".to_owned(),
85            })?;
86
87        row.try_get("exists")
88            .map_err(|source| Error::TableExistsQuery { source })?
89            .ok_or_else(|| Error::TableExistsUnexpectedResult {
90                reason: "metadata query returned NULL".to_owned(),
91            })
92    }
93
94    /// Executes a prepared lifecycle SQL statement.
95    ///
96    /// This method accepts statement text but intentionally returns only
97    /// affected-row metadata. It does not expose a generic result-row mapping
98    /// API.
99    pub async fn execute_statement(&mut self, sql: &str) -> Result<SqlExecutionOutcome> {
100        let result = self
101            .client
102            .execute(sql, &[])
103            .await
104            .map_err(|source| Error::SqlExecution { source })?;
105
106        Ok(SqlExecutionOutcome {
107            rows_affected: result.rows_affected().to_vec(),
108        })
109    }
110
111    /// Starts a bulk writer on this same SQL Server connection.
112    ///
113    /// The returned writer borrows the connected client, so lifecycle SQL and
114    /// bulk loading cannot accidentally use two different connections through
115    /// this API.
116    pub async fn bulk_writer(
117        &mut self,
118        table: TableName,
119        mappings: Vec<SchemaMapping>,
120        options: WriteOptions,
121    ) -> Result<ConnectedBulkWriter<'_>> {
122        let writer = BulkWriter::new(&mut self.client, table, mappings, options).await?;
123
124        Ok(ConnectedBulkWriter { writer })
125    }
126}
127
128impl ConnectedBulkWriter<'_> {
129    /// Writes one Arrow record batch.
130    pub async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteStats> {
131        self.writer.write_batch(batch).await
132    }
133
134    /// Finalizes the bulk writer and returns cumulative write statistics.
135    pub async fn finish(self) -> Result<WriteStats> {
136        self.writer.finish().await
137    }
138}
139
140/// Connects to SQL Server from an ADO-style connection string.
141///
142/// The connection uses this crate's `tiberius-raw-bulk` dependency identity and
143/// Tokio TCP transport internally. The returned wrapper hides those concrete
144/// types from downstream crates.
145///
146/// The raw connection string is not stored in the returned client or in errors.
147pub async fn connect_mssql_client_from_ado_string(
148    connection_string: &str,
149) -> Result<ConnectedMssqlClient> {
150    let config = tiberius::Config::from_ado_string(connection_string)
151        .map_err(|_source| Error::InvalidConnectionString)?;
152    let tcp = TcpStream::connect(config.get_addr())
153        .await
154        .map_err(|source| Error::ConnectionTcpConnect { source })?;
155    tcp.set_nodelay(true)
156        .map_err(|source| Error::ConnectionTcpConnect { source })?;
157
158    let client = tiberius::Client::connect(config, tcp.compat_write())
159        .await
160        .map_err(|source| Error::ConnectionClientSetup { source })?;
161
162    Ok(ConnectedMssqlClient { client })
163}
164
165fn table_exists_query(table: &TableName) -> String {
166    let mut conditions = vec![format!(
167        "t.name = {}",
168        sql_string_literal(table.table().as_str())
169    )];
170    if let Some(schema) = table.schema() {
171        conditions.push(format!("s.name = {}", sql_string_literal(schema.as_str())));
172    }
173
174    format!(
175        "SELECT CASE WHEN EXISTS (SELECT 1 FROM sys.tables AS t \
176         INNER JOIN sys.schemas AS s ON s.schema_id = t.schema_id \
177         WHERE {}) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END AS [exists]",
178        conditions.join(" AND ")
179    )
180}
181
182fn sql_string_literal(value: &str) -> String {
183    format!("N'{}'", value.replace('\'', "''"))
184}
185
186#[cfg(test)]
187mod tests {
188    use crate::{Error, connect_mssql_client_from_ado_string};
189
190    #[test]
191    fn sql_execution_outcome_records_rows_affected_in_order() {
192        let outcome = crate::SqlExecutionOutcome {
193            rows_affected: vec![2, 3, 5],
194        };
195
196        assert_eq!(outcome.rows_affected, vec![2, 3, 5]);
197        assert_eq!(outcome.total_rows_affected(), 10);
198    }
199
200    #[test]
201    fn table_exists_query_filters_schema_and_table() -> crate::Result<()> {
202        let table = crate::TableName::new("tenant", "people")?;
203        let query = super::table_exists_query(&table);
204
205        assert!(query.contains("FROM sys.tables AS t"));
206        assert!(query.contains("INNER JOIN sys.schemas AS s"));
207        assert!(query.contains("t.name = N'people'"));
208        assert!(query.contains("s.name = N'tenant'"));
209        Ok(())
210    }
211
212    #[test]
213    fn table_exists_query_escapes_string_literals() -> crate::Result<()> {
214        let table = crate::TableName::new("tenant's", "people's")?;
215        let query = super::table_exists_query(&table);
216
217        assert!(query.contains("t.name = N'people''s'"));
218        assert!(query.contains("s.name = N'tenant''s'"));
219        Ok(())
220    }
221
222    #[test]
223    fn unqualified_table_exists_query_filters_only_table_name() -> crate::Result<()> {
224        let table = crate::TableName::unqualified("people")?;
225        let query = super::table_exists_query(&table);
226
227        assert!(query.contains("t.name = N'people'"));
228        assert!(!query.contains("s.name ="));
229        Ok(())
230    }
231
232    #[test]
233    fn connected_client_type_is_public_without_raw_client_signature() {
234        let type_name = std::any::type_name::<crate::ConnectedMssqlClient>();
235
236        assert!(type_name.contains("ConnectedMssqlClient"));
237        assert!(!type_name.contains("tiberius::Client"));
238    }
239
240    #[test]
241    fn connected_writer_type_is_public_without_raw_transport_signature() {
242        let type_name = std::any::type_name::<crate::ConnectedBulkWriter<'static>>();
243
244        assert!(type_name.contains("ConnectedBulkWriter"));
245        assert!(!type_name.contains("tiberius::Client"));
246        assert!(!type_name.contains("tokio::net::TcpStream"));
247    }
248
249    #[tokio::test]
250    async fn invalid_connection_string_error_is_redacted() -> crate::Result<()> {
251        let connection_string =
252            "Server=tcp:localhost,notaport;Password=secret-token-123;Access Token=token-456";
253        let result = connect_mssql_client_from_ado_string(connection_string).await;
254        let Err(error) = result else {
255            return Err(Error::InvalidConnectionString);
256        };
257
258        assert!(matches!(error, Error::InvalidConnectionString));
259        let display = error.to_string();
260        let debug = format!("{error:?}");
261
262        for secret in ["secret-token-123", "token-456", connection_string] {
263            assert!(!display.contains(secret));
264            assert!(!debug.contains(secret));
265        }
266
267        Ok(())
268    }
269
270    #[tokio::test]
271    async fn tcp_connect_error_is_structured_and_redacted() -> crate::Result<()> {
272        let connection_string =
273            "Server=tcp:127.0.0.1,1;User Id=sa;Password=secret-token-123;Encrypt=false";
274        let result = connect_mssql_client_from_ado_string(connection_string).await;
275        let Err(error) = result else {
276            return Err(Error::InvalidConnectionString);
277        };
278
279        assert!(matches!(error, Error::ConnectionTcpConnect { .. }));
280        let display = error.to_string();
281        let debug = format!("{error:?}");
282
283        for secret in ["secret-token-123", connection_string] {
284            assert!(!display.contains(secret));
285            assert!(!debug.contains(secret));
286        }
287
288        Ok(())
289    }
290}