Skip to main content

rivet/source/mssql/
mod.rs

1//! **Layer: Execution** — MSSQL / SQL Server source engine.
2//!
3//! Third SQL engine after PostgreSQL and MySQL. The `tiberius` driver is
4//! async (tokio); the `Source` trait is sync `&mut self` (ADR-0011), so each
5//! `MssqlSource` owns a current-thread `tokio` runtime and `block_on`s every
6//! driver call — no async leaks into the runner.
7//!
8//! Dialect deltas vs PG/MySQL (routed through the shared seams):
9//! - identifier quoting `[col]` (`sql::quote_ident`)
10//! - cursor literal `N'…'` with `''` escaping (`query::cursor_rhs`)
11//! - introspection via `sys.*` catalog views
12//!
13//! Supported today: snapshot / incremental / chunked (range + dense) and keyset
14//! (seek) export, `check --type-report`, `doctor`, chunked-mode planning. The
15//! keyset page builder emits a dialect-correct
16//! `OFFSET 0 ROWS FETCH NEXT n ROWS ONLY` clause (T-SQL has no `LIMIT`).
17
18mod arrow_convert;
19mod proxy;
20
21pub use proxy::MssqlProxyKind;
22
23use std::sync::Arc;
24
25use arrow::datatypes::SchemaRef;
26use tiberius::{AuthMethod, Client, Config, EncryptionLevel};
27use tokio::net::TcpStream;
28use tokio::runtime::Runtime;
29use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
30
31use proxy::{detect_mssql_proxy_kind, warn_proxy_kind};
32
33use crate::config::TlsConfig;
34use crate::error::Result;
35use crate::source::batch_controller::{
36    AdaptiveBatchController, DEFAULT_BATCH_TARGET_MB, PROBE_BATCH_SIZE,
37};
38use crate::source::query::build_export_query;
39use crate::source::{BatchSink, ExportRequest, Source, TableIntrospection};
40use crate::types::{ColumnOverrides, TypeMapping};
41
42type MssqlClient = Client<Compat<TcpStream>>;
43
44/// SQL Server source. Owns the async driver + the runtime that drives it.
45///
46/// `pub` (not `pub(crate)`) so integration tests can reach `proxy_kind()` the
47/// same way they reach `MysqlSource::proxy_kind()`; the rest of the type
48/// carries the same "no external API contract" disclaimer as `MysqlSource`.
49pub struct MssqlSource {
50    rt: Runtime,
51    client: MssqlClient,
52    /// Connection database — the instance pressure proxy keys on it.
53    database: String,
54    /// Pooler/gateway classification, sampled once at connect time.
55    proxy_kind: MssqlProxyKind,
56}
57
58/// Parsed `sqlserver://user[:password]@host[:port]/db` connection parts.
59struct MssqlUrl {
60    host: String,
61    port: u16,
62    user: String,
63    password: String,
64    database: String,
65}
66
67fn parse_mssql_url(url: &str) -> Result<MssqlUrl> {
68    let rest = url
69        .strip_prefix("sqlserver://")
70        .or_else(|| url.strip_prefix("mssql://"))
71        .ok_or_else(|| anyhow::anyhow!("mssql url must start with sqlserver:// — got {url}"))?;
72    // userinfo @ host:port / db   (rsplit the last '@' so a '@' in a password
73    // is tolerated; '/' splits host from db).
74    let (userinfo, hostpart) = rest
75        .rsplit_once('@')
76        .ok_or_else(|| anyhow::anyhow!("mssql url missing user@host: {url}"))?;
77    let (user, password) = match userinfo.split_once(':') {
78        Some((u, p)) => (u.to_string(), p.to_string()),
79        None => (userinfo.to_string(), String::new()),
80    };
81    let (hostport, database) = hostpart
82        .split_once('/')
83        .map(|(h, d)| (h, d.to_string()))
84        .unwrap_or((hostpart, String::new()));
85    let (host, port) = match hostport.rsplit_once(':') {
86        Some((h, p)) => (
87            h.to_string(),
88            p.parse::<u16>()
89                .map_err(|_| anyhow::anyhow!("mssql url port not a number: {p}"))?,
90        ),
91        None => (hostport.to_string(), 1433),
92    };
93    if database.is_empty() {
94        anyhow::bail!("mssql url must include a database: sqlserver://user:pass@host:port/<db>");
95    }
96    Ok(MssqlUrl {
97        host,
98        port,
99        user,
100        password,
101        database,
102    })
103}
104
105impl MssqlSource {
106    /// Connect to SQL Server, honouring the shared `TlsConfig`. `url` is the
107    /// resolved `sqlserver://user:pass@host:port/db` form. A successful return
108    /// has completed a TLS login handshake and a `SELECT 1` round-trip.
109    pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
110        let parts = parse_mssql_url(url)?;
111        let mut config = Config::new();
112        config.host(&parts.host);
113        config.port(parts.port);
114        config.database(&parts.database);
115        config.authentication(AuthMethod::sql_server(&parts.user, &parts.password));
116
117        // SQL Server forces TLS on the login handshake regardless; map the
118        // shared TlsConfig onto tiberius' cert-trust knobs. A private CA goes
119        // through `trust_cert_ca`; otherwise dev self-signed certs need
120        // `trust_cert` (accept-invalid). Default keeps full verification.
121        config.encryption(EncryptionLevel::Required);
122        match tls {
123            Some(cfg) if cfg.accept_invalid_certs => config.trust_cert(),
124            Some(cfg) => {
125                if let Some(ca) = cfg.ca_file.as_deref() {
126                    config.trust_cert_ca(ca);
127                }
128            }
129            None => config.trust_cert(),
130        }
131
132        let rt = tokio::runtime::Builder::new_current_thread()
133            .enable_all()
134            .build()
135            .map_err(|e| anyhow::anyhow!("mssql: tokio runtime build failed: {e}"))?;
136
137        let client = rt.block_on(async {
138            let tcp = TcpStream::connect(config.get_addr())
139                .await
140                .map_err(|e| anyhow::anyhow!("mssql: TCP connect failed: {e}"))?;
141            tcp.set_nodelay(true).ok();
142            Client::connect(config, tcp.compat_write())
143                .await
144                .map_err(|e| anyhow::anyhow!("mssql: login failed: {e}"))
145        })?;
146
147        let mut src = Self {
148            rt,
149            client,
150            database: parts.database,
151            proxy_kind: MssqlProxyKind::Direct,
152        };
153        // Health round-trip — surfaces auth/permission errors at connect time
154        // (doctor relies on this).
155        src.query_scalar("SELECT 1")?;
156        // Best-effort pooler/gateway detection (mirrors PG `pg_backend_pid`
157        // drift and MySQL `CONNECTION_ID()` drift): one warning at connect
158        // time, never breaks the export. Disjoint borrows of `rt` (&) and
159        // `client` (&mut).
160        let kind = detect_mssql_proxy_kind(&src.rt, &mut src.client);
161        warn_proxy_kind(kind);
162        src.proxy_kind = kind;
163        Ok(src)
164    }
165
166    /// Expose the proxy classification for diagnostics (preflight, integration
167    /// tests). Not part of the `Source` trait — same internal-may-change
168    /// contract as the rest of `rivet::source::mssql::*`.
169    #[allow(dead_code)]
170    pub fn proxy_kind(&self) -> MssqlProxyKind {
171        self.proxy_kind
172    }
173}
174
175impl Source for MssqlSource {
176    fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()> {
177        // Keyset (seek) pages build a dialect-correct
178        // `OFFSET 0 ROWS FETCH NEXT n ROWS ONLY` clause (T-SQL has no `LIMIT`).
179        let built = build_export_query(request, crate::config::SourceType::Mssql);
180        let sql = built.sql.clone();
181        let overrides = request.column_overrides.clone();
182        // Stream the result one Arrow batch at a time (peak RSS ≈ one batch,
183        // independent of `chunk_size`) through the shared `AdaptiveBatchController`
184        // — it starts at a probe size and caps the batch to a memory target once
185        // the real row width is known (the cap is computed in the loop). The SQL
186        // Server analogue of the PostgreSQL cursor's `FETCH N`. (`adaptive` resize
187        // is a no-op here: a single streaming connection can't sample DB pressure
188        // mid-stream; the OPT-2 concurrency governor handles that at the chunk
189        // layer instead.)
190        let mut ctl =
191            AdaptiveBatchController::new(request.tuning, request.tuning.batch_size.max(1));
192        let mut cap_applied = false;
193        // Source-safety knobs (parity with the PG/MySQL export loops):
194        //  - lock_timeout → server-side `SET LOCK_TIMEOUT` so a blocked read
195        //    fails fast instead of waiting on a writer's lock indefinitely.
196        //  - statement_timeout → enforced client-side: SQL Server has no
197        //    statement-duration `SET` (unlike PG's `statement_timeout` / MySQL's
198        //    `max_execution_time`), so we stop pulling and error out once the
199        //    wall-clock budget is spent. The half-drained stream is dropped with
200        //    the (errored) source, so nothing leaks.
201        //  - throttle_ms → applied by the controller between batches.
202        let lock_timeout_ms = request.tuning.lock_timeout_s.saturating_mul(1000);
203        let stmt_timeout = (request.tuning.statement_timeout_s > 0)
204            .then(|| std::time::Duration::from_secs(request.tuning.statement_timeout_s));
205
206        let Self { rt, client, .. } = self;
207        rt.block_on(async {
208            use futures_util::stream::TryStreamExt;
209            use tiberius::QueryItem;
210
211            if lock_timeout_ms > 0 {
212                client
213                    .execute(format!("SET LOCK_TIMEOUT {lock_timeout_ms}"), &[])
214                    .await
215                    .map_err(|e| anyhow::anyhow!("mssql: SET LOCK_TIMEOUT failed: {e}"))?;
216            }
217
218            let started = std::time::Instant::now();
219            let mut stream = client
220                .query(sql.as_str(), &[])
221                .await
222                .map_err(|e| anyhow::anyhow!("mssql: query failed: {e}"))?;
223
224            let mut columns: Vec<tiberius::Column> = Vec::new();
225            let mut buf: Vec<tiberius::Row> = Vec::with_capacity(ctl.target());
226            let mut schema: Option<SchemaRef> = None;
227
228            while let Some(item) = stream
229                .try_next()
230                .await
231                .map_err(|e| anyhow::anyhow!("mssql: streaming rows failed: {e}"))?
232            {
233                if let Some(budget) = stmt_timeout
234                    && started.elapsed() > budget
235                {
236                    anyhow::bail!(
237                        "mssql: statement timeout after {}s (tuning.statement_timeout_s)",
238                        budget.as_secs()
239                    );
240                }
241                match item {
242                    // A single SELECT yields one metadata token (the column
243                    // shape) ahead of its rows.
244                    QueryItem::Metadata(meta) if columns.is_empty() => {
245                        columns = meta.columns().to_vec();
246                    }
247                    QueryItem::Metadata(_) => {}
248                    QueryItem::Row(row) => {
249                        buf.push(row);
250                        if buf.len() >= ctl.target() {
251                            let arrow_bytes =
252                                emit_mssql_batch(&columns, &overrides, &mut schema, &buf, sink)?;
253                            let n = buf.len();
254                            buf.clear();
255                            // First batch: cap to a memory target now that the
256                            // real Arrow width is known (same probe→cap the
257                            // PG/MySQL loops do, clamped to the configured
258                            // batch_size by the controller).
259                            if !cap_applied && n > 0 {
260                                let arrow_per_row = (arrow_bytes / n).max(64);
261                                let target_mb = request
262                                    .tuning
263                                    .batch_size_memory_mb
264                                    .unwrap_or(DEFAULT_BATCH_TARGET_MB);
265                                let safe = ((target_mb * 1024 * 1024) / arrow_per_row)
266                                    .max(PROBE_BATCH_SIZE);
267                                if let Some(new) = ctl.apply_memory_cap(safe) {
268                                    log::info!(
269                                        "MSSQL batch cap: arrow≈{} B/row, target={} MB → batch_size → {}",
270                                        arrow_per_row,
271                                        target_mb,
272                                        new
273                                    );
274                                    buf.reserve(new.saturating_sub(buf.capacity()));
275                                }
276                                cap_applied = true;
277                            }
278                            // adaptive no-op mid-stream (sample → None); throttle.
279                            ctl.after_batch(|| None);
280                            ctl.throttle();
281                        }
282                    }
283                }
284            }
285            // Final partial batch — or, for an empty result set, a single call
286            // that still emits the (empty) schema so the sink writes a
287            // correctly-typed empty output. Rows arrive in the query's
288            // `ORDER BY` order, so the last batch's last row carries the max
289            // cursor the sink extracts.
290            if !buf.is_empty() || schema.is_none() {
291                emit_mssql_batch(&columns, &overrides, &mut schema, &buf, sink)?;
292            }
293            Ok::<_, anyhow::Error>(())
294        })?;
295        Ok(())
296    }
297
298    fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
299        let Self { rt, client, .. } = self;
300        rt.block_on(async {
301            let row = client
302                .query(sql, &[])
303                .await
304                .map_err(|e| anyhow::anyhow!("mssql: scalar query failed: {e}"))?
305                .into_row()
306                .await
307                .map_err(|e| anyhow::anyhow!("mssql: reading scalar row failed: {e}"))?;
308            Ok(row.and_then(|r| scalar_to_string(&r)))
309        })
310    }
311
312    fn type_mappings(
313        &mut self,
314        query: &str,
315        column_overrides: &ColumnOverrides,
316    ) -> Result<Vec<TypeMapping>> {
317        // Zero-row wrapper so the server returns column metadata without a scan.
318        let wrapped = format!("SELECT * FROM ({query}) AS _rivet_q WHERE 1 = 0");
319        let overrides = column_overrides.clone();
320        let Self { rt, client, .. } = self;
321        rt.block_on(async {
322            let mut stream = client
323                .query(wrapped.as_str(), &[])
324                .await
325                .map_err(|e| anyhow::anyhow!("mssql: type-probe query failed: {e}"))?;
326            let columns = stream
327                .columns()
328                .await
329                .map_err(|e| anyhow::anyhow!("mssql: type-probe metadata failed: {e}"))?
330                .map(<[_]>::to_vec)
331                .unwrap_or_default();
332            // Drain so the connection is reusable.
333            let _ = stream.into_first_result().await;
334            Ok(arrow_convert::mssql_type_mappings(&columns, &overrides))
335        })
336    }
337
338    fn sample_pressure(&mut self) -> Option<u64> {
339        let db = self.database.clone();
340        let Self { rt, client, .. } = self;
341        // `Log Flush Waits` is the SQL Server analog of MySQL `Innodb_log_waits`:
342        // a writer stalled waiting on the log = source write pressure. The
343        // `cntr_value` of a `*/sec` perfmon counter is the raw cumulative count,
344        // so it is monotonic — exactly what the governor compares deltas of.
345        let sql = "SELECT cntr_value FROM sys.dm_os_performance_counters \
346                   WHERE counter_name LIKE 'Log Flush Wait%' AND instance_name = @P1";
347        rt.block_on(async {
348            let row = client
349                .query(sql, &[&db])
350                .await
351                .ok()?
352                .into_row()
353                .await
354                .ok()??;
355            row.get::<i64, _>(0).map(|v| v.max(0) as u64)
356        })
357    }
358}
359
360/// Emit one Arrow batch from `rows`, building (and emitting) the schema on the
361/// first call and reusing it thereafter. Decimal scales are recovered from the
362/// data — tiberius drops a column's declared precision/scale — so the first
363/// batch must carry each decimal column's first non-null value (true for every
364/// table in practice; a decimal column NULL for the whole first batch falls back
365/// to scale 0, same as the pre-streaming behaviour on an all-null column).
366///
367/// Returns the emitted batch's Arrow memory footprint (bytes), so the export
368/// loop can size the memory cap from the real row width; `0` for an empty batch.
369fn emit_mssql_batch(
370    columns: &[tiberius::Column],
371    overrides: &ColumnOverrides,
372    schema: &mut Option<SchemaRef>,
373    rows: &[tiberius::Row],
374    sink: &mut dyn BatchSink,
375) -> Result<usize> {
376    let schema_ref = match schema {
377        Some(s) => s.clone(),
378        None => {
379            let (built, _decoders) =
380                arrow_convert::mssql_columns_to_schema(columns, overrides, rows)?;
381            let s: SchemaRef = Arc::new(built);
382            sink.on_schema(s.clone())?;
383            *schema = Some(s.clone());
384            s
385        }
386    };
387    if !rows.is_empty() {
388        let batch = arrow_convert::mssql_rows_to_record_batch(&schema_ref, rows)?;
389        let bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
390        sink.on_batch(&batch)?;
391        return Ok(bytes);
392    }
393    Ok(0)
394}
395
396/// Render a row's first column as a display string for `query_scalar`
397/// (min/max bounds, COUNT(*), SELECT 1). Covers the scalar shapes the planner
398/// asks for; richer typing flows through the export path, not here.
399fn scalar_to_string(row: &tiberius::Row) -> Option<String> {
400    use tiberius::ColumnData;
401    let cell = row.cells().next().map(|(_, d)| d)?;
402    match cell {
403        ColumnData::U8(v) => v.map(|x| x.to_string()),
404        ColumnData::I16(v) => v.map(|x| x.to_string()),
405        ColumnData::I32(v) => v.map(|x| x.to_string()),
406        ColumnData::I64(v) => v.map(|x| x.to_string()),
407        ColumnData::F32(v) => v.map(|x| x.to_string()),
408        ColumnData::F64(v) => v.map(|x| x.to_string()),
409        ColumnData::Bit(v) => v.map(|x| x.to_string()),
410        ColumnData::String(v) => v.as_ref().map(|s| s.to_string()),
411        ColumnData::Numeric(v) => v.map(|n| {
412            // unscaled value with an inserted decimal point at `scale`.
413            let raw = n.value();
414            let scale = n.scale() as usize;
415            if scale == 0 {
416                raw.to_string()
417            } else {
418                let neg = raw < 0;
419                let digits = raw.unsigned_abs().to_string();
420                let digits = format!("{digits:0>width$}", width = scale + 1);
421                let (int, frac) = digits.split_at(digits.len() - scale);
422                format!("{}{int}.{frac}", if neg { "-" } else { "" })
423            }
424        }),
425        ColumnData::Guid(v) => v.map(|g| g.to_string()),
426        other => Some(format!("{other:?}")),
427    }
428}
429
430/// Probe `sys.*` for the stats chunked-mode planning needs (ADR-0015 seam).
431/// Mirrors `introspect_pg_table_for_chunking` / `introspect_mysql_table_for_chunking`.
432pub(crate) fn introspect_mssql_table_for_chunking(
433    url: &str,
434    tls: Option<&TlsConfig>,
435    qualified_table: &str,
436) -> Result<TableIntrospection> {
437    let (schema, table) = match qualified_table.split_once('.') {
438        Some((s, t)) => (s.to_string(), t.to_string()),
439        None => ("dbo".to_string(), qualified_table.to_string()),
440    };
441    let mut src = MssqlSource::connect_with_tls(url, tls)?;
442
443    // Row estimate from `sys.dm_db_partition_stats` (rows in the heap/clustered
444    // index, index_id 0/1).
445    let count_sql = format!(
446        "SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
447         JOIN sys.objects o ON o.object_id = p.object_id \
448         JOIN sys.schemas s ON s.schema_id = o.schema_id \
449         WHERE s.name = N'{}' AND o.name = N'{}' AND p.index_id IN (0,1)",
450        schema.replace('\'', "''"),
451        table.replace('\'', "''")
452    );
453    let row_estimate = src
454        .query_scalar(&count_sql)?
455        .and_then(|s| s.parse::<i64>().ok())
456        .unwrap_or(0);
457
458    // Single-column integer PK → range chunking. `sys.indexes (is_primary_key)`
459    // + one `index_columns` row + an integer base type.
460    let pk_sql = format!(
461        "SELECT TOP 1 c.name, t.name FROM sys.indexes i \
462         JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id \
463         JOIN sys.columns c ON c.object_id = ic.object_id AND c.column_id = ic.column_id \
464         JOIN sys.types t ON t.user_type_id = c.user_type_id \
465         JOIN sys.objects o ON o.object_id = i.object_id \
466         JOIN sys.schemas s ON s.schema_id = o.schema_id \
467         WHERE i.is_primary_key = 1 AND s.name = N'{}' AND o.name = N'{}' \
468         GROUP BY c.name, t.name HAVING COUNT(*) = 1",
469        schema.replace('\'', "''"),
470        table.replace('\'', "''")
471    );
472    let mut single_int_pk = None;
473    let mut keyset_keys = Vec::new();
474    if let Some(pk_col) = src.query_scalar(&pk_sql)? {
475        // A single-column PK is always a usable keyset key (unique, NOT NULL).
476        keyset_keys.push(pk_col.clone());
477        // The scalar query returns only the column name; re-probe the type to
478        // decide range-chunk eligibility.
479        let type_sql = format!(
480            "SELECT t.name FROM sys.columns c \
481             JOIN sys.types t ON t.user_type_id = c.user_type_id \
482             JOIN sys.objects o ON o.object_id = c.object_id \
483             JOIN sys.schemas s ON s.schema_id = o.schema_id \
484             WHERE s.name = N'{}' AND o.name = N'{}' AND c.name = N'{}'",
485            schema.replace('\'', "''"),
486            table.replace('\'', "''"),
487            pk_col.replace('\'', "''")
488        );
489        if let Some(ty) = src.query_scalar(&type_sql)?
490            && matches!(ty.as_str(), "tinyint" | "smallint" | "int" | "bigint")
491        {
492            single_int_pk = Some(pk_col);
493        }
494    }
495
496    Ok(TableIntrospection {
497        single_int_pk,
498        keyset_keys,
499        row_estimate,
500        avg_row_bytes: None,
501    })
502}