1mod arrow_convert;
19mod proxy;
20
21pub use proxy::MssqlProxyKind;
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use arrow::datatypes::SchemaRef;
27use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
28use tokio::net::TcpStream;
29use tokio::runtime::Runtime;
30use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
31
32use proxy::{detect_mssql_proxy_kind, warn_proxy_kind};
33
34use crate::config::{TlsConfig, TlsMode};
35use crate::error::Result;
36use crate::source::batch_controller::{
37 AdaptiveBatchController, DEFAULT_BATCH_TARGET_MB, PROBE_BATCH_SIZE,
38};
39use crate::source::query::build_export_query;
40use crate::source::{BatchSink, ExportRequest, Source, TableIntrospection};
41use crate::types::{ColumnOverrides, TypeMapping};
42
43type MssqlClient = Client<Compat<TcpStream>>;
44
45pub struct MssqlSource {
51 rt: Runtime,
52 client: MssqlClient,
53 proxy_kind: MssqlProxyKind,
55 lock_timeout_applied: bool,
58}
59
60impl Drop for MssqlSource {
61 fn drop(&mut self) {
75 if !self.lock_timeout_applied {
76 return;
77 }
78 let Self { rt, client, .. } = self;
79 let _ = rt.block_on(async {
80 tokio::time::timeout(
81 std::time::Duration::from_secs(2),
82 client.execute("SET LOCK_TIMEOUT -1", &[]),
83 )
84 .await
85 });
86 }
87}
88
89struct MssqlUrl {
91 host: String,
92 port: u16,
93 user: String,
94 password: String,
95 database: String,
96}
97
98fn parse_mssql_url(url: &str) -> Result<MssqlUrl> {
99 let rest = url
100 .strip_prefix("sqlserver://")
101 .or_else(|| url.strip_prefix("mssql://"))
102 .ok_or_else(|| anyhow::anyhow!("mssql url must start with sqlserver:// — got {url}"))?;
103 let (userinfo, hostpart) = rest
106 .rsplit_once('@')
107 .ok_or_else(|| anyhow::anyhow!("mssql url missing user@host: {url}"))?;
108 let (user, password) = match userinfo.split_once(':') {
109 Some((u, p)) => (u.to_string(), p.to_string()),
110 None => (userinfo.to_string(), String::new()),
111 };
112 let (hostport, database) = hostpart
113 .split_once('/')
114 .map(|(h, d)| (h, d.to_string()))
115 .unwrap_or((hostpart, String::new()));
116 let (host, port) = match hostport.rsplit_once(':') {
117 Some((h, p)) => (
118 h.to_string(),
119 p.parse::<u16>()
120 .map_err(|_| anyhow::anyhow!("mssql url port not a number: {p}"))?,
121 ),
122 None => (hostport.to_string(), 1433),
123 };
124 if database.is_empty() {
125 anyhow::bail!("mssql url must include a database: sqlserver://user:pass@host:port/<db>");
126 }
127 Ok(MssqlUrl {
128 host,
129 port,
130 user,
131 password,
132 database,
133 })
134}
135
136impl MssqlSource {
137 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
141 crate::source::require_tls_or_loopback(url, tls)?;
147 let parts = parse_mssql_url(url)?;
148 let mut config = Config::new();
149 config.host(&parts.host);
150 config.port(parts.port);
151 config.database(&parts.database);
152 config.authentication(AuthMethod::sql_server(&parts.user, &parts.password));
153
154 config.encryption(EncryptionLevel::Required);
159 match tls {
160 Some(cfg) if cfg.mode == TlsMode::Disable || cfg.accept_invalid_certs => {
166 config.trust_cert()
167 }
168 Some(cfg) => {
169 static WEBPKI_WARNED: std::sync::Once = std::sync::Once::new();
178 WEBPKI_WARNED.call_once(|| {
179 log::warn!(
180 "mssql: TLS certificate validation is enabled, but the SQL Server \
181 engine pins an old rustls-webpki (via tiberius) with known CA \
182 name-constraint advisories (RUSTSEC-2026-0098/0099). Validation \
183 against a name-constraint-asserting private CA may accept a \
184 mis-issued certificate. Track tiberius for a rustls upgrade."
185 );
186 });
187 if let Some(ca) = cfg.ca_file.as_deref() {
188 config.trust_cert_ca(ca);
189 }
190 }
191 None => {
192 static WARNED: std::sync::Once = std::sync::Once::new();
201 WARNED.call_once(|| {
202 log::warn!(
203 "mssql: connecting with TLS certificate validation disabled \
204 (no `source.tls:` block) — the connection is encrypted but the \
205 server certificate is not verified (MITM not detected). Add \
206 `source.tls: {{ mode: verify-full, ca_file: <ca.pem> }}` to enable \
207 strict validation (or `mode: verify-ca` to skip only hostname checks)."
208 );
209 });
210 config.trust_cert();
211 }
212 }
213
214 let rt = tokio::runtime::Builder::new_current_thread()
215 .enable_all()
216 .build()
217 .map_err(|e| anyhow::anyhow!("mssql: tokio runtime build failed: {e}"))?;
218
219 let client = rt.block_on(async {
220 let tcp = TcpStream::connect(config.get_addr())
221 .await
222 .map_err(|e| anyhow::anyhow!("mssql: TCP connect failed: {e}"))?;
223 tcp.set_nodelay(true).ok();
224 Client::connect(config, tcp.compat_write())
225 .await
226 .map_err(|e| anyhow::anyhow!("mssql: login failed: {e}"))
227 })?;
228
229 let mut src = Self {
230 rt,
231 client,
232 proxy_kind: MssqlProxyKind::Direct,
233 lock_timeout_applied: false,
234 };
235 src.query_scalar("SELECT 1")?;
238 let kind = detect_mssql_proxy_kind(&src.rt, &mut src.client);
243 warn_proxy_kind(kind);
244 src.proxy_kind = kind;
245 Ok(src)
246 }
247
248 #[allow(dead_code)]
252 pub fn proxy_kind(&self) -> MssqlProxyKind {
253 self.proxy_kind
254 }
255
256 fn mssql_decimal_catalog_hints_opt(
264 &mut self,
265 query: &str,
266 ) -> Option<HashMap<String, (u8, i8)>> {
267 let (schema, table) = parse_mssql_simple_from_table(query)?;
268 match self.fetch_mssql_decimal_catalog_hints(&schema, &table) {
269 Ok(m) => m,
270 Err(e) => {
271 log::warn!(
277 "mssql decimal catalog lookup failed for {schema}.{table} — decimal scale \
278 will fall back to first-batch inference (declare it with a `columns:` \
279 override if an all-NULL first batch truncates it): {e}"
280 );
281 None
282 }
283 }
284 }
285
286 fn fetch_mssql_decimal_catalog_hints(
290 &mut self,
291 schema: &str,
292 table: &str,
293 ) -> Result<Option<HashMap<String, (u8, i8)>>> {
294 let sql = format!(
298 "SELECT c.name, c.precision, c.scale \
299 FROM sys.columns c \
300 JOIN sys.types t ON t.user_type_id = c.user_type_id \
301 JOIN sys.objects o ON o.object_id = c.object_id \
302 JOIN sys.schemas s ON s.schema_id = o.schema_id \
303 WHERE s.name = N'{}' AND o.name = N'{}' \
304 AND t.name IN ('decimal', 'numeric')",
305 schema.replace('\'', "''"),
306 table.replace('\'', "''")
307 );
308 let Self { rt, client, .. } = self;
309 let rows = rt.block_on(async {
310 client
311 .query(sql.as_str(), &[])
312 .await
313 .map_err(|e| anyhow::anyhow!("mssql: sys.columns probe failed: {e}"))?
314 .into_first_result()
315 .await
316 .map_err(|e| anyhow::anyhow!("mssql: reading sys.columns rows failed: {e}"))
317 })?;
318
319 let mut map = HashMap::new();
320 for row in &rows {
321 let name: Option<&str> = row.try_get(0).ok().flatten();
325 let precision: Option<u8> = row.try_get(1).ok().flatten();
326 let scale: Option<u8> = row.try_get(2).ok().flatten();
327 if let (Some(name), Some(p), Some(s)) = (name, precision, scale)
328 && let Some(pair) = catalog_decimal_to_params(p, s)
329 {
330 map.insert(name.to_string(), pair);
331 }
332 }
333
334 if map.is_empty() {
335 Ok(None)
336 } else {
337 log::debug!(
338 "mssql decimal catalog: resolved {} DECIMAL/NUMERIC column(s) for {schema}.{table}",
339 map.len(),
340 );
341 Ok(Some(map))
342 }
343 }
344}
345
346fn catalog_decimal_to_params(precision: u8, scale: u8) -> Option<(u8, i8)> {
351 if precision == 0 || precision > 38 {
352 return None;
353 }
354 if scale > precision || scale > i8::MAX as u8 {
355 return None;
356 }
357 Some((precision, scale as i8))
358}
359
360fn parse_mssql_simple_from_table(query: &str) -> Option<(String, String)> {
367 let from_idx = mssql_find_outer_from_keyword(query)?;
368 let tail = trim_sql_ws(query.get(from_idx + 4..)?);
369 let (first, after1) = parse_mssql_ident_piece(tail)?;
370 let after1 = trim_sql_ws(after1);
371 let (schema, table, after) = if after1.starts_with('.') {
373 let (second, after2) = parse_mssql_ident_piece(trim_sql_ws(after1.get(1..)?))?;
374 let after2 = trim_sql_ws(after2);
375 if after2.starts_with('.') {
376 let (third, after3) = parse_mssql_ident_piece(trim_sql_ws(after2.get(1..)?))?;
378 (second, third, trim_sql_ws(after3))
379 } else {
380 (first, second, after2)
381 }
382 } else {
383 ("dbo".to_string(), first, after1)
384 };
385 let after = skip_mssql_optional_alias(after)?;
388 if mssql_joins_or_comma(after) {
389 return None;
390 }
391 Some((schema, table))
392}
393
394fn trim_sql_ws(s: &str) -> &str {
395 s.trim_matches(|c: char| matches!(c, ' ' | '\t' | '\n' | '\r'))
396}
397
398fn is_sql_ident_byte(b: u8) -> bool {
399 b.is_ascii_alphanumeric() || b == b'_'
400}
401
402fn sql_keyword_at(haystack: &[u8], idx: usize, kw_lower: &[u8]) -> bool {
405 let n = kw_lower.len();
406 if idx + n > haystack.len() || !haystack[idx..idx + n].eq_ignore_ascii_case(kw_lower) {
407 return false;
408 }
409 let before_ok = idx == 0 || !is_sql_ident_byte(haystack[idx - 1]);
410 let after_ok = idx + n >= haystack.len() || !is_sql_ident_byte(haystack[idx + n]);
411 before_ok && after_ok
412}
413
414fn mssql_find_outer_from_keyword(sql: &str) -> Option<usize> {
417 let b = sql.as_bytes();
418 let mut i = 0usize;
419 let mut depth = 0usize;
420 let mut in_quote = false;
421 while i < b.len() {
422 if in_quote {
423 if b[i] == b'\'' {
424 if i + 1 < b.len() && b[i + 1] == b'\'' {
425 i += 2;
426 } else {
427 in_quote = false;
428 i += 1;
429 }
430 continue;
431 }
432 i += 1;
433 continue;
434 }
435 match b[i] {
436 b'\'' => in_quote = true,
437 b'(' => depth += 1,
438 b')' => depth = depth.saturating_sub(1),
439 _ if depth == 0 && sql_keyword_at(b, i, b"from") => return Some(i),
440 _ => {}
441 }
442 i += 1;
443 }
444 None
445}
446
447fn parse_mssql_ident_piece(rest: &str) -> Option<(String, &str)> {
450 let rest = trim_sql_ws(rest);
451 if let Some(after_open) = rest.strip_prefix('[') {
452 let mut out = String::new();
453 let mut chars = after_open.chars();
454 while let Some(ch) = chars.next() {
455 if ch == ']' {
456 if chars.as_str().starts_with(']') {
457 chars.next();
458 out.push(']');
459 continue;
460 }
461 return Some((out, chars.as_str()));
462 }
463 out.push(ch);
464 }
465 return None; }
467 let bytes = rest.as_bytes();
468 if bytes.is_empty() || (!bytes[0].is_ascii_alphabetic() && bytes[0] != b'_') {
469 return None;
470 }
471 let mut i = 1usize;
472 while i < bytes.len() && is_sql_ident_byte(bytes[i]) {
473 i += 1;
474 }
475 Some((rest.get(0..i)?.to_string(), rest.get(i..)?))
476}
477
478fn mssql_joins_or_comma(rest: &str) -> bool {
481 let r = trim_sql_ws(rest);
482 if r.starts_with(',') || r.starts_with('.') {
483 return true;
484 }
485 let b = r.as_bytes();
486 ["inner", "left", "right", "full", "cross", "join"]
487 .iter()
488 .any(|kw| sql_keyword_at(b, 0, kw.as_bytes()))
489}
490
491fn skip_mssql_optional_alias(rest: &str) -> Option<&str> {
495 let rest = trim_sql_ws(rest);
496 if rest.is_empty() || mssql_starts_clause_boundary(rest) || mssql_joins_or_comma(rest) {
497 return Some(rest);
498 }
499 let mut rest = rest;
500 if sql_keyword_at(rest.as_bytes(), 0, b"as") {
501 rest = trim_sql_ws(rest.get(2..)?);
502 }
503 let (_, tail) = parse_mssql_ident_piece(rest)?;
504 Some(trim_sql_ws(tail))
505}
506
507fn mssql_starts_clause_boundary(rest: &str) -> bool {
508 let r = trim_sql_ws(rest);
509 if r.is_empty() {
510 return true;
511 }
512 const KWS: &[&[u8]] = &[
513 b"where",
514 b"group",
515 b"having",
516 b"order",
517 b"union",
518 b"except",
519 b"intersect",
520 b"for",
521 b"option",
522 b"offset",
523 ];
524 let b = r.as_bytes();
525 KWS.iter().any(|kw| sql_keyword_at(b, 0, kw))
526}
527
528impl Source for MssqlSource {
529 fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()> {
530 let built = build_export_query(request, crate::config::SourceType::Mssql);
533 let sql = built.sql.clone();
534 let overrides = request.column_overrides.clone();
535 let mut ctl =
544 AdaptiveBatchController::new(request.tuning, request.tuning.batch_size.max(1));
545 let mut cap_applied = false;
546 let lock_timeout_ms = request.tuning.lock_timeout_s.saturating_mul(1000);
556 let stmt_timeout = (request.tuning.statement_timeout_s > 0)
557 .then(|| std::time::Duration::from_secs(request.tuning.statement_timeout_s));
558
559 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
565 let decimal_hints = self.mssql_decimal_catalog_hints_opt(hint_query);
566
567 if lock_timeout_ms > 0 {
570 self.lock_timeout_applied = true;
571 }
572
573 let Self { rt, client, .. } = self;
574 rt.block_on(async {
575 use futures_util::stream::TryStreamExt;
576 use tiberius::QueryItem;
577
578 if lock_timeout_ms > 0 {
579 client
580 .execute(format!("SET LOCK_TIMEOUT {lock_timeout_ms}"), &[])
581 .await
582 .map_err(|e| anyhow::anyhow!("mssql: SET LOCK_TIMEOUT failed: {e}"))?;
583 }
584
585 let started = std::time::Instant::now();
586 let mut stream = client
587 .query(sql.as_str(), &[])
588 .await
589 .map_err(|e| anyhow::anyhow!("mssql: query failed: {e}"))?;
590
591 let mut columns: Vec<tiberius::Column> = Vec::new();
592 let mut buf: Vec<tiberius::Row> = Vec::with_capacity(ctl.target());
593 let mut schema: Option<SchemaRef> = None;
594 let max_value_bytes = request.tuning.max_value_bytes();
598
599 while let Some(item) = stream
600 .try_next()
601 .await
602 .map_err(|e| anyhow::anyhow!("mssql: streaming rows failed: {e}"))?
603 {
604 if let Some(budget) = stmt_timeout
605 && started.elapsed() > budget
606 {
607 return Err(crate::source::StatementDurationTimeout::mssql(
612 budget.as_secs(),
613 )
614 .into());
615 }
616 match item {
617 QueryItem::Metadata(meta) if columns.is_empty() => {
620 columns = meta.columns().to_vec();
621 if let Ok((provisional, _)) = arrow_convert::mssql_columns_to_schema(
632 &columns,
633 &overrides,
634 &[],
635 decimal_hints.as_ref(),
636 ) {
637 let eff = request
638 .tuning
639 .effective_batch_size(Some(&Arc::new(provisional)));
640 ctl.raise_configured_ceiling(eff);
641 }
642 }
643 QueryItem::Metadata(_) => {}
644 QueryItem::Row(row) => {
645 buf.push(row);
646 if buf.len() >= ctl.target() {
647 let arrow_bytes = emit_mssql_batch(
648 &columns,
649 &overrides,
650 decimal_hints.as_ref(),
651 &mut schema,
652 &buf,
653 sink,
654 max_value_bytes,
655 )?;
656 let n = buf.len();
657 buf.clear();
658 if !cap_applied && n > 0 {
663 let arrow_per_row = (arrow_bytes / n).max(64);
664 let target_mb = request
665 .tuning
666 .batch_size_memory_mb
667 .unwrap_or(DEFAULT_BATCH_TARGET_MB);
668 let safe = ((target_mb * 1024 * 1024) / arrow_per_row)
669 .max(PROBE_BATCH_SIZE);
670 if let Some(new) = ctl.apply_memory_cap(safe) {
671 log::info!(
672 "MSSQL batch cap: arrow≈{} B/row, target={} MB → batch_size → {}",
673 arrow_per_row,
674 target_mb,
675 new
676 );
677 buf.reserve(new.saturating_sub(buf.capacity()));
678 }
679 cap_applied = true;
680 }
681 ctl.after_batch(|| None);
683 ctl.throttle();
684 }
685 }
686 }
687 }
688 if !buf.is_empty() || schema.is_none() {
694 emit_mssql_batch(
695 &columns,
696 &overrides,
697 decimal_hints.as_ref(),
698 &mut schema,
699 &buf,
700 sink,
701 max_value_bytes,
702 )?;
703 }
704 Ok::<_, anyhow::Error>(())
705 })?;
706 Ok(())
707 }
708
709 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
710 let Self { rt, client, .. } = self;
711 rt.block_on(async {
712 let row = client
713 .query(sql, &[])
714 .await
715 .map_err(|e| anyhow::anyhow!("mssql: scalar query failed: {e}"))?
716 .into_row()
717 .await
718 .map_err(|e| anyhow::anyhow!("mssql: reading scalar row failed: {e}"))?;
719 Ok(row.and_then(|r| scalar_to_string(&r)))
720 })
721 }
722
723 fn type_mappings(
724 &mut self,
725 query: &str,
726 column_overrides: &ColumnOverrides,
727 ) -> Result<Vec<TypeMapping>> {
728 let wrapped = format!("SELECT * FROM ({query}) AS _rivet_q WHERE 1 = 0");
730 let overrides = column_overrides.clone();
731 let Self { rt, client, .. } = self;
732 rt.block_on(async {
733 let mut stream = client
734 .query(wrapped.as_str(), &[])
735 .await
736 .map_err(|e| anyhow::anyhow!("mssql: type-probe query failed: {e}"))?;
737 let columns = stream
738 .columns()
739 .await
740 .map_err(|e| anyhow::anyhow!("mssql: type-probe metadata failed: {e}"))?
741 .map(<[_]>::to_vec)
742 .unwrap_or_default();
743 let _ = stream.into_first_result().await;
745 Ok(arrow_convert::mssql_type_mappings(&columns, &overrides))
746 })
747 }
748
749 fn sample_pressure(&mut self) -> Option<u64> {
750 let Self { rt, client, .. } = self;
751 let sql = "SELECT SUM(cntr_value) FROM sys.dm_os_performance_counters \
761 WHERE counter_name IN ('Workfiles Created/sec', 'Worktables Created/sec')";
762 rt.block_on(async {
763 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
764 row.get::<i64, _>(0).map(|v| v.max(0) as u64)
765 })
766 }
767}
768
769impl MssqlSource {
770 pub(crate) fn harm_counters(&mut self) -> Option<Vec<(String, i64)>> {
777 let Self { rt, client, .. } = self;
778 let sql = "SELECT SUM(waiting_tasks_count), SUM(wait_time_ms) \
779 FROM sys.dm_os_wait_stats WHERE wait_type LIKE 'LCK%'";
780 rt.block_on(async {
781 let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
782 let waits = row.get::<i64, _>(0).unwrap_or(0);
783 let wait_ms = row.get::<i64, _>(1).unwrap_or(0);
784 Some(vec![
785 ("mssql_lock_waits".to_string(), waits),
786 ("mssql_lock_wait_ms".to_string(), wait_ms),
787 ])
788 })
789 }
790
791 pub(crate) fn has_view_server_state(&mut self) -> Option<bool> {
796 let Self { rt, client, .. } = self;
797 rt.block_on(async {
798 let row = client
799 .query(
800 "SELECT HAS_PERMS_BY_NAME(NULL, NULL, 'VIEW SERVER STATE')",
801 &[],
802 )
803 .await
804 .ok()?
805 .into_row()
806 .await
807 .ok()??;
808 row.get::<i32, _>(0).map(|v| v == 1)
809 })
810 }
811}
812
813pub(crate) fn sample_harm_counters(
816 url: &str,
817 tls: Option<&TlsConfig>,
818) -> Option<Vec<(String, i64)>> {
819 let mut src = MssqlSource::connect_with_tls(url, tls).ok()?;
820 src.harm_counters()
821}
822
823pub(crate) fn sample_view_server_state(url: &str, tls: Option<&TlsConfig>) -> Option<bool> {
828 let mut src = MssqlSource::connect_with_tls(url, tls).ok()?;
829 src.has_view_server_state()
830}
831
832fn emit_mssql_batch(
842 columns: &[tiberius::Column],
843 overrides: &ColumnOverrides,
844 decimal_hints: Option<&HashMap<String, (u8, i8)>>,
845 schema: &mut Option<SchemaRef>,
846 rows: &[tiberius::Row],
847 sink: &mut dyn BatchSink,
848 max_value_bytes: Option<usize>,
849) -> Result<usize> {
850 let schema_ref = match schema {
851 Some(s) => s.clone(),
852 None => {
853 let (built, _decoders) =
854 arrow_convert::mssql_columns_to_schema(columns, overrides, rows, decimal_hints)?;
855 let s: SchemaRef = Arc::new(built);
856 sink.on_schema(s.clone())?;
857 *schema = Some(s.clone());
858 s
859 }
860 };
861 if !rows.is_empty() {
862 let batch = arrow_convert::mssql_rows_to_record_batch(&schema_ref, rows, max_value_bytes)?;
863 let bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
864 sink.on_batch(&batch)?;
865 return Ok(bytes);
866 }
867 Ok(0)
868}
869
870fn scalar_to_string(row: &tiberius::Row) -> Option<String> {
874 use tiberius::ColumnData;
875 let cell = row.cells().next().map(|(_, d)| d)?;
876 match cell {
877 ColumnData::U8(v) => v.map(|x| x.to_string()),
878 ColumnData::I16(v) => v.map(|x| x.to_string()),
879 ColumnData::I32(v) => v.map(|x| x.to_string()),
880 ColumnData::I64(v) => v.map(|x| x.to_string()),
881 ColumnData::F32(v) => v.map(|x| x.to_string()),
882 ColumnData::F64(v) => v.map(|x| x.to_string()),
883 ColumnData::Bit(v) => v.map(|x| x.to_string()),
884 ColumnData::String(v) => v.as_ref().map(|s| s.to_string()),
885 ColumnData::Numeric(v) => v.map(|n| {
886 let raw = n.value();
888 let scale = n.scale() as usize;
889 if scale == 0 {
890 raw.to_string()
891 } else {
892 let neg = raw < 0;
893 let digits = raw.unsigned_abs().to_string();
894 let digits = format!("{digits:0>width$}", width = scale + 1);
895 let (int, frac) = digits.split_at(digits.len() - scale);
896 format!("{}{int}.{frac}", if neg { "-" } else { "" })
897 }
898 }),
899 ColumnData::Guid(v) => v.map(|g| g.to_string()),
900 other => Some(format!("{other:?}")),
901 }
902}
903
904pub(crate) fn introspect_mssql_table_for_chunking(
907 url: &str,
908 tls: Option<&TlsConfig>,
909 qualified_table: &str,
910) -> Result<TableIntrospection> {
911 let (schema, table) = match qualified_table.split_once('.') {
912 Some((s, t)) => (s.to_string(), t.to_string()),
913 None => ("dbo".to_string(), qualified_table.to_string()),
914 };
915 let mut src = MssqlSource::connect_with_tls(url, tls)?;
916
917 let count_sql = format!(
920 "SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
921 JOIN sys.objects o ON o.object_id = p.object_id \
922 JOIN sys.schemas s ON s.schema_id = o.schema_id \
923 WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
924 schema.replace('\'', "''"),
925 table.replace('\'', "''")
926 );
927 let row_estimate = src
928 .query_scalar(&count_sql)?
929 .and_then(|s| s.parse::<i64>().ok())
930 .unwrap_or(0);
931
932 let pk_sql = format!(
935 "SELECT TOP 1 c.name, t.name FROM sys.indexes i \
936 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
937 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
938 JOIN sys.types t ON t.user_type_id = c.user_type_id \
939 JOIN sys.objects o ON o.object_id = i.object_id \
940 JOIN sys.schemas s ON s.schema_id = o.schema_id \
941 WHERE i.is_primary_key = 1 AND s.name = N'{}' AND o.name = N'{}' \
942 GROUP BY c.name, t.name HAVING COUNT(*) = 1",
943 schema.replace('\'', "''"),
944 table.replace('\'', "''")
945 );
946 let keyset_sql = format!(
956 "SELECT STRING_AGG(col, CHAR(31)) WITHIN GROUP (ORDER BY is_pk DESC, col) FROM ( \
957 SELECT col, MAX(is_pk) AS is_pk FROM ( \
958 SELECT MIN(c.name) AS col, MAX(CONVERT(int, i.is_primary_key)) AS is_pk \
959 FROM sys.indexes i \
960 JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.key_ordinal > 0 \
961 JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
962 JOIN sys.objects o ON o.object_id = i.object_id \
963 JOIN sys.schemas s ON s.schema_id = o.schema_id \
964 WHERE i.is_unique = 1 AND c.is_nullable = 0 AND s.name = N'{}' AND o.name = N'{}' \
965 GROUP BY i.object_id, i.index_id HAVING COUNT(*) = 1 \
966 ) per_index GROUP BY col \
967 ) deduped",
968 schema.replace('\'', "''"),
969 table.replace('\'', "''")
970 );
971 let keyset_keys: Vec<String> = src
972 .query_scalar(&keyset_sql)?
973 .map(|s| {
974 s.split('\u{1f}')
975 .filter(|c| !c.is_empty())
976 .map(str::to_string)
977 .collect()
978 })
979 .unwrap_or_default();
980
981 let mut single_int_pk = None;
984 if let Some(pk_col) = src.query_scalar(&pk_sql)? {
985 let type_sql = format!(
988 "SELECT t.name FROM sys.columns c \
989 JOIN sys.types t ON t.user_type_id = c.user_type_id \
990 JOIN sys.objects o ON o.object_id = c.object_id \
991 JOIN sys.schemas s ON s.schema_id = o.schema_id \
992 WHERE s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
993 schema.replace('\'', "''"),
994 table.replace('\'', "''"),
995 pk_col.replace('\'', "''")
996 );
997 if let Some(ty) = src.query_scalar(&type_sql)?
998 && matches!(ty.as_str(), "tinyint" | "smallint" | "int" | "bigint")
999 {
1000 single_int_pk = Some(pk_col);
1001 }
1002 }
1003
1004 Ok(TableIntrospection {
1005 single_int_pk,
1006 keyset_keys,
1007 row_estimate,
1008 avg_row_bytes: None,
1009 })
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::{catalog_decimal_to_params, parse_mssql_simple_from_table};
1015
1016 fn parse(q: &str) -> Option<(String, String)> {
1017 parse_mssql_simple_from_table(q)
1018 }
1019
1020 #[test]
1021 fn parse_unqualified_table_defaults_to_dbo() {
1022 assert_eq!(
1023 parse("SELECT id, amount FROM transactions ORDER BY id"),
1024 Some(("dbo".into(), "transactions".into()))
1025 );
1026 }
1027
1028 #[test]
1029 fn parse_schema_qualified() {
1030 assert_eq!(
1031 parse("SELECT id FROM sales.orders WHERE id > 1"),
1032 Some(("sales".into(), "orders".into()))
1033 );
1034 }
1035
1036 #[test]
1037 fn parse_db_schema_table_takes_last_two() {
1038 assert_eq!(
1039 parse("SELECT * FROM mydb.sales.orders"),
1040 Some(("sales".into(), "orders".into()))
1041 );
1042 }
1043
1044 #[test]
1045 fn parse_bracketed_identifiers() {
1046 assert_eq!(
1047 parse("SELECT * FROM [my schema].[order items]"),
1048 Some(("my schema".into(), "order items".into()))
1049 );
1050 }
1051
1052 #[test]
1053 fn parse_table_with_alias() {
1054 assert_eq!(
1055 parse("SELECT t.id FROM transactions AS t WHERE t.x = 1"),
1056 Some(("dbo".into(), "transactions".into()))
1057 );
1058 assert_eq!(
1059 parse("SELECT t.id FROM transactions t ORDER BY t.id"),
1060 Some(("dbo".into(), "transactions".into()))
1061 );
1062 }
1063
1064 #[test]
1065 fn parse_rejects_join() {
1066 assert_eq!(parse("SELECT * FROM a INNER JOIN b ON a.id = b.id"), None);
1067 assert_eq!(parse("SELECT * FROM a JOIN b ON a.id = b.id"), None);
1068 }
1069
1070 #[test]
1071 fn parse_rejects_comma_list() {
1072 assert_eq!(parse("SELECT * FROM a, b WHERE a.id = b.id"), None);
1073 }
1074
1075 #[test]
1076 fn parse_rejects_subquery_from() {
1077 assert_eq!(parse("SELECT * FROM (SELECT * FROM t) AS s"), None);
1078 }
1079
1080 #[test]
1081 fn parse_ignores_from_inside_string_literal() {
1082 assert_eq!(
1084 parse("SELECT 'from x', amount FROM ledger WHERE note = 'paid from cash'"),
1085 Some(("dbo".into(), "ledger".into()))
1086 );
1087 }
1088
1089 #[test]
1090 fn catalog_bounds_accept_well_formed_and_reject_degenerate() {
1091 assert_eq!(catalog_decimal_to_params(10, 2), Some((10, 2)));
1093 assert_eq!(catalog_decimal_to_params(38, 0), Some((38, 0)));
1095 assert_eq!(catalog_decimal_to_params(38, 38), Some((38, 38)));
1096 assert_eq!(catalog_decimal_to_params(0, 0), None);
1100 assert_eq!(catalog_decimal_to_params(39, 0), None);
1101 assert_eq!(catalog_decimal_to_params(10, 11), None);
1102 }
1103}