1mod arrow_convert;
19mod proxy;
20
21pub use proxy::MssqlProxyKind;
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use arrow::datatypes::SchemaRef;
27use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
28use tokio::net::TcpStream;
29use tokio::runtime::Runtime;
30use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
31
32use proxy::{detect_mssql_proxy_kind, warn_proxy_kind};
33
34use crate::config::{TlsConfig, TlsMode};
35use crate::error::Result;
36use crate::source::batch_controller::{
37 AdaptiveBatchController, DEFAULT_BATCH_TARGET_MB, PROBE_BATCH_SIZE,
38};
39use crate::source::query::build_export_query;
40use crate::source::{BatchSink, ExportRequest, Source, TableIntrospection};
41use crate::types::{ColumnOverrides, TypeMapping};
42
43type MssqlClient = Client<Compat<TcpStream>>;
44
45pub struct MssqlSource {
51 rt: Runtime,
52 client: MssqlClient,
53 proxy_kind: MssqlProxyKind,
55 lock_timeout_applied: bool,
58}
59
60impl Drop for MssqlSource {
61 fn drop(&mut self) {
75 if !self.lock_timeout_applied {
76 return;
77 }
78 let Self { rt, client, .. } = self;
79 let _ = rt.block_on(async {
80 tokio::time::timeout(
81 std::time::Duration::from_secs(2),
82 client.execute("SET LOCK_TIMEOUT -1", &[]),
83 )
84 .await
85 });
86 }
87}
88
89struct MssqlUrl {
91 host: String,
92 port: u16,
93 user: String,
94 password: String,
95 database: String,
96}
97
98fn parse_mssql_url(url: &str) -> Result<MssqlUrl> {
99 let rest = url
100 .strip_prefix("sqlserver://")
101 .or_else(|| url.strip_prefix("mssql://"))
102 .ok_or_else(|| anyhow::anyhow!("mssql url must start with sqlserver:// — got {url}"))?;
103 let (userinfo, hostpart) = rest
106 .rsplit_once('@')
107 .ok_or_else(|| anyhow::anyhow!("mssql url missing user@host: {url}"))?;
108 let (user, password) = match userinfo.split_once(':') {
109 Some((u, p)) => (u.to_string(), p.to_string()),
110 None => (userinfo.to_string(), String::new()),
111 };
112 let (hostport, database) = hostpart
113 .split_once('/')
114 .map(|(h, d)| (h, d.to_string()))
115 .unwrap_or((hostpart, String::new()));
116 let (host, port) = match hostport.rsplit_once(':') {
117 Some((h, p)) => (
118 h.to_string(),
119 p.parse::<u16>()
120 .map_err(|_| anyhow::anyhow!("mssql url port not a number: {p}"))?,
121 ),
122 None => (hostport.to_string(), 1433),
123 };
124 if database.is_empty() {
125 anyhow::bail!("mssql url must include a database: sqlserver://user:pass@host:port/<db>");
126 }
127 Ok(MssqlUrl {
128 host,
129 port,
130 user,
131 password,
132 database,
133 })
134}
135
136impl MssqlSource {
137 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
141 crate::source::require_tls_or_loopback(url, tls)?;
147 let parts = parse_mssql_url(url)?;
148 let mut config = Config::new();
149 config.host(&parts.host);
150 config.port(parts.port);
151 config.database(&parts.database);
152 config.authentication(AuthMethod::sql_server(&parts.user, &parts.password));
153
154 config.encryption(EncryptionLevel::Required);
159 match tls {
160 Some(cfg) if cfg.mode == TlsMode::Disable || cfg.accept_invalid_certs => {
166 config.trust_cert()
167 }
168 Some(cfg) => {
169 static WEBPKI_WARNED: std::sync::Once = std::sync::Once::new();
178 WEBPKI_WARNED.call_once(|| {
179 log::warn!(
180 "mssql: TLS certificate validation is enabled, but the SQL Server \
181 engine pins an old rustls-webpki (via tiberius) with known CA \
182 name-constraint advisories (RUSTSEC-2026-0098/0099). Validation \
183 against a name-constraint-asserting private CA may accept a \
184 mis-issued certificate. Track tiberius for a rustls upgrade."
185 );
186 });
187 if let Some(ca) = cfg.ca_file.as_deref() {
188 config.trust_cert_ca(ca);
189 }
190 }
191 None => {
192 static WARNED: std::sync::Once = std::sync::Once::new();
201 WARNED.call_once(|| {
202 log::warn!(
203 "mssql: connecting with TLS certificate validation disabled \
204 (no `source.tls:` block) — the connection is encrypted but the \
205 server certificate is not verified (MITM not detected). Add \
206 `source.tls: {{ mode: verify-full, ca_file: <ca.pem> }}` to enable \
207 strict validation (or `mode: verify-ca` to skip only hostname checks)."
208 );
209 });
210 config.trust_cert();
211 }
212 }
213
214 let rt = tokio::runtime::Builder::new_current_thread()
215 .enable_all()
216 .build()
217 .map_err(|e| anyhow::anyhow!("mssql: tokio runtime build failed: {e}"))?;
218
219 let client = rt.block_on(async {
220 let tcp = TcpStream::connect(config.get_addr())
221 .await
222 .map_err(|e| anyhow::anyhow!("mssql: TCP connect failed: {e}"))?;
223 tcp.set_nodelay(true).ok();
224 Client::connect(config, tcp.compat_write())
225 .await
226 .map_err(|e| anyhow::anyhow!("mssql: login failed: {e}"))
227 })?;
228
229 let mut src = Self {
230 rt,
231 client,
232 proxy_kind: MssqlProxyKind::Direct,
233 lock_timeout_applied: false,
234 };
235 src.query_scalar("SELECT 1")?;
238 let kind = detect_mssql_proxy_kind(&src.rt, &mut src.client);
243 warn_proxy_kind(kind);
244 src.proxy_kind = kind;
245 Ok(src)
246 }
247
248 #[allow(dead_code)]
252 pub fn proxy_kind(&self) -> MssqlProxyKind {
253 self.proxy_kind
254 }
255
256 fn mssql_decimal_catalog_hints_opt(
264 &mut self,
265 query: &str,
266 ) -> Option<HashMap<String, (u8, i8)>> {
267 let (schema, table) = parse_mssql_simple_from_table(query)?;
268 match self.fetch_mssql_decimal_catalog_hints(&schema, &table) {
269 Ok(m) => m,
270 Err(e) => {
271 log::warn!(
277 "mssql decimal catalog lookup failed for {schema}.{table} — decimal scale \
278 will fall back to first-batch inference (declare it with a `columns:` \
279 override if an all-NULL first batch truncates it): {e}"
280 );
281 None
282 }
283 }
284 }
285
286 fn fetch_mssql_decimal_catalog_hints(
290 &mut self,
291 schema: &str,
292 table: &str,
293 ) -> Result<Option<HashMap<String, (u8, i8)>>> {
294 let sql = format!(
298 "SELECT c.name, c.precision, c.scale \
299 FROM sys.columns c \
300 JOIN sys.types t ON t.user_type_id = c.user_type_id \
301 JOIN sys.objects o ON o.object_id = c.object_id \
302 JOIN sys.schemas s ON s.schema_id = o.schema_id \
303 WHERE s.name = N'{}' AND o.name = N'{}' \
304 AND t.name IN ('decimal', 'numeric')",
305 schema.replace('\'', "''"),
306 table.replace('\'', "''")
307 );
308 let Self { rt, client, .. } = self;
309 let rows = rt.block_on(async {
310 client
311 .query(sql.as_str(), &[])
312 .await
313 .map_err(|e| anyhow::anyhow!("mssql: sys.columns probe failed: {e}"))?
314 .into_first_result()
315 .await
316 .map_err(|e| anyhow::anyhow!("mssql: reading sys.columns rows failed: {e}"))
317 })?;
318
319 let mut map = HashMap::new();
320 for row in &rows {
321 let name: Option<&str> = row.try_get(0).ok().flatten();
325 let precision: Option<u8> = row.try_get(1).ok().flatten();
326 let scale: Option<u8> = row.try_get(2).ok().flatten();
327 if let (Some(name), Some(p), Some(s)) = (name, precision, scale)
328 && let Some(pair) = catalog_decimal_to_params(p, s)
329 {
330 map.insert(name.to_string(), pair);
331 }
332 }
333
334 if map.is_empty() {
335 Ok(None)
336 } else {
337 log::debug!(
338 "mssql decimal catalog: resolved {} DECIMAL/NUMERIC column(s) for {schema}.{table}",
339 map.len(),
340 );
341 Ok(Some(map))
342 }
343 }
344}
345
346fn catalog_decimal_to_params(precision: u8, scale: u8) -> Option<(u8, i8)> {
351 if precision == 0 || precision > 38 {
352 return None;
353 }
354 if scale > precision || scale > i8::MAX as u8 {
355 return None;
356 }
357 Some((precision, scale as i8))
358}
359
360fn parse_mssql_simple_from_table(query: &str) -> Option<(String, String)> {
367 let from_idx = mssql_find_outer_from_keyword(query)?;
368 let tail = trim_sql_ws(query.get(from_idx + 4..)?);
369 let (first, after1) = parse_mssql_ident_piece(tail)?;
370 let after1 = trim_sql_ws(after1);
371 let (schema, table, after) = if after1.starts_with('.') {
373 let (second, after2) = parse_mssql_ident_piece(trim_sql_ws(after1.get(1..)?))?;
374 let after2 = trim_sql_ws(after2);
375 if after2.starts_with('.') {
376 let (third, after3) = parse_mssql_ident_piece(trim_sql_ws(after2.get(1..)?))?;
378 (second, third, trim_sql_ws(after3))
379 } else {
380 (first, second, after2)
381 }
382 } else {
383 ("dbo".to_string(), first, after1)
384 };
385 let after = skip_mssql_optional_alias(after)?;
388 if mssql_joins_or_comma(after) {
389 return None;
390 }
391 Some((schema, table))
392}
393
394fn trim_sql_ws(s: &str) -> &str {
395 s.trim_matches(|c: char| matches!(c, ' ' | '\t' | '\n' | '\r'))
396}
397
398fn is_sql_ident_byte(b: u8) -> bool {
399 b.is_ascii_alphanumeric() || b == b'_'
400}
401
402fn sql_keyword_at(haystack: &[u8], idx: usize, kw_lower: &[u8]) -> bool {
405 let n = kw_lower.len();
406 if idx + n > haystack.len() || !haystack[idx..idx + n].eq_ignore_ascii_case(kw_lower) {
407 return false;
408 }
409 let before_ok = idx == 0 || !is_sql_ident_byte(haystack[idx - 1]);
410 let after_ok = idx + n >= haystack.len() || !is_sql_ident_byte(haystack[idx + n]);
411 before_ok && after_ok
412}
413
414fn mssql_find_outer_from_keyword(sql: &str) -> Option<usize> {
417 let b = sql.as_bytes();
418 let mut i = 0usize;
419 let mut depth = 0usize;
420 let mut in_quote = false;
421 while i < b.len() {
422 if in_quote {
423 if b[i] == b'\'' {
424 if i + 1 < b.len() && b[i + 1] == b'\'' {
425 i += 2;
426 } else {
427 in_quote = false;
428 i += 1;
429 }
430 continue;
431 }
432 i += 1;
433 continue;
434 }
435 match b[i] {
436 b'\'' => in_quote = true,
437 b'(' => depth += 1,
438 b')' => depth = depth.saturating_sub(1),
439 _ if depth == 0 && sql_keyword_at(b, i, b"from") => return Some(i),
440 _ => {}
441 }
442 i += 1;
443 }
444 None
445}
446
447fn parse_mssql_ident_piece(rest: &str) -> Option<(String, &str)> {
450 let rest = trim_sql_ws(rest);
451 if let Some(after_open) = rest.strip_prefix('[') {
452 let mut out = String::new();
453 let mut chars = after_open.chars();
454 while let Some(ch) = chars.next() {
455 if ch == ']' {
456 if chars.as_str().starts_with(']') {
457 chars.next();
458 out.push(']');
459 continue;
460 }
461 return Some((out, chars.as_str()));
462 }
463 out.push(ch);
464 }
465 return None; }
467 let bytes = rest.as_bytes();
468 if bytes.is_empty() || (!bytes[0].is_ascii_alphabetic() && bytes[0] != b'_') {
469 return None;
470 }
471 let mut i = 1usize;
472 while i < bytes.len() && is_sql_ident_byte(bytes[i]) {
473 i += 1;
474 }
475 Some((rest.get(0..i)?.to_string(), rest.get(i..)?))
476}
477
478fn mssql_joins_or_comma(rest: &str) -> bool {
481 let r = trim_sql_ws(rest);
482 if r.starts_with(',') || r.starts_with('.') {
483 return true;
484 }
485 let b = r.as_bytes();
486 ["inner", "left", "right", "full", "cross", "join"]
487 .iter()
488 .any(|kw| sql_keyword_at(b, 0, kw.as_bytes()))
489}
490
491fn skip_mssql_optional_alias(rest: &str) -> Option<&str> {
495 let rest = trim_sql_ws(rest);
496 if rest.is_empty() || mssql_starts_clause_boundary(rest) || mssql_joins_or_comma(rest) {
497 return Some(rest);
498 }
499 let mut rest = rest;
500 if sql_keyword_at(rest.as_bytes(), 0, b"as") {
501 rest = trim_sql_ws(rest.get(2..)?);
502 }
503 let (_, tail) = parse_mssql_ident_piece(rest)?;
504 Some(trim_sql_ws(tail))
505}
506
507fn mssql_starts_clause_boundary(rest: &str) -> bool {
508 let r = trim_sql_ws(rest);
509 if r.is_empty() {
510 return true;
511 }
512 const KWS: &[&[u8]] = &[
513 b"where",
514 b"group",
515 b"having",
516 b"order",
517 b"union",
518 b"except",
519 b"intersect",
520 b"for",
521 b"option",
522 b"offset",
523 ];
524 let b = r.as_bytes();
525 KWS.iter().any(|kw| sql_keyword_at(b, 0, kw))
526}
527
528impl Source for MssqlSource {
529 fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()> {
530 let built = build_export_query(request, crate::config::SourceType::Mssql);
533 let sql = built.sql.clone();
534 let overrides = request.column_overrides.clone();
535 let mut ctl =
544 AdaptiveBatchController::new(request.tuning, request.tuning.batch_size.max(1));
545 let mut cap_applied = false;
546 let lock_timeout_ms = request.tuning.lock_timeout_s.saturating_mul(1000);
556 let stmt_timeout = (request.tuning.statement_timeout_s > 0)
557 .then(|| std::time::Duration::from_secs(request.tuning.statement_timeout_s));
558
559 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
565 let decimal_hints = self.mssql_decimal_catalog_hints_opt(hint_query);
566
567 if lock_timeout_ms > 0 {
570 self.lock_timeout_applied = true;
571 }
572
573 let Self { rt, client, .. } = self;
574 rt.block_on(async {
575 use futures_util::stream::TryStreamExt;
576 use tiberius::QueryItem;
577
578 if lock_timeout_ms > 0 {
579 client
580 .execute(format!("SET LOCK_TIMEOUT {lock_timeout_ms}"), &[])
581 .await
582 .map_err(|e| anyhow::anyhow!("mssql: SET LOCK_TIMEOUT failed: {e}"))?;
583 }
584
585 let started = std::time::Instant::now();
586 let mut stream = client
587 .query(sql.as_str(), &[])
588 .await
589 .map_err(|e| anyhow::anyhow!("mssql: query failed: {e}"))?;
590
591 let mut columns: Vec<tiberius::Column> = Vec::new();
592 let mut buf: Vec<tiberius::Row> = Vec::with_capacity(ctl.target());
593 let mut schema: Option<SchemaRef> = None;
594 let max_value_bytes = request.tuning.max_value_bytes();
598
599 while let Some(item) = stream
600 .try_next()
601 .await
602 .map_err(|e| anyhow::anyhow!("mssql: streaming rows failed: {e}"))?
603 {
604 if let Some(budget) = stmt_timeout
605 && started.elapsed() > budget
606 {
607 return Err(crate::source::StatementDurationTimeout::mssql(
612 budget.as_secs(),
613 )
614 .into());
615 }
616 match item {
617 QueryItem::Metadata(meta) if columns.is_empty() => {
620 columns = meta.columns().to_vec();
621 if let Ok((provisional, _)) = arrow_convert::mssql_columns_to_schema(
632 &columns,
633 &overrides,
634 &[],
635 decimal_hints.as_ref(),
636 ) {
637 let eff = request
638 .tuning
639 .effective_batch_size(Some(&Arc::new(provisional)));
640 ctl.raise_configured_ceiling(eff);
641 }
642 }
643 QueryItem::Metadata(_) => {}
644 QueryItem::Row(row) => {
645 buf.push(row);
646 if buf.len() >= ctl.target() {
647 let arrow_bytes = emit_mssql_batch(
648 &columns,
649 &overrides,
650 decimal_hints.as_ref(),
651 &mut schema,
652 &buf,
653 sink,
654 max_value_bytes,
655 )?;
656 let n = buf.len();
657 buf.clear();
658 if !cap_applied && n > 0 {
663 let arrow_per_row = (arrow_bytes / n).max(64);
664 let target_mb = request
665 .tuning
666 .batch_size_memory_mb
667 .unwrap_or(DEFAULT_BATCH_TARGET_MB);
668 let safe = ((target_mb * 1024 * 1024) / arrow_per_row)
669 .max(PROBE_BATCH_SIZE);
670 if let Some(new) = ctl.apply_memory_cap(safe) {
671 log::info!(
672 "MSSQL batch cap: arrow≈{} B/row, target={} MB → batch_size → {}",
673 arrow_per_row,
674 target_mb,
675 new
676 );
677 buf.reserve(new.saturating_sub(buf.capacity()));
678 }
679 cap_applied = true;
680 }
681 ctl.after_batch(|| None);
683 ctl.throttle();
684 }
685 }
686 }
687 }
688 if !buf.is_empty() || schema.is_none() {
694 emit_mssql_batch(
695 &columns,
696 &overrides,
697 decimal_hints.as_ref(),
698 &mut schema,
699 &buf,
700 sink,
701 max_value_bytes,
702 )?;
703 }
704 Ok::<_, anyhow::Error>(())
705 })?;
706 Ok(())
707 }
708
709 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
710 let Self { rt, client, .. } = self;
711 rt.block_on(async {
712 let row = client
713 .query(sql, &[])
714 .await
715 .map_err(|e| anyhow::anyhow!("mssql: scalar query failed: {e}"))?
716 .into_row()
717 .await
718 .map_err(|e| anyhow::anyhow!("mssql: reading scalar row failed: {e}"))?;
719 Ok(row.and_then(|r| scalar_to_string(&r)))
720 })
721 }
722
723 fn type_mappings(
724 &mut self,
725 query: &str,
726 column_overrides: &ColumnOverrides,
727 ) -> Result<Vec<TypeMapping>> {
728 let wrapped = format!("SELECT * FROM ({query}) AS _rivet_q WHERE 1 = 0");
730 let overrides = column_overrides.clone();
731 let Self { rt, client, .. } = self;
732 rt.block_on(async {
733 let mut stream = client
734 .query(wrapped.as_str(), &[])
735 .await
736 .map_err(|e| anyhow::anyhow!("mssql: type-probe query failed: {e}"))?;
737 let columns = stream
738 .columns()
739 .await
740 .map_err(|e| anyhow::anyhow!("mssql: type-probe metadata failed: {e}"))?
741 .map(<[_]>::to_vec)
742 .unwrap_or_default();
743 let _ = stream.into_first_result().await;
745 Ok(arrow_convert::mssql_type_mappings(&columns, &overrides))
746 })
747 }
748
749 fn sample_pressure(&mut self) -> Option<u64> {
750 let Self { rt, client, .. } = self;
751 let sql = "SELECT SUM(cntr_value) FROM sys.dm_os_performance_counters \
761 WHERE counter_name IN ('Workfiles Created/sec', 'Worktables Created/sec')";
762 rt.block_on(async {
763 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
764 row.get::<i64, _>(0).map(|v| v.max(0) as u64)
765 })
766 }
767}
768
769fn emit_mssql_batch(
779 columns: &[tiberius::Column],
780 overrides: &ColumnOverrides,
781 decimal_hints: Option<&HashMap<String, (u8, i8)>>,
782 schema: &mut Option<SchemaRef>,
783 rows: &[tiberius::Row],
784 sink: &mut dyn BatchSink,
785 max_value_bytes: Option<usize>,
786) -> Result<usize> {
787 let schema_ref = match schema {
788 Some(s) => s.clone(),
789 None => {
790 let (built, _decoders) =
791 arrow_convert::mssql_columns_to_schema(columns, overrides, rows, decimal_hints)?;
792 let s: SchemaRef = Arc::new(built);
793 sink.on_schema(s.clone())?;
794 *schema = Some(s.clone());
795 s
796 }
797 };
798 if !rows.is_empty() {
799 let batch = arrow_convert::mssql_rows_to_record_batch(&schema_ref, rows, max_value_bytes)?;
800 let bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
801 sink.on_batch(&batch)?;
802 return Ok(bytes);
803 }
804 Ok(0)
805}
806
807fn scalar_to_string(row: &tiberius::Row) -> Option<String> {
811 use tiberius::ColumnData;
812 let cell = row.cells().next().map(|(_, d)| d)?;
813 match cell {
814 ColumnData::U8(v) => v.map(|x| x.to_string()),
815 ColumnData::I16(v) => v.map(|x| x.to_string()),
816 ColumnData::I32(v) => v.map(|x| x.to_string()),
817 ColumnData::I64(v) => v.map(|x| x.to_string()),
818 ColumnData::F32(v) => v.map(|x| x.to_string()),
819 ColumnData::F64(v) => v.map(|x| x.to_string()),
820 ColumnData::Bit(v) => v.map(|x| x.to_string()),
821 ColumnData::String(v) => v.as_ref().map(|s| s.to_string()),
822 ColumnData::Numeric(v) => v.map(|n| {
823 let raw = n.value();
825 let scale = n.scale() as usize;
826 if scale == 0 {
827 raw.to_string()
828 } else {
829 let neg = raw < 0;
830 let digits = raw.unsigned_abs().to_string();
831 let digits = format!("{digits:0>width$}", width = scale + 1);
832 let (int, frac) = digits.split_at(digits.len() - scale);
833 format!("{}{int}.{frac}", if neg { "-" } else { "" })
834 }
835 }),
836 ColumnData::Guid(v) => v.map(|g| g.to_string()),
837 other => Some(format!("{other:?}")),
838 }
839}
840
841pub(crate) fn introspect_mssql_table_for_chunking(
844 url: &str,
845 tls: Option<&TlsConfig>,
846 qualified_table: &str,
847) -> Result<TableIntrospection> {
848 let (schema, table) = match qualified_table.split_once('.') {
849 Some((s, t)) => (s.to_string(), t.to_string()),
850 None => ("dbo".to_string(), qualified_table.to_string()),
851 };
852 let mut src = MssqlSource::connect_with_tls(url, tls)?;
853
854 let count_sql = format!(
857 "SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
858 JOIN sys.objects o ON o.object_id = p.object_id \
859 JOIN sys.schemas s ON s.schema_id = o.schema_id \
860 WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
861 schema.replace('\'', "''"),
862 table.replace('\'', "''")
863 );
864 let row_estimate = src
865 .query_scalar(&count_sql)?
866 .and_then(|s| s.parse::<i64>().ok())
867 .unwrap_or(0);
868
869 let pk_sql = format!(
872 "SELECT TOP 1 c.name, t.name FROM sys.indexes i \
873 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
874 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
875 JOIN sys.types t ON t.user_type_id = c.user_type_id \
876 JOIN sys.objects o ON o.object_id = i.object_id \
877 JOIN sys.schemas s ON s.schema_id = o.schema_id \
878 WHERE i.is_primary_key = 1 AND s.name = N'{}' AND o.name = N'{}' \
879 GROUP BY c.name, t.name HAVING COUNT(*) = 1",
880 schema.replace('\'', "''"),
881 table.replace('\'', "''")
882 );
883 let keyset_sql = format!(
893 "SELECT STRING_AGG(col, CHAR(31)) WITHIN GROUP (ORDER BY is_pk DESC, col) FROM ( \
894 SELECT col, MAX(is_pk) AS is_pk FROM ( \
895 SELECT MIN(c.name) AS col, MAX(CONVERT(int, i.is_primary_key)) AS is_pk \
896 FROM sys.indexes i \
897 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.key_ordinal > 0 \
898 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
899 JOIN sys.objects o ON o.object_id = i.object_id \
900 JOIN sys.schemas s ON s.schema_id = o.schema_id \
901 WHERE i.is_unique = 1 AND c.is_nullable = 0 AND s.name = N'{}' AND o.name = N'{}' \
902 GROUP BY i.object_id, i.index_id HAVING COUNT(*) = 1 \
903 ) per_index GROUP BY col \
904 ) deduped",
905 schema.replace('\'', "''"),
906 table.replace('\'', "''")
907 );
908 let keyset_keys: Vec<String> = src
909 .query_scalar(&keyset_sql)?
910 .map(|s| {
911 s.split('\u{1f}')
912 .filter(|c| !c.is_empty())
913 .map(str::to_string)
914 .collect()
915 })
916 .unwrap_or_default();
917
918 let mut single_int_pk = None;
921 if let Some(pk_col) = src.query_scalar(&pk_sql)? {
922 let type_sql = format!(
925 "SELECT t.name FROM sys.columns c \
926 JOIN sys.types t ON t.user_type_id = c.user_type_id \
927 JOIN sys.objects o ON o.object_id = c.object_id \
928 JOIN sys.schemas s ON s.schema_id = o.schema_id \
929 WHERE s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
930 schema.replace('\'', "''"),
931 table.replace('\'', "''"),
932 pk_col.replace('\'', "''")
933 );
934 if let Some(ty) = src.query_scalar(&type_sql)?
935 && matches!(ty.as_str(), "tinyint" | "smallint" | "int" | "bigint")
936 {
937 single_int_pk = Some(pk_col);
938 }
939 }
940
941 Ok(TableIntrospection {
942 single_int_pk,
943 keyset_keys,
944 row_estimate,
945 avg_row_bytes: None,
946 })
947}
948
949#[cfg(test)]
950mod tests {
951 use super::{catalog_decimal_to_params, parse_mssql_simple_from_table};
952
953 fn parse(q: &str) -> Option<(String, String)> {
954 parse_mssql_simple_from_table(q)
955 }
956
957 #[test]
958 fn parse_unqualified_table_defaults_to_dbo() {
959 assert_eq!(
960 parse("SELECT id, amount FROM transactions ORDER BY id"),
961 Some(("dbo".into(), "transactions".into()))
962 );
963 }
964
965 #[test]
966 fn parse_schema_qualified() {
967 assert_eq!(
968 parse("SELECT id FROM sales.orders WHERE id > 1"),
969 Some(("sales".into(), "orders".into()))
970 );
971 }
972
973 #[test]
974 fn parse_db_schema_table_takes_last_two() {
975 assert_eq!(
976 parse("SELECT * FROM mydb.sales.orders"),
977 Some(("sales".into(), "orders".into()))
978 );
979 }
980
981 #[test]
982 fn parse_bracketed_identifiers() {
983 assert_eq!(
984 parse("SELECT * FROM [my schema].[order items]"),
985 Some(("my schema".into(), "order items".into()))
986 );
987 }
988
989 #[test]
990 fn parse_table_with_alias() {
991 assert_eq!(
992 parse("SELECT t.id FROM transactions AS t WHERE t.x = 1"),
993 Some(("dbo".into(), "transactions".into()))
994 );
995 assert_eq!(
996 parse("SELECT t.id FROM transactions t ORDER BY t.id"),
997 Some(("dbo".into(), "transactions".into()))
998 );
999 }
1000
1001 #[test]
1002 fn parse_rejects_join() {
1003 assert_eq!(parse("SELECT * FROM a INNER JOIN b ON a.id = b.id"), None);
1004 assert_eq!(parse("SELECT * FROM a JOIN b ON a.id = b.id"), None);
1005 }
1006
1007 #[test]
1008 fn parse_rejects_comma_list() {
1009 assert_eq!(parse("SELECT * FROM a, b WHERE a.id = b.id"), None);
1010 }
1011
1012 #[test]
1013 fn parse_rejects_subquery_from() {
1014 assert_eq!(parse("SELECT * FROM (SELECT * FROM t) AS s"), None);
1015 }
1016
1017 #[test]
1018 fn parse_ignores_from_inside_string_literal() {
1019 assert_eq!(
1021 parse("SELECT 'from x', amount FROM ledger WHERE note = 'paid from cash'"),
1022 Some(("dbo".into(), "ledger".into()))
1023 );
1024 }
1025
1026 #[test]
1027 fn catalog_bounds_accept_well_formed_and_reject_degenerate() {
1028 assert_eq!(catalog_decimal_to_params(10, 2), Some((10, 2)));
1030 assert_eq!(catalog_decimal_to_params(38, 0), Some((38, 0)));
1032 assert_eq!(catalog_decimal_to_params(38, 38), Some((38, 38)));
1033 assert_eq!(catalog_decimal_to_params(0, 0), None);
1037 assert_eq!(catalog_decimal_to_params(39, 0), None);
1038 assert_eq!(catalog_decimal_to_params(10, 11), None);
1039 }
1040}