Skip to main content

rivet/source/postgres/
mod.rs

1//! PostgreSQL `Source` implementation.
2//!
3//! Module layout:
4//!
5//! - `mod.rs` (this file) — `PostgresSource` struct + connect/TLS path, the
6//!   transaction-pooler detector, `PgTxnGuard`, sampling helpers
7//!   (`sample_temp_bytes`, `pg_sample_checkpoints_req`, `pg_fetch_work_mem_bytes`),
8//!   `introspect_pg_table_for_chunking`, the cursor + FETCH export loop
9//!   (`pg_run_export`), the `Source` trait impl, and the catalog-hint
10//!   resolver that bridges parsed FROM clauses to `pg_catalog`.
11//! - [`arrow_convert`] — the entire row → Arrow `RecordBatch` pipeline: type
12//!   mapping (`pg_columns_to_schema`, `rivet_type_for_pg_column`), per-cell
13//!   decoders (INTERVAL, UUID, enum, NUMERIC), and the array builders. Kept
14//!   in a sibling because it is the largest single-purpose cluster in this
15//!   driver (~620 LoC) and has zero reverse dependency back into the
16//!   connection / cursor layer.
17//! - [`from_parse`] — pure `&str`/`&[u8]` parser that extracts the simple
18//!   `<schema>.<table>` literal from a user query so the catalog-hint path
19//!   can cast it to `regclass`.  Zero postgres-crate dependency, fully
20//!   unit-tested in isolation.
21
22mod arrow_convert;
23mod from_parse;
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use postgres::types::Type;
30use postgres::{Client, NoTls};
31
32use crate::config::{SourceType, TlsConfig};
33use crate::error::Result;
34use crate::source::batch_controller::AdaptiveBatchController;
35use crate::source::query::build_export_query;
36use crate::source::tls::build_native_tls;
37use crate::tuning::SourceTuning;
38use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
39
40use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
41use from_parse::try_parse_pg_simple_from_regclass_literal;
42
43pub struct PostgresSource {
44    client: Client,
45    /// True when two consecutive pg_backend_pid() calls returned different values,
46    /// indicating a transaction-mode connection pooler (pgBouncer, Odyssey, etc.).
47    transaction_pooler: bool,
48}
49
50/// Detect whether the connection is going through a transaction-mode pooler
51/// (pgBouncer, Odyssey, etc.) by comparing backend PIDs across two implicit
52/// transactions. Returns true when PIDs differ — impossible on a direct
53/// connection or session-mode pooler where the same physical backend is kept.
54///
55/// False negatives are possible when pool_size = 1 (the same backend is always
56/// reused), so this is a best-effort warning rather than a hard guarantee.
57fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
58    let pid1: Option<i32> = client
59        .query_one("SELECT pg_backend_pid()", &[])
60        .ok()
61        .and_then(|r| r.try_get(0).ok());
62    let pid2: Option<i32> = client
63        .query_one("SELECT pg_backend_pid()", &[])
64        .ok()
65        .and_then(|r| r.try_get(0).ok());
66    matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
67}
68
69impl PostgresSource {
70    /// Connect with no transport security (legacy path). Prefer [`Self::connect_with_tls`]
71    /// for production workloads so credentials and result sets are not visible on the wire.
72    pub fn connect(url: &str) -> Result<Self> {
73        let mut client = Client::connect(url, NoTls)?;
74        let transaction_pooler = detect_pg_transaction_pooler(&mut client);
75        if transaction_pooler {
76            log::warn!(
77                "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
78                 SET LOCAL tuning is transaction-scoped; \
79                 LISTEN/NOTIFY and advisory locks are unavailable"
80            );
81        }
82        Ok(Self {
83            client,
84            transaction_pooler,
85        })
86    }
87
88    /// Connect honoring the user's [`TlsConfig`]. When `tls.mode` is
89    /// [`TlsMode::Disable`] this falls back to [`Self::connect`].
90    pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
91        // Refuse remote plaintext (no `tls:` block) before any dial (CWE-319).
92        crate::source::require_tls_or_loopback(url, tls)?;
93        match tls {
94            Some(cfg) if cfg.mode.is_enforced() => {
95                let connector = build_native_tls(cfg)?;
96                let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
97                let mut client = Client::connect(url, make_tls)?;
98                let transaction_pooler = detect_pg_transaction_pooler(&mut client);
99                if transaction_pooler {
100                    log::warn!(
101                        "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
102                         SET LOCAL tuning is transaction-scoped; \
103                         LISTEN/NOTIFY and advisory locks are unavailable"
104                    );
105                }
106                Ok(Self {
107                    client,
108                    transaction_pooler,
109                })
110            }
111            _ => Self::connect(url),
112        }
113    }
114}
115
116/// RAII guard for an open `BEGIN ... COMMIT` block.
117///
118/// `commit()` runs `COMMIT` and marks the txn done; if the guard is dropped
119/// before `commit()` (early return, `?`-bubbled error, or panic-driven unwind),
120/// `Drop` issues a best-effort `ROLLBACK`. Postgres releases any open cursors
121/// as part of ROLLBACK, so the cursor declared inside the txn is also cleaned
122/// up. Closes the **G1** gap from the DBA audit (cursor leak on panic).
123struct PgTxnGuard<'a> {
124    client: &'a mut Client,
125    committed: bool,
126}
127
128impl<'a> PgTxnGuard<'a> {
129    fn begin(client: &'a mut Client) -> Result<Self> {
130        client.batch_execute("BEGIN")?;
131        Ok(Self {
132            client,
133            committed: false,
134        })
135    }
136
137    fn client_mut(&mut self) -> &mut Client {
138        self.client
139    }
140
141    fn commit(mut self) -> Result<()> {
142        self.client.batch_execute("COMMIT")?;
143        self.committed = true;
144        Ok(())
145    }
146}
147
148impl Drop for PgTxnGuard<'_> {
149    fn drop(&mut self) {
150        if !self.committed
151            && let Err(e) = self.client.batch_execute("ROLLBACK")
152        {
153            // Drop must not panic. Worst case the connection is poisoned and
154            // the pool recycles it; log so operators see it.
155            log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
156        }
157    }
158}
159
160/// Snapshot `pg_stat_database.temp_bytes` for the current database.
161///
162/// Used by the pipeline job to compute per-run cursor / sort spill: we capture
163/// the cluster-wide counter immediately before and after each export and
164/// surface the delta on the run summary card. Failures (connect, query) return
165/// `None` — the metric is informational, not a correctness signal.
166///
167/// Note this is a cluster-level counter: concurrent activity from other
168/// connections during the run inflates the delta. For a single-tenant test
169/// box (the common pilot setup) it is accurate; for shared hosts it is a
170/// noisy upper bound, useful as a "your workload was loud" signal.
171pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
172    let mut client = connect_client(url, tls).ok()?;
173    client
174        .query_one(
175            "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
176            &[],
177        )
178        .ok()
179        .and_then(|r| r.try_get::<_, i64>(0).ok())
180}
181
182/// Probe `SHOW work_mem` and return the value in bytes.
183///
184/// PostgreSQL spills FETCH-cursor output to `pgsql_tmp/` once the in-flight
185/// row set exceeds `work_mem` — on wide rows with the default 4 MB the spill
186/// fires on every chunk and dominates `pg_stat_database.temp_bytes`. Knowing
187/// the value lets the cursor loop cap FETCH N below `work_mem × 0.7`, keeping
188/// the result set in memory.
189///
190/// Returns None on any parse / query failure — the cursor loop falls back to
191/// the configured static batch_size in that case.
192fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
193    let raw: Option<String> = client
194        .query_one("SHOW work_mem", &[])
195        .ok()
196        .and_then(|r| r.try_get::<_, String>(0).ok());
197    raw.as_deref().and_then(parse_work_mem)
198}
199
200/// Parse a `SHOW work_mem` value like `"4MB"`, `"16384kB"`, `"1GB"`, or a bare
201/// number-of-kB string (the older PG default unit) into a byte count. Returns
202/// `None` for anything else so callers can decide whether to fall back.
203fn parse_work_mem(raw: &str) -> Option<i64> {
204    let s = raw.trim();
205    // Split numeric prefix from optional unit.
206    let mut split = 0;
207    for (i, ch) in s.char_indices() {
208        if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
209            split = i;
210            break;
211        }
212        split = i + ch.len_utf8();
213    }
214    if split == 0 {
215        return None;
216    }
217    let (num_str, unit) = s.split_at(split);
218    let num: f64 = num_str.parse().ok()?;
219    let unit = unit.trim().to_ascii_lowercase();
220    let multiplier: f64 = match unit.as_str() {
221        // Postgres always uses 1024-based units, matching the syntax it
222        // accepts in postgresql.conf.
223        "" | "kb" => 1024.0,
224        "mb" => 1024.0 * 1024.0,
225        "gb" => 1024.0 * 1024.0 * 1024.0,
226        "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
227        _ => return None,
228    };
229    let bytes = (num * multiplier) as i64;
230    (bytes > 0).then_some(bytes)
231}
232
233/// Sample `checkpoints_req` from `pg_stat_bgwriter`.
234///
235/// PostgreSQL caches the statistics snapshot at the start of each transaction.
236/// We call `pg_stat_clear_snapshot()` first to discard that cache so every
237/// adaptive sample sees fresh counters rather than the frozen value from BEGIN.
238fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
239    let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
240    client
241        .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
242        .ok()
243        .and_then(|r| r.try_get::<_, i64>(0).ok())
244}
245
246/// Probe `pg_class` and `pg_index` for the stats chunked-mode planning needs.
247///
248/// Returns a [`crate::source::TableIntrospection`] populated from one connection
249/// (two round-trips total: one stats query, one PK query). Failure to connect
250/// or to query bubbles up as `Err`; missing rows or unanalyzed tables are
251/// represented as zero/None in the result so callers can decide policy.
252///
253/// The `qualified_table` argument is `<schema>.<table>` (e.g. `public.users`)
254/// or bare `<table>` (resolved under `public`). It is split internally with
255/// the same strict rules as the `table:` YAML shortcut — anything more
256/// elaborate must use the explicit-column path.
257pub(crate) fn introspect_pg_table_for_chunking(
258    url: &str,
259    tls: Option<&TlsConfig>,
260    qualified_table: &str,
261) -> Result<crate::source::TableIntrospection> {
262    let (schema, table) = match qualified_table.split_once('.') {
263        Some((s, t)) => (s.to_string(), t.to_string()),
264        None => ("public".to_string(), qualified_table.to_string()),
265    };
266    let mut client = connect_client(url, tls)?;
267
268    // ── reltuples + heap size, in one shot ──────────────────────────────
269    let (row_estimate, rel_size_bytes) = match client.query_opt(
270        "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
271         FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
272         WHERE n.nspname = $1::text AND c.relname = $2::text",
273        &[&schema, &table],
274    )? {
275        Some(row) => {
276            let rt: i64 = row.try_get(0).unwrap_or(0);
277            let sz: i64 = row.try_get(1).unwrap_or(0);
278            (rt.max(0), sz.max(0))
279        }
280        None => (0, 0),
281    };
282    let avg_row_bytes = if row_estimate > 0 {
283        Some(rel_size_bytes / row_estimate)
284    } else {
285        None
286    };
287
288    // ── single int PK probe ─────────────────────────────────────────────
289    let pk_rows = client.query(
290        "SELECT a.attname::text, t.typname::text \
291         FROM pg_index i \
292         JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
293         JOIN pg_type t ON t.oid = a.atttypid \
294         WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
295           AND i.indisprimary",
296        &[&schema, &table],
297    )?;
298    let single_int_pk = if pk_rows.len() == 1 {
299        let col: String = pk_rows[0].get(0);
300        let pg_type: String = pk_rows[0].get(1);
301        // Only integer-family types are safe for range chunking via min/max →
302        // BETWEEN slicing. Text/UUID/decimal would need different splitting
303        // logic and are excluded from auto-resolution.
304        if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
305            Some(col)
306        } else {
307            log::debug!(
308                "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
309            );
310            None
311        }
312    } else {
313        None
314    };
315
316    // ── keyset keys (OPT-4): single-column, NOT NULL, UNIQUE indexes ────
317    // `indnkeyatts = 1` keeps single-column indexes; `indkey[0] = a.attnum`
318    // binds to a real column (not an expression index); `attnotnull` removes
319    // NULL-ordering ambiguity. Index-backed + unique ⇒ keyset's `ORDER BY key
320    // LIMIT n` is a range scan and `WHERE key > last` never skips dup keys.
321    let keyset_rows = client.query(
322        "SELECT a.attname::text, i.indisprimary \
323         FROM pg_index i \
324         JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = i.indkey[0] \
325         WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
326           AND i.indisunique AND i.indnkeyatts = 1 AND a.attnotnull",
327        &[&schema, &table],
328    )?;
329    let mut keyset_keys: Vec<String> = Vec::new();
330    for primary in [true, false] {
331        for row in &keyset_rows {
332            let col: String = row.get(0);
333            let is_primary: bool = row.get(1);
334            if is_primary == primary && !keyset_keys.contains(&col) {
335                keyset_keys.push(col);
336            }
337        }
338    }
339
340    Ok(crate::source::TableIntrospection {
341        single_int_pk,
342        keyset_keys,
343        row_estimate,
344        avg_row_bytes,
345    })
346}
347
348/// Open a bare `postgres::Client` honoring the configured TLS policy.
349///
350/// Shared by preflight, doctor, and `rivet init` so every code path that
351/// connects to Postgres applies the same transport-security rules. Preflight
352/// and doctor pass the YAML `tls:` block; init runs before any YAML exists,
353/// so it derives a `TlsConfig` from the URL's `sslmode` parameter (see
354/// `crate::init::postgres::connect`). `tls = None` or `mode: disable` falls
355/// back to the insecure `NoTls` transport — a warning is logged from
356/// `create_source` so operators know TLS is off.
357pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
358    // Refuse remote plaintext (no `tls:` block) before any dial (CWE-319).
359    crate::source::require_tls_or_loopback(url, tls)?;
360    match tls {
361        Some(cfg) if cfg.mode.is_enforced() => {
362            let connector = build_native_tls(cfg)?;
363            let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
364            Ok(Client::connect(url, make_tls)?)
365        }
366        _ => Ok(Client::connect(url, NoTls)?),
367    }
368}
369
370/// Run the full export transaction against an open Postgres client.
371///
372/// All session-mutating SET commands use SET LOCAL so they are scoped to
373/// the transaction and reset automatically on COMMIT or ROLLBACK. The caller
374/// is responsible for issuing ROLLBACK if this function returns Err.
375///
376/// Returns (total_rows, had_schema). had_schema is false only when the query
377/// returned zero rows; the caller must emit an empty schema in that case.
378fn pg_run_export(
379    client: &mut Client,
380    built_sql: &str,
381    tuning: &SourceTuning,
382    column_overrides: &ColumnOverrides,
383    sink: &mut dyn super::BatchSink,
384    numeric_hints: Option<&HashMap<String, (u8, i8)>>,
385) -> Result<(usize, bool)> {
386    // Open the txn under guard *first* — if SET LOCAL or DECLARE fails below,
387    // Drop will roll back. Without the guard, a failure between BEGIN and the
388    // explicit ROLLBACK in the caller would leak a half-set-up txn into the pool.
389    let mut guard = PgTxnGuard::begin(client)?;
390    if tuning.statement_timeout_s > 0 {
391        guard.client_mut().batch_execute(&format!(
392            "SET LOCAL statement_timeout = '{}s'",
393            tuning.statement_timeout_s
394        ))?;
395    }
396    if tuning.lock_timeout_s > 0 {
397        guard.client_mut().batch_execute(&format!(
398            "SET LOCAL lock_timeout = '{}s'",
399            tuning.lock_timeout_s
400        ))?;
401    }
402    // Cap FETCH N under `work_mem × 0.7` so the cursor never spills to
403    // `pgsql_tmp/`. Without this, a wide-row chunk with the default
404    // `batch_size: 50000` × ~4 KB/row = ~200 MB easily exceeds the typical
405    // `work_mem: 4 MB` and writes the entire chunk to disk before the first
406    // FETCH returns. Measured cost on the content_items bench: ~3.2 GB of
407    // temp_bytes per export, dominating the DB-side signal report.
408    let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
409
410    guard
411        .client_mut()
412        .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
413
414    // The first FETCH is intentionally a small `PROBE_BATCH_SIZE` row-width
415    // probe (the controller starts there): without it we can't know
416    // `arrow_bytes/row` before the cursor runs, and a single FETCH of
417    // `tuning.batch_size` × wide rows already triggers a `pgsql_tmp/` spill.
418    let configured_batch_size = tuning.batch_size;
419    // Shared batch-size state machine; PG provides the FETCH N row source, the
420    // work_mem (or schema-derived) cap target, and the checkpoint pressure proxy.
421    let mut ctl = AdaptiveBatchController::new(tuning, configured_batch_size);
422    ctl.seed_pressure(if tuning.adaptive {
423        pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64)
424    } else {
425        None
426    });
427    let mut schema: Option<SchemaRef> = None;
428    let mut columns_cache: Option<Vec<(String, Type)>> = None;
429    let mut total_rows: usize = 0;
430    let mut cap_applied = false;
431    // Per-value ceiling (MB→bytes; `0`/None disables), enforced pre-allocation
432    // inside the batch builder so an oversized cell bails before Arrow reserves
433    // the buffer. Same source of truth as the sink's backstop guard.
434    let max_value_bytes = tuning.max_value_bytes();
435
436    loop {
437        let requested = ctl.target();
438        let fetch_sql = format!("FETCH {} FROM _rivet", requested);
439        let rows = guard.client_mut().query(&fetch_sql, &[])?;
440        if rows.is_empty() {
441            break;
442        }
443
444        if schema.is_none() {
445            let stmt_cols: Vec<(String, Type)> = rows[0]
446                .columns()
447                .iter()
448                .map(|c| (c.name().to_string(), c.type_().clone()))
449                .collect();
450            let s = Arc::new(pg_columns_to_schema(
451                rows[0].columns(),
452                column_overrides,
453                numeric_hints,
454            )?);
455            sink.on_schema(s.clone())?;
456            // When work_mem can't be read, fall back to the schema-derived
457            // effective batch size as the cap target (controller clamps it).
458            if work_mem_bytes.is_none() {
459                let effective = tuning.effective_batch_size(Some(&s));
460                ctl.apply_memory_cap(effective.max(requested));
461                cap_applied = true;
462            }
463            schema = Some(s);
464            columns_cache = Some(stmt_cols);
465        }
466
467        let row_count = rows.len();
468        total_rows += row_count;
469
470        let s = schema.as_ref().expect("schema set on first iteration");
471        let cols = columns_cache
472            .as_ref()
473            .expect("columns set on first iteration");
474        let batch = rows_to_record_batch_typed(s, cols, &rows, max_value_bytes)?;
475        drop(rows);
476
477        // After the first (probe) batch we know the actual row width. Cap the
478        // FETCH N below `work_mem × 0.7` so the cursor never spills:
479        //   pg_row_bytes ≈ arrow_per_row × 1.2 ; safe = work_mem×0.7 / pg_row_bytes
480        // The controller clamps it to the configured `batch_size`.
481        if !cap_applied
482            && let Some(wm) = work_mem_bytes
483            && row_count > 0
484        {
485            let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
486            let arrow_per_row = (arrow_bytes / row_count).max(1);
487            let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
488            let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
489            let mut target = safe;
490            if let Some(mem_mb) = tuning.batch_size_memory_mb {
491                let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
492                target = target.min(arrow_target.max(100));
493            }
494            if let Some(new) = ctl.apply_memory_cap(target) {
495                log::info!(
496                    "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N → {} (configured={})",
497                    wm,
498                    arrow_per_row,
499                    pg_per_row,
500                    new,
501                    configured_batch_size,
502                );
503            }
504            cap_applied = true;
505        }
506
507        sink.on_batch(&batch)?;
508
509        if let Some((new, under_pressure)) =
510            ctl.after_batch(|| pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64))
511        {
512            log::info!(
513                "adaptive batch size → {} ({})",
514                new,
515                if under_pressure {
516                    "pressure"
517                } else {
518                    "recovery"
519                }
520            );
521        }
522
523        log::info!("fetched {} rows so far...", total_rows);
524
525        if row_count < requested {
526            break;
527        }
528        ctl.throttle();
529    }
530
531    // Explicit CLOSE is technically redundant — COMMIT releases the cursor —
532    // but it documents intent and surfaces any close errors before COMMIT.
533    guard.client_mut().batch_execute("CLOSE _rivet")?;
534    guard.commit()?;
535    Ok((total_rows, schema.is_some()))
536}
537
538impl super::Source for PostgresSource {
539    fn export(
540        &mut self,
541        request: &super::ExportRequest<'_>,
542        sink: &mut dyn super::BatchSink,
543    ) -> Result<()> {
544        let built = build_export_query(request, SourceType::Postgres);
545        debug_assert!(
546            built.cursor_param.is_none(),
547            "Postgres path inlines cursor values as E'…' literals — binding is unused"
548        );
549        log::debug!(
550            "executing query (connection={}): {}",
551            if self.transaction_pooler {
552                "transaction-pooler"
553            } else {
554                "direct"
555            },
556            built.sql
557        );
558
559        // Resolve NUMERIC precision from the *unwrapped* base query when the
560        // caller wrapped `query` in a chunk/keyset subquery (which hides the
561        // source table from the catalog parser). Falls back to `query`.
562        let hint_query = request.catalog_hint_query.unwrap_or(request.query);
563        let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, hint_query);
564
565        // PgTxnGuard inside pg_run_export rolls the txn back automatically on
566        // any error or panic, so no explicit ROLLBACK is needed here.
567        let (total_rows, had_schema) = pg_run_export(
568            &mut self.client,
569            &built.sql,
570            request.tuning,
571            request.column_overrides,
572            sink,
573            numeric_hints.as_ref(),
574        )?;
575
576        if !had_schema {
577            sink.on_schema(Arc::new(Schema::empty()))?;
578        }
579
580        log::info!("total: {} rows", total_rows);
581        Ok(())
582    }
583
584    fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
585        let rows = self.client.query(sql, &[])?;
586        if rows.is_empty() {
587            return Ok(None);
588        }
589        let row = &rows[0];
590        if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
591            return Ok(Some(v.to_string()));
592        }
593        if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
594            return Ok(Some(v.to_string()));
595        }
596        if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
597            return Ok(Some(v.to_string()));
598        }
599        // TIMESTAMP / DATE / TIMESTAMPTZ — required for MIN/MAX on time columns (e.g. chunk_by_days)
600        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
601            return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
602        }
603        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
604            return Ok(Some(v.format("%Y-%m-%d").to_string()));
605        }
606        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
607            return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
608        }
609        if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
610            return Ok(Some(v));
611        }
612        Ok(None)
613    }
614
615    fn type_mappings(
616        &mut self,
617        query: &str,
618        column_overrides: &ColumnOverrides,
619    ) -> Result<Vec<TypeMapping>> {
620        let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
621        let stmt = self.client.prepare(&wrapped)?;
622        let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
623        let mappings = stmt
624            .columns()
625            .iter()
626            .map(|col| {
627                let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
628                let source = SourceColumn::simple(col.name(), col.type_().name(), true);
629                TypeMapping::from_source(&source, rivet)
630            })
631            .collect();
632        Ok(mappings)
633    }
634
635    /// Governor pressure proxy: `pg_stat_bgwriter.checkpoints_req` — the same
636    /// monotonic counter the adaptive batch loop samples. Rising between samples
637    /// means the source is checkpointing harder under write pressure.
638    fn sample_pressure(&mut self) -> Option<u64> {
639        pg_sample_checkpoints_req(&mut self.client).map(|v| v.max(0) as u64)
640    }
641}
642
643/// When the query is a single-table `SELECT … FROM rel` (no joins, no subquery
644/// in `FROM`), PostgreSQL result metadata does not carry `NUMERIC` typmod, but
645/// `information_schema` / the table DDL does. We resolve the base relation with
646/// a small parser and fetch declared precision/scale so `rivet init`-style
647/// exports work without hand-written `columns:` overrides.
648fn pg_numeric_catalog_hints_opt(
649    client: &mut Client,
650    query: &str,
651) -> Option<HashMap<String, (u8, i8)>> {
652    match pg_fetch_numeric_catalog_hints(client, query) {
653        Ok(m) => m,
654        Err(e) => {
655            // Reaching this arm means the parser identified a single-table query
656            // and we tried catalog lookup, but the lookup itself failed. That is
657            // unexpected (not "this query has a JOIN"), so surface it — otherwise
658            // a downstream NUMERIC mapping failure looks like a config problem
659            // when the real cause is here.
660            log::warn!(
661                "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
662            );
663            None
664        }
665    }
666}
667
668fn pg_fetch_numeric_catalog_hints(
669    client: &mut Client,
670    query: &str,
671) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
672    let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
673        return Ok(None);
674    };
675    let locate_sql = "SELECT n.nspname::text, c.relname::text \
676         FROM pg_catalog.pg_class c \
677         JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
678         WHERE c.oid = ($1::text)::regclass";
679    let row_opt = match client.query_opt(locate_sql, &[&regclass_lit]) {
680        Ok(r) => r,
681        Err(e) => {
682            log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
683            return Ok(None);
684        }
685    };
686    let Some(row) = row_opt else {
687        return Ok(None);
688    };
689    let schema: String = row.get(0);
690    let table: String = row.get(1);
691    let rows = client.query(
692        "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
693             FROM information_schema.columns \
694             WHERE table_schema = $1 AND table_name = $2 \
695             ORDER BY ordinal_position",
696        &[&schema, &table],
697    )?;
698
699    let mut map = HashMap::new();
700    for row in rows {
701        let col: String = row.get(0);
702        let dt: String = row.get(1);
703        if !is_pg_numeric_information_type(&dt) {
704            continue;
705        }
706        let p: Option<i32> = row.get(2);
707        let s: Option<i32> = row.get(3);
708        if let (Some(p), Some(s)) = (p, s)
709            && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
710        {
711            map.insert(col, pair);
712        }
713    }
714
715    if map.is_empty() {
716        Ok(None)
717    } else {
718        log::debug!(
719            "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
720            map.len(),
721        );
722        Ok(Some(map))
723    }
724}
725
726fn is_pg_numeric_information_type(dt: &str) -> bool {
727    let d = dt.trim().to_ascii_lowercase();
728    matches!(d.as_str(), "numeric" | "decimal")
729        || d.starts_with("numeric(")
730        || d.starts_with("decimal(")
731}
732
733/// Match Rivet YAML `decimal(p,s)` / Arrow limits (same bound as overrides).
734fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
735    if precision <= 0 || precision > 76 {
736        return None;
737    }
738    let precision_u = precision as u8;
739    if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
740        return None;
741    }
742    let scale_i = scale as i8;
743    if scale_i > precision as i8 {
744        return None;
745    }
746    Some((precision_u, scale_i))
747}
748
749#[cfg(test)]
750mod tests {
751    use super::catalog_numeric_to_decimal_params;
752
753    // FROM-clause parser tests live in `from_parse.rs` alongside the parser.
754
755    #[test]
756    fn catalog_decimal_bounds() {
757        assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
758        assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
759        assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
760        assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
761    }
762
763    #[test]
764    fn parse_work_mem_handles_pg_units() {
765        use super::parse_work_mem;
766        // Postgres SHOW work_mem normally returns "<N>kB", "<N>MB", "<N>GB".
767        // A bare integer is interpreted as kB (matches postgresql.conf parsing).
768        assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
769        assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
770        assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
771        assert_eq!(parse_work_mem("  4MB  "), Some(4 * 1024 * 1024));
772        assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
773        assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
774        assert_eq!(parse_work_mem(""), None);
775        assert_eq!(parse_work_mem("garbage"), None);
776        // We don't accept seconds / units PG would never emit for work_mem.
777        assert_eq!(parse_work_mem("4s"), None);
778    }
779}