Skip to main content

rivet/source/
mod.rs

1pub(crate) mod batch_controller;
2pub(crate) mod cdc;
3pub mod mssql;
4pub mod mysql;
5pub(crate) mod pg_numeric_wire;
6pub mod postgres;
7pub(crate) mod query;
8pub(crate) mod tls;
9pub(crate) mod value_checksum;
10
11use arrow::datatypes::SchemaRef;
12use arrow::record_batch::RecordBatch;
13
14use crate::config::{SourceConfig, TlsConfig};
15use crate::error::Result;
16use crate::plan::IncrementalCursorPlan;
17use crate::tuning::SourceTuning;
18use crate::types::{ColumnOverrides, CursorState, TypeMapping};
19
20/// A statement-DURATION timeout that **rivet itself** raised — distinct from a
21/// driver-native timeout that carries a structured code (PG 57014, MySQL 3024).
22///
23/// The MSSQL engine has no server-side statement-duration `SET`, so rivet
24/// enforces `tuning.statement_timeout_s` client-side and raises this when the
25/// budget is exceeded (see [`mssql`]). Before this type the retry classifier's
26/// permanence hinged on substring-matching rivet's OWN prose ("statement
27/// timeout after …"); a reworded message would silently flip the error back to
28/// *transient*, and the identical query would be retried until it burned the
29/// budget N times (measured: 3×300 s = 20 min for 0 rows). Carrying a typed
30/// marker means [`crate::pipeline::retry::classify_error`] downcasts the TYPE,
31/// so permanence survives any change to the human-facing wording. The string
32/// branches in the classifier remain a fallback for genuinely driver-native
33/// timeout messages we do not control.
34#[derive(Debug)]
35pub struct StatementDurationTimeout {
36    /// Full actionable message shown to the operator. The classifier keys off
37    /// the TYPE, not this text — it exists only for Display.
38    message: String,
39}
40
41impl StatementDurationTimeout {
42    /// MSSQL client-side statement-duration timeout (no server-side `SET`).
43    pub fn mssql(seconds: u64) -> Self {
44        Self {
45            message: format!(
46                "mssql: statement timeout after {seconds}s (tuning.statement_timeout_s) — \
47                 this query cannot finish within the budget; split it with `mode: chunked` \
48                 (per-chunk statements stay under the limit) or raise \
49                 `tuning.statement_timeout_s`"
50            ),
51        }
52    }
53}
54
55impl std::fmt::Display for StatementDurationTimeout {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.write_str(&self.message)
58    }
59}
60
61impl std::error::Error for StatementDurationTimeout {}
62
63/// Summary of a source table relevant to chunked-mode planning. Source-neutral
64/// shape so plan-build can ask either Postgres or MySQL for the same answer.
65///
66/// Populated by `crate::source::postgres::introspect_pg_table_for_chunking` and
67/// `crate::source::mysql::introspect_mysql_table_for_chunking`. Both helpers
68/// rely on catalog stats (`pg_class` / `information_schema.TABLES`) so the
69/// numbers are only as fresh as the last `ANALYZE` / autoanalyse.
70///
71/// # Why this is a data-shape seam, not a trait
72///
73/// The two per-engine introspection functions have identical signatures
74/// (`fn(url, tls, qualified_table) -> Result<TableIntrospection>`) and return
75/// this shared struct. The parallel shape sometimes invites a refactor along
76/// the lines of `trait Introspector { fn introspect_table(...) }` with one
77/// impl per engine — that refactor adds ceremony without reducing duplication,
78/// because the *bodies* share nothing useful: PG queries `pg_class` /
79/// `pg_index` / `pg_attribute` / `pg_type` (PG-specific type names like
80/// `int2`/`int4`/`int8`) via the `postgres` client; MySQL queries
81/// `information_schema.TABLES` / `STATISTICS` with the InnoDB
82/// `AVG_ROW_LENGTH` overflow correction via the `mysql` client. No shared
83/// implementation logic exists to extract into trait-default methods. A
84/// trait would only rename where the engine match happens
85/// (`match config.source.source_type { … }` at the call site → factory
86/// returning `Box<dyn Introspector>`); the match doesn't disappear.
87///
88/// The seam therefore lives at the **data shape**: this struct is the
89/// shared contract, the two free functions are the adapters, the per-call
90/// dispatch is an `enum`-driven `match`. See ADR-0015 for the full
91/// rationale and the architecture-review walks that led here.
92#[derive(Debug, Clone, Default)]
93pub(crate) struct TableIntrospection {
94    /// Name of the single integer-family PK column, if present and safe to
95    /// range-chunk. `None` when the table has no PK, has a composite PK, or
96    /// the PK type is not an integer family (text, uuid, decimal, …).
97    pub single_int_pk: Option<String>,
98    /// Single-column, NOT NULL, **unique** index columns usable as a keyset
99    /// (seek) pagination key — PK first (any type), then other UNIQUE indexes
100    /// (OPT-4). Index-backed and unique by construction, so `ORDER BY key
101    /// LIMIT n` is a bounded index range scan (never a filesort) and
102    /// `WHERE key > last` never skips rows with a duplicate key. Empty when the
103    /// table has no such key.
104    pub keyset_keys: Vec<String>,
105    /// Best-effort row count: PG `reltuples`, MySQL `TABLE_ROWS`. `0` means
106    /// the table is empty or stats are unavailable.
107    pub row_estimate: i64,
108    /// Heap-size-per-row in bytes. `None` for empty / unanalysed tables.
109    /// Used to convert `chunk_size_memory_mb` into a row count.
110    pub avg_row_bytes: Option<i64>,
111}
112
113impl TableIntrospection {
114    /// The auto-selected keyset key: the first usable single-column unique
115    /// NOT NULL key (PK preferred). `None` when the table has none.
116    pub fn auto_keyset_key(&self) -> Option<&str> {
117        self.keyset_keys.first().map(String::as_str)
118    }
119
120    /// Whether `col` is a usable keyset key (single-column, unique, NOT NULL,
121    /// index-backed). Used to validate an explicit `chunk_by_key`.
122    pub fn is_usable_keyset_key(&self, col: &str) -> bool {
123        self.keyset_keys.iter().any(|k| k == col)
124    }
125}
126
127/// Receives schema and batches from a source, one at a time.
128pub trait BatchSink {
129    fn on_schema(&mut self, schema: SchemaRef) -> Result<()>;
130    fn on_batch(&mut self, batch: &RecordBatch) -> Result<()>;
131}
132
133/// Read-only inputs for a single export call.
134///
135/// Packs the parameters that used to live as 5 positional args on
136/// `Source::export` into a named struct. `sink` is **not** part of this struct
137/// — it is `&mut` and conceptually the output channel, separate from the
138/// read-only request configuration.
139pub struct ExportRequest<'a> {
140    /// Already-materialized SQL (after `resolve_query`). The driver still wraps
141    /// it with the dialect-specific incremental predicate via
142    /// [`crate::source::query::build_incremental_query`] when `incremental` is set.
143    pub query: &'a str,
144    /// The *unwrapped* base query to resolve catalog-dependent type hints from
145    /// (PostgreSQL `NUMERIC` precision/scale, which the wire protocol omits — the
146    /// driver parses the `FROM` clause and asks `pg_catalog`). Chunked, dense and
147    /// keyset runners wrap `query` in a `SELECT … FROM (<base>) …` subquery that
148    /// hides the source table from the catalog parser, so they pass the original
149    /// base query here. `None` ⇒ resolve from `query` (full/incremental, where it
150    /// is already the unwrapped form). Drivers that read precision from the wire
151    /// (MySQL) ignore this field.
152    pub catalog_hint_query: Option<&'a str>,
153    pub incremental: Option<&'a IncrementalCursorPlan>,
154    pub cursor: Option<&'a CursorState>,
155    pub tuning: &'a SourceTuning,
156    /// Per-column type declarations from `rivet.yaml` (`exports[].columns:`).
157    /// Drivers apply them during schema building so e.g. a `NUMERIC` column
158    /// without declared precision can still be exported as `Decimal128(18,2)`
159    /// when the user has stated the type explicitly.
160    pub column_overrides: &'a ColumnOverrides,
161    /// Keyset (seek) pagination page size (OPT-4). When `Some(n)` *and*
162    /// `incremental` carries the key plan, the driver builds one keyset page
163    /// (`WHERE key > cursor ORDER BY key LIMIT n`) instead of the unbounded
164    /// incremental/snapshot query. The keyset runner drives the outer loop.
165    pub page_limit: Option<usize>,
166}
167
168impl<'a> ExportRequest<'a> {
169    /// A request whose `query` is already the **unwrapped base** form, so
170    /// catalog type hints resolve directly from it. Use for snapshot,
171    /// incremental and keyset runners: the driver applies any incremental /
172    /// keyset predicate internally, so the source table stays visible to the
173    /// catalog parser and `catalog_hint_query` is `None`.
174    pub fn unwrapped(
175        query: &'a str,
176        tuning: &'a SourceTuning,
177        column_overrides: &'a ColumnOverrides,
178    ) -> Self {
179        Self {
180            query,
181            catalog_hint_query: None,
182            incremental: None,
183            cursor: None,
184            tuning,
185            column_overrides,
186            page_limit: None,
187        }
188    }
189
190    /// A request whose `query` is a `SELECT … FROM (<base>) …` **wrapper** that
191    /// hides the source table (chunked / dense / time-window). `base` — the
192    /// unwrapped query catalog hints resolve from — is a required argument, so a
193    /// wrapping runner cannot silently fall back to the table-hiding wrapper and
194    /// lose PG `NUMERIC` precision (the bug the catalog-hint fix / ADR-0020
195    /// closed). Drivers that read precision from the wire (MySQL) ignore it.
196    pub fn wrapped(
197        query: &'a str,
198        base: &'a str,
199        tuning: &'a SourceTuning,
200        column_overrides: &'a ColumnOverrides,
201    ) -> Self {
202        Self {
203            query,
204            catalog_hint_query: Some(base),
205            incremental: None,
206            cursor: None,
207            tuning,
208            column_overrides,
209            page_limit: None,
210        }
211    }
212
213    /// Attach the incremental cursor plan (the driver builds the `WHERE cursor >
214    /// ? ORDER BY` predicate). Pass-through `Option` so mode-polymorphic callers
215    /// can forward `strategy.incremental_plan()` directly.
216    pub fn with_incremental(mut self, plan: Option<&'a IncrementalCursorPlan>) -> Self {
217        self.incremental = plan;
218        self
219    }
220
221    /// Attach the last committed cursor value the next run resumes after.
222    pub fn with_cursor(mut self, cursor: Option<&'a CursorState>) -> Self {
223        self.cursor = cursor;
224        self
225    }
226
227    /// Set the keyset (seek) page size — one bounded `… WHERE key > cursor ORDER
228    /// BY key LIMIT n` page instead of the unbounded query.
229    pub fn with_page_limit(mut self, page_limit: usize) -> Self {
230        self.page_limit = Some(page_limit);
231        self
232    }
233}
234
235pub trait Source: Send {
236    /// Execute `request.query` and stream batches into `sink`.
237    fn export(&mut self, request: &ExportRequest<'_>, sink: &mut dyn BatchSink) -> Result<()>;
238
239    fn query_scalar(&mut self, sql: &str) -> Result<Option<String>>;
240
241    /// Return `TypeMapping` for every column in `query` without fetching rows.
242    ///
243    /// Used by `rivet check --type-report` to show the full type provenance
244    /// (source native type → RivetType → Arrow type → fidelity) before export.
245    /// Implementations execute `SELECT * FROM (...) AS _q LIMIT 0` so only
246    /// server-side type metadata is transferred.
247    fn type_mappings(
248        &mut self,
249        query: &str,
250        column_overrides: &ColumnOverrides,
251    ) -> Result<Vec<TypeMapping>>;
252
253    /// Sample a monotonic source-pressure counter for the OPT-2 concurrency
254    /// governor (`pipeline::chunked::exec`).
255    ///
256    /// Higher = more pressure. The governor compares successive samples
257    /// (`cur > prev` ⇒ under pressure) — the same convention the adaptive
258    /// batch-size loop already uses. Returns `None` when the engine can't
259    /// cheaply sample a pressure proxy, in which case the governor holds
260    /// parallelism flat. Default: `None`.
261    fn sample_pressure(&mut self) -> Option<u64> {
262        None
263    }
264}
265
266pub fn create_source(config: &SourceConfig) -> Result<Box<dyn Source>> {
267    use crate::config::SourceType;
268    let url = config.resolve_url()?;
269    warn_if_tls_disabled(config);
270    match config.source_type {
271        SourceType::Postgres => Ok(Box::new(postgres::PostgresSource::connect_with_tls(
272            &url,
273            config.tls.as_ref(),
274        )?)),
275        SourceType::Mysql => Ok(Box::new(mysql::MysqlSource::connect_with_tls(
276            &url,
277            config.tls.as_ref(),
278        )?)),
279        SourceType::Mssql => Ok(Box::new(mssql::MssqlSource::connect_with_tls(
280            &url,
281            config.tls.as_ref(),
282        )?)),
283    }
284}
285
286/// Pre-allocation per-value size guard, shared by every engine's
287/// `arrow_convert`. The sink-side `check_value_ceiling`
288/// (`pipeline::sink::mod`) scans the *already-built* Arrow batch, so an
289/// oversized cell costs the driver-decode copy **and** the Arrow-build copy
290/// before that guard fires. This check runs at the decode/`Value` stage — after
291/// the unavoidable driver copy, but *before* the value is appended into the
292/// `StringBuilder` / `BinaryBuilder` — so the Arrow allocation never grows to
293/// hold it. Only variable-length values (Utf8 / Binary) can be individually
294/// huge; fixed-width arms (ints/floats/dates) never call this.
295///
296/// `max_value_bytes` is `tuning.max_value_bytes()` (MB → bytes with the
297/// `Some(0)`/`None` ⇒ disabled semantics). The message mirrors the sink guard's
298/// `RIVET_VALUE_TOO_LARGE` so both read identically; the sink guard stays as the
299/// backstop (it also covers meta / enriched columns and is the contract test).
300pub(crate) fn value_within_ceiling(
301    column: &str,
302    len: usize,
303    max_value_bytes: Option<usize>,
304) -> Result<()> {
305    if let Some(limit) = max_value_bytes
306        && len > limit
307    {
308        anyhow::bail!(
309            "RIVET_VALUE_TOO_LARGE: column '{}' has a single value of {:.1} MB, exceeding the \
310             per-value ceiling of {} MB. One oversized cell can OOM the process regardless of \
311             batch size. Raise `tuning.max_value_mb` (or set it to 0 to disable the guard) if \
312             this value is expected.",
313            column,
314            len as f64 / (1024.0 * 1024.0),
315            limit / (1024 * 1024),
316        );
317    }
318    Ok(())
319}
320
321#[cfg(test)]
322mod value_ceiling_tests {
323    use super::value_within_ceiling;
324
325    #[test]
326    fn sec_value_ceiling_pre_alloc_over_limit_errors() {
327        let err = value_within_ceiling("payload", 2 * 1024 * 1024, Some(1024 * 1024)).unwrap_err();
328        let msg = format!("{err:#}");
329        assert!(msg.contains("RIVET_VALUE_TOO_LARGE"), "got: {msg}");
330        assert!(msg.contains("payload"), "names the column: {msg}");
331    }
332
333    #[test]
334    fn sec_value_ceiling_pre_alloc_at_or_under_limit_ok() {
335        assert!(value_within_ceiling("c", 1024 * 1024, Some(1024 * 1024)).is_ok());
336        assert!(value_within_ceiling("c", 0, Some(1024 * 1024)).is_ok());
337    }
338
339    #[test]
340    fn sec_value_ceiling_pre_alloc_disabled_never_errors() {
341        // `None` (set when tuning.max_value_mb is 0 or unset) disables the guard.
342        assert!(value_within_ceiling("c", usize::MAX, None).is_ok());
343    }
344}
345
346/// One-time nudge to enable TLS when the current config connects in plaintext.
347/// Emitted at `warn` level so operators see it even at the default log level.
348/// `create_source` is called multiple times per run (plan/preflight/exec/chunk
349/// workers), so we gate the warning behind a `Once` to fire exactly once per
350/// process rather than 3-4 times in stderr.
351pub(crate) fn warn_if_tls_disabled(config: &SourceConfig) {
352    let enforced = config.tls.as_ref().is_some_and(|t| t.mode.is_enforced());
353    if enforced {
354        return;
355    }
356    // Loopback (localhost / 127.0.0.0/8 / ::1) is the local-dev / docker case:
357    // the bytes never leave the box, so the plaintext warning is just noise on
358    // a newcomer's laptop. Resolve best-effort — if the URL can't be resolved we
359    // fall through and warn (fail-safe). The real CWE-319 signal still fires for
360    // any remote host.
361    if config.resolve_url().is_ok_and(|u| host_is_loopback(&u)) {
362        return;
363    }
364    static WARNED: std::sync::Once = std::sync::Once::new();
365    WARNED.call_once(|| {
366        log::warn!(
367            "source: TLS is not enforced — credentials and result rows cross the network in plaintext. \
368             Add `source.tls.mode: verify-full` (with `ca_file:` if your CA is private) to enable transport security."
369        );
370    });
371}
372
373/// Whether the host in a `scheme://[user[:pass]@]host[:port][/db][?…]`
374/// connection URL is a loopback address (`127.0.0.0/8`, `::1`) or the literal
375/// `localhost`.
376///
377/// Used by [`require_tls_or_loopback`] to decide TLS posture from the host:
378/// loopback is the docker / local-dev case where the bytes never leave the box,
379/// so plaintext is fine; a remote host without TLS leaks credentials and rows.
380///
381/// Fails **closed**: any URL we cannot confidently parse a loopback host out of
382/// is treated as non-loopback, so a parse gap can only ever *tighten* the gate
383/// (refuse a connection), never silently allow plaintext to an unverified host.
384pub(crate) fn host_is_loopback(url: &str) -> bool {
385    // Strip the scheme (`postgresql://`, `mysql://`, `sqlserver://`, …).
386    let after_scheme = match url.split_once("://") {
387        Some((_, rest)) => rest,
388        None => url,
389    };
390    // Authority ends at the first `/`, `?` or `#`.
391    let authority = after_scheme
392        .split(['/', '?', '#'])
393        .next()
394        .unwrap_or(after_scheme);
395    // Drop `user[:pass]@` — rsplit the last `@` so an `@` inside a password is
396    // tolerated (it belongs to the userinfo, not the host).
397    let host_port = match authority.rsplit_once('@') {
398        Some((_, hp)) => hp,
399        None => authority,
400    };
401    // Host vs port. IPv6 literals are bracketed (`[::1]:5432`); for those the
402    // host is the bracketed span, and any `:` inside is part of the address.
403    let host = if let Some(rest) = host_port.strip_prefix('[') {
404        match rest.split_once(']') {
405            Some((h, _)) => h,
406            None => return false, // unterminated bracket — fail closed
407        }
408    } else {
409        // Bare host or IPv4: the host ends at the (single) port `:`.
410        host_port.split(':').next().unwrap_or(host_port)
411    };
412
413    if host.eq_ignore_ascii_case("localhost") {
414        return true;
415    }
416    // `IpAddr::is_loopback` covers the whole 127.0.0.0/8 block and `::1`.
417    host.parse::<std::net::IpAddr>()
418        .is_ok_and(|ip| ip.is_loopback())
419}
420
421/// Gate plaintext / trust-any-cert connections by host (CWE-319 / CWE-295).
422///
423/// When no `tls:` block is configured (`tls == None`) **and** the resolved host
424/// is not loopback, refuse the connection *before any network I/O* with a
425/// TLS-required policy error. This stops the per-engine connect helpers from
426/// silently dialing a remote database in cleartext (Postgres/MySQL `NoTls`) or
427/// trusting any server certificate (MSSQL `trust_cert`).
428///
429/// Loopback hosts (docker / local dev) keep today's behaviour — plaintext is
430/// allowed there because the bytes never leave the box. An explicit
431/// `tls: { mode: disable }` is `Some(..)`, so it is the operator's opt-in to
432/// remote plaintext and is **not** refused here.
433pub(crate) fn require_tls_or_loopback(url: &str, tls: Option<&TlsConfig>) -> Result<()> {
434    if tls.is_none() && !host_is_loopback(url) {
435        // The message must name TLS *and* that it is a policy refusal for a
436        // remote host. Emit it at `error` level (→ stderr) as well as returning
437        // it: callers like `doctor` print the `Err` to stdout in their own
438        // `[FAIL]` style and only re-raise a generic summary, so the log line is
439        // what guarantees the TLS-required reason reaches stderr. Deliberately
440        // avoids socket-error vocabulary ("could not connect", "timeout", "os
441        // error") so it is never mistaken for a connect-time failure.
442        let msg = "source: TLS required — refusing to connect to a remote (non-loopback) \
443             host without TLS; credentials and every exported row would cross the network \
444             in cleartext. Add `source.tls: { mode: verify-full }` (with `ca_file:` for a \
445             private CA) to enable transport security, or explicitly opt into remote \
446             plaintext with `source.tls: { mode: disable }` if this network path is \
447             already trusted.";
448        log::error!("{msg}");
449        anyhow::bail!("{msg}");
450    }
451    Ok(())
452}
453
454#[cfg(test)]
455mod tls_gate_tests {
456    use super::{host_is_loopback, require_tls_or_loopback};
457    use crate::config::{TlsConfig, TlsMode};
458
459    #[test]
460    fn loopback_variants_are_loopback() {
461        assert!(host_is_loopback(
462            "postgresql://rivet:rivet@127.0.0.1:5432/rivet"
463        ));
464        assert!(host_is_loopback(
465            "postgresql://rivet:rivet@localhost:5432/rivet"
466        ));
467        assert!(host_is_loopback("mysql://root@127.0.0.1:3306/db"));
468        // Whole 127.0.0.0/8 block is loopback.
469        assert!(host_is_loopback("postgresql://u:p@127.255.0.9/db"));
470        // IPv6 loopback, bracketed with and without a port.
471        assert!(host_is_loopback("postgresql://u:p@[::1]:5432/db"));
472        assert!(host_is_loopback("sqlserver://sa:pw@[::1]/master"));
473        // Case-insensitive host, no port, no db.
474        assert!(host_is_loopback("mysql://root@LOCALHOST"));
475        // An `@` inside the password must not be mistaken for the host boundary.
476        assert!(host_is_loopback("postgresql://u:p@ss@127.0.0.1:5432/db"));
477    }
478
479    #[test]
480    fn remote_hosts_are_not_loopback() {
481        assert!(!host_is_loopback(
482            "postgresql://rivet:rivet@10.255.255.1:5432/rivet"
483        ));
484        assert!(!host_is_loopback(
485            "postgresql://u:p@db.example.com:5432/app"
486        ));
487        assert!(!host_is_loopback("mysql://root@192.168.1.10:3306/db"));
488        assert!(!host_is_loopback("sqlserver://sa:pw@10.0.0.5:1433/master"));
489        // Not loopback: an unbracketed IPv6-looking address won't parse here, so
490        // it fails closed (treated as remote).
491        assert!(!host_is_loopback("postgresql://u:p@::1:5432/db"));
492    }
493
494    #[test]
495    fn gate_refuses_remote_plaintext_only() {
496        let remote = "postgresql://rivet:rivet@10.255.255.1:5432/rivet";
497        let loopback = "postgresql://rivet:rivet@127.0.0.1:5432/rivet";
498        let disable = TlsConfig {
499            mode: TlsMode::Disable,
500            ..Default::default()
501        };
502        let verify = TlsConfig {
503            mode: TlsMode::VerifyFull,
504            ..Default::default()
505        };
506
507        // Remote + no tls block → refused.
508        assert!(require_tls_or_loopback(remote, None).is_err());
509        // Loopback + no tls block → allowed (docker / dev path).
510        assert!(require_tls_or_loopback(loopback, None).is_ok());
511        // Explicit `mode: disable` is the remote-plaintext opt-in → allowed.
512        assert!(require_tls_or_loopback(remote, Some(&disable)).is_ok());
513        // Enforced TLS to a remote host → allowed (the connect path uses TLS).
514        assert!(require_tls_or_loopback(remote, Some(&verify)).is_ok());
515    }
516}