1mod arrow_convert;
19pub(crate) mod cdc;
20mod proxy;
21
22pub use proxy::MssqlProxyKind;
23
24use std::collections::HashMap;
25use std::sync::Arc;
26
27use arrow::datatypes::SchemaRef;
28use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
29use tokio::net::TcpStream;
30use tokio::runtime::Runtime;
31use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
32
33use proxy::{detect_mssql_proxy_kind, warn_proxy_kind};
34
35use crate::config::{TlsConfig, TlsMode};
36use crate::error::Result;
37use crate::source::batch_controller::{
38 AdaptiveBatchController, DEFAULT_BATCH_TARGET_MB, PROBE_BATCH_SIZE,
39};
40use crate::source::query::build_export_query;
41use crate::source::{BatchSink, ExportRequest, Source, TableIntrospection};
42use crate::types::{ColumnOverrides, TypeMapping};
43
44type MssqlClient = Client<Compat<TcpStream>>;
45
46pub struct MssqlSource {
52 rt: Runtime,
53 client: MssqlClient,
54 proxy_kind: MssqlProxyKind,
56 lock_timeout_applied: bool,
59}
60
61impl Drop for MssqlSource {
62 fn drop(&mut self) {
76 if !self.lock_timeout_applied {
77 return;
78 }
79 let Self { rt, client, .. } = self;
80 let _ = rt.block_on(async {
81 tokio::time::timeout(
82 std::time::Duration::from_secs(2),
83 client.execute("SET LOCK_TIMEOUT -1", &[]),
84 )
85 .await
86 });
87 }
88}
89
90pub(crate) struct MssqlUrl {
92 pub host: String,
93 pub port: u16,
94 pub user: String,
95 pub password: String,
96 pub database: String,
97}
98
99pub(crate) fn parse_mssql_url(url: &str) -> Result<MssqlUrl> {
100 let rest = url
101 .strip_prefix("sqlserver://")
102 .or_else(|| url.strip_prefix("mssql://"))
103 .ok_or_else(|| anyhow::anyhow!("mssql url must start with sqlserver:// — got {url}"))?;
104 let (userinfo, hostpart) = rest
107 .rsplit_once('@')
108 .ok_or_else(|| anyhow::anyhow!("mssql url missing user@host: {url}"))?;
109 let (user, password) = match userinfo.split_once(':') {
110 Some((u, p)) => (u.to_string(), p.to_string()),
111 None => (userinfo.to_string(), String::new()),
112 };
113 let (hostport, database) = hostpart
114 .split_once('/')
115 .map(|(h, d)| (h, d.to_string()))
116 .unwrap_or((hostpart, String::new()));
117 let (host, port) = match hostport.rsplit_once(':') {
118 Some((h, p)) => (
119 h.to_string(),
120 p.parse::<u16>()
121 .map_err(|_| anyhow::anyhow!("mssql url port not a number: {p}"))?,
122 ),
123 None => (hostport.to_string(), 1433),
124 };
125 if database.is_empty() {
126 anyhow::bail!("mssql url must include a database: sqlserver://user:pass@host:port/<db>");
127 }
128 Ok(MssqlUrl {
129 host,
130 port,
131 user,
132 password,
133 database,
134 })
135}
136
137impl MssqlSource {
138 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
142 crate::source::require_tls_or_loopback(url, tls)?;
148 let parts = parse_mssql_url(url)?;
149 let mut config = Config::new();
150 config.host(&parts.host);
151 config.port(parts.port);
152 config.database(&parts.database);
153 config.authentication(AuthMethod::sql_server(&parts.user, &parts.password));
154
155 config.encryption(EncryptionLevel::Required);
160 match tls {
161 Some(cfg) if cfg.mode == TlsMode::Disable || cfg.accept_invalid_certs => {
167 config.trust_cert()
168 }
169 Some(cfg) => {
170 static WEBPKI_WARNED: std::sync::Once = std::sync::Once::new();
179 WEBPKI_WARNED.call_once(|| {
180 log::warn!(
181 "mssql: TLS certificate validation is enabled, but the SQL Server \
182 engine pins an old rustls-webpki (via tiberius) with known CA \
183 name-constraint advisories (RUSTSEC-2026-0098/0099). Validation \
184 against a name-constraint-asserting private CA may accept a \
185 mis-issued certificate. Track tiberius for a rustls upgrade."
186 );
187 });
188 if let Some(ca) = cfg.ca_file.as_deref() {
189 config.trust_cert_ca(ca);
190 }
191 }
192 None => {
193 static WARNED: std::sync::Once = std::sync::Once::new();
202 WARNED.call_once(|| {
203 log::warn!(
204 "mssql: connecting with TLS certificate validation disabled \
205 (no `source.tls:` block) — the connection is encrypted but the \
206 server certificate is not verified (MITM not detected). Add \
207 `source.tls: {{ mode: verify-full, ca_file: <ca.pem> }}` to enable \
208 strict validation (or `mode: verify-ca` to skip only hostname checks)."
209 );
210 });
211 config.trust_cert();
212 }
213 }
214
215 let rt = tokio::runtime::Builder::new_current_thread()
216 .enable_all()
217 .build()
218 .map_err(|e| anyhow::anyhow!("mssql: tokio runtime build failed: {e}"))?;
219
220 let client = rt.block_on(async {
221 let tcp = TcpStream::connect(config.get_addr())
222 .await
223 .map_err(|e| anyhow::anyhow!("mssql: TCP connect failed: {e}"))?;
224 tcp.set_nodelay(true).ok();
225 Client::connect(config, tcp.compat_write())
226 .await
227 .map_err(|e| anyhow::anyhow!("mssql: login failed: {e}"))
228 })?;
229
230 let mut src = Self {
231 rt,
232 client,
233 proxy_kind: MssqlProxyKind::Direct,
234 lock_timeout_applied: false,
235 };
236 src.query_scalar("SELECT 1")?;
239 let kind = detect_mssql_proxy_kind(&src.rt, &mut src.client);
244 warn_proxy_kind(kind);
245 src.proxy_kind = kind;
246 Ok(src)
247 }
248
249 #[allow(dead_code)]
253 pub fn proxy_kind(&self) -> MssqlProxyKind {
254 self.proxy_kind
255 }
256
257 fn mssql_decimal_catalog_hints_opt(
265 &mut self,
266 query: &str,
267 ) -> Option<HashMap<String, (u8, i8)>> {
268 let (schema, table) = parse_mssql_simple_from_table(query)?;
269 match self.fetch_mssql_decimal_catalog_hints(&schema, &table) {
270 Ok(m) => m,
271 Err(e) => {
272 log::warn!(
278 "mssql decimal catalog lookup failed for {schema}.{table} — decimal scale \
279 will fall back to first-batch inference (declare it with a `columns:` \
280 override if an all-NULL first batch truncates it): {e}"
281 );
282 None
283 }
284 }
285 }
286
287 fn fetch_mssql_decimal_catalog_hints(
291 &mut self,
292 schema: &str,
293 table: &str,
294 ) -> Result<Option<HashMap<String, (u8, i8)>>> {
295 let sql = format!(
299 "SELECT c.name, c.precision, c.scale \
300 FROM sys.columns c \
301 JOIN sys.types t ON t.user_type_id = c.user_type_id \
302 JOIN sys.objects o ON o.object_id = c.object_id \
303 JOIN sys.schemas s ON s.schema_id = o.schema_id \
304 WHERE s.name = N'{}' AND o.name = N'{}' \
305 AND t.name IN ('decimal', 'numeric')",
306 schema.replace('\'', "''"),
307 table.replace('\'', "''")
308 );
309 let Self { rt, client, .. } = self;
310 let rows = rt.block_on(async {
311 client
312 .query(sql.as_str(), &[])
313 .await
314 .map_err(|e| anyhow::anyhow!("mssql: sys.columns probe failed: {e}"))?
315 .into_first_result()
316 .await
317 .map_err(|e| anyhow::anyhow!("mssql: reading sys.columns rows failed: {e}"))
318 })?;
319
320 let mut map = HashMap::new();
321 for row in &rows {
322 let name: Option<&str> = row.try_get(0).ok().flatten();
326 let precision: Option<u8> = row.try_get(1).ok().flatten();
327 let scale: Option<u8> = row.try_get(2).ok().flatten();
328 if let (Some(name), Some(p), Some(s)) = (name, precision, scale)
329 && let Some(pair) = catalog_decimal_to_params(p, s)
330 {
331 map.insert(name.to_string(), pair);
332 }
333 }
334
335 if map.is_empty() {
336 Ok(None)
337 } else {
338 log::debug!(
339 "mssql decimal catalog: resolved {} DECIMAL/NUMERIC column(s) for {schema}.{table}",
340 map.len(),
341 );
342 Ok(Some(map))
343 }
344 }
345}
346
347fn catalog_decimal_to_params(precision: u8, scale: u8) -> Option<(u8, i8)> {
352 if precision == 0 || precision > 38 {
353 return None;
354 }
355 if scale > precision || scale > i8::MAX as u8 {
356 return None;
357 }
358 Some((precision, scale as i8))
359}
360
361fn parse_mssql_simple_from_table(query: &str) -> Option<(String, String)> {
368 let from_idx = mssql_find_outer_from_keyword(query)?;
369 let tail = trim_sql_ws(query.get(from_idx + 4..)?);
370 let (first, after1) = parse_mssql_ident_piece(tail)?;
371 let after1 = trim_sql_ws(after1);
372 let (schema, table, after) = if after1.starts_with('.') {
374 let (second, after2) = parse_mssql_ident_piece(trim_sql_ws(after1.get(1..)?))?;
375 let after2 = trim_sql_ws(after2);
376 if after2.starts_with('.') {
377 let (third, after3) = parse_mssql_ident_piece(trim_sql_ws(after2.get(1..)?))?;
379 (second, third, trim_sql_ws(after3))
380 } else {
381 (first, second, after2)
382 }
383 } else {
384 ("dbo".to_string(), first, after1)
385 };
386 let after = skip_mssql_optional_alias(after)?;
389 if mssql_joins_or_comma(after) {
390 return None;
391 }
392 Some((schema, table))
393}
394
395fn trim_sql_ws(s: &str) -> &str {
396 s.trim_matches(|c: char| matches!(c, ' ' | '\t' | '\n' | '\r'))
397}
398
399fn is_sql_ident_byte(b: u8) -> bool {
400 b.is_ascii_alphanumeric() || b == b'_'
401}
402
403fn sql_keyword_at(haystack: &[u8], idx: usize, kw_lower: &[u8]) -> bool {
406 let n = kw_lower.len();
407 if idx + n > haystack.len() || !haystack[idx..idx + n].eq_ignore_ascii_case(kw_lower) {
408 return false;
409 }
410 let before_ok = idx == 0 || !is_sql_ident_byte(haystack[idx - 1]);
411 let after_ok = idx + n >= haystack.len() || !is_sql_ident_byte(haystack[idx + n]);
412 before_ok && after_ok
413}
414
415fn mssql_find_outer_from_keyword(sql: &str) -> Option<usize> {
418 let b = sql.as_bytes();
419 let mut i = 0usize;
420 let mut depth = 0usize;
421 let mut in_quote = false;
422 while i < b.len() {
423 if in_quote {
424 if b[i] == b'\'' {
425 if i + 1 < b.len() && b[i + 1] == b'\'' {
426 i += 2;
427 } else {
428 in_quote = false;
429 i += 1;
430 }
431 continue;
432 }
433 i += 1;
434 continue;
435 }
436 match b[i] {
437 b'\'' => in_quote = true,
438 b'(' => depth += 1,
439 b')' => depth = depth.saturating_sub(1),
440 _ if depth == 0 && sql_keyword_at(b, i, b"from") => return Some(i),
441 _ => {}
442 }
443 i += 1;
444 }
445 None
446}
447
448fn parse_mssql_ident_piece(rest: &str) -> Option<(String, &str)> {
451 let rest = trim_sql_ws(rest);
452 if let Some(after_open) = rest.strip_prefix('[') {
453 let mut out = String::new();
454 let mut chars = after_open.chars();
455 while let Some(ch) = chars.next() {
456 if ch == ']' {
457 if chars.as_str().starts_with(']') {
458 chars.next();
459 out.push(']');
460 continue;
461 }
462 return Some((out, chars.as_str()));
463 }
464 out.push(ch);
465 }
466 return None; }
468 let bytes = rest.as_bytes();
469 if bytes.is_empty() || (!bytes[0].is_ascii_alphabetic() && bytes[0] != b'_') {
470 return None;
471 }
472 let mut i = 1usize;
473 while i < bytes.len() && is_sql_ident_byte(bytes[i]) {
474 i += 1;
475 }
476 Some((rest.get(0..i)?.to_string(), rest.get(i..)?))
477}
478
479fn mssql_joins_or_comma(rest: &str) -> bool {
482 let r = trim_sql_ws(rest);
483 if r.starts_with(',') || r.starts_with('.') {
484 return true;
485 }
486 let b = r.as_bytes();
487 ["inner", "left", "right", "full", "cross", "join"]
488 .iter()
489 .any(|kw| sql_keyword_at(b, 0, kw.as_bytes()))
490}
491
492fn skip_mssql_optional_alias(rest: &str) -> Option<&str> {
496 let rest = trim_sql_ws(rest);
497 if rest.is_empty() || mssql_starts_clause_boundary(rest) || mssql_joins_or_comma(rest) {
498 return Some(rest);
499 }
500 let mut rest = rest;
501 if sql_keyword_at(rest.as_bytes(), 0, b"as") {
502 rest = trim_sql_ws(rest.get(2..)?);
503 }
504 let (_, tail) = parse_mssql_ident_piece(rest)?;
505 Some(trim_sql_ws(tail))
506}
507
508fn mssql_starts_clause_boundary(rest: &str) -> bool {
509 let r = trim_sql_ws(rest);
510 if r.is_empty() {
511 return true;
512 }
513 const KWS: &[&[u8]] = &[
514 b"where",
515 b"group",
516 b"having",
517 b"order",
518 b"union",
519 b"except",
520 b"intersect",
521 b"for",
522 b"option",
523 b"offset",
524 ];
525 let b = r.as_bytes();
526 KWS.iter().any(|kw| sql_keyword_at(b, 0, kw))
527}
528
529impl Source for MssqlSource {
530 fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()> {
531 let built = build_export_query(request, crate::config::SourceType::Mssql);
534 let sql = built.sql.clone();
535 let overrides = request.column_overrides.clone();
536 let mut ctl =
545 AdaptiveBatchController::new(request.tuning, request.tuning.batch_size.max(1));
546 let mut cap_applied = false;
547 let lock_timeout_ms = request.tuning.lock_timeout_s.saturating_mul(1000);
557 let stmt_timeout = (request.tuning.statement_timeout_s > 0)
558 .then(|| std::time::Duration::from_secs(request.tuning.statement_timeout_s));
559
560 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
566 let decimal_hints = self.mssql_decimal_catalog_hints_opt(hint_query);
567
568 if lock_timeout_ms > 0 {
571 self.lock_timeout_applied = true;
572 }
573
574 let Self { rt, client, .. } = self;
575 rt.block_on(async {
576 use futures_util::stream::TryStreamExt;
577 use tiberius::QueryItem;
578
579 if lock_timeout_ms > 0 {
580 client
581 .execute(format!("SET LOCK_TIMEOUT {lock_timeout_ms}"), &[])
582 .await
583 .map_err(|e| anyhow::anyhow!("mssql: SET LOCK_TIMEOUT failed: {e}"))?;
584 }
585
586 let started = std::time::Instant::now();
587 let mut stream = client
588 .query(sql.as_str(), &[])
589 .await
590 .map_err(|e| anyhow::anyhow!("mssql: query failed: {e}"))?;
591
592 let mut columns: Vec<tiberius::Column> = Vec::new();
593 let mut buf: Vec<tiberius::Row> = Vec::with_capacity(ctl.target());
594 let mut schema: Option<SchemaRef> = None;
595 let max_value_bytes = request.tuning.max_value_bytes();
599
600 while let Some(item) = stream
601 .try_next()
602 .await
603 .map_err(|e| anyhow::anyhow!("mssql: streaming rows failed: {e}"))?
604 {
605 if let Some(budget) = stmt_timeout
606 && started.elapsed() > budget
607 {
608 return Err(crate::source::StatementDurationTimeout::mssql(
613 budget.as_secs(),
614 )
615 .into());
616 }
617 match item {
618 QueryItem::Metadata(meta) if columns.is_empty() => {
621 columns = meta.columns().to_vec();
622 if let Ok((provisional, _)) = arrow_convert::mssql_columns_to_schema(
633 &columns,
634 &overrides,
635 &[],
636 decimal_hints.as_ref(),
637 ) {
638 let eff = request
639 .tuning
640 .effective_batch_size(Some(&Arc::new(provisional)));
641 ctl.raise_configured_ceiling(eff);
642 }
643 }
644 QueryItem::Metadata(_) => {}
645 QueryItem::Row(row) => {
646 buf.push(row);
647 if buf.len() >= ctl.target() {
648 let arrow_bytes = emit_mssql_batch(
649 &columns,
650 &overrides,
651 decimal_hints.as_ref(),
652 &mut schema,
653 &buf,
654 sink,
655 max_value_bytes,
656 )?;
657 let n = buf.len();
658 buf.clear();
659 if !cap_applied && n > 0 {
664 let arrow_per_row = (arrow_bytes / n).max(64);
665 let target_mb = request
666 .tuning
667 .batch_size_memory_mb
668 .unwrap_or(DEFAULT_BATCH_TARGET_MB);
669 let safe = ((target_mb * 1024 * 1024) / arrow_per_row)
670 .max(PROBE_BATCH_SIZE);
671 if let Some(new) = ctl.apply_memory_cap(safe) {
672 log::info!(
673 "MSSQL batch cap: arrow≈{} B/row, target={} MB → batch_size → {}",
674 arrow_per_row,
675 target_mb,
676 new
677 );
678 buf.reserve(new.saturating_sub(buf.capacity()));
679 }
680 cap_applied = true;
681 }
682 ctl.after_batch(|| None);
684 ctl.throttle(n);
685 }
686 }
687 }
688 }
689 if !buf.is_empty() || schema.is_none() {
695 emit_mssql_batch(
696 &columns,
697 &overrides,
698 decimal_hints.as_ref(),
699 &mut schema,
700 &buf,
701 sink,
702 max_value_bytes,
703 )?;
704 }
705 Ok::<_, anyhow::Error>(())
706 })?;
707 Ok(())
708 }
709
710 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
711 let Self { rt, client, .. } = self;
712 rt.block_on(async {
713 let row = client
714 .query(sql, &[])
715 .await
716 .map_err(|e| anyhow::anyhow!("mssql: scalar query failed: {e}"))?
717 .into_row()
718 .await
719 .map_err(|e| anyhow::anyhow!("mssql: reading scalar row failed: {e}"))?;
720 Ok(row.and_then(|r| scalar_to_string(&r)))
721 })
722 }
723
724 fn type_mappings(
725 &mut self,
726 query: &str,
727 column_overrides: &ColumnOverrides,
728 ) -> Result<Vec<TypeMapping>> {
729 let decimal_hints = self.mssql_decimal_catalog_hints_opt(query);
733 let wrapped = format!("SELECT * FROM ({query}) AS _rivet_q WHERE 1 = 0");
735 let overrides = column_overrides.clone();
736 let Self { rt, client, .. } = self;
737 rt.block_on(async {
738 let mut stream = client
739 .query(wrapped.as_str(), &[])
740 .await
741 .map_err(|e| anyhow::anyhow!("mssql: type-probe query failed: {e}"))?;
742 let columns = stream
743 .columns()
744 .await
745 .map_err(|e| anyhow::anyhow!("mssql: type-probe metadata failed: {e}"))?
746 .map(<[_]>::to_vec)
747 .unwrap_or_default();
748 let _ = stream.into_first_result().await;
750 Ok(arrow_convert::mssql_type_mappings(
751 &columns,
752 &overrides,
753 decimal_hints.as_ref(),
754 ))
755 })
756 }
757
758 fn sample_pressure(&mut self) -> Option<u64> {
759 let Self { rt, client, .. } = self;
760 let sql = "SELECT SUM(cntr_value) FROM sys.dm_os_performance_counters \
770 WHERE counter_name IN ('Workfiles Created/sec', 'Worktables Created/sec')";
771 rt.block_on(async {
772 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
773 row.get::<i64, _>(0).map(|v| v.max(0) as u64)
774 })
775 }
776}
777
778impl MssqlSource {
779 pub(crate) fn harm_counters(&mut self) -> Option<Vec<(String, i64)>> {
786 let Self { rt, client, .. } = self;
787 let sql = "SELECT SUM(waiting_tasks_count), SUM(wait_time_ms) \
788 FROM sys.dm_os_wait_stats WHERE wait_type LIKE 'LCK%'";
789 rt.block_on(async {
790 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
791 let waits = row.get::<i64, _>(0).unwrap_or(0);
792 let wait_ms = row.get::<i64, _>(1).unwrap_or(0);
793 Some(vec![
794 ("mssql_lock_waits".to_string(), waits),
795 ("mssql_lock_wait_ms".to_string(), wait_ms),
796 ])
797 })
798 }
799
800 pub(crate) fn has_view_server_state(&mut self) -> Option<bool> {
805 let Self { rt, client, .. } = self;
806 rt.block_on(async {
807 let row = client
808 .query(
809 "SELECT HAS_PERMS_BY_NAME(NULL, NULL, 'VIEW SERVER STATE')",
810 &[],
811 )
812 .await
813 .ok()?
814 .into_row()
815 .await
816 .ok()??;
817 row.get::<i32, _>(0).map(|v| v == 1)
818 })
819 }
820}
821
822pub(crate) fn sample_harm_counters(
825 url: &str,
826 tls: Option<&TlsConfig>,
827) -> Option<Vec<(String, i64)>> {
828 let mut src = MssqlSource::connect_with_tls(url, tls).ok()?;
829 src.harm_counters()
830}
831
832pub(crate) fn sample_view_server_state(url: &str, tls: Option<&TlsConfig>) -> Option<bool> {
837 let mut src = MssqlSource::connect_with_tls(url, tls).ok()?;
838 src.has_view_server_state()
839}
840
841fn emit_mssql_batch(
851 columns: &[tiberius::Column],
852 overrides: &ColumnOverrides,
853 decimal_hints: Option<&HashMap<String, (u8, i8)>>,
854 schema: &mut Option<SchemaRef>,
855 rows: &[tiberius::Row],
856 sink: &mut dyn BatchSink,
857 max_value_bytes: Option<usize>,
858) -> Result<usize> {
859 let schema_ref = match schema {
860 Some(s) => s.clone(),
861 None => {
862 let (built, _decoders) =
863 arrow_convert::mssql_columns_to_schema(columns, overrides, rows, decimal_hints)?;
864 let s: SchemaRef = Arc::new(built);
865 sink.on_schema(s.clone())?;
866 *schema = Some(s.clone());
867 s
868 }
869 };
870 if !rows.is_empty() {
871 let batch = arrow_convert::mssql_rows_to_record_batch(&schema_ref, rows, max_value_bytes)?;
872 let bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
873 sink.on_batch(&batch)?;
874 return Ok(bytes);
875 }
876 Ok(0)
877}
878
879fn scalar_to_string(row: &tiberius::Row) -> Option<String> {
883 use tiberius::ColumnData;
884 let cell = row.cells().next().map(|(_, d)| d)?;
885 match cell {
886 ColumnData::U8(v) => v.map(|x| x.to_string()),
887 ColumnData::I16(v) => v.map(|x| x.to_string()),
888 ColumnData::I32(v) => v.map(|x| x.to_string()),
889 ColumnData::I64(v) => v.map(|x| x.to_string()),
890 ColumnData::F32(v) => v.map(|x| x.to_string()),
891 ColumnData::F64(v) => v.map(|x| x.to_string()),
892 ColumnData::Bit(v) => v.map(|x| x.to_string()),
893 ColumnData::String(v) => v.as_ref().map(|s| s.to_string()),
894 ColumnData::Numeric(v) => v.map(|n| {
895 let raw = n.value();
897 let scale = n.scale() as usize;
898 if scale == 0 {
899 raw.to_string()
900 } else {
901 let neg = raw < 0;
902 let digits = raw.unsigned_abs().to_string();
903 let digits = format!("{digits:0>width$}", width = scale + 1);
904 let (int, frac) = digits.split_at(digits.len() - scale);
905 format!("{}{int}.{frac}", if neg { "-" } else { "" })
906 }
907 }),
908 ColumnData::Guid(v) => v.map(|g| g.to_string()),
909 other => Some(format!("{other:?}")),
910 }
911}
912
913pub(crate) fn introspect_mssql_table_for_chunking(
916 url: &str,
917 tls: Option<&TlsConfig>,
918 qualified_table: &str,
919) -> Result<TableIntrospection> {
920 let (schema, table) = match qualified_table.split_once('.') {
921 Some((s, t)) => (s.to_string(), t.to_string()),
922 None => ("dbo".to_string(), qualified_table.to_string()),
923 };
924 let mut src = MssqlSource::connect_with_tls(url, tls)?;
925
926 let count_sql = format!(
929 "SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
930 JOIN sys.objects o ON o.object_id = p.object_id \
931 JOIN sys.schemas s ON s.schema_id = o.schema_id \
932 WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
933 schema.replace('\'', "''"),
934 table.replace('\'', "''")
935 );
936 let row_estimate = src
937 .query_scalar(&count_sql)?
938 .and_then(|s| s.parse::<i64>().ok())
939 .unwrap_or(0);
940
941 let pk_sql = format!(
944 "SELECT TOP 1 c.name, t.name FROM sys.indexes i \
945 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
946 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
947 JOIN sys.types t ON t.user_type_id = c.user_type_id \
948 JOIN sys.objects o ON o.object_id = i.object_id \
949 JOIN sys.schemas s ON s.schema_id = o.schema_id \
950 WHERE i.is_primary_key = 1 AND s.name = N'{}' AND o.name = N'{}' \
951 GROUP BY c.name, t.name HAVING COUNT(*) = 1",
952 schema.replace('\'', "''"),
953 table.replace('\'', "''")
954 );
955 let keyset_sql = format!(
965 "SELECT STRING_AGG(col, CHAR(31)) WITHIN GROUP (ORDER BY is_pk DESC, col) FROM ( \
966 SELECT col, MAX(is_pk) AS is_pk FROM ( \
967 SELECT MIN(c.name) AS col, MAX(CONVERT(int, i.is_primary_key)) AS is_pk \
968 FROM sys.indexes i \
969 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.key_ordinal > 0 \
970 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
971 JOIN sys.objects o ON o.object_id = i.object_id \
972 JOIN sys.schemas s ON s.schema_id = o.schema_id \
973 WHERE i.is_unique = 1 AND c.is_nullable = 0 AND s.name = N'{}' AND o.name = N'{}' \
974 GROUP BY i.object_id, i.index_id HAVING COUNT(*) = 1 \
975 ) per_index GROUP BY col \
976 ) deduped",
977 schema.replace('\'', "''"),
978 table.replace('\'', "''")
979 );
980 let keyset_keys: Vec<String> = src
981 .query_scalar(&keyset_sql)?
982 .map(|s| {
983 s.split('\u{1f}')
984 .filter(|c| !c.is_empty())
985 .map(str::to_string)
986 .collect()
987 })
988 .unwrap_or_default();
989
990 let mut single_int_pk = None;
993 if let Some(pk_col) = src.query_scalar(&pk_sql)? {
994 let type_sql = format!(
997 "SELECT t.name FROM sys.columns c \
998 JOIN sys.types t ON t.user_type_id = c.user_type_id \
999 JOIN sys.objects o ON o.object_id = c.object_id \
1000 JOIN sys.schemas s ON s.schema_id = o.schema_id \
1001 WHERE s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
1002 schema.replace('\'', "''"),
1003 table.replace('\'', "''"),
1004 pk_col.replace('\'', "''")
1005 );
1006 if let Some(ty) = src.query_scalar(&type_sql)?
1007 && matches!(ty.as_str(), "tinyint" | "smallint" | "int" | "bigint")
1008 {
1009 single_int_pk = Some(pk_col);
1010 }
1011 }
1012
1013 Ok(TableIntrospection {
1014 single_int_pk,
1015 keyset_keys,
1016 row_estimate,
1017 avg_row_bytes: None,
1018 })
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::{catalog_decimal_to_params, parse_mssql_simple_from_table};
1024
1025 fn parse(q: &str) -> Option<(String, String)> {
1026 parse_mssql_simple_from_table(q)
1027 }
1028
1029 #[test]
1030 fn parse_unqualified_table_defaults_to_dbo() {
1031 assert_eq!(
1032 parse("SELECT id, amount FROM transactions ORDER BY id"),
1033 Some(("dbo".into(), "transactions".into()))
1034 );
1035 }
1036
1037 #[test]
1038 fn parse_schema_qualified() {
1039 assert_eq!(
1040 parse("SELECT id FROM sales.orders WHERE id > 1"),
1041 Some(("sales".into(), "orders".into()))
1042 );
1043 }
1044
1045 #[test]
1046 fn parse_db_schema_table_takes_last_two() {
1047 assert_eq!(
1048 parse("SELECT * FROM mydb.sales.orders"),
1049 Some(("sales".into(), "orders".into()))
1050 );
1051 }
1052
1053 #[test]
1054 fn parse_bracketed_identifiers() {
1055 assert_eq!(
1056 parse("SELECT * FROM [my schema].[order items]"),
1057 Some(("my schema".into(), "order items".into()))
1058 );
1059 }
1060
1061 #[test]
1062 fn parse_table_with_alias() {
1063 assert_eq!(
1064 parse("SELECT t.id FROM transactions AS t WHERE t.x = 1"),
1065 Some(("dbo".into(), "transactions".into()))
1066 );
1067 assert_eq!(
1068 parse("SELECT t.id FROM transactions t ORDER BY t.id"),
1069 Some(("dbo".into(), "transactions".into()))
1070 );
1071 }
1072
1073 #[test]
1074 fn parse_rejects_join() {
1075 assert_eq!(parse("SELECT * FROM a INNER JOIN b ON a.id = b.id"), None);
1076 assert_eq!(parse("SELECT * FROM a JOIN b ON a.id = b.id"), None);
1077 }
1078
1079 #[test]
1080 fn parse_rejects_comma_list() {
1081 assert_eq!(parse("SELECT * FROM a, b WHERE a.id = b.id"), None);
1082 }
1083
1084 #[test]
1085 fn parse_rejects_subquery_from() {
1086 assert_eq!(parse("SELECT * FROM (SELECT * FROM t) AS s"), None);
1087 }
1088
1089 #[test]
1090 fn parse_ignores_from_inside_string_literal() {
1091 assert_eq!(
1093 parse("SELECT 'from x', amount FROM ledger WHERE note = 'paid from cash'"),
1094 Some(("dbo".into(), "ledger".into()))
1095 );
1096 }
1097
1098 #[test]
1099 fn catalog_bounds_accept_well_formed_and_reject_degenerate() {
1100 assert_eq!(catalog_decimal_to_params(10, 2), Some((10, 2)));
1102 assert_eq!(catalog_decimal_to_params(38, 0), Some((38, 0)));
1104 assert_eq!(catalog_decimal_to_params(38, 38), Some((38, 38)));
1105 assert_eq!(catalog_decimal_to_params(0, 0), None);
1109 assert_eq!(catalog_decimal_to_params(39, 0), None);
1110 assert_eq!(catalog_decimal_to_params(10, 11), None);
1111 }
1112}