Skip to main content

rivet/source/mssql/
mod.rs

1//! **Layer: Execution** — MSSQL / SQL Server source engine.
2//!
3//! Third SQL engine after PostgreSQL and MySQL. The `tiberius` driver is
4//! async (tokio); the `Source` trait is sync `&mut self` (ADR-0011), so each
5//! `MssqlSource` owns a current-thread `tokio` runtime and `block_on`s every
6//! driver call — no async leaks into the runner.
7//!
8//! Dialect deltas vs PG/MySQL (routed through the shared seams):
9//! - identifier quoting `[col]` (`sql::quote_ident`)
10//! - cursor literal `N'…'` with `''` escaping (`query::cursor_rhs`)
11//! - introspection via `sys.*` catalog views
12//!
13//! Supported today: snapshot / incremental / chunked (range + dense) and keyset
14//! (seek) export, `check --type-report`, `doctor`, chunked-mode planning. The
15//! keyset page builder emits a dialect-correct
16//! `OFFSET 0 ROWS FETCH NEXT n ROWS ONLY` clause (T-SQL has no `LIMIT`).
17
18mod arrow_convert;
19mod proxy;
20
21pub use proxy::MssqlProxyKind;
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use arrow::datatypes::SchemaRef;
27use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
28use tokio::net::TcpStream;
29use tokio::runtime::Runtime;
30use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
31
32use proxy::{detect_mssql_proxy_kind, warn_proxy_kind};
33
34use crate::config::{TlsConfig, TlsMode};
35use crate::error::Result;
36use crate::source::batch_controller::{
37    AdaptiveBatchController, DEFAULT_BATCH_TARGET_MB, PROBE_BATCH_SIZE,
38};
39use crate::source::query::build_export_query;
40use crate::source::{BatchSink, ExportRequest, Source, TableIntrospection};
41use crate::types::{ColumnOverrides, TypeMapping};
42
43type MssqlClient = Client<Compat<TcpStream>>;
44
45/// SQL Server source. Owns the async driver + the runtime that drives it.
46///
47/// `pub` (not `pub(crate)`) so integration tests can reach `proxy_kind()` the
48/// same way they reach `MysqlSource::proxy_kind()`; the rest of the type
49/// carries the same "no external API contract" disclaimer as `MysqlSource`.
50pub struct MssqlSource {
51    rt: Runtime,
52    client: MssqlClient,
53    /// Pooler/gateway classification, sampled once at connect time.
54    proxy_kind: MssqlProxyKind,
55    /// Whether the export issued `SET LOCK_TIMEOUT` on this connection, so the
56    /// `Drop` teardown knows to reset it (Epic 18 B2 — pooler-safe session).
57    lock_timeout_applied: bool,
58}
59
60impl Drop for MssqlSource {
61    /// Pooler-safe session teardown (Epic 18 B2). rivet never opens a
62    /// transaction on this connection — every read is an autocommit `SELECT`,
63    /// so there is no transaction to leave dangling across the `block_on`
64    /// bridge (ADR-0011). The only session state the export mutates is
65    /// `SET LOCK_TIMEOUT`; reset it to the SQL Server default (`-1`, wait
66    /// indefinitely) before the connection closes so a *multiplexed* pooler
67    /// that keeps the backend connection alive cannot hand our non-default
68    /// `LOCK_TIMEOUT` to the next session that reuses it.
69    ///
70    /// Best-effort and time-boxed: after a failed read the stream is
71    /// half-drained and the connection is dying anyway, so the reset (and the
72    /// physical connection) just goes away; the 2 s cap guarantees `Drop`
73    /// can never hang on a wedged connection.
74    fn drop(&mut self) {
75        if !self.lock_timeout_applied {
76            return;
77        }
78        let Self { rt, client, .. } = self;
79        let _ = rt.block_on(async {
80            tokio::time::timeout(
81                std::time::Duration::from_secs(2),
82                client.execute("SET LOCK_TIMEOUT -1", &[]),
83            )
84            .await
85        });
86    }
87}
88
89/// Parsed `sqlserver://user[:password]@host[:port]/db` connection parts.
90struct MssqlUrl {
91    host: String,
92    port: u16,
93    user: String,
94    password: String,
95    database: String,
96}
97
98fn parse_mssql_url(url: &str) -> Result<MssqlUrl> {
99    let rest = url
100        .strip_prefix("sqlserver://")
101        .or_else(|| url.strip_prefix("mssql://"))
102        .ok_or_else(|| anyhow::anyhow!("mssql url must start with sqlserver:// — got {url}"))?;
103    // userinfo @ host:port / db   (rsplit the last '@' so a '@' in a password
104    // is tolerated; '/' splits host from db).
105    let (userinfo, hostpart) = rest
106        .rsplit_once('@')
107        .ok_or_else(|| anyhow::anyhow!("mssql url missing user@host: {url}"))?;
108    let (user, password) = match userinfo.split_once(':') {
109        Some((u, p)) => (u.to_string(), p.to_string()),
110        None => (userinfo.to_string(), String::new()),
111    };
112    let (hostport, database) = hostpart
113        .split_once('/')
114        .map(|(h, d)| (h, d.to_string()))
115        .unwrap_or((hostpart, String::new()));
116    let (host, port) = match hostport.rsplit_once(':') {
117        Some((h, p)) => (
118            h.to_string(),
119            p.parse::<u16>()
120                .map_err(|_| anyhow::anyhow!("mssql url port not a number: {p}"))?,
121        ),
122        None => (hostport.to_string(), 1433),
123    };
124    if database.is_empty() {
125        anyhow::bail!("mssql url must include a database: sqlserver://user:pass@host:port/<db>");
126    }
127    Ok(MssqlUrl {
128        host,
129        port,
130        user,
131        password,
132        database,
133    })
134}
135
136impl MssqlSource {
137    /// Connect to SQL Server, honouring the shared `TlsConfig`. `url` is the
138    /// resolved `sqlserver://user:pass@host:port/db` form. A successful return
139    /// has completed a TLS login handshake and a `SELECT 1` round-trip.
140    pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
141        // Refuse trust-any-cert to a remote host with no `tls:` block before any
142        // dial (CWE-295): SQL Server always encrypts the login handshake, but
143        // with `trust_cert` that handshake is unauthenticated, so a MITM is not
144        // detected. Loopback keeps trust-cert (dev); a remote host must opt in
145        // explicitly via `tls: { mode: ... }`.
146        crate::source::require_tls_or_loopback(url, tls)?;
147        let parts = parse_mssql_url(url)?;
148        let mut config = Config::new();
149        config.host(&parts.host);
150        config.port(parts.port);
151        config.database(&parts.database);
152        config.authentication(AuthMethod::sql_server(&parts.user, &parts.password));
153
154        // SQL Server forces TLS on the login handshake regardless; map the
155        // shared TlsConfig onto tiberius' cert-trust knobs. A private CA goes
156        // through `trust_cert_ca`; otherwise dev self-signed certs need
157        // `trust_cert` (accept-invalid). Default keeps full verification.
158        config.encryption(EncryptionLevel::Required);
159        match tls {
160            // `mode: disable` is the operator's explicit opt-in to an
161            // unauthenticated (trust-any-cert) connection — the SQL Server
162            // analogue of PG/MySQL remote plaintext. It is the documented way
163            // to keep trust-cert against a remote host the gate above would
164            // otherwise have refused.
165            Some(cfg) if cfg.mode == TlsMode::Disable || cfg.accept_invalid_certs => {
166                config.trust_cert()
167            }
168            Some(cfg) => {
169                // Strict cert validation is ON here (mode verify-ca/verify-full,
170                // no accept_invalid_certs). This is the ONLY MSSQL path that
171                // exercises rustls-webpki, which is pinned to a vulnerable 0.101
172                // via tiberius 0.12 (no newer tiberius exists; see
173                // .cargo/audit.toml). The CA name-constraint advisories bite
174                // only when validating against a name-constraint-asserting
175                // private CA — narrow, but the operator who turned on strict
176                // validation is exactly who should know. Warn once.
177                static WEBPKI_WARNED: std::sync::Once = std::sync::Once::new();
178                WEBPKI_WARNED.call_once(|| {
179                    log::warn!(
180                        "mssql: TLS certificate validation is enabled, but the SQL Server \
181                         engine pins an old rustls-webpki (via tiberius) with known CA \
182                         name-constraint advisories (RUSTSEC-2026-0098/0099). Validation \
183                         against a name-constraint-asserting private CA may accept a \
184                         mis-issued certificate. Track tiberius for a rustls upgrade."
185                    );
186                });
187                if let Some(ca) = cfg.ca_file.as_deref() {
188                    config.trust_cert_ca(ca);
189                }
190            }
191            None => {
192                // Reached only for a LOOPBACK host (the gate above refuses a
193                // remote host with no `tls:` block). On loopback, tiberius
194                // trusts the server certificate without verifying issuer or
195                // hostname: the handshake is encrypted but unauthenticated. That
196                // is safe here because the bytes never leave the box, and it
197                // keeps dev / self-signed docker setups working without opt-in.
198                // Warn once, naming the config key that turns on strict
199                // validation.
200                static WARNED: std::sync::Once = std::sync::Once::new();
201                WARNED.call_once(|| {
202                    log::warn!(
203                        "mssql: connecting with TLS certificate validation disabled \
204                         (no `source.tls:` block) — the connection is encrypted but the \
205                         server certificate is not verified (MITM not detected). Add \
206                         `source.tls: {{ mode: verify-full, ca_file: <ca.pem> }}` to enable \
207                         strict validation (or `mode: verify-ca` to skip only hostname checks)."
208                    );
209                });
210                config.trust_cert();
211            }
212        }
213
214        let rt = tokio::runtime::Builder::new_current_thread()
215            .enable_all()
216            .build()
217            .map_err(|e| anyhow::anyhow!("mssql: tokio runtime build failed: {e}"))?;
218
219        let client = rt.block_on(async {
220            let tcp = TcpStream::connect(config.get_addr())
221                .await
222                .map_err(|e| anyhow::anyhow!("mssql: TCP connect failed: {e}"))?;
223            tcp.set_nodelay(true).ok();
224            Client::connect(config, tcp.compat_write())
225                .await
226                .map_err(|e| anyhow::anyhow!("mssql: login failed: {e}"))
227        })?;
228
229        let mut src = Self {
230            rt,
231            client,
232            proxy_kind: MssqlProxyKind::Direct,
233            lock_timeout_applied: false,
234        };
235        // Health round-trip — surfaces auth/permission errors at connect time
236        // (doctor relies on this).
237        src.query_scalar("SELECT 1")?;
238        // Best-effort pooler/gateway detection (mirrors PG `pg_backend_pid`
239        // drift and MySQL `CONNECTION_ID()` drift): one warning at connect
240        // time, never breaks the export. Disjoint borrows of `rt` (&) and
241        // `client` (&mut).
242        let kind = detect_mssql_proxy_kind(&src.rt, &mut src.client);
243        warn_proxy_kind(kind);
244        src.proxy_kind = kind;
245        Ok(src)
246    }
247
248    /// Expose the proxy classification for diagnostics (preflight, integration
249    /// tests). Not part of the `Source` trait — same internal-may-change
250    /// contract as the rest of `rivet::source::mssql::*`.
251    #[allow(dead_code)]
252    pub fn proxy_kind(&self) -> MssqlProxyKind {
253        self.proxy_kind
254    }
255
256    /// Declared `(precision, scale)` per decimal/numeric column, read from
257    /// `sys.columns`, for a simple single-table `SELECT … FROM [schema.]table`.
258    /// `None` for any query the FROM parser does not handle (joins, comma lists,
259    /// subqueries) or when the lookup fails — the schema builder then falls back
260    /// to data-inference, today's behaviour. Never fails the export: a lookup
261    /// error is logged (mirrors `pg_numeric_catalog_hints_opt`) and downgraded
262    /// to `None`.
263    fn mssql_decimal_catalog_hints_opt(
264        &mut self,
265        query: &str,
266    ) -> Option<HashMap<String, (u8, i8)>> {
267        let (schema, table) = parse_mssql_simple_from_table(query)?;
268        match self.fetch_mssql_decimal_catalog_hints(&schema, &table) {
269            Ok(m) => m,
270            Err(e) => {
271                // The parser identified a single-table query but the catalog
272                // lookup itself failed (permissions, gateway). Surface it —
273                // otherwise a downstream decimal scale-0 freeze on an all-NULL
274                // first batch looks like a config problem when the real cause is
275                // a missing `sys.columns` read here.
276                log::warn!(
277                    "mssql decimal catalog lookup failed for {schema}.{table} — decimal scale \
278                     will fall back to first-batch inference (declare it with a `columns:` \
279                     override if an all-NULL first batch truncates it): {e}"
280                );
281                None
282            }
283        }
284    }
285
286    /// Probe `sys.columns` for each `decimal`/`numeric` column's declared
287    /// `(precision, scale)`. Joined through `sys.schemas`/`sys.objects` so the
288    /// `(schema, table)` pair resolves the exact base table the export reads.
289    fn fetch_mssql_decimal_catalog_hints(
290        &mut self,
291        schema: &str,
292        table: &str,
293    ) -> Result<Option<HashMap<String, (u8, i8)>>> {
294        // `decimal` and `numeric` are synonyms in SQL Server and share one
295        // `sys.types` entry per scale; filter on the base type name so only
296        // fixed-point columns (not money / int / float) carry a hint.
297        let sql = format!(
298            "SELECT c.name, c.precision, c.scale \
299             FROM sys.columns c \
300             JOIN sys.types t ON t.user_type_id = c.user_type_id \
301             JOIN sys.objects o ON o.object_id = c.object_id \
302             JOIN sys.schemas s ON s.schema_id = o.schema_id \
303             WHERE s.name = N'{}' AND o.name = N'{}' \
304             AND t.name IN ('decimal', 'numeric')",
305            schema.replace('\'', "''"),
306            table.replace('\'', "''")
307        );
308        let Self { rt, client, .. } = self;
309        let rows = rt.block_on(async {
310            client
311                .query(sql.as_str(), &[])
312                .await
313                .map_err(|e| anyhow::anyhow!("mssql: sys.columns probe failed: {e}"))?
314                .into_first_result()
315                .await
316                .map_err(|e| anyhow::anyhow!("mssql: reading sys.columns rows failed: {e}"))
317        })?;
318
319        let mut map = HashMap::new();
320        for row in &rows {
321            // sys.columns: name = sysname (nvarchar), precision/scale = tinyint.
322            // `try_get` (not `get`) so an unexpected cell type downgrades to a
323            // skipped hint rather than panicking the export.
324            let name: Option<&str> = row.try_get(0).ok().flatten();
325            let precision: Option<u8> = row.try_get(1).ok().flatten();
326            let scale: Option<u8> = row.try_get(2).ok().flatten();
327            if let (Some(name), Some(p), Some(s)) = (name, precision, scale)
328                && let Some(pair) = catalog_decimal_to_params(p, s)
329            {
330                map.insert(name.to_string(), pair);
331            }
332        }
333
334        if map.is_empty() {
335            Ok(None)
336        } else {
337            log::debug!(
338                "mssql decimal catalog: resolved {} DECIMAL/NUMERIC column(s) for {schema}.{table}",
339                map.len(),
340            );
341            Ok(Some(map))
342        }
343    }
344}
345
346/// Convert `sys.columns` `(precision, scale)` into Rivet `decimal(p, s)`
347/// parameters, rejecting anything outside the bounds the YAML overrides accept.
348/// SQL Server caps precision at 38 and scale ≤ precision, so a well-formed
349/// catalog row always passes; the guard defends against a degenerate row.
350fn catalog_decimal_to_params(precision: u8, scale: u8) -> Option<(u8, i8)> {
351    if precision == 0 || precision > 38 {
352        return None;
353    }
354    if scale > precision || scale > i8::MAX as u8 {
355        return None;
356    }
357    Some((precision, scale as i8))
358}
359
360/// Extract the `(schema, table)` of a simple single-table T-SQL
361/// `SELECT … FROM [schema.]table` (no joins, no comma list, no subquery in
362/// `FROM`). Returns `None` for anything more complex — the caller falls back to
363/// data-inference rather than guessing. Schema defaults to `dbo` when the table
364/// is unqualified. Handles `[bracketed]` and bare identifiers; pure `&str`
365/// work, so it is unit-testable without a live server.
366fn parse_mssql_simple_from_table(query: &str) -> Option<(String, String)> {
367    let from_idx = mssql_find_outer_from_keyword(query)?;
368    let tail = trim_sql_ws(query.get(from_idx + 4..)?);
369    let (first, after1) = parse_mssql_ident_piece(tail)?;
370    let after1 = trim_sql_ws(after1);
371    // `schema.table` (optionally `db.schema.table` → take the last two parts).
372    let (schema, table, after) = if after1.starts_with('.') {
373        let (second, after2) = parse_mssql_ident_piece(trim_sql_ws(after1.get(1..)?))?;
374        let after2 = trim_sql_ws(after2);
375        if after2.starts_with('.') {
376            // db.schema.table — `first` is the database, drop it.
377            let (third, after3) = parse_mssql_ident_piece(trim_sql_ws(after2.get(1..)?))?;
378            (second, third, trim_sql_ws(after3))
379        } else {
380            (first, second, after2)
381        }
382    } else {
383        ("dbo".to_string(), first, after1)
384    };
385    // Reject joins / comma-lists / a trailing dotted continuation we didn't
386    // consume; only a clause boundary (WHERE/ORDER/…/end) or an alias may follow.
387    let after = skip_mssql_optional_alias(after)?;
388    if mssql_joins_or_comma(after) {
389        return None;
390    }
391    Some((schema, table))
392}
393
394fn trim_sql_ws(s: &str) -> &str {
395    s.trim_matches(|c: char| matches!(c, ' ' | '\t' | '\n' | '\r'))
396}
397
398fn is_sql_ident_byte(b: u8) -> bool {
399    b.is_ascii_alphanumeric() || b == b'_'
400}
401
402/// Case-insensitive keyword match at byte `idx` with identifier-boundary checks
403/// on both sides (so `from_x` does not match `from`).
404fn sql_keyword_at(haystack: &[u8], idx: usize, kw_lower: &[u8]) -> bool {
405    let n = kw_lower.len();
406    if idx + n > haystack.len() || !haystack[idx..idx + n].eq_ignore_ascii_case(kw_lower) {
407        return false;
408    }
409    let before_ok = idx == 0 || !is_sql_ident_byte(haystack[idx - 1]);
410    let after_ok = idx + n >= haystack.len() || !is_sql_ident_byte(haystack[idx + n]);
411    before_ok && after_ok
412}
413
414/// Byte offset of the top-level `FROM`, skipping nested parentheses
415/// (subqueries) and `'…'` string literals (with `''` escapes).
416fn mssql_find_outer_from_keyword(sql: &str) -> Option<usize> {
417    let b = sql.as_bytes();
418    let mut i = 0usize;
419    let mut depth = 0usize;
420    let mut in_quote = false;
421    while i < b.len() {
422        if in_quote {
423            if b[i] == b'\'' {
424                if i + 1 < b.len() && b[i + 1] == b'\'' {
425                    i += 2;
426                } else {
427                    in_quote = false;
428                    i += 1;
429                }
430                continue;
431            }
432            i += 1;
433            continue;
434        }
435        match b[i] {
436            b'\'' => in_quote = true,
437            b'(' => depth += 1,
438            b')' => depth = depth.saturating_sub(1),
439            _ if depth == 0 && sql_keyword_at(b, i, b"from") => return Some(i),
440            _ => {}
441        }
442        i += 1;
443    }
444    None
445}
446
447/// Parse one T-SQL identifier piece: `[bracketed name]` (with `]]` escapes) or
448/// a bare `ident`. Returns the unquoted name and the remaining tail.
449fn parse_mssql_ident_piece(rest: &str) -> Option<(String, &str)> {
450    let rest = trim_sql_ws(rest);
451    if let Some(after_open) = rest.strip_prefix('[') {
452        let mut out = String::new();
453        let mut chars = after_open.chars();
454        while let Some(ch) = chars.next() {
455            if ch == ']' {
456                if chars.as_str().starts_with(']') {
457                    chars.next();
458                    out.push(']');
459                    continue;
460                }
461                return Some((out, chars.as_str()));
462            }
463            out.push(ch);
464        }
465        return None; // unterminated bracket
466    }
467    let bytes = rest.as_bytes();
468    if bytes.is_empty() || (!bytes[0].is_ascii_alphabetic() && bytes[0] != b'_') {
469        return None;
470    }
471    let mut i = 1usize;
472    while i < bytes.len() && is_sql_ident_byte(bytes[i]) {
473        i += 1;
474    }
475    Some((rest.get(0..i)?.to_string(), rest.get(i..)?))
476}
477
478/// `true` when a join / comma-list follows the relation — the parser rejects
479/// these (catalog hints only resolve for a single base table).
480fn mssql_joins_or_comma(rest: &str) -> bool {
481    let r = trim_sql_ws(rest);
482    if r.starts_with(',') || r.starts_with('.') {
483        return true;
484    }
485    let b = r.as_bytes();
486    ["inner", "left", "right", "full", "cross", "join"]
487        .iter()
488        .any(|kw| sql_keyword_at(b, 0, kw.as_bytes()))
489}
490
491/// Consume an optional table alias (`[AS] alias`) after the relation, stopping
492/// at a clause boundary. Returns the tail after the alias, or `None` if what
493/// follows is a join/comma (so the caller rejects the query).
494fn skip_mssql_optional_alias(rest: &str) -> Option<&str> {
495    let rest = trim_sql_ws(rest);
496    if rest.is_empty() || mssql_starts_clause_boundary(rest) || mssql_joins_or_comma(rest) {
497        return Some(rest);
498    }
499    let mut rest = rest;
500    if sql_keyword_at(rest.as_bytes(), 0, b"as") {
501        rest = trim_sql_ws(rest.get(2..)?);
502    }
503    let (_, tail) = parse_mssql_ident_piece(rest)?;
504    Some(trim_sql_ws(tail))
505}
506
507fn mssql_starts_clause_boundary(rest: &str) -> bool {
508    let r = trim_sql_ws(rest);
509    if r.is_empty() {
510        return true;
511    }
512    const KWS: &[&[u8]] = &[
513        b"where",
514        b"group",
515        b"having",
516        b"order",
517        b"union",
518        b"except",
519        b"intersect",
520        b"for",
521        b"option",
522        b"offset",
523    ];
524    let b = r.as_bytes();
525    KWS.iter().any(|kw| sql_keyword_at(b, 0, kw))
526}
527
528impl Source for MssqlSource {
529    fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()> {
530        // Keyset (seek) pages build a dialect-correct
531        // `OFFSET 0 ROWS FETCH NEXT n ROWS ONLY` clause (T-SQL has no `LIMIT`).
532        let built = build_export_query(request, crate::config::SourceType::Mssql);
533        let sql = built.sql.clone();
534        let overrides = request.column_overrides.clone();
535        // Stream the result one Arrow batch at a time (peak RSS ≈ one batch,
536        // independent of `chunk_size`) through the shared `AdaptiveBatchController`
537        // — it starts at a probe size and caps the batch to a memory target once
538        // the real row width is known (the cap is computed in the loop). The SQL
539        // Server analogue of the PostgreSQL cursor's `FETCH N`. (`adaptive` resize
540        // is a no-op here: a single streaming connection can't sample DB pressure
541        // mid-stream; the OPT-2 concurrency governor handles that at the chunk
542        // layer instead.)
543        let mut ctl =
544            AdaptiveBatchController::new(request.tuning, request.tuning.batch_size.max(1));
545        let mut cap_applied = false;
546        // Source-safety knobs (parity with the PG/MySQL export loops):
547        //  - lock_timeout → server-side `SET LOCK_TIMEOUT` so a blocked read
548        //    fails fast instead of waiting on a writer's lock indefinitely.
549        //  - statement_timeout → enforced client-side: SQL Server has no
550        //    statement-duration `SET` (unlike PG's `statement_timeout` / MySQL's
551        //    `max_execution_time`), so we stop pulling and error out once the
552        //    wall-clock budget is spent. The half-drained stream is dropped with
553        //    the (errored) source, so nothing leaks.
554        //  - throttle_ms → applied by the controller between batches.
555        let lock_timeout_ms = request.tuning.lock_timeout_s.saturating_mul(1000);
556        let stmt_timeout = (request.tuning.statement_timeout_s > 0)
557            .then(|| std::time::Duration::from_secs(request.tuning.statement_timeout_s));
558
559        // Resolve declared decimal precision/scale from `sys.columns` for the
560        // *unwrapped* base query (the chunk/keyset wrapper hides the source
561        // table from the FROM parser, so resolve from the base — same restriction
562        // as PG's catalog hints). `None` ⇒ not a simple single-table SELECT, so
563        // the schema builder falls back to data-inference, today's behaviour.
564        let hint_query = request.catalog_hint_query.unwrap_or(request.query);
565        let decimal_hints = self.mssql_decimal_catalog_hints_opt(hint_query);
566
567        // Record that we are about to mutate session state so `Drop` resets it
568        // (Epic 18 B2). Set before the disjoint-borrow destructure below.
569        if lock_timeout_ms > 0 {
570            self.lock_timeout_applied = true;
571        }
572
573        let Self { rt, client, .. } = self;
574        rt.block_on(async {
575            use futures_util::stream::TryStreamExt;
576            use tiberius::QueryItem;
577
578            if lock_timeout_ms > 0 {
579                client
580                    .execute(format!("SET LOCK_TIMEOUT {lock_timeout_ms}"), &[])
581                    .await
582                    .map_err(|e| anyhow::anyhow!("mssql: SET LOCK_TIMEOUT failed: {e}"))?;
583            }
584
585            let started = std::time::Instant::now();
586            let mut stream = client
587                .query(sql.as_str(), &[])
588                .await
589                .map_err(|e| anyhow::anyhow!("mssql: query failed: {e}"))?;
590
591            let mut columns: Vec<tiberius::Column> = Vec::new();
592            let mut buf: Vec<tiberius::Row> = Vec::with_capacity(ctl.target());
593            let mut schema: Option<SchemaRef> = None;
594            // Per-value ceiling (MB→bytes; `0`/None disables), enforced
595            // pre-allocation inside the batch builder so an oversized cell bails
596            // before Arrow reserves the buffer. Same source of truth as the sink.
597            let max_value_bytes = request.tuning.max_value_bytes();
598
599            while let Some(item) = stream
600                .try_next()
601                .await
602                .map_err(|e| anyhow::anyhow!("mssql: streaming rows failed: {e}"))?
603            {
604                if let Some(budget) = stmt_timeout
605                    && started.elapsed() > budget
606                {
607                    // Typed marker (not a bare string): the retry classifier
608                    // downcasts the TYPE → permanent, so a reworded message can
609                    // never silently make this deterministic timeout retryable.
610                    // Its Display carries the same actionable hint for the user.
611                    return Err(crate::source::StatementDurationTimeout::mssql(
612                        budget.as_secs(),
613                    )
614                    .into());
615                }
616                match item {
617                    // A single SELECT yields one metadata token (the column
618                    // shape) ahead of its rows.
619                    QueryItem::Metadata(meta) if columns.is_empty() => {
620                        columns = meta.columns().to_vec();
621                        // First moment the column shape is known. SQL Server
622                        // can't seed the controller from `effective_batch_size`
623                        // up-front (the schema isn't known until now), so raise
624                        // the ceiling here — otherwise it stays pinned at the
625                        // static `batch_size` and the post-probe memory cap
626                        // (shrink-only) can never grow a narrow table's batch
627                        // past it, the way the PG/MySQL loops do. A provisional
628                        // schema (no rows) is enough for the row-byte estimate;
629                        // the real, decimal-scale-correct schema is still built
630                        // per batch in `emit_mssql_batch`.
631                        if let Ok((provisional, _)) = arrow_convert::mssql_columns_to_schema(
632                            &columns,
633                            &overrides,
634                            &[],
635                            decimal_hints.as_ref(),
636                        ) {
637                            let eff = request
638                                .tuning
639                                .effective_batch_size(Some(&Arc::new(provisional)));
640                            ctl.raise_configured_ceiling(eff);
641                        }
642                    }
643                    QueryItem::Metadata(_) => {}
644                    QueryItem::Row(row) => {
645                        buf.push(row);
646                        if buf.len() >= ctl.target() {
647                            let arrow_bytes = emit_mssql_batch(
648                                &columns,
649                                &overrides,
650                                decimal_hints.as_ref(),
651                                &mut schema,
652                                &buf,
653                                sink,
654                                max_value_bytes,
655                            )?;
656                            let n = buf.len();
657                            buf.clear();
658                            // First batch: cap to a memory target now that the
659                            // real Arrow width is known (same probe→cap the
660                            // PG/MySQL loops do, clamped to the configured
661                            // batch_size by the controller).
662                            if !cap_applied && n > 0 {
663                                let arrow_per_row = (arrow_bytes / n).max(64);
664                                let target_mb = request
665                                    .tuning
666                                    .batch_size_memory_mb
667                                    .unwrap_or(DEFAULT_BATCH_TARGET_MB);
668                                let safe = ((target_mb * 1024 * 1024) / arrow_per_row)
669                                    .max(PROBE_BATCH_SIZE);
670                                if let Some(new) = ctl.apply_memory_cap(safe) {
671                                    log::info!(
672                                        "MSSQL batch cap: arrow≈{} B/row, target={} MB → batch_size → {}",
673                                        arrow_per_row,
674                                        target_mb,
675                                        new
676                                    );
677                                    buf.reserve(new.saturating_sub(buf.capacity()));
678                                }
679                                cap_applied = true;
680                            }
681                            // adaptive no-op mid-stream (sample → None); throttle.
682                            ctl.after_batch(|| None);
683                            ctl.throttle();
684                        }
685                    }
686                }
687            }
688            // Final partial batch — or, for an empty result set, a single call
689            // that still emits the (empty) schema so the sink writes a
690            // correctly-typed empty output. Rows arrive in the query's
691            // `ORDER BY` order, so the last batch's last row carries the max
692            // cursor the sink extracts.
693            if !buf.is_empty() || schema.is_none() {
694                emit_mssql_batch(
695                    &columns,
696                    &overrides,
697                    decimal_hints.as_ref(),
698                    &mut schema,
699                    &buf,
700                    sink,
701                    max_value_bytes,
702                )?;
703            }
704            Ok::<_, anyhow::Error>(())
705        })?;
706        Ok(())
707    }
708
709    fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
710        let Self { rt, client, .. } = self;
711        rt.block_on(async {
712            let row = client
713                .query(sql, &[])
714                .await
715                .map_err(|e| anyhow::anyhow!("mssql: scalar query failed: {e}"))?
716                .into_row()
717                .await
718                .map_err(|e| anyhow::anyhow!("mssql: reading scalar row failed: {e}"))?;
719            Ok(row.and_then(|r| scalar_to_string(&r)))
720        })
721    }
722
723    fn type_mappings(
724        &mut self,
725        query: &str,
726        column_overrides: &ColumnOverrides,
727    ) -> Result<Vec<TypeMapping>> {
728        // Zero-row wrapper so the server returns column metadata without a scan.
729        let wrapped = format!("SELECT * FROM ({query}) AS _rivet_q WHERE 1 = 0");
730        let overrides = column_overrides.clone();
731        let Self { rt, client, .. } = self;
732        rt.block_on(async {
733            let mut stream = client
734                .query(wrapped.as_str(), &[])
735                .await
736                .map_err(|e| anyhow::anyhow!("mssql: type-probe query failed: {e}"))?;
737            let columns = stream
738                .columns()
739                .await
740                .map_err(|e| anyhow::anyhow!("mssql: type-probe metadata failed: {e}"))?
741                .map(<[_]>::to_vec)
742                .unwrap_or_default();
743            // Drain so the connection is reusable.
744            let _ = stream.into_first_result().await;
745            Ok(arrow_convert::mssql_type_mappings(&columns, &overrides))
746        })
747    }
748
749    fn sample_pressure(&mut self) -> Option<u64> {
750        let Self { rt, client, .. } = self;
751        // Extraction-pressure proxy (Epic 18 C2): cumulative `Workfiles Created`
752        // + `Worktables Created` (SQLServer:Access Methods). A workfile /
753        // worktable is created when a sort or hash spills to tempdb — the SQL
754        // Server analogue of PG `temp_bytes` / MySQL `Created_tmp_disk_tables`.
755        // The `cntr_value` of these `*/sec`-named perfmon counters is the raw
756        // cumulative count, so their sum is monotonic — exactly what the governor
757        // compares deltas of. Replaces `Log Flush Waits`, which is redo-**write**
758        // pressure and barely moves during a read-only export. Instance-level
759        // (no per-database `instance_name`), so no parameter is bound.
760        let sql = "SELECT SUM(cntr_value) FROM sys.dm_os_performance_counters \
761                   WHERE counter_name IN ('Workfiles Created/sec', 'Worktables Created/sec')";
762        rt.block_on(async {
763            let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
764            row.get::<i64, _>(0).map(|v| v.max(0) as u64)
765        })
766    }
767}
768
769impl MssqlSource {
770    /// Snapshot lock-wait counters from `sys.dm_os_wait_stats` (LCK_* waits) —
771    /// the SQL Server contention signal the 0.12 harm A/B tracked. This is a
772    /// server-scoped DMV: it needs `VIEW SERVER STATE`; a missing grant (or any
773    /// query error) yields `None`, so the metric is simply skipped, never failing
774    /// the export. Cumulative since server start; the pipeline deltas it around
775    /// the run.
776    pub(crate) fn harm_counters(&mut self) -> Option<Vec<(String, i64)>> {
777        let Self { rt, client, .. } = self;
778        let sql = "SELECT SUM(waiting_tasks_count), SUM(wait_time_ms) \
779                   FROM sys.dm_os_wait_stats WHERE wait_type LIKE 'LCK%'";
780        rt.block_on(async {
781            let row = client.query(sql, &[]).await.ok()?.into_row().await.ok()??;
782            let waits = row.get::<i64, _>(0).unwrap_or(0);
783            let wait_ms = row.get::<i64, _>(1).unwrap_or(0);
784            Some(vec![
785                ("mssql_lock_waits".to_string(), waits),
786                ("mssql_lock_wait_ms".to_string(), wait_ms),
787            ])
788        })
789    }
790
791    /// Does the current login hold `VIEW SERVER STATE` — the permission
792    /// [`harm_counters`] needs? `Some(true/false)` via `HAS_PERMS_BY_NAME`
793    /// (callable by any login for its own permissions, so this probe itself
794    /// never needs a grant); `None` only if even that round-trip fails.
795    pub(crate) fn has_view_server_state(&mut self) -> Option<bool> {
796        let Self { rt, client, .. } = self;
797        rt.block_on(async {
798            let row = client
799                .query(
800                    "SELECT HAS_PERMS_BY_NAME(NULL, NULL, 'VIEW SERVER STATE')",
801                    &[],
802                )
803                .await
804                .ok()?
805                .into_row()
806                .await
807                .ok()??;
808            row.get::<i32, _>(0).map(|v| v == 1)
809        })
810    }
811}
812
813/// Connect and snapshot MSSQL harm counters; see [`MssqlSource::harm_counters`].
814/// `None` on connect failure or a missing `VIEW SERVER STATE` grant.
815pub(crate) fn sample_harm_counters(
816    url: &str,
817    tls: Option<&TlsConfig>,
818) -> Option<Vec<(String, i64)>> {
819    let mut src = MssqlSource::connect_with_tls(url, tls).ok()?;
820    src.harm_counters()
821}
822
823/// Connect and check whether the login has `VIEW SERVER STATE` — used by
824/// `rivet doctor` to *advise* (never block) that source-harm metrics will be
825/// skipped without it. `None` on connect failure, in which case doctor stays
826/// silent rather than guess.
827pub(crate) fn sample_view_server_state(url: &str, tls: Option<&TlsConfig>) -> Option<bool> {
828    let mut src = MssqlSource::connect_with_tls(url, tls).ok()?;
829    src.has_view_server_state()
830}
831
832/// Emit one Arrow batch from `rows`, building (and emitting) the schema on the
833/// first call and reusing it thereafter. tiberius drops a decimal column's
834/// declared precision/scale, so the scale is recovered from the `decimal_hints`
835/// catalog lookup (the upstream, lossless source that survives an all-NULL
836/// first batch); only an expression/computed column with no catalog entry falls
837/// back to inferring the scale from the first batch's data.
838///
839/// Returns the emitted batch's Arrow memory footprint (bytes), so the export
840/// loop can size the memory cap from the real row width; `0` for an empty batch.
841fn emit_mssql_batch(
842    columns: &[tiberius::Column],
843    overrides: &ColumnOverrides,
844    decimal_hints: Option<&HashMap<String, (u8, i8)>>,
845    schema: &mut Option<SchemaRef>,
846    rows: &[tiberius::Row],
847    sink: &mut dyn BatchSink,
848    max_value_bytes: Option<usize>,
849) -> Result<usize> {
850    let schema_ref = match schema {
851        Some(s) => s.clone(),
852        None => {
853            let (built, _decoders) =
854                arrow_convert::mssql_columns_to_schema(columns, overrides, rows, decimal_hints)?;
855            let s: SchemaRef = Arc::new(built);
856            sink.on_schema(s.clone())?;
857            *schema = Some(s.clone());
858            s
859        }
860    };
861    if !rows.is_empty() {
862        let batch = arrow_convert::mssql_rows_to_record_batch(&schema_ref, rows, max_value_bytes)?;
863        let bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
864        sink.on_batch(&batch)?;
865        return Ok(bytes);
866    }
867    Ok(0)
868}
869
870/// Render a row's first column as a display string for `query_scalar`
871/// (min/max bounds, COUNT(*), SELECT 1). Covers the scalar shapes the planner
872/// asks for; richer typing flows through the export path, not here.
873fn scalar_to_string(row: &tiberius::Row) -> Option<String> {
874    use tiberius::ColumnData;
875    let cell = row.cells().next().map(|(_, d)| d)?;
876    match cell {
877        ColumnData::U8(v) => v.map(|x| x.to_string()),
878        ColumnData::I16(v) => v.map(|x| x.to_string()),
879        ColumnData::I32(v) => v.map(|x| x.to_string()),
880        ColumnData::I64(v) => v.map(|x| x.to_string()),
881        ColumnData::F32(v) => v.map(|x| x.to_string()),
882        ColumnData::F64(v) => v.map(|x| x.to_string()),
883        ColumnData::Bit(v) => v.map(|x| x.to_string()),
884        ColumnData::String(v) => v.as_ref().map(|s| s.to_string()),
885        ColumnData::Numeric(v) => v.map(|n| {
886            // unscaled value with an inserted decimal point at `scale`.
887            let raw = n.value();
888            let scale = n.scale() as usize;
889            if scale == 0 {
890                raw.to_string()
891            } else {
892                let neg = raw < 0;
893                let digits = raw.unsigned_abs().to_string();
894                let digits = format!("{digits:0>width$}", width = scale + 1);
895                let (int, frac) = digits.split_at(digits.len() - scale);
896                format!("{}{int}.{frac}", if neg { "-" } else { "" })
897            }
898        }),
899        ColumnData::Guid(v) => v.map(|g| g.to_string()),
900        other => Some(format!("{other:?}")),
901    }
902}
903
904/// Probe `sys.*` for the stats chunked-mode planning needs (ADR-0015 seam).
905/// Mirrors `introspect_pg_table_for_chunking` / `introspect_mysql_table_for_chunking`.
906pub(crate) fn introspect_mssql_table_for_chunking(
907    url: &str,
908    tls: Option<&TlsConfig>,
909    qualified_table: &str,
910) -> Result<TableIntrospection> {
911    let (schema, table) = match qualified_table.split_once('.') {
912        Some((s, t)) => (s.to_string(), t.to_string()),
913        None => ("dbo".to_string(), qualified_table.to_string()),
914    };
915    let mut src = MssqlSource::connect_with_tls(url, tls)?;
916
917    // Row estimate from `sys.dm_db_partition_stats` (rows in the heap/clustered
918    // index, index_id 0/1).
919    let count_sql = format!(
920        "SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
921         JOIN sys.objects o ON o.object_id = p.object_id \
922         JOIN sys.schemas s ON s.schema_id = o.schema_id \
923         WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
924        schema.replace('\'', "''"),
925        table.replace('\'', "''")
926    );
927    let row_estimate = src
928        .query_scalar(&count_sql)?
929        .and_then(|s| s.parse::<i64>().ok())
930        .unwrap_or(0);
931
932    // Single-column integer PK → range chunking. `sys.indexes (is_primary_key)`
933    // + one `index_columns` row + an integer base type.
934    let pk_sql = format!(
935        "SELECT TOP 1 c.name, t.name FROM sys.indexes i \
936         JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
937         JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
938         JOIN sys.types t ON t.user_type_id = c.user_type_id \
939         JOIN sys.objects o ON o.object_id = i.object_id \
940         JOIN sys.schemas s ON s.schema_id = o.schema_id \
941         WHERE i.is_primary_key = 1 AND s.name = N'{}' AND o.name = N'{}' \
942         GROUP BY c.name, t.name HAVING COUNT(*) = 1",
943        schema.replace('\'', "''"),
944        table.replace('\'', "''")
945    );
946    // Keyset keys (OPT-4) — parity with `postgres/mod.rs:314-340`: every
947    // single-column, NOT NULL, UNIQUE index (the PK *plus* any unique
948    // constraint/index), PK-first and de-duplicated, not just the PK. SQL
949    // Server: `sys.indexes.is_unique = 1`, exactly one key column
950    // (`ic.key_ordinal > 0` + `HAVING COUNT(*) = 1`), and the column is NOT NULL
951    // — so `ORDER BY key LIMIT n` is an index range scan and `WHERE key > last`
952    // never skips dup keys. Aggregated with a `CHAR(31)` (unit-separator)
953    // delimiter because the introspection seam only exposes `query_scalar`; that
954    // byte cannot appear in a real identifier, so the split is unambiguous.
955    let keyset_sql = format!(
956        "SELECT STRING_AGG(col, CHAR(31)) WITHIN GROUP (ORDER BY is_pk DESC, col) FROM ( \
957           SELECT col, MAX(is_pk) AS is_pk FROM ( \
958             SELECT MIN(c.name) AS col, MAX(CONVERT(int, i.is_primary_key)) AS is_pk \
959             FROM sys.indexes i \
960             JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.key_ordinal > 0 \
961             JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
962             JOIN sys.objects o ON o.object_id = i.object_id \
963             JOIN sys.schemas s ON s.schema_id = o.schema_id \
964             WHERE i.is_unique = 1 AND c.is_nullable = 0 AND s.name = N'{}' AND o.name = N'{}' \
965             GROUP BY i.object_id, i.index_id HAVING COUNT(*) = 1 \
966           ) per_index GROUP BY col \
967         ) deduped",
968        schema.replace('\'', "''"),
969        table.replace('\'', "''")
970    );
971    let keyset_keys: Vec<String> = src
972        .query_scalar(&keyset_sql)?
973        .map(|s| {
974            s.split('\u{1f}')
975                .filter(|c| !c.is_empty())
976                .map(str::to_string)
977                .collect()
978        })
979        .unwrap_or_default();
980
981    // Single-column integer PK → range chunking. Its own probe (the keyset list
982    // above doesn't carry the type, and range-chunk eligibility needs it).
983    let mut single_int_pk = None;
984    if let Some(pk_col) = src.query_scalar(&pk_sql)? {
985        // The scalar query returns only the column name; re-probe the type to
986        // decide range-chunk eligibility.
987        let type_sql = format!(
988            "SELECT t.name FROM sys.columns c \
989             JOIN sys.types t ON t.user_type_id = c.user_type_id \
990             JOIN sys.objects o ON o.object_id = c.object_id \
991             JOIN sys.schemas s ON s.schema_id = o.schema_id \
992             WHERE s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
993            schema.replace('\'', "''"),
994            table.replace('\'', "''"),
995            pk_col.replace('\'', "''")
996        );
997        if let Some(ty) = src.query_scalar(&type_sql)?
998            && matches!(ty.as_str(), "tinyint" | "smallint" | "int" | "bigint")
999        {
1000            single_int_pk = Some(pk_col);
1001        }
1002    }
1003
1004    Ok(TableIntrospection {
1005        single_int_pk,
1006        keyset_keys,
1007        row_estimate,
1008        avg_row_bytes: None,
1009    })
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::{catalog_decimal_to_params, parse_mssql_simple_from_table};
1015
1016    fn parse(q: &str) -> Option<(String, String)> {
1017        parse_mssql_simple_from_table(q)
1018    }
1019
1020    #[test]
1021    fn parse_unqualified_table_defaults_to_dbo() {
1022        assert_eq!(
1023            parse("SELECT id, amount FROM transactions ORDER BY id"),
1024            Some(("dbo".into(), "transactions".into()))
1025        );
1026    }
1027
1028    #[test]
1029    fn parse_schema_qualified() {
1030        assert_eq!(
1031            parse("SELECT id FROM sales.orders WHERE id > 1"),
1032            Some(("sales".into(), "orders".into()))
1033        );
1034    }
1035
1036    #[test]
1037    fn parse_db_schema_table_takes_last_two() {
1038        assert_eq!(
1039            parse("SELECT * FROM mydb.sales.orders"),
1040            Some(("sales".into(), "orders".into()))
1041        );
1042    }
1043
1044    #[test]
1045    fn parse_bracketed_identifiers() {
1046        assert_eq!(
1047            parse("SELECT * FROM [my schema].[order items]"),
1048            Some(("my schema".into(), "order items".into()))
1049        );
1050    }
1051
1052    #[test]
1053    fn parse_table_with_alias() {
1054        assert_eq!(
1055            parse("SELECT t.id FROM transactions AS t WHERE t.x = 1"),
1056            Some(("dbo".into(), "transactions".into()))
1057        );
1058        assert_eq!(
1059            parse("SELECT t.id FROM transactions t ORDER BY t.id"),
1060            Some(("dbo".into(), "transactions".into()))
1061        );
1062    }
1063
1064    #[test]
1065    fn parse_rejects_join() {
1066        assert_eq!(parse("SELECT * FROM a INNER JOIN b ON a.id = b.id"), None);
1067        assert_eq!(parse("SELECT * FROM a JOIN b ON a.id = b.id"), None);
1068    }
1069
1070    #[test]
1071    fn parse_rejects_comma_list() {
1072        assert_eq!(parse("SELECT * FROM a, b WHERE a.id = b.id"), None);
1073    }
1074
1075    #[test]
1076    fn parse_rejects_subquery_from() {
1077        assert_eq!(parse("SELECT * FROM (SELECT * FROM t) AS s"), None);
1078    }
1079
1080    #[test]
1081    fn parse_ignores_from_inside_string_literal() {
1082        // The first top-level FROM is the real one, not the literal's bytes.
1083        assert_eq!(
1084            parse("SELECT 'from x', amount FROM ledger WHERE note = 'paid from cash'"),
1085            Some(("dbo".into(), "ledger".into()))
1086        );
1087    }
1088
1089    #[test]
1090    fn catalog_bounds_accept_well_formed_and_reject_degenerate() {
1091        // DECIMAL(10,2) — the bug's column — rides through losslessly.
1092        assert_eq!(catalog_decimal_to_params(10, 2), Some((10, 2)));
1093        // SQL Server max precision.
1094        assert_eq!(catalog_decimal_to_params(38, 0), Some((38, 0)));
1095        assert_eq!(catalog_decimal_to_params(38, 38), Some((38, 38)));
1096        // Degenerate rows are rejected (defends against a corrupt catalog row),
1097        // so the builder falls back to data-inference rather than emitting a
1098        // nonsensical decimal type.
1099        assert_eq!(catalog_decimal_to_params(0, 0), None);
1100        assert_eq!(catalog_decimal_to_params(39, 0), None);
1101        assert_eq!(catalog_decimal_to_params(10, 11), None);
1102    }
1103}