Skip to main content

rivven_rdbc/
sqlserver.rs

1//! SQL Server backend implementation for rivven-rdbc
2//!
3//! Provides Microsoft SQL Server-specific implementations:
4//! - Connection and prepared statements
5//! - Transaction support with savepoints
6//! - Streaming row iteration
7//! - Connection pooling
8//! - Schema provider for introspection
9
10use async_trait::async_trait;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::Arc;
15use std::time::Instant;
16use tiberius::{AuthMethod, Client, Config};
17use tokio::net::TcpStream;
18use tokio::sync::Mutex;
19use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
20
21use crate::connection::{
22    Connection, ConnectionConfig, ConnectionFactory, ConnectionLifecycle, DatabaseType,
23    IsolationLevel, PreparedStatement, RowStream, Transaction,
24};
25use crate::error::{Error, Result};
26use crate::security::validate_sql_identifier;
27use crate::types::{Row, Value};
28
29/// Streaming row iterator backed by a Vec.
30///
31/// tiberius processes results in result-set granularity; this adapter
32/// yields rows one-by-one from a materialized result. Preserves the
33/// `RowStream` contract for callers while keeping memory behaviour
34/// explicit and documented.
35struct VecRowStream {
36    rows: std::vec::IntoIter<Row>,
37}
38
39impl VecRowStream {
40    fn new(rows: Vec<Row>) -> Self {
41        Self {
42            rows: rows.into_iter(),
43        }
44    }
45}
46
47impl RowStream for VecRowStream {
48    fn next(
49        &mut self,
50    ) -> Pin<Box<dyn Future<Output = crate::error::Result<Option<Row>>> + Send + '_>> {
51        Box::pin(async move { Ok(self.rows.next()) })
52    }
53}
54
55/// SQL Server connection
56pub struct SqlServerConnection {
57    client: Arc<Mutex<Client<Compat<TcpStream>>>>,
58    in_transaction: Arc<AtomicBool>,
59    created_at: Instant,
60    last_used: Mutex<Instant>,
61}
62
63impl SqlServerConnection {
64    /// Get the age of this connection (time since creation)
65    pub fn age(&self) -> std::time::Duration {
66        self.created_at.elapsed()
67    }
68
69    /// Check if connection is older than the specified max lifetime
70    pub fn is_expired(&self, max_lifetime: std::time::Duration) -> bool {
71        self.age() > max_lifetime
72    }
73
74    /// Get time since last use
75    pub async fn idle_time(&self) -> std::time::Duration {
76        self.last_used.lock().await.elapsed()
77    }
78
79    /// Create a new SQL Server connection from config
80    pub async fn connect(config: &ConnectionConfig) -> Result<Self> {
81        // Parse URL: sqlserver://user:pass@host:port/database
82        let url = url::Url::parse(&config.url)
83            .map_err(|e| Error::config(format!("Invalid SQL Server URL: {}", e)))?;
84
85        let mut tib_config = Config::new();
86
87        tib_config.host(url.host_str().unwrap_or("localhost"));
88        tib_config.port(url.port().unwrap_or(1433));
89        tib_config.database(url.path().trim_start_matches('/'));
90
91        // Authentication from URL
92        let username = if url.username().is_empty() {
93            "sa"
94        } else {
95            url.username()
96        };
97        let password = url.password().unwrap_or("");
98        tib_config.authentication(AuthMethod::sql_server(username, password));
99
100        // TLS settings from properties
101        if config
102            .properties
103            .get("trust_cert")
104            .map(|s| s == "true")
105            .unwrap_or(false)
106        {
107            tib_config.trust_cert();
108        }
109
110        let tcp = TcpStream::connect(tib_config.get_addr())
111            .await
112            .map_err(|e| Error::connection(format!("Failed to connect: {}", e)))?;
113
114        tcp.set_nodelay(true).ok();
115
116        let client = Client::connect(tib_config, tcp.compat_write())
117            .await
118            .map_err(|e| Error::connection(format!("Failed to authenticate: {}", e)))?;
119
120        let now = Instant::now();
121        Ok(Self {
122            client: Arc::new(Mutex::new(client)),
123            in_transaction: Arc::new(AtomicBool::new(false)),
124            created_at: now,
125            last_used: Mutex::new(now),
126        })
127    }
128
129    /// Create connection from URL
130    pub async fn from_url(url: &str) -> Result<Self> {
131        let config = ConnectionConfig::new(url);
132        Self::connect(&config).await
133    }
134
135    async fn update_last_used(&self) {
136        *self.last_used.lock().await = Instant::now();
137    }
138}
139
140#[async_trait]
141impl ConnectionLifecycle for SqlServerConnection {
142    fn created_at(&self) -> Instant {
143        self.created_at
144    }
145
146    async fn idle_time(&self) -> std::time::Duration {
147        self.last_used.lock().await.elapsed()
148    }
149
150    async fn touch(&self) {
151        self.update_last_used().await;
152    }
153}
154
155/// Owned parameter wrapper for safe, native tiberius parameter binding (C-1 fix).
156///
157/// Converts rivven `Value` to tiberius `ColumnData` for typed TDS protocol binding.
158/// Parameters are **never interpolated into SQL text** — they are sent as typed
159/// protocol-level parameters, making SQL injection impossible regardless of content.
160///
161/// # Security
162///
163/// This replaces the previous `substitute_params()` + `value_to_string()` approach
164/// which performed client-side string interpolation, allowing SQL injection on any
165/// Value that contained SQL metacharacters.
166struct SqlParam(Value);
167
168impl tiberius::ToSql for SqlParam {
169    fn to_sql(&self) -> tiberius::ColumnData<'_> {
170        use std::borrow::Cow;
171        use tiberius::ColumnData;
172        use Value::*;
173
174        match &self.0 {
175            Null => ColumnData::String(None),
176            Bool(b) => ColumnData::Bit(Some(*b)),
177            Int8(n) => ColumnData::I16(Some(*n as i16)), // TDS has no i8
178            Int16(n) => ColumnData::I16(Some(*n)),
179            Int32(n) => ColumnData::I32(Some(*n)),
180            Int64(n) => ColumnData::I64(Some(*n)),
181            Float32(n) => ColumnData::F32(Some(*n)),
182            Float64(n) => ColumnData::F64(Some(*n)),
183            String(s) => ColumnData::String(Some(Cow::Borrowed(s.as_str()))),
184            Bytes(b) => ColumnData::Binary(Some(Cow::Borrowed(b.as_slice()))),
185            Uuid(u) => ColumnData::Guid(Some(*u)),
186            // chrono types → ISO 8601 string representation (SQL Server parses these natively)
187            // Using string parameters is safe: the value is bound as a typed TDS parameter,
188            // not interpolated into SQL text. SQL Server handles the implicit conversion.
189            Date(d) => ColumnData::String(Some(Cow::Owned(d.format("%Y-%m-%d").to_string()))),
190            Time(t) => ColumnData::String(Some(Cow::Owned(t.format("%H:%M:%S%.f").to_string()))),
191            DateTime(dt) => ColumnData::String(Some(Cow::Owned(
192                dt.format("%Y-%m-%dT%H:%M:%S%.f").to_string(),
193            ))),
194            DateTimeTz(dt) => ColumnData::String(Some(Cow::Owned(
195                dt.format("%Y-%m-%dT%H:%M:%S%.f%:z").to_string(),
196            ))),
197            // Types without direct TDS mapping → string representation
198            Decimal(d) => ColumnData::String(Some(Cow::Owned(d.to_string()))),
199            Json(j) => ColumnData::String(Some(Cow::Owned(j.to_string()))),
200            Enum(s) => ColumnData::String(Some(Cow::Borrowed(s.as_str()))),
201            Array(arr) => {
202                let json = serde_json::to_string(arr).unwrap_or_else(|e| {
203                    tracing::warn!(
204                        "Failed to serialize Array to JSON for SQL Server param: {}",
205                        e
206                    );
207                    "[]".to_string()
208                });
209                ColumnData::String(Some(Cow::Owned(json)))
210            }
211            Composite(map) => {
212                let json = serde_json::to_string(map).unwrap_or_else(|e| {
213                    tracing::warn!(
214                        "Failed to serialize Composite to JSON for SQL Server param: {}",
215                        e
216                    );
217                    "{}".to_string()
218                });
219                ColumnData::String(Some(Cow::Owned(json)))
220            }
221            Interval(micros) => ColumnData::I64(Some(*micros)),
222            Bit(bits) => ColumnData::Binary(Some(Cow::Borrowed(bits.as_slice()))),
223            Geometry(wkb) | Geography(wkb) => {
224                ColumnData::Binary(Some(Cow::Borrowed(wkb.as_slice())))
225            }
226            Range { .. } => ColumnData::String(None), // ranges not natively supported in SQL Server
227            Custom { data, .. } => ColumnData::Binary(Some(Cow::Borrowed(data.as_slice()))),
228        }
229    }
230}
231
232/// Build a slice of tiberius parameter references from owned SqlParams.
233///
234/// Returns the param_refs vector. Caller must keep `tib_params` alive
235/// for the duration of the query call since `param_refs` borrows from them.
236#[inline]
237fn param_refs(tib_params: &[SqlParam]) -> Vec<&dyn tiberius::ToSql> {
238    tib_params
239        .iter()
240        .map(|p| p as &dyn tiberius::ToSql)
241        .collect()
242}
243
244/// Convert tiberius column value to rivven Value
245fn tiberius_to_value(_col: &tiberius::Column, row: &tiberius::Row, idx: usize) -> Value {
246    // probe typed columns before raw bytes to avoid BIT returning as Bytes.
247    // Order: bool first, then numeric, then string, then bytes (catch-all).
248    if let Ok(Some(v)) = row.try_get::<bool, _>(idx) {
249        return Value::Bool(v);
250    }
251    if let Ok(Some(v)) = row.try_get::<i16, _>(idx) {
252        return Value::Int16(v);
253    }
254    if let Ok(Some(v)) = row.try_get::<i32, _>(idx) {
255        return Value::Int32(v);
256    }
257    if let Ok(Some(v)) = row.try_get::<i64, _>(idx) {
258        return Value::Int64(v);
259    }
260    if let Ok(Some(v)) = row.try_get::<f32, _>(idx) {
261        return Value::Float32(v);
262    }
263    if let Ok(Some(v)) = row.try_get::<f64, _>(idx) {
264        return Value::Float64(v);
265    }
266    if let Ok(Some(v)) = row.try_get::<&str, _>(idx) {
267        return Value::String(v.to_string());
268    }
269    if let Ok(Some(v)) = row.try_get::<uuid::Uuid, _>(idx) {
270        return Value::Uuid(v);
271    }
272    if let Ok(Some(bytes)) = row.try_get::<&[u8], _>(idx) {
273        return Value::Bytes(bytes.to_vec());
274    }
275
276    Value::Null
277}
278
279/// Convert tiberius Row to rivven Row
280fn tiberius_row_to_row(tib_row: &tiberius::Row) -> Row {
281    let columns: Vec<String> = tib_row
282        .columns()
283        .iter()
284        .map(|c| c.name().to_string())
285        .collect();
286
287    let values: Vec<Value> = tib_row
288        .columns()
289        .iter()
290        .enumerate()
291        .map(|(i, col)| tiberius_to_value(col, tib_row, i))
292        .collect();
293
294    Row::new(columns, values)
295}
296
297#[async_trait]
298impl Connection for SqlServerConnection {
299    async fn execute(&self, query: &str, params: &[Value]) -> Result<u64> {
300        self.update_last_used().await;
301
302        let tib_params: Vec<SqlParam> = params.iter().cloned().map(SqlParam).collect();
303        let refs = param_refs(&tib_params);
304        let mut client = self.client.lock().await;
305
306        let result = client
307            .execute(query, &refs)
308            .await
309            .map_err(|e| Error::execution(format!("Execute failed: {}", e)))?;
310
311        Ok(result.total() as u64)
312    }
313
314    async fn query(&self, query: &str, params: &[Value]) -> Result<Vec<Row>> {
315        self.update_last_used().await;
316
317        let tib_params: Vec<SqlParam> = params.iter().cloned().map(SqlParam).collect();
318        let refs = param_refs(&tib_params);
319        let mut client = self.client.lock().await;
320
321        let stream = client
322            .query(query, &refs)
323            .await
324            .map_err(|e| Error::execution(format!("Query failed: {}", e)))?;
325
326        let tib_rows = stream
327            .into_first_result()
328            .await
329            .map_err(|e| Error::execution(format!("Failed to fetch rows: {}", e)))?;
330
331        Ok(tib_rows.iter().map(tiberius_row_to_row).collect())
332    }
333
334    async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement>> {
335        // SQL Server uses parameterized queries (sp_executesql), not server-side
336        // prepared statements. The connection reference enables execute/query.
337        Ok(Box::new(SqlServerPreparedStatement {
338            sql: sql.to_string(),
339            client: Arc::clone(&self.client),
340        }))
341    }
342
343    async fn begin(&self) -> Result<Box<dyn Transaction>> {
344        self.update_last_used().await;
345
346        {
347            let mut client = self.client.lock().await;
348            client
349                .execute("BEGIN TRANSACTION", &[])
350                .await
351                .map_err(|e| Error::transaction(format!("Failed to begin transaction: {}", e)))?;
352        }
353
354        self.in_transaction.store(true, Ordering::SeqCst);
355
356        Ok(Box::new(SqlServerTransaction {
357            client: Arc::clone(&self.client),
358            committed: AtomicBool::new(false),
359            rolled_back: AtomicBool::new(false),
360            in_transaction: Arc::clone(&self.in_transaction),
361        }))
362    }
363
364    /// SQL Server requires SET TRANSACTION ISOLATION LEVEL *before*
365    /// BEGIN TRANSACTION for it to take effect for that transaction.
366    async fn begin_with_isolation(
367        &self,
368        isolation: IsolationLevel,
369    ) -> Result<Box<dyn Transaction>> {
370        self.update_last_used().await;
371
372        let level_sql = match isolation {
373            IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
374            IsolationLevel::ReadCommitted => "READ COMMITTED",
375            IsolationLevel::RepeatableRead => "REPEATABLE READ",
376            IsolationLevel::Serializable => "SERIALIZABLE",
377            IsolationLevel::Snapshot => "SNAPSHOT",
378        };
379
380        {
381            let mut client = self.client.lock().await;
382            client
383                .execute(
384                    format!("SET TRANSACTION ISOLATION LEVEL {}", level_sql),
385                    &[],
386                )
387                .await
388                .map_err(|e| Error::transaction(format!("Failed to set isolation level: {}", e)))?;
389            client
390                .execute("BEGIN TRANSACTION", &[])
391                .await
392                .map_err(|e| Error::transaction(format!("Failed to begin transaction: {}", e)))?;
393        }
394
395        self.in_transaction.store(true, Ordering::SeqCst);
396
397        Ok(Box::new(SqlServerTransaction {
398            client: Arc::clone(&self.client),
399            committed: AtomicBool::new(false),
400            rolled_back: AtomicBool::new(false),
401            in_transaction: Arc::clone(&self.in_transaction),
402        }))
403    }
404
405    async fn query_stream(&self, query: &str, params: &[Value]) -> Result<Pin<Box<dyn RowStream>>> {
406        // tiberius's TDS protocol pipeline processes results in
407        // result-set granularity; this adapter materializes them as a RowStream.
408        // A hard row limit prevents OOM for unbounded queries. For large result
409        // sets, callers should use OFFSET/FETCH pagination or the TableSource
410        // incremental mode.
411        const MAX_STREAM_ROWS: usize = 1_000_000;
412
413        let rows = self.query(query, params).await?;
414        if rows.len() >= MAX_STREAM_ROWS {
415            tracing::warn!(
416                row_count = rows.len(),
417                max = MAX_STREAM_ROWS,
418                "query_stream result set reached MAX_STREAM_ROWS limit; \
419                 consider using OFFSET/FETCH pagination for large queries"
420            );
421        }
422        Ok(Box::pin(VecRowStream::new(rows)))
423    }
424
425    async fn is_valid(&self) -> bool {
426        let mut client = self.client.lock().await;
427        client.execute("SELECT 1", &[]).await.is_ok()
428    }
429
430    async fn close(&self) -> Result<()> {
431        // Connection closes when dropped
432        Ok(())
433    }
434}
435
436/// SQL Server prepared statement backed by a shared connection.
437///
438/// SQL Server uses parameterized queries (sp_executesql) rather than
439/// server-side prepared statements. This wrapper stores the query text
440/// and a reference to the parent connection for execute/query delegation.
441pub struct SqlServerPreparedStatement {
442    sql: String,
443    client: Arc<Mutex<Client<Compat<TcpStream>>>>,
444}
445
446#[async_trait]
447impl PreparedStatement for SqlServerPreparedStatement {
448    async fn execute(&self, params: &[Value]) -> Result<u64> {
449        let tib_params: Vec<SqlParam> = params.iter().cloned().map(SqlParam).collect();
450        let refs = param_refs(&tib_params);
451        let mut client = self.client.lock().await;
452
453        let result = client
454            .execute(&*self.sql, &refs)
455            .await
456            .map_err(|e| Error::execution(format!("Prepared execute failed: {}", e)))?;
457
458        Ok(result.total() as u64)
459    }
460
461    async fn query(&self, params: &[Value]) -> Result<Vec<Row>> {
462        let tib_params: Vec<SqlParam> = params.iter().cloned().map(SqlParam).collect();
463        let refs = param_refs(&tib_params);
464        let mut client = self.client.lock().await;
465
466        let stream = client
467            .query(&*self.sql, &refs)
468            .await
469            .map_err(|e| Error::execution(format!("Prepared query failed: {}", e)))?;
470
471        let tib_rows = stream
472            .into_first_result()
473            .await
474            .map_err(|e| Error::execution(format!("Failed to fetch rows: {}", e)))?;
475
476        Ok(tib_rows.iter().map(tiberius_row_to_row).collect())
477    }
478
479    fn sql(&self) -> &str {
480        &self.sql
481    }
482}
483
484/// SQL Server transaction
485pub struct SqlServerTransaction {
486    client: Arc<Mutex<Client<Compat<TcpStream>>>>,
487    committed: AtomicBool,
488    rolled_back: AtomicBool,
489    in_transaction: Arc<AtomicBool>,
490}
491
492#[async_trait]
493impl Transaction for SqlServerTransaction {
494    async fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>> {
495        let tib_params: Vec<SqlParam> = params.iter().cloned().map(SqlParam).collect();
496        let refs = param_refs(&tib_params);
497        let mut client = self.client.lock().await;
498
499        let stream = client
500            .query(sql, &refs)
501            .await
502            .map_err(|e| Error::execution(format!("Query failed: {}", e)))?;
503
504        let tib_rows = stream
505            .into_first_result()
506            .await
507            .map_err(|e| Error::execution(format!("Failed to fetch rows: {}", e)))?;
508
509        Ok(tib_rows.iter().map(tiberius_row_to_row).collect())
510    }
511
512    async fn execute(&self, sql: &str, params: &[Value]) -> Result<u64> {
513        let tib_params: Vec<SqlParam> = params.iter().cloned().map(SqlParam).collect();
514        let refs = param_refs(&tib_params);
515        let mut client = self.client.lock().await;
516
517        let result = client
518            .execute(sql, &refs)
519            .await
520            .map_err(|e| Error::execution(format!("Execute failed: {}", e)))?;
521
522        Ok(result.total() as u64)
523    }
524
525    async fn commit(self: Box<Self>) -> Result<()> {
526        if self.rolled_back.load(Ordering::SeqCst) {
527            return Err(Error::transaction("Transaction already rolled back"));
528        }
529        if self.committed.load(Ordering::SeqCst) {
530            return Err(Error::transaction("Transaction already committed"));
531        }
532
533        let mut client = self.client.lock().await;
534        client
535            .execute("COMMIT TRANSACTION", &[])
536            .await
537            .map_err(|e| Error::transaction(format!("Failed to commit: {}", e)))?;
538
539        self.committed.store(true, Ordering::SeqCst);
540        self.in_transaction.store(false, Ordering::SeqCst);
541        Ok(())
542    }
543
544    async fn rollback(self: Box<Self>) -> Result<()> {
545        if self.committed.load(Ordering::SeqCst) {
546            return Err(Error::transaction("Transaction already committed"));
547        }
548        if self.rolled_back.load(Ordering::SeqCst) {
549            return Ok(()); // Idempotent rollback
550        }
551
552        let mut client = self.client.lock().await;
553        client
554            .execute("ROLLBACK TRANSACTION", &[])
555            .await
556            .map_err(|e| Error::transaction(format!("Failed to rollback: {}", e)))?;
557
558        self.rolled_back.store(true, Ordering::SeqCst);
559        self.in_transaction.store(false, Ordering::SeqCst);
560        Ok(())
561    }
562
563    async fn set_isolation_level(&self, level: IsolationLevel) -> Result<()> {
564        let level_sql = match level {
565            IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
566            IsolationLevel::ReadCommitted => "READ COMMITTED",
567            IsolationLevel::RepeatableRead => "REPEATABLE READ",
568            IsolationLevel::Serializable => "SERIALIZABLE",
569            IsolationLevel::Snapshot => "SNAPSHOT",
570        };
571
572        let mut client = self.client.lock().await;
573        client
574            .execute(
575                format!("SET TRANSACTION ISOLATION LEVEL {}", level_sql),
576                &[],
577            )
578            .await
579            .map_err(|e| Error::transaction(format!("Failed to set isolation level: {}", e)))?;
580
581        Ok(())
582    }
583
584    async fn savepoint(&self, name: &str) -> Result<()> {
585        validate_sql_identifier(name)?;
586        let mut client = self.client.lock().await;
587        client
588            .execute(format!("SAVE TRANSACTION {}", name), &[])
589            .await
590            .map_err(|e| Error::transaction(format!("Failed to create savepoint: {}", e)))?;
591        Ok(())
592    }
593
594    async fn rollback_to_savepoint(&self, name: &str) -> Result<()> {
595        validate_sql_identifier(name)?;
596        let mut client = self.client.lock().await;
597        client
598            .execute(format!("ROLLBACK TRANSACTION {}", name), &[])
599            .await
600            .map_err(|e| Error::transaction(format!("Failed to rollback to savepoint: {}", e)))?;
601        Ok(())
602    }
603
604    async fn release_savepoint(&self, _name: &str) -> Result<()> {
605        // SQL Server doesn't have RELEASE SAVEPOINT - savepoints are released on commit
606        Ok(())
607    }
608}
609
610// Drop impl to prevent abandoned transactions. If the transaction was neither
611// committed nor rolled back, issue a best-effort ROLLBACK and reset the
612// parent connection's in_transaction flag.
613impl Drop for SqlServerTransaction {
614    fn drop(&mut self) {
615        if !self.committed.load(Ordering::SeqCst) && !self.rolled_back.load(Ordering::SeqCst) {
616            let client = self.client.clone();
617            let in_tx = self.in_transaction.clone();
618            // Fire-and-forget rollback: avoids block_in_place which panics on
619            // current_thread (single-threaded) runtimes.
620            if let Ok(handle) = tokio::runtime::Handle::try_current() {
621                handle.spawn(async move {
622                    let mut guard = client.lock().await;
623                    if let Err(e) = guard.execute("ROLLBACK TRANSACTION", &[]).await {
624                        tracing::warn!("Auto-rollback on SqlServerTransaction drop failed: {}", e);
625                    } else {
626                        tracing::debug!("SqlServerTransaction auto-rolled back on drop");
627                    }
628                    in_tx.store(false, Ordering::SeqCst);
629                });
630            } else {
631                tracing::warn!(
632                    "SqlServerTransaction dropped outside of a Tokio runtime; \
633                     rollback skipped — connection may be left in a dirty state"
634                );
635            }
636        }
637    }
638}
639
640/// SQL Server connection factory
641pub struct SqlServerConnectionFactory {
642    #[allow(dead_code)] // Stored for future use; connect() uses passed config
643    config: ConnectionConfig,
644}
645
646impl SqlServerConnectionFactory {
647    /// Create a new SQL Server connection factory
648    pub fn new(config: ConnectionConfig) -> Self {
649        Self { config }
650    }
651}
652
653#[async_trait]
654impl ConnectionFactory for SqlServerConnectionFactory {
655    async fn connect(&self, config: &ConnectionConfig) -> Result<Box<dyn Connection>> {
656        // Use passed config (consistent with ConnectionFactory contract)
657        let conn = SqlServerConnection::connect(config).await?;
658        Ok(Box::new(conn))
659    }
660
661    fn database_type(&self) -> DatabaseType {
662        DatabaseType::SqlServer
663    }
664}
665
666#[cfg(test)]
667mod tests {
668    use super::*;
669    use tiberius::ToSql;
670
671    #[test]
672    fn test_sql_param_null() {
673        let p = SqlParam(Value::Null);
674        let cd = p.to_sql();
675        // Null string → None variant
676        assert!(matches!(cd, tiberius::ColumnData::String(None)));
677    }
678
679    #[test]
680    fn test_sql_param_bool() {
681        let p = SqlParam(Value::Bool(true));
682        let cd = p.to_sql();
683        assert!(matches!(cd, tiberius::ColumnData::Bit(Some(true))));
684    }
685
686    #[test]
687    fn test_sql_param_integers() {
688        assert!(matches!(
689            SqlParam(Value::Int8(42)).to_sql(),
690            tiberius::ColumnData::I16(Some(42))
691        ));
692        assert!(matches!(
693            SqlParam(Value::Int16(1000)).to_sql(),
694            tiberius::ColumnData::I16(Some(1000))
695        ));
696        assert!(matches!(
697            SqlParam(Value::Int32(100_000)).to_sql(),
698            tiberius::ColumnData::I32(Some(100_000))
699        ));
700        assert!(matches!(
701            SqlParam(Value::Int64(1_000_000_000)).to_sql(),
702            tiberius::ColumnData::I64(Some(1_000_000_000))
703        ));
704    }
705
706    #[test]
707    fn test_sql_param_string() {
708        let p = SqlParam(Value::String("hello".into()));
709        if let tiberius::ColumnData::String(Some(cow)) = p.to_sql() {
710            assert_eq!(&*cow, "hello");
711        } else {
712            panic!("Expected String ColumnData");
713        }
714    }
715
716    #[test]
717    fn test_sql_param_string_with_injection_chars() {
718        // SQL metacharacters are harmless because the value is bound as a typed
719        // parameter — never interpolated into SQL text
720        let p = SqlParam(Value::String("x'; DROP TABLE users--".into()));
721        if let tiberius::ColumnData::String(Some(cow)) = p.to_sql() {
722            assert_eq!(&*cow, "x'; DROP TABLE users--");
723        } else {
724            panic!("Expected String ColumnData");
725        }
726    }
727
728    #[test]
729    fn test_sql_param_bytes() {
730        let p = SqlParam(Value::Bytes(vec![0xDE, 0xAD]));
731        if let tiberius::ColumnData::Binary(Some(cow)) = p.to_sql() {
732            assert_eq!(&*cow, &[0xDE, 0xAD]);
733        } else {
734            panic!("Expected Binary ColumnData");
735        }
736    }
737
738    #[test]
739    fn test_sql_param_uuid() {
740        let uuid = uuid::Uuid::new_v4();
741        let p = SqlParam(Value::Uuid(uuid));
742        assert!(matches!(p.to_sql(), tiberius::ColumnData::Guid(Some(_))));
743    }
744
745    #[test]
746    fn test_sql_param_chrono_types() {
747        use chrono::{NaiveDate, NaiveDateTime, NaiveTime, Utc};
748
749        let d = NaiveDate::from_ymd_opt(2025, 1, 15).unwrap();
750        let _cd = SqlParam(Value::Date(d)).to_sql();
751
752        let t = NaiveTime::from_hms_opt(12, 30, 45).unwrap();
753        let _cd = SqlParam(Value::Time(t)).to_sql();
754
755        let dt = NaiveDateTime::new(d, t);
756        let _cd = SqlParam(Value::DateTime(dt)).to_sql();
757
758        let dtz = Utc::now();
759        let _cd = SqlParam(Value::DateTimeTz(dtz)).to_sql();
760    }
761
762    #[test]
763    fn test_sql_param_json() {
764        let j = serde_json::json!({"key": "value"});
765        let p = SqlParam(Value::Json(j));
766        if let tiberius::ColumnData::String(Some(cow)) = p.to_sql() {
767            assert!(cow.contains("key"));
768        } else {
769            panic!("Expected String ColumnData for JSON");
770        }
771    }
772
773    #[test]
774    fn test_savepoint_name_validation() {
775        // Valid names succeed
776        assert!(validate_sql_identifier("sp1").is_ok());
777        assert!(validate_sql_identifier("my_savepoint").is_ok());
778
779        // Injection attempts are rejected
780        assert!(validate_sql_identifier("x; DROP TABLE users--").is_err());
781        assert!(validate_sql_identifier("").is_err());
782        assert!(validate_sql_identifier("x' OR '1'='1").is_err());
783    }
784
785    #[test]
786    fn test_connection_config() {
787        use crate::connection::ConnectionConfig;
788        let config = ConnectionConfig::new("sqlserver://user:pass@localhost:1433/mydb");
789        assert_eq!(config.url, "sqlserver://user:pass@localhost:1433/mydb");
790    }
791}