1use crate::connection::{
2 AsyncConnection, BulkInsert, ConnectOptions, ExecutionSummary, ForeignKey, QueryResult,
3 SchemaInfo, StatementResult,
4};
5use crate::error::SqlError;
6use crate::stream::BoxRowStream;
7use crate::url::DatabaseUrl;
8use crate::value::{ColumnInfo, Row, TypeHint, Value};
9use async_trait::async_trait;
10use chrono::{DateTime as ChronoDateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Utc};
11use secrecy::ExposeSecret;
12use tiberius::{
13 Client, ColumnData, ColumnType, EncryptionLevel, IntoSql, TokenRow, numeric::Numeric,
14};
15use tokio::net::TcpStream;
16use tokio_util::compat::TokioAsyncWriteCompatExt;
17
18pub struct MssqlConnection {
19 client: Client<tokio_util::compat::Compat<TcpStream>>,
20}
21
22#[async_trait]
23impl AsyncConnection for MssqlConnection {
24 async fn execute(&mut self, sql: &str) -> Result<ExecutionSummary, SqlError> {
25 let result = self
26 .client
27 .execute(sql, &[])
28 .await
29 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
30 let affected = result.rows_affected().first().copied();
31 Ok(ExecutionSummary {
32 rows_affected: affected,
33 command_tag: None,
34 })
35 }
36
37 async fn query(&mut self, sql: &str) -> Result<QueryResult, SqlError> {
38 let rows = self
39 .client
40 .query(sql, &[])
41 .await
42 .map_err(|e| SqlError::QueryFailed(e.to_string()))?
43 .into_first_result()
44 .await
45 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
46
47 if rows.is_empty() {
48 return Ok(QueryResult {
49 columns: Vec::new(),
50 rows: Vec::new(),
51 });
52 }
53
54 let columns: Vec<ColumnInfo> = rows[0]
55 .columns()
56 .iter()
57 .map(|c| ColumnInfo {
58 name: c.name().to_string(),
59 type_hint: mssql_type_to_hint(c.column_type()),
60 nullable: true,
61 })
62 .collect();
63
64 let data_rows: Vec<Row> = rows
65 .into_iter()
66 .map(|row| {
67 row.columns()
68 .iter()
69 .enumerate()
70 .map(|(i, col)| mssql_to_value(&row, i, col.column_type()))
71 .collect()
72 })
73 .collect();
74
75 Ok(QueryResult {
76 columns,
77 rows: data_rows,
78 })
79 }
80
81 async fn query_stream(
93 &mut self,
94 sql: &str,
95 ) -> Result<(Vec<ColumnInfo>, BoxRowStream<'_>), SqlError> {
96 use futures_util::stream::{StreamExt, TryStreamExt};
97 use tiberius::QueryItem;
98
99 let mut query_stream = self
100 .client
101 .query(sql, &[])
102 .await
103 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
104
105 let (columns, col_types) = match query_stream.try_next().await {
108 Ok(Some(QueryItem::Metadata(meta))) => {
109 let columns: Vec<ColumnInfo> = meta
110 .columns()
111 .iter()
112 .map(|c| ColumnInfo {
113 name: c.name().to_string(),
114 type_hint: mssql_type_to_hint(c.column_type()),
115 nullable: true,
116 })
117 .collect();
118 let col_types: Vec<ColumnType> =
119 meta.columns().iter().map(|c| c.column_type()).collect();
120 (columns, col_types)
121 }
122 Ok(Some(QueryItem::Row(_))) | Ok(None) => (Vec::new(), Vec::new()),
125 Err(e) => return Err(SqlError::QueryFailed(e.to_string())),
126 };
127
128 let stream = futures_util::stream::try_unfold(
129 (query_stream, col_types),
130 |(mut query_stream, col_types)| async move {
131 match query_stream.try_next().await {
132 Ok(Some(QueryItem::Row(row))) => {
133 let values: Row = col_types
134 .iter()
135 .enumerate()
136 .map(|(i, col_type)| mssql_to_value(&row, i, *col_type))
137 .collect();
138 Ok(Some((values, (query_stream, col_types))))
139 }
140 Ok(Some(QueryItem::Metadata(_))) | Ok(None) => Ok(None),
143 Err(e) => Err(SqlError::QueryFailed(e.to_string())),
144 }
145 },
146 )
147 .boxed();
148 Ok((columns, stream))
149 }
150
151 async fn execute_multi(&mut self, sql: &str) -> Result<Vec<StatementResult>, SqlError> {
152 let result_sets = self
153 .client
154 .query(sql, &[])
155 .await
156 .map_err(|e| SqlError::QueryFailed(e.to_string()))?
157 .into_results()
158 .await
159 .map_err(|e| SqlError::QueryFailed(e.to_string()))?;
160
161 let mut results = Vec::new();
162 for rows in result_sets {
163 if rows.is_empty() {
164 results.push(StatementResult::Query(QueryResult {
165 columns: Vec::new(),
166 rows: Vec::new(),
167 }));
168 continue;
169 }
170 let columns: Vec<ColumnInfo> = rows[0]
171 .columns()
172 .iter()
173 .map(|c| ColumnInfo {
174 name: c.name().to_string(),
175 type_hint: mssql_type_to_hint(c.column_type()),
176 nullable: true,
177 })
178 .collect();
179
180 let data_rows: Vec<Row> = rows
181 .into_iter()
182 .map(|row| {
183 row.columns()
184 .iter()
185 .enumerate()
186 .map(|(i, col)| mssql_to_value(&row, i, col.column_type()))
187 .collect()
188 })
189 .collect();
190
191 results.push(StatementResult::Query(QueryResult {
192 columns,
193 rows: data_rows,
194 }));
195 }
196
197 if results.is_empty() {
198 let summary = self.execute(sql).await?;
199 results.push(StatementResult::Summary(summary));
200 }
201
202 Ok(results)
203 }
204
205 async fn ping(&mut self) -> Result<(), SqlError> {
206 self.client
207 .query("SELECT 1", &[])
208 .await
209 .map_err(|e| SqlError::ConnectionFailed(e.to_string()))?
210 .into_first_result()
211 .await
212 .map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
213 Ok(())
214 }
215
216 async fn list_tables(&mut self, schema: Option<&str>) -> Result<Vec<String>, SqlError> {
217 let schema = schema.unwrap_or("dbo");
218 let sql = format!(
219 "SELECT TABLE_NAME AS table_name FROM information_schema.tables WHERE table_schema = '{}' AND table_type = 'BASE TABLE' ORDER BY table_name",
220 escape_mssql_string(schema)
221 );
222 let result = self.query(&sql).await?;
223 let names: Vec<String> = result
224 .rows
225 .into_iter()
226 .filter_map(|row| {
227 row.into_iter().next().and_then(|v| match v {
228 Value::String(s) => Some(s),
229 _ => None,
230 })
231 })
232 .collect();
233 Ok(names)
234 }
235
236 async fn list_schemas(&mut self) -> Result<Vec<SchemaInfo>, SqlError> {
237 let sql = "SELECT name, CASE WHEN name = SCHEMA_NAME() THEN 1 ELSE 0 END FROM sys.schemas ORDER BY name";
243 let result = self.query(sql).await?;
244 let schemas: Vec<SchemaInfo> = result
245 .rows
246 .into_iter()
247 .filter_map(|row| {
248 let name = match row.first() {
249 Some(Value::String(s)) => s.clone(),
250 _ => return None,
251 };
252 let is_default = crate::connection::is_default_from_value(row.get(1));
253 Some(SchemaInfo { name, is_default })
254 })
255 .collect();
256 Ok(schemas)
257 }
258
259 async fn describe_table(
260 &mut self,
261 schema: Option<&str>,
262 table: &str,
263 ) -> Result<QueryResult, SqlError> {
264 let schema = schema.unwrap_or("dbo");
265 let sql = format!(
266 "SELECT COLUMN_NAME AS column_name, DATA_TYPE AS data_type, IS_NULLABLE AS is_nullable, COLUMN_DEFAULT AS column_default, NUMERIC_PRECISION AS numeric_precision, NUMERIC_SCALE AS numeric_scale FROM information_schema.columns WHERE table_schema = '{}' AND table_name = '{}' ORDER BY ORDINAL_POSITION",
267 escape_mssql_string(schema),
268 escape_mssql_string(table)
269 );
270 self.query(&sql).await
271 }
272
273 async fn primary_key(
274 &mut self,
275 schema: Option<&str>,
276 table: &str,
277 ) -> Result<Vec<String>, SqlError> {
278 let schema = schema.unwrap_or("dbo");
279 let sql = format!(
282 "SELECT k.COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE k \
283 JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS c \
284 ON c.CONSTRAINT_NAME = k.CONSTRAINT_NAME \
285 AND c.TABLE_SCHEMA = k.TABLE_SCHEMA \
286 AND c.TABLE_NAME = k.TABLE_NAME \
287 WHERE c.CONSTRAINT_TYPE = 'PRIMARY KEY' \
288 AND k.TABLE_SCHEMA = '{}' AND k.TABLE_NAME = '{}' \
289 ORDER BY k.ORDINAL_POSITION",
290 escape_mssql_string(schema),
291 escape_mssql_string(table)
292 );
293 let result = self.query(&sql).await?;
294 Ok(result
295 .rows
296 .into_iter()
297 .filter_map(|row| {
298 row.into_iter().next().and_then(|v| match v {
299 Value::String(s) => Some(s),
300 _ => None,
301 })
302 })
303 .collect())
304 }
305
306 async fn list_foreign_keys(
307 &mut self,
308 schema: Option<&str>,
309 ) -> Result<Vec<ForeignKey>, SqlError> {
310 let schema = schema.unwrap_or("dbo");
311 let sql = format!(
314 "SELECT fk.name, \
315 OBJECT_NAME(fkc.parent_object_id) AS child_table, \
316 COL_NAME(fkc.parent_object_id, fkc.parent_column_id) AS child_col, \
317 OBJECT_NAME(fkc.referenced_object_id) AS parent_table, \
318 COL_NAME(fkc.referenced_object_id, fkc.referenced_column_id) AS parent_col, \
319 fk.delete_referential_action_desc, \
320 fkc.constraint_column_id \
321 FROM sys.foreign_keys fk \
322 JOIN sys.foreign_key_columns fkc ON fkc.constraint_object_id = fk.object_id \
323 WHERE SCHEMA_NAME(fk.schema_id) = '{}' \
324 ORDER BY fk.name, fkc.constraint_column_id",
325 escape_mssql_string(schema)
326 );
327 let result = self.query(&sql).await?;
328 let mut map: indexmap::IndexMap<String, ForeignKey> = indexmap::IndexMap::new();
329 for row in result.rows {
330 let mut cols = row.into_iter();
331 let conname = match cols.next() {
332 Some(Value::String(s)) => s,
333 _ => continue,
334 };
335 let child_table = match cols.next() {
336 Some(Value::String(s)) => s,
337 _ => continue,
338 };
339 let child_col = match cols.next() {
340 Some(Value::String(s)) => s,
341 _ => continue,
342 };
343 let parent_table = match cols.next() {
344 Some(Value::String(s)) => s,
345 _ => continue,
346 };
347 let parent_col = match cols.next() {
348 Some(Value::String(s)) => s,
349 _ => continue,
350 };
351 let on_delete = match cols.next() {
352 Some(Value::String(s)) if !s.is_empty() && s != "NO_ACTION" => {
353 Some(s.replace('_', " "))
354 }
355 _ => None,
356 };
357 let entry = map.entry(conname).or_insert_with(|| ForeignKey {
358 child_table: child_table.clone(),
359 child_columns: Vec::new(),
360 parent_table: parent_table.clone(),
361 parent_columns: Vec::new(),
362 on_delete,
363 });
364 entry.child_columns.push(child_col);
365 entry.parent_columns.push(parent_col);
366 }
367 Ok(map.into_values().collect())
368 }
369
370 async fn bulk_insert_rows(&mut self, target: BulkInsert<'_>) -> Result<usize, SqlError> {
371 if target.rows.is_empty() {
372 return Ok(0);
373 }
374
375 let qtable = crate::copy::quote_identifier(target.table, crate::backend::Backend::MsSql);
382
383 let dest_cols = self.fetch_bulk_updatable_columns(target.table).await?;
391 verify_bulk_column_alignment(&dest_cols, target.columns)?;
392
393 let mut req = self
394 .client
395 .bulk_insert(qtable.as_str())
396 .await
397 .map_err(|e| classify_bulk_setup_error(&e))?;
398
399 let hints: Vec<TypeHint> = target.columns.iter().map(|c| c.type_hint).collect();
400 for row in target.rows {
401 let mut token_row = TokenRow::<'static>::with_capacity(target.columns.len());
402 for (idx, v) in row.iter().enumerate() {
403 let hint = hints.get(idx).copied().unwrap_or(TypeHint::Other);
404 token_row.push(value_to_column_data(v, hint)?);
405 }
406 req.send(token_row)
407 .await
408 .map_err(|e| SqlError::QueryFailed(format!("MSSQL bulk send: {e}")))?;
409 }
410
411 let res = req
412 .finalize()
413 .await
414 .map_err(|e| SqlError::QueryFailed(format!("MSSQL bulk finalize: {e}")))?;
415 Ok(res.total() as usize)
416 }
417}
418
419impl MssqlConnection {
420 async fn fetch_bulk_updatable_columns(&mut self, table: &str) -> Result<Vec<String>, SqlError> {
438 let qualified = parse_mssql_qualified_identifier(table);
444 let schema_filter = match &qualified.schema {
445 Some(schema) => format!(" AND c.TABLE_SCHEMA = '{}'", escape_mssql_string(schema)),
446 None => String::new(),
447 };
448 let table_name = qualified.name;
449 let sql = format!(
450 "SELECT c.COLUMN_NAME, \
451 COLUMNPROPERTY(OBJECT_ID(QUOTENAME(c.TABLE_SCHEMA) + '.' + QUOTENAME(c.TABLE_NAME)), c.COLUMN_NAME, 'IsIdentity') AS is_identity, \
452 COLUMNPROPERTY(OBJECT_ID(QUOTENAME(c.TABLE_SCHEMA) + '.' + QUOTENAME(c.TABLE_NAME)), c.COLUMN_NAME, 'IsComputed') AS is_computed, \
453 c.DATA_TYPE \
454 FROM INFORMATION_SCHEMA.COLUMNS c \
455 WHERE c.TABLE_NAME = '{}'{} \
456 ORDER BY c.ORDINAL_POSITION",
457 escape_mssql_string(&table_name),
458 schema_filter,
459 );
460 let result = self.query(&sql).await.map_err(|e| {
461 SqlError::BulkUnavailable(format!(
462 "MSSQL bulk pre-flight: could not introspect destination columns: {e}"
463 ))
464 })?;
465 let mut cols = Vec::with_capacity(result.rows.len());
466 for row in &result.rows {
467 let is_identity = column_flag_bool(&row[1]);
470 let is_computed = column_flag_bool(&row[2]);
471 let is_rowversion =
472 matches!(&row[3], Value::String(s) if s.eq_ignore_ascii_case("timestamp"));
473 if is_identity || is_computed || is_rowversion {
474 continue;
475 }
476 if let Value::String(name) = &row[0] {
477 cols.push(name.clone());
478 }
479 }
480 Ok(cols)
481 }
482}
483
484fn column_flag_bool(v: &Value) -> bool {
491 match v {
492 Value::Bool(b) => *b,
493 Value::Int64(n) => *n != 0,
494 _ => false,
495 }
496}
497
498#[derive(Debug, Clone, PartialEq, Eq)]
504struct QualifiedIdentifier {
505 schema: Option<String>,
506 name: String,
507}
508
509fn parse_mssql_qualified_identifier(input: &str) -> QualifiedIdentifier {
526 let trimmed = input.trim();
527 let (first, rest) = parse_one_identifier(trimmed);
528 match rest {
529 Some(after_dot) => {
530 let (second, _) = parse_one_identifier(after_dot);
532 QualifiedIdentifier {
533 schema: Some(first),
534 name: second,
535 }
536 }
537 None => QualifiedIdentifier {
538 schema: None,
539 name: first,
540 },
541 }
542}
543
544fn parse_one_identifier(s: &str) -> (String, Option<&str>) {
551 if let Some(after_open) = s.strip_prefix('[') {
552 let bytes = after_open.as_bytes();
557 let mut i = 0;
558 let mut close = None;
559 while i < bytes.len() {
560 if bytes[i] == b']' {
561 if i + 1 < bytes.len() && bytes[i + 1] == b']' {
562 i += 2;
563 continue;
564 }
565 close = Some(i);
566 break;
567 }
568 i += 1;
569 }
570 match close {
571 Some(end) => {
572 let inner = &after_open[..end];
573 let unquoted = inner.replace("]]", "]");
574 let after_close = &after_open[end + 1..];
575 let rest = after_close.strip_prefix('.');
576 (unquoted, rest)
577 }
578 None => {
579 (after_open.to_string(), None)
584 }
585 }
586 } else {
587 match s.find('.') {
588 Some(i) => (s[..i].to_string(), Some(&s[i + 1..])),
589 None => (s.to_string(), None),
590 }
591 }
592}
593
594fn verify_bulk_column_alignment(
603 dest_cols: &[String],
604 target_cols: &[ColumnInfo],
605) -> Result<(), SqlError> {
606 if dest_cols.len() != target_cols.len() {
607 return Err(SqlError::BulkUnavailable(format!(
608 "MSSQL bulk path requires destination to have exactly the same \
609 non-IDENTITY columns as the source ({} dest cols vs {} source cols). \
610 The destination may have IDENTITY columns the source doesn't, or \
611 columns the source doesn't write to — generic INSERT can handle \
612 this with a named column list",
613 dest_cols.len(),
614 target_cols.len()
615 )));
616 }
617 for (idx, (dest, src)) in dest_cols.iter().zip(target_cols).enumerate() {
618 if !dest.eq_ignore_ascii_case(&src.name) {
619 return Err(SqlError::BulkUnavailable(format!(
620 "MSSQL bulk path requires destination column order to match source. \
621 Position {idx}: dest = {dest:?}, source = {src_name:?}. \
622 Generic INSERT uses a named column list and works regardless of order",
623 src_name = src.name
624 )));
625 }
626 }
627 Ok(())
628}
629
630fn classify_bulk_setup_error(e: &tiberius::error::Error) -> SqlError {
644 let msg = e.to_string();
645 if msg.contains("Cannot bulk load") || msg.contains("expecting column metadata") {
646 return SqlError::BulkUnavailable(format!("MSSQL rejected bulk_insert setup: {msg}"));
647 }
648 SqlError::QueryFailed(format!("MSSQL bulk_insert setup: {msg}"))
653}
654
655fn value_to_column_data(v: &Value, hint: TypeHint) -> Result<ColumnData<'static>, SqlError> {
662 use std::borrow::Cow;
663
664 Ok(match v {
665 Value::Null => null_for_hint(hint),
666 Value::Bool(b) => ColumnData::Bit(Some(*b)),
667 Value::Int64(n) => ColumnData::I64(Some(*n)),
668 Value::Float64(f) => ColumnData::F64(Some(*f)),
669 Value::Decimal(s) => {
670 let n = parse_decimal_to_numeric(s)
671 .map_err(|e| SqlError::QueryFailed(format!("MSSQL bulk: decimal {s:?}: {e}")))?;
672 ColumnData::Numeric(Some(n))
673 }
674 Value::String(s) => ColumnData::String(Some(Cow::Owned(s.clone()))),
675 Value::Bytes(b) => ColumnData::Binary(Some(Cow::Owned(b.clone()))),
676 Value::Date(d) => (*d).into_sql(),
677 Value::Time(t) => (*t).into_sql(),
678 Value::DateTime(dt) => (*dt).into_sql(),
679 Value::DateTimeTz(dt) => (*dt).into_sql(),
680 Value::Json(j) => {
681 let rendered = serde_json::to_string(j)
684 .map_err(|e| SqlError::QueryFailed(format!("MSSQL bulk: JSON serialize: {e}")))?;
685 ColumnData::String(Some(Cow::Owned(rendered)))
686 }
687 Value::Uuid(s) => {
688 let u = tiberius::Uuid::parse_str(s)
689 .map_err(|e| SqlError::QueryFailed(format!("MSSQL bulk: UUID {s:?}: {e}")))?;
690 ColumnData::Guid(Some(u))
691 }
692 Value::Array(a) => {
693 let rendered = serde_json::to_string(a)
696 .map_err(|e| SqlError::QueryFailed(format!("MSSQL bulk: array serialize: {e}")))?;
697 ColumnData::String(Some(Cow::Owned(rendered)))
698 }
699 })
700}
701
702fn null_for_hint(hint: TypeHint) -> ColumnData<'static> {
707 match hint {
708 TypeHint::Bool => ColumnData::Bit(None),
709 TypeHint::Int64 => ColumnData::I64(None),
710 TypeHint::Float64 => ColumnData::F64(None),
711 TypeHint::Decimal => ColumnData::Numeric(None),
712 TypeHint::Bytes => ColumnData::Binary(None),
713 TypeHint::Date => ColumnData::Date(None),
714 TypeHint::Time => ColumnData::Time(None),
715 TypeHint::DateTime => ColumnData::DateTime2(None),
716 TypeHint::DateTimeTz => ColumnData::DateTimeOffset(None),
717 TypeHint::Uuid => ColumnData::Guid(None),
718 _ => ColumnData::String(None),
721 }
722}
723
724fn parse_decimal_to_numeric(s: &str) -> Result<Numeric, String> {
729 let trimmed = s.trim();
730 if trimmed.is_empty() {
731 return Err("empty string".into());
732 }
733 if trimmed.contains(['e', 'E']) {
734 return Err("scientific notation not supported".into());
735 }
736 let (sign, rest) = match trimmed.as_bytes()[0] {
737 b'-' => (-1i128, &trimmed[1..]),
738 b'+' => (1i128, &trimmed[1..]),
739 _ => (1i128, trimmed),
740 };
741 let (int_part, frac_part) = match rest.split_once('.') {
742 Some((a, b)) => (a, b),
743 None => (rest, ""),
744 };
745 if int_part.is_empty() && frac_part.is_empty() {
746 return Err("no digits".into());
747 }
748 let mut digits = String::with_capacity(int_part.len() + frac_part.len());
749 digits.push_str(int_part);
750 digits.push_str(frac_part);
751 if !digits.chars().all(|c| c.is_ascii_digit()) {
752 return Err(format!("non-digit character in {s:?}"));
753 }
754 let raw: i128 = digits.parse().map_err(|e| format!("parse mantissa: {e}"))?;
755 let scale: u8 = frac_part
756 .len()
757 .try_into()
758 .map_err(|_| "scale exceeds u8".to_string())?;
759 if scale >= 38 {
760 return Err(format!("scale {scale} exceeds MSSQL max 37"));
761 }
762 Ok(Numeric::new_with_scale(sign * raw, scale))
763}
764
765pub(crate) async fn connect(
766 url: &DatabaseUrl,
767 opts: &ConnectOptions,
768) -> Result<MssqlConnection, SqlError> {
769 let mut config = tiberius::Config::new();
770 config.host(url.host().unwrap_or("localhost"));
771 config.port(url.port().unwrap_or(1433));
772
773 if !url.username().is_empty() {
774 let password = opts
776 .effective_password(url)
777 .map(|p| p.expose_secret().to_string())
778 .unwrap_or_default();
779 config.authentication(tiberius::AuthMethod::sql_server(url.username(), password));
780 }
781
782 if !url.database().is_empty() {
783 config.database(url.database());
784 }
785
786 if opts.insecure {
787 config.trust_cert();
788 }
789
790 let params = url.params();
791 if let Some(encrypt) = params.get("encrypt") {
792 match encrypt.as_str() {
793 "false" | "disable" | "off" => config.encryption(EncryptionLevel::Off),
794 "true" | "on" | "require" => config.encryption(EncryptionLevel::Required),
795 _ => {}
796 }
797 }
798 if let Some(trust) = params
799 .get("trust_server_certificate")
800 .or_else(|| params.get("trustServerCertificate"))
801 && (trust == "true" || trust == "yes" || trust == "1")
802 {
803 config.trust_cert();
804 }
805
806 let tcp = tokio::net::TcpStream::connect(config.get_addr())
807 .await
808 .map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
809 tcp.set_nodelay(true)
810 .map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
811
812 let client = tiberius::Client::connect(config, tcp.compat_write())
813 .await
814 .map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
815
816 Ok(MssqlConnection { client })
817}
818
819fn mssql_type_to_hint(col_type: ColumnType) -> TypeHint {
820 match col_type {
821 ColumnType::Bit | ColumnType::Bitn => TypeHint::Bool,
822 ColumnType::Int1
823 | ColumnType::Int2
824 | ColumnType::Int4
825 | ColumnType::Int8
826 | ColumnType::Intn => TypeHint::Int64,
827 ColumnType::Float4 | ColumnType::Float8 | ColumnType::Floatn => TypeHint::Float64,
828 ColumnType::Decimaln | ColumnType::Numericn | ColumnType::Money | ColumnType::Money4 => {
829 TypeHint::Decimal
830 }
831 ColumnType::BigVarChar
832 | ColumnType::BigChar
833 | ColumnType::NVarchar
834 | ColumnType::NChar
835 | ColumnType::Text
836 | ColumnType::NText
837 | ColumnType::Xml => TypeHint::String,
838 ColumnType::BigVarBin | ColumnType::BigBinary | ColumnType::Image => TypeHint::Bytes,
839 ColumnType::Datetime4
840 | ColumnType::Datetime
841 | ColumnType::Datetimen
842 | ColumnType::Datetime2 => TypeHint::DateTime,
843 ColumnType::Daten => TypeHint::Date,
844 ColumnType::Timen => TypeHint::Time,
845 ColumnType::DatetimeOffsetn => TypeHint::DateTimeTz,
846 ColumnType::Guid => TypeHint::Uuid,
847 ColumnType::Udt | ColumnType::SSVariant => TypeHint::Other,
848 ColumnType::Null => TypeHint::Null,
849 }
850}
851
852fn mssql_to_value(row: &tiberius::Row, idx: usize, col_type: ColumnType) -> Value {
853 fn opt<T, E>(r: Result<Option<T>, E>) -> Option<T> {
854 r.ok().flatten()
855 }
856
857 match col_type {
858 ColumnType::Bit | ColumnType::Bitn => opt(row.try_get::<bool, _>(idx))
859 .map(Value::Bool)
860 .unwrap_or(Value::Null),
861 ColumnType::Int1 => opt(row.try_get::<u8, _>(idx))
862 .map(|v| Value::Int64(v as i64))
863 .unwrap_or(Value::Null),
864 ColumnType::Int2 => opt(row.try_get::<i16, _>(idx))
865 .map(|v| Value::Int64(v as i64))
866 .unwrap_or(Value::Null),
867 ColumnType::Int4 => opt(row.try_get::<i32, _>(idx))
868 .map(|v| Value::Int64(v as i64))
869 .unwrap_or(Value::Null),
870 ColumnType::Int8 => opt(row.try_get::<i64, _>(idx))
871 .map(Value::Int64)
872 .unwrap_or(Value::Null),
873 ColumnType::Intn => opt(row.try_get::<i64, _>(idx))
874 .map(Value::Int64)
875 .or_else(|| opt(row.try_get::<i32, _>(idx)).map(|v| Value::Int64(v as i64)))
876 .or_else(|| opt(row.try_get::<i16, _>(idx)).map(|v| Value::Int64(v as i64)))
877 .or_else(|| opt(row.try_get::<u8, _>(idx)).map(|v| Value::Int64(v as i64)))
878 .unwrap_or(Value::Null),
879 ColumnType::Float4 => opt(row.try_get::<f32, _>(idx))
880 .map(|v| Value::Float64(v as f64))
881 .unwrap_or(Value::Null),
882 ColumnType::Float8 => opt(row.try_get::<f64, _>(idx))
883 .map(Value::Float64)
884 .unwrap_or(Value::Null),
885 ColumnType::Floatn => opt(row.try_get::<f64, _>(idx))
886 .map(Value::Float64)
887 .or_else(|| opt(row.try_get::<f32, _>(idx)).map(|v| Value::Float64(v as f64)))
888 .unwrap_or(Value::Null),
889 ColumnType::Money | ColumnType::Money4 => opt(row.try_get::<f64, _>(idx))
890 .map(|v| Value::Decimal(format!("{:.4}", v)))
891 .unwrap_or(Value::Null),
892 ColumnType::Decimaln | ColumnType::Numericn => {
893 opt(row.try_get::<tiberius::numeric::Numeric, _>(idx))
894 .map(|v| Value::Decimal(v.to_string()))
895 .unwrap_or(Value::Null)
896 }
897 ColumnType::BigVarChar
898 | ColumnType::BigChar
899 | ColumnType::NVarchar
900 | ColumnType::NChar
901 | ColumnType::Text
902 | ColumnType::NText => opt(row.try_get::<&str, _>(idx))
903 .map(|v| Value::String(v.to_string()))
904 .unwrap_or(Value::Null),
905 ColumnType::Xml => opt(row.try_get::<&tiberius::xml::XmlData, _>(idx))
906 .map(|v| Value::String(v.to_string()))
907 .unwrap_or(Value::Null),
908 ColumnType::BigVarBin | ColumnType::BigBinary | ColumnType::Image => {
909 opt(row.try_get::<&[u8], _>(idx))
910 .map(|v| Value::Bytes(v.to_vec()))
911 .unwrap_or(Value::Null)
912 }
913 ColumnType::Guid => opt(row.try_get::<tiberius::Uuid, _>(idx))
914 .map(|v| Value::Uuid(v.to_string()))
915 .unwrap_or(Value::Null),
916 ColumnType::Datetime4
917 | ColumnType::Datetime
918 | ColumnType::Datetimen
919 | ColumnType::Datetime2 => opt(row.try_get::<NaiveDateTime, _>(idx))
920 .map(Value::DateTime)
921 .unwrap_or(Value::Null),
922 ColumnType::Daten => opt(row.try_get::<NaiveDate, _>(idx))
923 .map(Value::Date)
924 .unwrap_or(Value::Null),
925 ColumnType::Timen => opt(row.try_get::<NaiveTime, _>(idx))
926 .map(Value::Time)
927 .unwrap_or(Value::Null),
928 ColumnType::DatetimeOffsetn => opt(row.try_get::<ChronoDateTime<FixedOffset>, _>(idx))
929 .map(|v| Value::DateTimeTz(v.with_timezone(&Utc)))
930 .or_else(|| opt(row.try_get::<ChronoDateTime<Utc>, _>(idx)).map(Value::DateTimeTz))
931 .unwrap_or(Value::Null),
932 ColumnType::Udt | ColumnType::SSVariant => opt(row.try_get::<&str, _>(idx))
933 .map(|v| Value::String(v.to_string()))
934 .unwrap_or(Value::Null),
935 ColumnType::Null => Value::Null,
936 }
937}
938
939fn escape_mssql_string(s: &str) -> String {
940 s.replace("'", "''")
941}
942
943#[cfg(test)]
944mod tests {
945 use super::*;
946 use crate::url::DatabaseUrl;
947
948 const TEST_MSSQL_URL: &str =
949 "mssql://sa:Ferrule123!@127.0.0.1:11433/ferrule?trustServerCertificate=true";
950
951 fn try_connect() -> Option<Box<dyn crate::Connection>> {
952 let url = DatabaseUrl::parse(TEST_MSSQL_URL).ok()?;
953 let conn = crate::connect(&url, &ConnectOptions::default(), None).ok()?;
954 Some(conn)
955 }
956
957 #[test]
958 fn test_mssql_ping() {
959 let Some(mut conn) = try_connect() else {
960 eprintln!("MSSQL test container not available, skipping test_mssql_ping");
961 return;
962 };
963 conn.ping().expect("ping should succeed");
964 }
965
966 #[test]
967 fn test_mssql_query() {
968 let Some(mut conn) = try_connect() else {
969 eprintln!("MSSQL test container not available, skipping test_mssql_query");
970 return;
971 };
972 let result = conn
973 .query("SELECT * FROM test_users")
974 .expect("query should succeed");
975 assert!(!result.columns.is_empty(), "should have columns");
976 assert!(!result.rows.is_empty(), "should have rows");
977 }
978
979 #[test]
980 fn test_mssql_execute() {
981 let Some(mut conn) = try_connect() else {
982 eprintln!("MSSQL test container not available, skipping test_mssql_execute");
983 return;
984 };
985 let summary = conn
986 .execute("INSERT INTO test_users (name, age) VALUES ('TestUser', 99)")
987 .expect("execute should succeed");
988 assert!(
989 summary.rows_affected.is_some_and(|n| n > 0),
990 "should have affected rows"
991 );
992 }
993
994 #[test]
995 fn test_mssql_list_tables() {
996 let Some(mut conn) = try_connect() else {
997 eprintln!("MSSQL test container not available, skipping test_mssql_list_tables");
998 return;
999 };
1000 let tables = conn.list_tables(None).expect("list_tables should succeed");
1001 assert!(
1002 tables.contains(&"test_users".to_string()),
1003 "should contain test_users"
1004 );
1005 }
1006
1007 #[test]
1008 fn test_mssql_list_schemas() {
1009 let Some(mut conn) = try_connect() else {
1010 eprintln!("MSSQL test container not available, skipping test_mssql_list_schemas");
1011 return;
1012 };
1013 let schemas = conn.list_schemas().expect("list_schemas should succeed");
1014 let dbo = schemas
1015 .iter()
1016 .find(|s| s.name == "dbo")
1017 .unwrap_or_else(|| panic!("should contain dbo, got: {schemas:?}"));
1018 assert!(dbo.is_default, "dbo should be the default schema");
1019 }
1020
1021 #[test]
1022 fn test_mssql_describe_table() {
1023 let Some(mut conn) = try_connect() else {
1024 eprintln!("MSSQL test container not available, skipping test_mssql_describe_table");
1025 return;
1026 };
1027 let result = conn
1028 .describe_table(None, "test_users")
1029 .expect("describe_table should succeed");
1030 assert_eq!(result.columns.len(), 6, "should return 6 metadata columns");
1031 let col_names: Vec<String> = result.columns.iter().map(|c| c.name.clone()).collect();
1032 assert_eq!(
1033 col_names,
1034 vec![
1035 "column_name",
1036 "data_type",
1037 "is_nullable",
1038 "column_default",
1039 "numeric_precision",
1040 "numeric_scale"
1041 ]
1042 );
1043 }
1044
1045 #[test]
1046 fn test_mssql_type_mapping() {
1047 let Some(mut conn) = try_connect() else {
1048 eprintln!("MSSQL test container not available, skipping test_mssql_type_mapping");
1049 return;
1050 };
1051 let result = conn
1052 .query("SELECT name, age, score, active, meta FROM test_users WHERE name = 'Alice'")
1053 .expect("query should succeed");
1054 assert_eq!(result.rows.len(), 1);
1055 let row = &result.rows[0];
1056 assert!(matches!(row[0], Value::String(_)), "name should be String");
1057 assert!(matches!(row[1], Value::Int64(_)), "age should be Int64");
1058 assert!(
1059 matches!(row[2], Value::Float64(_) | Value::Decimal(_)),
1060 "score should be Float64 or Decimal"
1061 );
1062 assert!(
1063 matches!(row[3], Value::Int64(_) | Value::Bool(_)),
1064 "active should be Int64 or Bool"
1065 );
1066 assert!(
1067 matches!(row[4], Value::Json(_) | Value::String(_)),
1068 "meta should be Json or String"
1069 );
1070 }
1071
1072 #[test]
1075 fn parse_decimal_simple() {
1076 let n = parse_decimal_to_numeric("99.5").unwrap();
1077 assert_eq!(n.value(), 995);
1078 assert_eq!(n.scale(), 1);
1079 }
1080
1081 #[test]
1082 fn parse_decimal_negative_with_explicit_plus() {
1083 let n = parse_decimal_to_numeric("-12.345").unwrap();
1084 assert_eq!(n.value(), -12345);
1085 assert_eq!(n.scale(), 3);
1086 let p = parse_decimal_to_numeric("+0.5").unwrap();
1087 assert_eq!(p.value(), 5);
1088 assert_eq!(p.scale(), 1);
1089 }
1090
1091 #[test]
1092 fn parse_decimal_integer_has_zero_scale() {
1093 let n = parse_decimal_to_numeric("42").unwrap();
1094 assert_eq!(n.value(), 42);
1095 assert_eq!(n.scale(), 0);
1096 }
1097
1098 fn col(name: &str) -> ColumnInfo {
1101 ColumnInfo {
1102 name: name.to_string(),
1103 type_hint: TypeHint::String,
1104 nullable: true,
1105 }
1106 }
1107
1108 #[test]
1109 fn verify_alignment_accepts_exact_match() {
1110 let dest = vec!["id".to_string(), "name".to_string(), "age".to_string()];
1111 let target = vec![col("id"), col("name"), col("age")];
1112 verify_bulk_column_alignment(&dest, &target).expect("matched columns should pass");
1113 }
1114
1115 #[test]
1116 fn verify_alignment_is_case_insensitive() {
1117 let dest = vec!["ID".to_string(), "Name".to_string()];
1121 let target = vec![col("id"), col("name")];
1122 verify_bulk_column_alignment(&dest, &target).expect("case-insensitive should pass");
1123 }
1124
1125 #[test]
1126 fn verify_alignment_rejects_count_mismatch() {
1127 let dest = vec!["a".to_string(), "b".to_string()];
1128 let target = vec![col("a"), col("b"), col("c")];
1129 let err = verify_bulk_column_alignment(&dest, &target).expect_err("count mismatch");
1130 assert!(matches!(err, SqlError::BulkUnavailable(_)));
1131 let msg = err.to_string();
1132 assert!(
1133 msg.contains("2 dest cols") && msg.contains("3 source cols"),
1134 "useful diagnostic: {msg}"
1135 );
1136 }
1137
1138 #[test]
1139 fn verify_alignment_rejects_order_mismatch() {
1140 let dest = vec!["b".to_string(), "a".to_string()];
1143 let target = vec![col("a"), col("b")];
1144 let err = verify_bulk_column_alignment(&dest, &target).expect_err("order mismatch");
1145 assert!(matches!(err, SqlError::BulkUnavailable(_)));
1146 let msg = err.to_string();
1147 assert!(
1148 msg.contains("Position 0") && msg.contains("\"b\"") && msg.contains("\"a\""),
1149 "useful diagnostic: {msg}"
1150 );
1151 }
1152
1153 #[test]
1154 fn verify_alignment_rejects_extra_destination_columns() {
1155 let dest = vec!["a".to_string(), "b".to_string(), "extra".to_string()];
1159 let target = vec![col("a"), col("b")];
1160 let err = verify_bulk_column_alignment(&dest, &target).expect_err("extra dest cols");
1161 assert!(matches!(err, SqlError::BulkUnavailable(_)));
1162 }
1163
1164 fn qual(schema: Option<&str>, name: &str) -> QualifiedIdentifier {
1167 QualifiedIdentifier {
1168 schema: schema.map(|s| s.to_string()),
1169 name: name.to_string(),
1170 }
1171 }
1172
1173 #[test]
1174 fn parse_qualified_plain_unqualified() {
1175 assert_eq!(
1176 parse_mssql_qualified_identifier("test_users"),
1177 qual(None, "test_users")
1178 );
1179 }
1180
1181 #[test]
1182 fn parse_qualified_dot_form() {
1183 assert_eq!(
1184 parse_mssql_qualified_identifier("dbo.test_users"),
1185 qual(Some("dbo"), "test_users")
1186 );
1187 }
1188
1189 #[test]
1190 fn parse_qualified_bracketed_both_halves() {
1191 assert_eq!(
1192 parse_mssql_qualified_identifier("[dbo].[test_users]"),
1193 qual(Some("dbo"), "test_users")
1194 );
1195 }
1196
1197 #[test]
1198 fn parse_qualified_bracketed_with_embedded_dot() {
1199 assert_eq!(
1204 parse_mssql_qualified_identifier("[my.weird].[table]"),
1205 qual(Some("my.weird"), "table")
1206 );
1207 }
1208
1209 #[test]
1210 fn parse_qualified_mixed_dot_and_brackets() {
1211 assert_eq!(
1212 parse_mssql_qualified_identifier("dbo.[test users]"),
1213 qual(Some("dbo"), "test users")
1214 );
1215 assert_eq!(
1216 parse_mssql_qualified_identifier("[dbo].test_users"),
1217 qual(Some("dbo"), "test_users")
1218 );
1219 }
1220
1221 #[test]
1222 fn parse_qualified_unbracketed_with_space() {
1223 assert_eq!(
1228 parse_mssql_qualified_identifier("my table"),
1229 qual(None, "my table")
1230 );
1231 }
1232
1233 #[test]
1234 fn parse_qualified_escaped_close_bracket() {
1235 assert_eq!(
1238 parse_mssql_qualified_identifier("[wei]]rd].[table]"),
1239 qual(Some("wei]rd"), "table")
1240 );
1241 }
1242
1243 #[test]
1244 fn parse_qualified_unmatched_bracket_is_defensive() {
1245 assert_eq!(
1250 parse_mssql_qualified_identifier("[unfinished"),
1251 qual(None, "unfinished")
1252 );
1253 }
1254
1255 #[test]
1256 fn parse_qualified_trims_surrounding_whitespace() {
1257 assert_eq!(
1258 parse_mssql_qualified_identifier(" dbo.test_users "),
1259 qual(Some("dbo"), "test_users")
1260 );
1261 }
1262
1263 #[test]
1273 fn column_flag_bool_handles_int_bool_null() {
1274 assert!(column_flag_bool(&Value::Bool(true)));
1275 assert!(!column_flag_bool(&Value::Bool(false)));
1276 assert!(column_flag_bool(&Value::Int64(1)));
1277 assert!(!column_flag_bool(&Value::Int64(0)));
1278 assert!(!column_flag_bool(&Value::Null));
1283 assert!(!column_flag_bool(&Value::String("yes".into())));
1285 }
1286
1287 #[test]
1288 fn parse_decimal_rejects_scientific_notation() {
1289 assert!(parse_decimal_to_numeric("1.5e10").is_err());
1290 assert!(parse_decimal_to_numeric("1E5").is_err());
1291 }
1292
1293 #[test]
1294 fn parse_decimal_rejects_malformed() {
1295 assert!(parse_decimal_to_numeric("").is_err());
1296 assert!(parse_decimal_to_numeric("abc").is_err());
1297 assert!(parse_decimal_to_numeric("1..5").is_err());
1298 assert!(parse_decimal_to_numeric(".").is_err());
1299 }
1300
1301 #[test]
1302 fn value_to_column_data_handles_primitives() {
1303 assert!(matches!(
1304 value_to_column_data(&Value::Bool(true), TypeHint::Bool).unwrap(),
1305 ColumnData::Bit(Some(true))
1306 ));
1307 assert!(matches!(
1308 value_to_column_data(&Value::Int64(42), TypeHint::Int64).unwrap(),
1309 ColumnData::I64(Some(42))
1310 ));
1311 let f = value_to_column_data(&Value::Float64(1.5), TypeHint::Float64).unwrap();
1312 assert!(matches!(f, ColumnData::F64(Some(v)) if (v - 1.5).abs() < 1e-12));
1313 }
1314
1315 #[test]
1316 fn value_to_column_data_decimal_routes_through_numeric() {
1317 let d = value_to_column_data(&Value::Decimal("12.34".into()), TypeHint::Decimal).unwrap();
1318 match d {
1319 ColumnData::Numeric(Some(n)) => {
1320 assert_eq!(n.value(), 1234);
1321 assert_eq!(n.scale(), 2);
1322 }
1323 other => panic!("expected Numeric, got {other:?}"),
1324 }
1325 }
1326
1327 #[test]
1328 fn value_to_column_data_string_bytes_uuid() {
1329 match value_to_column_data(&Value::String("hi".into()), TypeHint::String).unwrap() {
1330 ColumnData::String(Some(s)) => assert_eq!(s.as_ref(), "hi"),
1331 other => panic!("expected String, got {other:?}"),
1332 }
1333 match value_to_column_data(&Value::Bytes(vec![1, 2, 3]), TypeHint::Bytes).unwrap() {
1334 ColumnData::Binary(Some(b)) => assert_eq!(b.as_ref(), &[1u8, 2, 3]),
1335 other => panic!("expected Binary, got {other:?}"),
1336 }
1337 match value_to_column_data(
1338 &Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into()),
1339 TypeHint::Uuid,
1340 )
1341 .unwrap()
1342 {
1343 ColumnData::Guid(Some(u)) => {
1344 assert_eq!(u.to_string(), "550e8400-e29b-41d4-a716-446655440000");
1345 }
1346 other => panic!("expected Guid, got {other:?}"),
1347 }
1348 }
1349
1350 #[test]
1351 fn value_to_column_data_json_and_array_serialize_as_nvarchar() {
1352 let j = serde_json::json!({"role": "admin"});
1353 match value_to_column_data(&Value::Json(j), TypeHint::Json).unwrap() {
1354 ColumnData::String(Some(s)) => {
1355 assert!(s.contains("\"role\":\"admin\""));
1356 }
1357 other => panic!("expected String for JSON, got {other:?}"),
1358 }
1359 let a = Value::Array(vec![Value::Int64(1), Value::Int64(2)]);
1360 match value_to_column_data(&a, TypeHint::Array).unwrap() {
1361 ColumnData::String(Some(s)) => assert_eq!(s.as_ref(), "[1,2]"),
1362 other => panic!("expected String for Array, got {other:?}"),
1363 }
1364 }
1365
1366 #[test]
1367 fn value_to_column_data_null_picks_typed_none() {
1368 assert!(matches!(
1370 value_to_column_data(&Value::Null, TypeHint::Bool).unwrap(),
1371 ColumnData::Bit(None)
1372 ));
1373 assert!(matches!(
1374 value_to_column_data(&Value::Null, TypeHint::Int64).unwrap(),
1375 ColumnData::I64(None)
1376 ));
1377 assert!(matches!(
1378 value_to_column_data(&Value::Null, TypeHint::Decimal).unwrap(),
1379 ColumnData::Numeric(None)
1380 ));
1381 assert!(matches!(
1382 value_to_column_data(&Value::Null, TypeHint::Bytes).unwrap(),
1383 ColumnData::Binary(None)
1384 ));
1385 assert!(matches!(
1386 value_to_column_data(&Value::Null, TypeHint::DateTimeTz).unwrap(),
1387 ColumnData::DateTimeOffset(None)
1388 ));
1389 assert!(matches!(
1390 value_to_column_data(&Value::Null, TypeHint::Uuid).unwrap(),
1391 ColumnData::Guid(None)
1392 ));
1393 assert!(matches!(
1395 value_to_column_data(&Value::Null, TypeHint::Json).unwrap(),
1396 ColumnData::String(None)
1397 ));
1398 assert!(matches!(
1399 value_to_column_data(&Value::Null, TypeHint::Other).unwrap(),
1400 ColumnData::String(None)
1401 ));
1402 }
1403
1404 #[test]
1412 fn test_mssql_bulk_insert_rows_round_trip() {
1413 let Some(mut conn) = try_connect() else {
1414 eprintln!(
1415 "MSSQL test container not available, skipping test_mssql_bulk_insert_rows_round_trip"
1416 );
1417 return;
1418 };
1419
1420 let pid = std::process::id();
1421 let table = format!("ferrule_bulk_test_{pid}");
1422 let _ = conn.execute(&format!(
1423 "IF OBJECT_ID('{table}', 'U') IS NOT NULL DROP TABLE {table}"
1424 ));
1425 conn.execute(&format!(
1426 "CREATE TABLE {table} (\
1427 id BIGINT NOT NULL, \
1428 name NVARCHAR(255) NULL, \
1429 active BIT NULL, \
1430 score DECIMAL(10,2) NULL, \
1431 meta NVARCHAR(MAX) NULL, \
1432 uid UNIQUEIDENTIFIER NULL\
1433 )"
1434 ))
1435 .expect("CREATE TABLE");
1436
1437 let columns = vec![
1438 ColumnInfo {
1439 name: "id".into(),
1440 type_hint: TypeHint::Int64,
1441 nullable: false,
1442 },
1443 ColumnInfo {
1444 name: "name".into(),
1445 type_hint: TypeHint::String,
1446 nullable: true,
1447 },
1448 ColumnInfo {
1449 name: "active".into(),
1450 type_hint: TypeHint::Bool,
1451 nullable: true,
1452 },
1453 ColumnInfo {
1454 name: "score".into(),
1455 type_hint: TypeHint::Decimal,
1456 nullable: true,
1457 },
1458 ColumnInfo {
1459 name: "meta".into(),
1460 type_hint: TypeHint::Json,
1461 nullable: true,
1462 },
1463 ColumnInfo {
1464 name: "uid".into(),
1465 type_hint: TypeHint::Uuid,
1466 nullable: true,
1467 },
1468 ];
1469
1470 let rows: Vec<Row> = vec![
1471 vec![
1472 Value::Int64(1),
1473 Value::String("Alice".into()),
1474 Value::Bool(true),
1475 Value::Decimal("99.50".into()),
1476 Value::Json(serde_json::json!({"role": "admin"})),
1477 Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into()),
1478 ],
1479 vec![
1480 Value::Int64(2),
1481 Value::String("Bob".into()),
1482 Value::Bool(false),
1483 Value::Decimal("-7.25".into()),
1484 Value::Json(serde_json::json!({"role": "user"})),
1485 Value::Null,
1486 ],
1487 vec![
1488 Value::Int64(3),
1489 Value::Null,
1490 Value::Null,
1491 Value::Null,
1492 Value::Null,
1493 Value::Null,
1494 ],
1495 ];
1496
1497 let n = conn
1498 .bulk_insert_rows(BulkInsert {
1499 table: &table,
1500 columns: &columns,
1501 rows: &rows,
1502 copy_format: crate::copy::CopyFormat::Text,
1503 })
1504 .expect("bulk_insert_rows");
1505 assert_eq!(n, 3);
1506
1507 let result = conn
1509 .query(&format!(
1510 "SELECT id, name, active, score, meta, uid FROM {table} ORDER BY id"
1511 ))
1512 .expect("read-back query");
1513 assert_eq!(result.rows.len(), 3);
1514
1515 if let Value::Decimal(s) = &result.rows[0][3] {
1517 assert!(
1518 s.starts_with("99.5"),
1519 "row 1 score should be ~99.50, got {s:?}"
1520 );
1521 } else if let Value::Float64(f) = result.rows[0][3] {
1522 assert!((f - 99.5).abs() < 1e-6, "row 1 score got {f}");
1523 } else {
1524 panic!(
1525 "row 1 score should be Decimal or Float64, got {:?}",
1526 result.rows[0][3]
1527 );
1528 }
1529
1530 assert!(matches!(&result.rows[1][5], Value::Null));
1532
1533 assert!(matches!(&result.rows[2][1], Value::Null));
1535 assert!(matches!(&result.rows[2][2], Value::Null));
1536 assert!(matches!(&result.rows[2][3], Value::Null));
1537
1538 conn.execute(&format!("DROP TABLE {table}"))
1540 .expect("DROP TABLE");
1541 }
1542
1543 #[test]
1544 fn test_mssql_primary_key() {
1545 let Some(mut conn) = try_connect() else {
1546 eprintln!("MSSQL test container not available, skipping test_mssql_primary_key");
1547 return;
1548 };
1549 let pk = conn.primary_key(None, "test_users").expect("primary_key");
1550 assert_eq!(pk, vec!["id".to_string()]);
1551 }
1552
1553 #[test]
1554 fn test_mssql_list_foreign_keys() {
1555 let Some(mut conn) = try_connect() else {
1556 eprintln!("MSSQL test container not available, skipping test_mssql_list_foreign_keys");
1557 return;
1558 };
1559 let pid = std::process::id();
1560 let child = format!("ferrule_fk_test_orders_{pid}");
1561 let _ = conn.execute(&format!("DROP TABLE IF EXISTS {child}"));
1562 conn.execute(&format!(
1563 "CREATE TABLE {child} (\
1564 id INT IDENTITY(1,1) PRIMARY KEY, \
1565 user_id INT FOREIGN KEY REFERENCES test_users(id) ON DELETE CASCADE\
1566 )"
1567 ))
1568 .expect("CREATE TABLE");
1569
1570 let fks = conn.list_foreign_keys(None).expect("list_foreign_keys");
1571 let matching: Vec<_> = fks.iter().filter(|fk| fk.child_table == child).collect();
1572 assert_eq!(matching.len(), 1, "expected 1 FK from {child}, got {fks:?}");
1573 let fk = matching[0];
1574 assert_eq!(fk.child_columns, vec!["user_id".to_string()]);
1575 assert_eq!(fk.parent_table, "test_users");
1576 assert_eq!(fk.parent_columns, vec!["id".to_string()]);
1577 assert_eq!(fk.on_delete.as_deref(), Some("CASCADE"));
1578
1579 let _ = conn.execute(&format!("DROP TABLE {child}"));
1580 }
1581
1582 #[test]
1586 fn test_mssql_copy_skip_then_upsert() {
1587 use crate::backend::Backend;
1588 use crate::copy::{CopyOptions, CopySource, IfExists, copy_rows};
1589
1590 let (Some(mut src), Some(mut dst)) = (try_connect(), try_connect()) else {
1591 eprintln!(
1592 "MSSQL test container not available, skipping test_mssql_copy_skip_then_upsert"
1593 );
1594 return;
1595 };
1596
1597 let pid = std::process::id();
1598 let src_table = format!("ferrule_ms_skip_src_{pid}");
1599 let dst_table = format!("ferrule_ms_skip_dst_{pid}");
1600 let _ = src.execute(&format!("DROP TABLE IF EXISTS {src_table}"));
1601 let _ = dst.execute(&format!("DROP TABLE IF EXISTS {dst_table}"));
1602 src.execute(&format!(
1603 "CREATE TABLE {src_table} (id INT PRIMARY KEY, name NVARCHAR(64), val INT)"
1604 ))
1605 .expect("CREATE src");
1606 dst.execute(&format!(
1607 "CREATE TABLE {dst_table} (id INT PRIMARY KEY, name NVARCHAR(64), val INT)"
1608 ))
1609 .expect("CREATE dst");
1610 src.execute(&format!(
1611 "INSERT INTO {src_table} VALUES (1, 'new-1', 10), (2, 'new-2', 20)"
1612 ))
1613 .expect("seed src");
1614 dst.execute(&format!("INSERT INTO {dst_table} VALUES (1, 'old-1', 99)"))
1615 .expect("seed dst");
1616
1617 let opts = CopyOptions {
1619 source: CopySource::Query {
1620 sql: format!("SELECT * FROM {src_table} ORDER BY id"),
1621 into: dst_table.clone(),
1622 },
1623 if_exists: IfExists::Skip,
1624 ..Default::default()
1625 };
1626 copy_rows(&mut src, Backend::MsSql, &mut dst, Backend::MsSql, &opts)
1627 .expect("copy_rows skip");
1628
1629 let out = dst
1630 .query(&format!(
1631 "SELECT id, name, val FROM {dst_table} ORDER BY id"
1632 ))
1633 .expect("verify skip");
1634 assert_eq!(out.rows.len(), 2);
1635 assert!(matches!(&out.rows[0][1], Value::String(s) if s == "old-1"));
1636 assert!(matches!(&out.rows[1][1], Value::String(s) if s == "new-2"));
1637
1638 let opts = CopyOptions {
1640 source: CopySource::Query {
1641 sql: format!("SELECT * FROM {src_table} ORDER BY id"),
1642 into: dst_table.clone(),
1643 },
1644 if_exists: IfExists::Upsert,
1645 ..Default::default()
1646 };
1647 copy_rows(&mut src, Backend::MsSql, &mut dst, Backend::MsSql, &opts)
1648 .expect("copy_rows upsert");
1649
1650 let out = dst
1651 .query(&format!(
1652 "SELECT id, name, val FROM {dst_table} ORDER BY id"
1653 ))
1654 .expect("verify upsert");
1655 assert_eq!(out.rows.len(), 2);
1656 assert!(matches!(&out.rows[0][1], Value::String(s) if s == "new-1"));
1657 assert!(matches!(&out.rows[0][2], Value::Int64(10)));
1658 assert!(matches!(&out.rows[1][1], Value::String(s) if s == "new-2"));
1659
1660 let _ = src.execute(&format!("DROP TABLE {src_table}"));
1661 let _ = dst.execute(&format!("DROP TABLE {dst_table}"));
1662 }
1663}