1mod arrow_convert;
19mod proxy;
20
21pub use proxy::MssqlProxyKind;
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use arrow::datatypes::SchemaRef;
27use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
28use tokio::net::TcpStream;
29use tokio::runtime::Runtime;
30use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
31
32use proxy::{detect_mssql_proxy_kind, warn_proxy_kind};
33
34use crate::config::{TlsConfig, TlsMode};
35use crate::error::Result;
36use crate::source::batch_controller::{
37 AdaptiveBatchController, DEFAULT_BATCH_TARGET_MB, PROBE_BATCH_SIZE,
38};
39use crate::source::query::build_export_query;
40use crate::source::{BatchSink, ExportRequest, Source, TableIntrospection};
41use crate::types::{ColumnOverrides, TypeMapping};
42
43type MssqlClient = Client<Compat<TcpStream>>;
44
45pub struct MssqlSource {
51 rt: Runtime,
52 client: MssqlClient,
53 proxy_kind: MssqlProxyKind,
55 lock_timeout_applied: bool,
58}
59
60impl Drop for MssqlSource {
61 fn drop(&mut self) {
75 if !self.lock_timeout_applied {
76 return;
77 }
78 let Self { rt, client, .. } = self;
79 let _ = rt.block_on(async {
80 tokio::time::timeout(
81 std::time::Duration::from_secs(2),
82 client.execute("SET LOCK_TIMEOUT -1", &[]),
83 )
84 .await
85 });
86 }
87}
88
89struct MssqlUrl {
91 host: String,
92 port: u16,
93 user: String,
94 password: String,
95 database: String,
96}
97
98fn parse_mssql_url(url: &str) -> Result<MssqlUrl> {
99 let rest = url
100 .strip_prefix("sqlserver://")
101 .or_else(|| url.strip_prefix("mssql://"))
102 .ok_or_else(|| anyhow::anyhow!("mssql url must start with sqlserver:// — got {url}"))?;
103 let (userinfo, hostpart) = rest
106 .rsplit_once('@')
107 .ok_or_else(|| anyhow::anyhow!("mssql url missing user@host: {url}"))?;
108 let (user, password) = match userinfo.split_once(':') {
109 Some((u, p)) => (u.to_string(), p.to_string()),
110 None => (userinfo.to_string(), String::new()),
111 };
112 let (hostport, database) = hostpart
113 .split_once('/')
114 .map(|(h, d)| (h, d.to_string()))
115 .unwrap_or((hostpart, String::new()));
116 let (host, port) = match hostport.rsplit_once(':') {
117 Some((h, p)) => (
118 h.to_string(),
119 p.parse::<u16>()
120 .map_err(|_| anyhow::anyhow!("mssql url port not a number: {p}"))?,
121 ),
122 None => (hostport.to_string(), 1433),
123 };
124 if database.is_empty() {
125 anyhow::bail!("mssql url must include a database: sqlserver://user:pass@host:port/<db>");
126 }
127 Ok(MssqlUrl {
128 host,
129 port,
130 user,
131 password,
132 database,
133 })
134}
135
136impl MssqlSource {
137 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
141 crate::source::require_tls_or_loopback(url, tls)?;
147 let parts = parse_mssql_url(url)?;
148 let mut config = Config::new();
149 config.host(&parts.host);
150 config.port(parts.port);
151 config.database(&parts.database);
152 config.authentication(AuthMethod::sql_server(&parts.user, &parts.password));
153
154 config.encryption(EncryptionLevel::Required);
159 match tls {
160 Some(cfg) if cfg.mode == TlsMode::Disable || cfg.accept_invalid_certs => {
166 config.trust_cert()
167 }
168 Some(cfg) => {
169 static WEBPKI_WARNED: std::sync::Once = std::sync::Once::new();
178 WEBPKI_WARNED.call_once(|| {
179 log::warn!(
180 "mssql: TLS certificate validation is enabled, but the SQL Server \
181 engine pins an old rustls-webpki (via tiberius) with known CA \
182 name-constraint advisories (RUSTSEC-2026-0098/0099). Validation \
183 against a name-constraint-asserting private CA may accept a \
184 mis-issued certificate. Track tiberius for a rustls upgrade."
185 );
186 });
187 if let Some(ca) = cfg.ca_file.as_deref() {
188 config.trust_cert_ca(ca);
189 }
190 }
191 None => {
192 static WARNED: std::sync::Once = std::sync::Once::new();
201 WARNED.call_once(|| {
202 log::warn!(
203 "mssql: connecting with TLS certificate validation disabled \
204 (no `source.tls:` block) — the connection is encrypted but the \
205 server certificate is not verified (MITM not detected). Add \
206 `source.tls: {{ mode: verify-full, ca_file: <ca.pem> }}` to enable \
207 strict validation (or `mode: verify-ca` to skip only hostname checks)."
208 );
209 });
210 config.trust_cert();
211 }
212 }
213
214 let rt = tokio::runtime::Builder::new_current_thread()
215 .enable_all()
216 .build()
217 .map_err(|e| anyhow::anyhow!("mssql: tokio runtime build failed: {e}"))?;
218
219 let client = rt.block_on(async {
220 let tcp = TcpStream::connect(config.get_addr())
221 .await
222 .map_err(|e| anyhow::anyhow!("mssql: TCP connect failed: {e}"))?;
223 tcp.set_nodelay(true).ok();
224 Client::connect(config, tcp.compat_write())
225 .await
226 .map_err(|e| anyhow::anyhow!("mssql: login failed: {e}"))
227 })?;
228
229 let mut src = Self {
230 rt,
231 client,
232 proxy_kind: MssqlProxyKind::Direct,
233 lock_timeout_applied: false,
234 };
235 src.query_scalar("SELECT 1")?;
238 let kind = detect_mssql_proxy_kind(&src.rt, &mut src.client);
243 warn_proxy_kind(kind);
244 src.proxy_kind = kind;
245 Ok(src)
246 }
247
248 #[allow(dead_code)]
252 pub fn proxy_kind(&self) -> MssqlProxyKind {
253 self.proxy_kind
254 }
255
256 fn mssql_decimal_catalog_hints_opt(
264 &mut self,
265 query: &str,
266 ) -> Option<HashMap<String, (u8, i8)>> {
267 let (schema, table) = parse_mssql_simple_from_table(query)?;
268 match self.fetch_mssql_decimal_catalog_hints(&schema, &table) {
269 Ok(m) => m,
270 Err(e) => {
271 log::warn!(
277 "mssql decimal catalog lookup failed for {schema}.{table} — decimal scale \
278 will fall back to first-batch inference (declare it with a `columns:` \
279 override if an all-NULL first batch truncates it): {e}"
280 );
281 None
282 }
283 }
284 }
285
286 fn fetch_mssql_decimal_catalog_hints(
290 &mut self,
291 schema: &str,
292 table: &str,
293 ) -> Result<Option<HashMap<String, (u8, i8)>>> {
294 let sql = format!(
298 "SELECT c.name, c.precision, c.scale \
299 FROM sys.columns c \
300 JOIN sys.types t ON t.user_type_id = c.user_type_id \
301 JOIN sys.objects o ON o.object_id = c.object_id \
302 JOIN sys.schemas s ON s.schema_id = o.schema_id \
303 WHERE s.name = N'{}' AND o.name = N'{}' \
304 AND t.name IN ('decimal', 'numeric')",
305 schema.replace('\'', "''"),
306 table.replace('\'', "''")
307 );
308 let Self { rt, client, .. } = self;
309 let rows = rt.block_on(async {
310 client
311 .query(sql.as_str(), &[])
312 .await
313 .map_err(|e| anyhow::anyhow!("mssql: sys.columns probe failed: {e}"))?
314 .into_first_result()
315 .await
316 .map_err(|e| anyhow::anyhow!("mssql: reading sys.columns rows failed: {e}"))
317 })?;
318
319 let mut map = HashMap::new();
320 for row in &rows {
321 let name: Option<&str> = row.try_get(0).ok().flatten();
325 let precision: Option<u8> = row.try_get(1).ok().flatten();
326 let scale: Option<u8> = row.try_get(2).ok().flatten();
327 if let (Some(name), Some(p), Some(s)) = (name, precision, scale)
328 && let Some(pair) = catalog_decimal_to_params(p, s)
329 {
330 map.insert(name.to_string(), pair);
331 }
332 }
333
334 if map.is_empty() {
335 Ok(None)
336 } else {
337 log::debug!(
338 "mssql decimal catalog: resolved {} DECIMAL/NUMERIC column(s) for {schema}.{table}",
339 map.len(),
340 );
341 Ok(Some(map))
342 }
343 }
344}
345
346fn catalog_decimal_to_params(precision: u8, scale: u8) -> Option<(u8, i8)> {
351 if precision == 0 || precision > 38 {
352 return None;
353 }
354 if scale > precision || scale > i8::MAX as u8 {
355 return None;
356 }
357 Some((precision, scale as i8))
358}
359
360fn parse_mssql_simple_from_table(query: &str) -> Option<(String, String)> {
367 let from_idx = mssql_find_outer_from_keyword(query)?;
368 let tail = trim_sql_ws(query.get(from_idx + 4..)?);
369 let (first, after1) = parse_mssql_ident_piece(tail)?;
370 let after1 = trim_sql_ws(after1);
371 let (schema, table, after) = if after1.starts_with('.') {
373 let (second, after2) = parse_mssql_ident_piece(trim_sql_ws(after1.get(1..)?))?;
374 let after2 = trim_sql_ws(after2);
375 if after2.starts_with('.') {
376 let (third, after3) = parse_mssql_ident_piece(trim_sql_ws(after2.get(1..)?))?;
378 (second, third, trim_sql_ws(after3))
379 } else {
380 (first, second, after2)
381 }
382 } else {
383 ("dbo".to_string(), first, after1)
384 };
385 let after = skip_mssql_optional_alias(after)?;
388 if mssql_joins_or_comma(after) {
389 return None;
390 }
391 Some((schema, table))
392}
393
394fn trim_sql_ws(s: &str) -> &str {
395 s.trim_matches(|c: char| matches!(c, ' ' | '\t' | '\n' | '\r'))
396}
397
398fn is_sql_ident_byte(b: u8) -> bool {
399 b.is_ascii_alphanumeric() || b == b'_'
400}
401
402fn sql_keyword_at(haystack: &[u8], idx: usize, kw_lower: &[u8]) -> bool {
405 let n = kw_lower.len();
406 if idx + n > haystack.len() || !haystack[idx..idx + n].eq_ignore_ascii_case(kw_lower) {
407 return false;
408 }
409 let before_ok = idx == 0 || !is_sql_ident_byte(haystack[idx - 1]);
410 let after_ok = idx + n >= haystack.len() || !is_sql_ident_byte(haystack[idx + n]);
411 before_ok && after_ok
412}
413
414fn mssql_find_outer_from_keyword(sql: &str) -> Option<usize> {
417 let b = sql.as_bytes();
418 let mut i = 0usize;
419 let mut depth = 0usize;
420 let mut in_quote = false;
421 while i < b.len() {
422 if in_quote {
423 if b[i] == b'\'' {
424 if i + 1 < b.len() && b[i + 1] == b'\'' {
425 i += 2;
426 } else {
427 in_quote = false;
428 i += 1;
429 }
430 continue;
431 }
432 i += 1;
433 continue;
434 }
435 match b[i] {
436 b'\'' => in_quote = true,
437 b'(' => depth += 1,
438 b')' => depth = depth.saturating_sub(1),
439 _ if depth == 0 && sql_keyword_at(b, i, b"from") => return Some(i),
440 _ => {}
441 }
442 i += 1;
443 }
444 None
445}
446
447fn parse_mssql_ident_piece(rest: &str) -> Option<(String, &str)> {
450 let rest = trim_sql_ws(rest);
451 if let Some(after_open) = rest.strip_prefix('[') {
452 let mut out = String::new();
453 let mut chars = after_open.chars();
454 while let Some(ch) = chars.next() {
455 if ch == ']' {
456 if chars.as_str().starts_with(']') {
457 chars.next();
458 out.push(']');
459 continue;
460 }
461 return Some((out, chars.as_str()));
462 }
463 out.push(ch);
464 }
465 return None; }
467 let bytes = rest.as_bytes();
468 if bytes.is_empty() || (!bytes[0].is_ascii_alphabetic() && bytes[0] != b'_') {
469 return None;
470 }
471 let mut i = 1usize;
472 while i < bytes.len() && is_sql_ident_byte(bytes[i]) {
473 i += 1;
474 }
475 Some((rest.get(0..i)?.to_string(), rest.get(i..)?))
476}
477
478fn mssql_joins_or_comma(rest: &str) -> bool {
481 let r = trim_sql_ws(rest);
482 if r.starts_with(',') || r.starts_with('.') {
483 return true;
484 }
485 let b = r.as_bytes();
486 ["inner", "left", "right", "full", "cross", "join"]
487 .iter()
488 .any(|kw| sql_keyword_at(b, 0, kw.as_bytes()))
489}
490
491fn skip_mssql_optional_alias(rest: &str) -> Option<&str> {
495 let rest = trim_sql_ws(rest);
496 if rest.is_empty() || mssql_starts_clause_boundary(rest) || mssql_joins_or_comma(rest) {
497 return Some(rest);
498 }
499 let mut rest = rest;
500 if sql_keyword_at(rest.as_bytes(), 0, b"as") {
501 rest = trim_sql_ws(rest.get(2..)?);
502 }
503 let (_, tail) = parse_mssql_ident_piece(rest)?;
504 Some(trim_sql_ws(tail))
505}
506
507fn mssql_starts_clause_boundary(rest: &str) -> bool {
508 let r = trim_sql_ws(rest);
509 if r.is_empty() {
510 return true;
511 }
512 const KWS: &[&[u8]] = &[
513 b"where",
514 b"group",
515 b"having",
516 b"order",
517 b"union",
518 b"except",
519 b"intersect",
520 b"for",
521 b"option",
522 b"offset",
523 ];
524 let b = r.as_bytes();
525 KWS.iter().any(|kw| sql_keyword_at(b, 0, kw))
526}
527
528impl Source for MssqlSource {
529 fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()> {
530 let built = build_export_query(request, crate::config::SourceType::Mssql);
533 let sql = built.sql.clone();
534 let overrides = request.column_overrides.clone();
535 let mut ctl =
544 AdaptiveBatchController::new(request.tuning, request.tuning.batch_size.max(1));
545 let mut cap_applied = false;
546 let lock_timeout_ms = request.tuning.lock_timeout_s.saturating_mul(1000);
556 let stmt_timeout = (request.tuning.statement_timeout_s > 0)
557 .then(|| std::time::Duration::from_secs(request.tuning.statement_timeout_s));
558
559 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
565 let decimal_hints = self.mssql_decimal_catalog_hints_opt(hint_query);
566
567 if lock_timeout_ms > 0 {
570 self.lock_timeout_applied = true;
571 }
572
573 let Self { rt, client, .. } = self;
574 rt.block_on(async {
575 use futures_util::stream::TryStreamExt;
576 use tiberius::QueryItem;
577
578 if lock_timeout_ms > 0 {
579 client
580 .execute(format!("SET LOCK_TIMEOUT {lock_timeout_ms}"), &[])
581 .await
582 .map_err(|e| anyhow::anyhow!("mssql: SET LOCK_TIMEOUT failed: {e}"))?;
583 }
584
585 let started = std::time::Instant::now();
586 let mut stream = client
587 .query(sql.as_str(), &[])
588 .await
589 .map_err(|e| anyhow::anyhow!("mssql: query failed: {e}"))?;
590
591 let mut columns: Vec<tiberius::Column> = Vec::new();
592 let mut buf: Vec<tiberius::Row> = Vec::with_capacity(ctl.target());
593 let mut schema: Option<SchemaRef> = None;
594 let max_value_bytes = request.tuning.max_value_bytes();
598
599 while let Some(item) = stream
600 .try_next()
601 .await
602 .map_err(|e| anyhow::anyhow!("mssql: streaming rows failed: {e}"))?
603 {
604 if let Some(budget) = stmt_timeout
605 && started.elapsed() > budget
606 {
607 return Err(crate::source::StatementDurationTimeout::mssql(
612 budget.as_secs(),
613 )
614 .into());
615 }
616 match item {
617 QueryItem::Metadata(meta) if columns.is_empty() => {
620 columns = meta.columns().to_vec();
621 }
622 QueryItem::Metadata(_) => {}
623 QueryItem::Row(row) => {
624 buf.push(row);
625 if buf.len() >= ctl.target() {
626 let arrow_bytes = emit_mssql_batch(
627 &columns,
628 &overrides,
629 decimal_hints.as_ref(),
630 &mut schema,
631 &buf,
632 sink,
633 max_value_bytes,
634 )?;
635 let n = buf.len();
636 buf.clear();
637 if !cap_applied && n > 0 {
642 let arrow_per_row = (arrow_bytes / n).max(64);
643 let target_mb = request
644 .tuning
645 .batch_size_memory_mb
646 .unwrap_or(DEFAULT_BATCH_TARGET_MB);
647 let safe = ((target_mb * 1024 * 1024) / arrow_per_row)
648 .max(PROBE_BATCH_SIZE);
649 if let Some(new) = ctl.apply_memory_cap(safe) {
650 log::info!(
651 "MSSQL batch cap: arrow≈{} B/row, target={} MB → batch_size → {}",
652 arrow_per_row,
653 target_mb,
654 new
655 );
656 buf.reserve(new.saturating_sub(buf.capacity()));
657 }
658 cap_applied = true;
659 }
660 ctl.after_batch(|| None);
662 ctl.throttle();
663 }
664 }
665 }
666 }
667 if !buf.is_empty() || schema.is_none() {
673 emit_mssql_batch(
674 &columns,
675 &overrides,
676 decimal_hints.as_ref(),
677 &mut schema,
678 &buf,
679 sink,
680 max_value_bytes,
681 )?;
682 }
683 Ok::<_, anyhow::Error>(())
684 })?;
685 Ok(())
686 }
687
688 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
689 let Self { rt, client, .. } = self;
690 rt.block_on(async {
691 let row = client
692 .query(sql, &[])
693 .await
694 .map_err(|e| anyhow::anyhow!("mssql: scalar query failed: {e}"))?
695 .into_row()
696 .await
697 .map_err(|e| anyhow::anyhow!("mssql: reading scalar row failed: {e}"))?;
698 Ok(row.and_then(|r| scalar_to_string(&r)))
699 })
700 }
701
702 fn type_mappings(
703 &mut self,
704 query: &str,
705 column_overrides: &ColumnOverrides,
706 ) -> Result<Vec<TypeMapping>> {
707 let wrapped = format!("SELECT * FROM ({query}) AS _rivet_q WHERE 1 = 0");
709 let overrides = column_overrides.clone();
710 let Self { rt, client, .. } = self;
711 rt.block_on(async {
712 let mut stream = client
713 .query(wrapped.as_str(), &[])
714 .await
715 .map_err(|e| anyhow::anyhow!("mssql: type-probe query failed: {e}"))?;
716 let columns = stream
717 .columns()
718 .await
719 .map_err(|e| anyhow::anyhow!("mssql: type-probe metadata failed: {e}"))?
720 .map(<[_]>::to_vec)
721 .unwrap_or_default();
722 let _ = stream.into_first_result().await;
724 Ok(arrow_convert::mssql_type_mappings(&columns, &overrides))
725 })
726 }
727
728 fn sample_pressure(&mut self) -> Option<u64> {
729 let Self { rt, client, .. } = self;
730 let sql = "SELECT SUM(cntr_value) FROM sys.dm_os_performance_counters \
740 WHERE counter_name IN ('Workfiles Created/sec', 'Worktables Created/sec')";
741 rt.block_on(async {
742 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
743 row.get::<i64, _>(0).map(|v| v.max(0) as u64)
744 })
745 }
746}
747
748fn emit_mssql_batch(
758 columns: &[tiberius::Column],
759 overrides: &ColumnOverrides,
760 decimal_hints: Option<&HashMap<String, (u8, i8)>>,
761 schema: &mut Option<SchemaRef>,
762 rows: &[tiberius::Row],
763 sink: &mut dyn BatchSink,
764 max_value_bytes: Option<usize>,
765) -> Result<usize> {
766 let schema_ref = match schema {
767 Some(s) => s.clone(),
768 None => {
769 let (built, _decoders) =
770 arrow_convert::mssql_columns_to_schema(columns, overrides, rows, decimal_hints)?;
771 let s: SchemaRef = Arc::new(built);
772 sink.on_schema(s.clone())?;
773 *schema = Some(s.clone());
774 s
775 }
776 };
777 if !rows.is_empty() {
778 let batch = arrow_convert::mssql_rows_to_record_batch(&schema_ref, rows, max_value_bytes)?;
779 let bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
780 sink.on_batch(&batch)?;
781 return Ok(bytes);
782 }
783 Ok(0)
784}
785
786fn scalar_to_string(row: &tiberius::Row) -> Option<String> {
790 use tiberius::ColumnData;
791 let cell = row.cells().next().map(|(_, d)| d)?;
792 match cell {
793 ColumnData::U8(v) => v.map(|x| x.to_string()),
794 ColumnData::I16(v) => v.map(|x| x.to_string()),
795 ColumnData::I32(v) => v.map(|x| x.to_string()),
796 ColumnData::I64(v) => v.map(|x| x.to_string()),
797 ColumnData::F32(v) => v.map(|x| x.to_string()),
798 ColumnData::F64(v) => v.map(|x| x.to_string()),
799 ColumnData::Bit(v) => v.map(|x| x.to_string()),
800 ColumnData::String(v) => v.as_ref().map(|s| s.to_string()),
801 ColumnData::Numeric(v) => v.map(|n| {
802 let raw = n.value();
804 let scale = n.scale() as usize;
805 if scale == 0 {
806 raw.to_string()
807 } else {
808 let neg = raw < 0;
809 let digits = raw.unsigned_abs().to_string();
810 let digits = format!("{digits:0>width$}", width = scale + 1);
811 let (int, frac) = digits.split_at(digits.len() - scale);
812 format!("{}{int}.{frac}", if neg { "-" } else { "" })
813 }
814 }),
815 ColumnData::Guid(v) => v.map(|g| g.to_string()),
816 other => Some(format!("{other:?}")),
817 }
818}
819
820pub(crate) fn introspect_mssql_table_for_chunking(
823 url: &str,
824 tls: Option<&TlsConfig>,
825 qualified_table: &str,
826) -> Result<TableIntrospection> {
827 let (schema, table) = match qualified_table.split_once('.') {
828 Some((s, t)) => (s.to_string(), t.to_string()),
829 None => ("dbo".to_string(), qualified_table.to_string()),
830 };
831 let mut src = MssqlSource::connect_with_tls(url, tls)?;
832
833 let count_sql = format!(
836 "SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
837 JOIN sys.objects o ON o.object_id = p.object_id \
838 JOIN sys.schemas s ON s.schema_id = o.schema_id \
839 WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
840 schema.replace('\'', "''"),
841 table.replace('\'', "''")
842 );
843 let row_estimate = src
844 .query_scalar(&count_sql)?
845 .and_then(|s| s.parse::<i64>().ok())
846 .unwrap_or(0);
847
848 let pk_sql = format!(
851 "SELECT TOP 1 c.name, t.name FROM sys.indexes i \
852 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
853 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
854 JOIN sys.types t ON t.user_type_id = c.user_type_id \
855 JOIN sys.objects o ON o.object_id = i.object_id \
856 JOIN sys.schemas s ON s.schema_id = o.schema_id \
857 WHERE i.is_primary_key = 1 AND s.name = N'{}' AND o.name = N'{}' \
858 GROUP BY c.name, t.name HAVING COUNT(*) = 1",
859 schema.replace('\'', "''"),
860 table.replace('\'', "''")
861 );
862 let keyset_sql = format!(
872 "SELECT STRING_AGG(col, CHAR(31)) WITHIN GROUP (ORDER BY is_pk DESC, col) FROM ( \
873 SELECT col, MAX(is_pk) AS is_pk FROM ( \
874 SELECT MIN(c.name) AS col, MAX(CONVERT(int, i.is_primary_key)) AS is_pk \
875 FROM sys.indexes i \
876 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.key_ordinal > 0 \
877 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
878 JOIN sys.objects o ON o.object_id = i.object_id \
879 JOIN sys.schemas s ON s.schema_id = o.schema_id \
880 WHERE i.is_unique = 1 AND c.is_nullable = 0 AND s.name = N'{}' AND o.name = N'{}' \
881 GROUP BY i.object_id, i.index_id HAVING COUNT(*) = 1 \
882 ) per_index GROUP BY col \
883 ) deduped",
884 schema.replace('\'', "''"),
885 table.replace('\'', "''")
886 );
887 let keyset_keys: Vec<String> = src
888 .query_scalar(&keyset_sql)?
889 .map(|s| {
890 s.split('\u{1f}')
891 .filter(|c| !c.is_empty())
892 .map(str::to_string)
893 .collect()
894 })
895 .unwrap_or_default();
896
897 let mut single_int_pk = None;
900 if let Some(pk_col) = src.query_scalar(&pk_sql)? {
901 let type_sql = format!(
904 "SELECT t.name FROM sys.columns c \
905 JOIN sys.types t ON t.user_type_id = c.user_type_id \
906 JOIN sys.objects o ON o.object_id = c.object_id \
907 JOIN sys.schemas s ON s.schema_id = o.schema_id \
908 WHERE s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
909 schema.replace('\'', "''"),
910 table.replace('\'', "''"),
911 pk_col.replace('\'', "''")
912 );
913 if let Some(ty) = src.query_scalar(&type_sql)?
914 && matches!(ty.as_str(), "tinyint" | "smallint" | "int" | "bigint")
915 {
916 single_int_pk = Some(pk_col);
917 }
918 }
919
920 Ok(TableIntrospection {
921 single_int_pk,
922 keyset_keys,
923 row_estimate,
924 avg_row_bytes: None,
925 })
926}
927
928#[cfg(test)]
929mod tests {
930 use super::{catalog_decimal_to_params, parse_mssql_simple_from_table};
931
932 fn parse(q: &str) -> Option<(String, String)> {
933 parse_mssql_simple_from_table(q)
934 }
935
936 #[test]
937 fn parse_unqualified_table_defaults_to_dbo() {
938 assert_eq!(
939 parse("SELECT id, amount FROM transactions ORDER BY id"),
940 Some(("dbo".into(), "transactions".into()))
941 );
942 }
943
944 #[test]
945 fn parse_schema_qualified() {
946 assert_eq!(
947 parse("SELECT id FROM sales.orders WHERE id > 1"),
948 Some(("sales".into(), "orders".into()))
949 );
950 }
951
952 #[test]
953 fn parse_db_schema_table_takes_last_two() {
954 assert_eq!(
955 parse("SELECT * FROM mydb.sales.orders"),
956 Some(("sales".into(), "orders".into()))
957 );
958 }
959
960 #[test]
961 fn parse_bracketed_identifiers() {
962 assert_eq!(
963 parse("SELECT * FROM [my schema].[order items]"),
964 Some(("my schema".into(), "order items".into()))
965 );
966 }
967
968 #[test]
969 fn parse_table_with_alias() {
970 assert_eq!(
971 parse("SELECT t.id FROM transactions AS t WHERE t.x = 1"),
972 Some(("dbo".into(), "transactions".into()))
973 );
974 assert_eq!(
975 parse("SELECT t.id FROM transactions t ORDER BY t.id"),
976 Some(("dbo".into(), "transactions".into()))
977 );
978 }
979
980 #[test]
981 fn parse_rejects_join() {
982 assert_eq!(parse("SELECT * FROM a INNER JOIN b ON a.id = b.id"), None);
983 assert_eq!(parse("SELECT * FROM a JOIN b ON a.id = b.id"), None);
984 }
985
986 #[test]
987 fn parse_rejects_comma_list() {
988 assert_eq!(parse("SELECT * FROM a, b WHERE a.id = b.id"), None);
989 }
990
991 #[test]
992 fn parse_rejects_subquery_from() {
993 assert_eq!(parse("SELECT * FROM (SELECT * FROM t) AS s"), None);
994 }
995
996 #[test]
997 fn parse_ignores_from_inside_string_literal() {
998 assert_eq!(
1000 parse("SELECT 'from x', amount FROM ledger WHERE note = 'paid from cash'"),
1001 Some(("dbo".into(), "ledger".into()))
1002 );
1003 }
1004
1005 #[test]
1006 fn catalog_bounds_accept_well_formed_and_reject_degenerate() {
1007 assert_eq!(catalog_decimal_to_params(10, 2), Some((10, 2)));
1009 assert_eq!(catalog_decimal_to_params(38, 0), Some((38, 0)));
1011 assert_eq!(catalog_decimal_to_params(38, 38), Some((38, 38)));
1012 assert_eq!(catalog_decimal_to_params(0, 0), None);
1016 assert_eq!(catalog_decimal_to_params(39, 0), None);
1017 assert_eq!(catalog_decimal_to_params(10, 11), None);
1018 }
1019}