Skip to main content

rivet/source/postgres/
mod.rs

1//! PostgreSQL `Source` implementation.
2//!
3//! Module layout:
4//!
5//! - `mod.rs` (this file) — `PostgresSource` struct + connect/TLS path, the
6//!   transaction-pooler detector, `PgTxnGuard`, sampling helpers
7//!   (`sample_temp_bytes`, `pg_sample_checkpoints_req`, `pg_fetch_work_mem_bytes`),
8//!   `introspect_pg_table_for_chunking`, the cursor + FETCH export loop
9//!   (`pg_run_export`), the `Source` trait impl, and the catalog-hint
10//!   resolver that bridges parsed FROM clauses to `pg_catalog`.
11//! - [`arrow_convert`] — the entire row → Arrow `RecordBatch` pipeline: type
12//!   mapping (`pg_columns_to_schema`, `rivet_type_for_pg_column`), per-cell
13//!   decoders (INTERVAL, UUID, enum, NUMERIC), and the array builders. Kept
14//!   in a sibling because it is the largest single-purpose cluster in this
15//!   driver (~620 LoC) and has zero reverse dependency back into the
16//!   connection / cursor layer.
17//! - [`from_parse`] — pure `&str`/`&[u8]` parser that extracts the simple
18//!   `<schema>.<table>` literal from a user query so the catalog-hint path
19//!   can cast it to `regclass`.  Zero postgres-crate dependency, fully
20//!   unit-tested in isolation.
21
22mod arrow_convert;
23mod from_parse;
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use postgres::types::Type;
30use postgres::{Client, NoTls};
31
32use crate::config::{SourceType, TlsConfig};
33use crate::error::Result;
34use crate::source::query::build_incremental_query;
35use crate::source::tls::build_native_tls;
36use crate::tuning::{ADAPTIVE_SAMPLE_INTERVAL, SourceTuning, next_adaptive_batch_size};
37use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
38
39use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
40use from_parse::try_parse_pg_simple_from_regclass_literal;
41
42pub struct PostgresSource {
43    client: Client,
44    /// True when two consecutive pg_backend_pid() calls returned different values,
45    /// indicating a transaction-mode connection pooler (pgBouncer, Odyssey, etc.).
46    transaction_pooler: bool,
47}
48
49/// Detect whether the connection is going through a transaction-mode pooler
50/// (pgBouncer, Odyssey, etc.) by comparing backend PIDs across two implicit
51/// transactions. Returns true when PIDs differ — impossible on a direct
52/// connection or session-mode pooler where the same physical backend is kept.
53///
54/// False negatives are possible when pool_size = 1 (the same backend is always
55/// reused), so this is a best-effort warning rather than a hard guarantee.
56fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
57    let pid1: Option<i32> = client
58        .query_one("SELECT pg_backend_pid()", &[])
59        .ok()
60        .and_then(|r| r.try_get(0).ok());
61    let pid2: Option<i32> = client
62        .query_one("SELECT pg_backend_pid()", &[])
63        .ok()
64        .and_then(|r| r.try_get(0).ok());
65    matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
66}
67
68impl PostgresSource {
69    /// Connect with no transport security (legacy path). Prefer [`Self::connect_with_tls`]
70    /// for production workloads so credentials and result sets are not visible on the wire.
71    pub fn connect(url: &str) -> Result<Self> {
72        let mut client = Client::connect(url, NoTls)?;
73        let transaction_pooler = detect_pg_transaction_pooler(&mut client);
74        if transaction_pooler {
75            log::warn!(
76                "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
77                 SET LOCAL tuning is transaction-scoped; \
78                 LISTEN/NOTIFY and advisory locks are unavailable"
79            );
80        }
81        Ok(Self {
82            client,
83            transaction_pooler,
84        })
85    }
86
87    /// Connect honoring the user's [`TlsConfig`]. When `tls.mode` is
88    /// [`TlsMode::Disable`] this falls back to [`Self::connect`].
89    pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
90        match tls {
91            Some(cfg) if cfg.mode.is_enforced() => {
92                let connector = build_native_tls(cfg)?;
93                let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
94                let mut client = Client::connect(url, make_tls)?;
95                let transaction_pooler = detect_pg_transaction_pooler(&mut client);
96                if transaction_pooler {
97                    log::warn!(
98                        "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
99                         SET LOCAL tuning is transaction-scoped; \
100                         LISTEN/NOTIFY and advisory locks are unavailable"
101                    );
102                }
103                Ok(Self {
104                    client,
105                    transaction_pooler,
106                })
107            }
108            _ => Self::connect(url),
109        }
110    }
111}
112
113/// RAII guard for an open `BEGIN ... COMMIT` block.
114///
115/// `commit()` runs `COMMIT` and marks the txn done; if the guard is dropped
116/// before `commit()` (early return, `?`-bubbled error, or panic-driven unwind),
117/// `Drop` issues a best-effort `ROLLBACK`. Postgres releases any open cursors
118/// as part of ROLLBACK, so the cursor declared inside the txn is also cleaned
119/// up. Closes the **G1** gap from the DBA audit (cursor leak on panic).
120struct PgTxnGuard<'a> {
121    client: &'a mut Client,
122    committed: bool,
123}
124
125impl<'a> PgTxnGuard<'a> {
126    fn begin(client: &'a mut Client) -> Result<Self> {
127        client.batch_execute("BEGIN")?;
128        Ok(Self {
129            client,
130            committed: false,
131        })
132    }
133
134    fn client_mut(&mut self) -> &mut Client {
135        self.client
136    }
137
138    fn commit(mut self) -> Result<()> {
139        self.client.batch_execute("COMMIT")?;
140        self.committed = true;
141        Ok(())
142    }
143}
144
145impl Drop for PgTxnGuard<'_> {
146    fn drop(&mut self) {
147        if !self.committed
148            && let Err(e) = self.client.batch_execute("ROLLBACK")
149        {
150            // Drop must not panic. Worst case the connection is poisoned and
151            // the pool recycles it; log so operators see it.
152            log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
153        }
154    }
155}
156
157/// Snapshot `pg_stat_database.temp_bytes` for the current database.
158///
159/// Used by the pipeline job to compute per-run cursor / sort spill: we capture
160/// the cluster-wide counter immediately before and after each export and
161/// surface the delta on the run summary card. Failures (connect, query) return
162/// `None` — the metric is informational, not a correctness signal.
163///
164/// Note this is a cluster-level counter: concurrent activity from other
165/// connections during the run inflates the delta. For a single-tenant test
166/// box (the common pilot setup) it is accurate; for shared hosts it is a
167/// noisy upper bound, useful as a "your workload was loud" signal.
168pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
169    let mut client = connect_client(url, tls).ok()?;
170    client
171        .query_one(
172            "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
173            &[],
174        )
175        .ok()
176        .and_then(|r| r.try_get::<_, i64>(0).ok())
177}
178
179/// Probe `SHOW work_mem` and return the value in bytes.
180///
181/// PostgreSQL spills FETCH-cursor output to `pgsql_tmp/` once the in-flight
182/// row set exceeds `work_mem` — on wide rows with the default 4 MB the spill
183/// fires on every chunk and dominates `pg_stat_database.temp_bytes`. Knowing
184/// the value lets the cursor loop cap FETCH N below `work_mem × 0.7`, keeping
185/// the result set in memory.
186///
187/// Returns None on any parse / query failure — the cursor loop falls back to
188/// the configured static batch_size in that case.
189fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
190    let raw: Option<String> = client
191        .query_one("SHOW work_mem", &[])
192        .ok()
193        .and_then(|r| r.try_get::<_, String>(0).ok());
194    raw.as_deref().and_then(parse_work_mem)
195}
196
197/// Parse a `SHOW work_mem` value like `"4MB"`, `"16384kB"`, `"1GB"`, or a bare
198/// number-of-kB string (the older PG default unit) into a byte count. Returns
199/// `None` for anything else so callers can decide whether to fall back.
200fn parse_work_mem(raw: &str) -> Option<i64> {
201    let s = raw.trim();
202    // Split numeric prefix from optional unit.
203    let mut split = 0;
204    for (i, ch) in s.char_indices() {
205        if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
206            split = i;
207            break;
208        }
209        split = i + ch.len_utf8();
210    }
211    if split == 0 {
212        return None;
213    }
214    let (num_str, unit) = s.split_at(split);
215    let num: f64 = num_str.parse().ok()?;
216    let unit = unit.trim().to_ascii_lowercase();
217    let multiplier: f64 = match unit.as_str() {
218        // Postgres always uses 1024-based units, matching the syntax it
219        // accepts in postgresql.conf.
220        "" | "kb" => 1024.0,
221        "mb" => 1024.0 * 1024.0,
222        "gb" => 1024.0 * 1024.0 * 1024.0,
223        "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
224        _ => return None,
225    };
226    let bytes = (num * multiplier) as i64;
227    (bytes > 0).then_some(bytes)
228}
229
230/// Sample `checkpoints_req` from `pg_stat_bgwriter`.
231///
232/// PostgreSQL caches the statistics snapshot at the start of each transaction.
233/// We call `pg_stat_clear_snapshot()` first to discard that cache so every
234/// adaptive sample sees fresh counters rather than the frozen value from BEGIN.
235fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
236    let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
237    client
238        .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
239        .ok()
240        .and_then(|r| r.try_get::<_, i64>(0).ok())
241}
242
243/// Probe `pg_class` and `pg_index` for the stats chunked-mode planning needs.
244///
245/// Returns a [`crate::source::TableIntrospection`] populated from one connection
246/// (two round-trips total: one stats query, one PK query). Failure to connect
247/// or to query bubbles up as `Err`; missing rows or unanalyzed tables are
248/// represented as zero/None in the result so callers can decide policy.
249///
250/// The `qualified_table` argument is `<schema>.<table>` (e.g. `public.users`)
251/// or bare `<table>` (resolved under `public`). It is split internally with
252/// the same strict rules as the `table:` YAML shortcut — anything more
253/// elaborate must use the explicit-column path.
254pub(crate) fn introspect_pg_table_for_chunking(
255    url: &str,
256    tls: Option<&TlsConfig>,
257    qualified_table: &str,
258) -> Result<crate::source::TableIntrospection> {
259    let (schema, table) = match qualified_table.split_once('.') {
260        Some((s, t)) => (s.to_string(), t.to_string()),
261        None => ("public".to_string(), qualified_table.to_string()),
262    };
263    let mut client = connect_client(url, tls)?;
264
265    // ── reltuples + heap size, in one shot ──────────────────────────────
266    let (row_estimate, rel_size_bytes) = match client.query_opt(
267        "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
268         FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
269         WHERE n.nspname = $1::text AND c.relname = $2::text",
270        &[&schema, &table],
271    )? {
272        Some(row) => {
273            let rt: i64 = row.try_get(0).unwrap_or(0);
274            let sz: i64 = row.try_get(1).unwrap_or(0);
275            (rt.max(0), sz.max(0))
276        }
277        None => (0, 0),
278    };
279    let avg_row_bytes = if row_estimate > 0 {
280        Some(rel_size_bytes / row_estimate)
281    } else {
282        None
283    };
284
285    // ── single int PK probe ─────────────────────────────────────────────
286    let pk_rows = client.query(
287        "SELECT a.attname::text, t.typname::text \
288         FROM pg_index i \
289         JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
290         JOIN pg_type t ON t.oid = a.atttypid \
291         WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
292           AND i.indisprimary",
293        &[&schema, &table],
294    )?;
295    let single_int_pk = if pk_rows.len() == 1 {
296        let col: String = pk_rows[0].get(0);
297        let pg_type: String = pk_rows[0].get(1);
298        // Only integer-family types are safe for range chunking via min/max →
299        // BETWEEN slicing. Text/UUID/decimal would need different splitting
300        // logic and are excluded from auto-resolution.
301        if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
302            Some(col)
303        } else {
304            log::debug!(
305                "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
306            );
307            None
308        }
309    } else {
310        None
311    };
312
313    Ok(crate::source::TableIntrospection {
314        single_int_pk,
315        row_estimate,
316        avg_row_bytes,
317    })
318}
319
320/// Open a bare `postgres::Client` honoring the configured TLS policy.
321///
322/// Shared by preflight, doctor, and init so every code path that connects to
323/// Postgres applies the same transport-security rules. `tls = None` or
324/// `mode: disable` falls back to the insecure `NoTls` transport — a warning is
325/// logged from `create_source` so operators know TLS is off.
326pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
327    match tls {
328        Some(cfg) if cfg.mode.is_enforced() => {
329            let connector = build_native_tls(cfg)?;
330            let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
331            Ok(Client::connect(url, make_tls)?)
332        }
333        _ => Ok(Client::connect(url, NoTls)?),
334    }
335}
336
337/// Run the full export transaction against an open Postgres client.
338///
339/// All session-mutating SET commands use SET LOCAL so they are scoped to
340/// the transaction and reset automatically on COMMIT or ROLLBACK. The caller
341/// is responsible for issuing ROLLBACK if this function returns Err.
342///
343/// Returns (total_rows, had_schema). had_schema is false only when the query
344/// returned zero rows; the caller must emit an empty schema in that case.
345fn pg_run_export(
346    client: &mut Client,
347    built_sql: &str,
348    tuning: &SourceTuning,
349    column_overrides: &ColumnOverrides,
350    sink: &mut dyn super::BatchSink,
351    numeric_hints: Option<&HashMap<String, (u8, i8)>>,
352) -> Result<(usize, bool)> {
353    // Open the txn under guard *first* — if SET LOCAL or DECLARE fails below,
354    // Drop will roll back. Without the guard, a failure between BEGIN and the
355    // explicit ROLLBACK in the caller would leak a half-set-up txn into the pool.
356    let mut guard = PgTxnGuard::begin(client)?;
357    if tuning.statement_timeout_s > 0 {
358        guard.client_mut().batch_execute(&format!(
359            "SET LOCAL statement_timeout = '{}s'",
360            tuning.statement_timeout_s
361        ))?;
362    }
363    if tuning.lock_timeout_s > 0 {
364        guard.client_mut().batch_execute(&format!(
365            "SET LOCAL lock_timeout = '{}s'",
366            tuning.lock_timeout_s
367        ))?;
368    }
369    // Cap FETCH N under `work_mem × 0.7` so the cursor never spills to
370    // `pgsql_tmp/`. Without this, a wide-row chunk with the default
371    // `batch_size: 50000` × ~4 KB/row = ~200 MB easily exceeds the typical
372    // `work_mem: 4 MB` and writes the entire chunk to disk before the first
373    // FETCH returns. Measured cost on the content_items bench: ~3.2 GB of
374    // temp_bytes per export, dominating the DB-side signal report.
375    let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
376
377    guard
378        .client_mut()
379        .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
380
381    // First FETCH is intentionally small — it acts as a row-width probe.
382    // Without it we can't know `arrow_bytes/row` before the cursor runs, and a
383    // single FETCH of `tuning.batch_size` × wide rows already triggers spill.
384    // 500 wide rows × 4 KB ≈ 2 MB, well under the typical work_mem of 4 MB.
385    const PROBE_FETCH_SIZE: usize = 500;
386    let configured_batch_size = tuning.batch_size;
387    let mut fetch_size = configured_batch_size.min(PROBE_FETCH_SIZE);
388    let mut fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
389    let mut work_mem_cap_applied = false;
390    let mut schema: Option<SchemaRef> = None;
391    let mut columns_cache: Option<Vec<(String, Type)>> = None;
392    let mut total_rows: usize = 0;
393    let mut base_fetch_size = fetch_size;
394    let mut adaptive_last_ckpt: Option<i64> = if tuning.adaptive {
395        pg_sample_checkpoints_req(guard.client_mut())
396    } else {
397        None
398    };
399    let mut batch_count: usize = 0;
400
401    loop {
402        // `fetch_size` may be resized mid-iteration (probe → work_mem cap
403        // adjustment). The end-of-loop "exhausted" check must compare against
404        // the size we actually requested in this round, not the resized one.
405        let requested_this_iter = fetch_size;
406        let rows = guard.client_mut().query(&fetch_sql, &[])?;
407        if rows.is_empty() {
408            break;
409        }
410
411        if schema.is_none() {
412            let stmt_cols: Vec<(String, Type)> = rows[0]
413                .columns()
414                .iter()
415                .map(|c| (c.name().to_string(), c.type_().clone()))
416                .collect();
417            let s = Arc::new(pg_columns_to_schema(
418                rows[0].columns(),
419                column_overrides,
420                numeric_hints,
421            )?);
422            sink.on_schema(s.clone())?;
423            schema = Some(s.clone());
424            columns_cache = Some(stmt_cols);
425
426            // Defer the work_mem cap to after we measure the actual batch
427            // bytes below — schema-only estimates under-count wide TEXT.
428            let effective = tuning.effective_batch_size(Some(&s));
429            if effective != fetch_size && work_mem_bytes.is_none() {
430                fetch_size = effective.max(fetch_size);
431                fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
432            }
433            base_fetch_size = fetch_size;
434        }
435
436        let row_count = rows.len();
437        total_rows += row_count;
438
439        let s = schema.as_ref().expect("schema set on first iteration");
440        let cols = columns_cache
441            .as_ref()
442            .expect("columns set on first iteration");
443        let batch = rows_to_record_batch_typed(s, cols, &rows)?;
444        drop(rows);
445
446        // After the first (small probe) batch we know the actual row width.
447        // Compute the cursor-spill-safe FETCH size from observed Arrow bytes:
448        //   pg_row_bytes ≈ arrow_per_row × 1.2  (small fudge for tuple header)
449        //   safe = work_mem × 0.7 / pg_row_bytes
450        // Subsequent FETCHes use that, clamped never to exceed the user's
451        // configured `batch_size` (so an explicit knob still wins).
452        if !work_mem_cap_applied
453            && let Some(wm) = work_mem_bytes
454            && row_count > 0
455        {
456            let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
457            let arrow_per_row = (arrow_bytes / row_count).max(1);
458            let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
459            let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
460            let mut target = safe.min(configured_batch_size);
461            if let Some(mem_mb) = tuning.batch_size_memory_mb {
462                let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
463                target = target.min(arrow_target.max(100));
464            }
465            if target != fetch_size {
466                log::info!(
467                    "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N {} → {} (configured={})",
468                    wm,
469                    arrow_per_row,
470                    pg_per_row,
471                    fetch_size,
472                    target,
473                    configured_batch_size,
474                );
475                fetch_size = target;
476                fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
477                base_fetch_size = fetch_size;
478            }
479            work_mem_cap_applied = true;
480        }
481
482        sink.on_batch(&batch)?;
483
484        batch_count += 1;
485        if tuning.adaptive
486            && batch_count.is_multiple_of(ADAPTIVE_SAMPLE_INTERVAL)
487            && let Some(cur) = pg_sample_checkpoints_req(guard.client_mut())
488        {
489            let under_pressure = adaptive_last_ckpt.is_some_and(|prev| cur > prev);
490            adaptive_last_ckpt = Some(cur);
491            let next = next_adaptive_batch_size(fetch_size, base_fetch_size, under_pressure);
492            if next != fetch_size {
493                fetch_size = next;
494                fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
495                log::info!(
496                    "adaptive batch size → {} ({})",
497                    fetch_size,
498                    if under_pressure {
499                        "pressure"
500                    } else {
501                        "recovery"
502                    }
503                );
504            }
505        }
506
507        log::info!("fetched {} rows so far...", total_rows);
508
509        if row_count < requested_this_iter {
510            break;
511        }
512
513        if tuning.throttle_ms > 0 {
514            std::thread::sleep(std::time::Duration::from_millis(tuning.throttle_ms));
515        }
516    }
517
518    // Explicit CLOSE is technically redundant — COMMIT releases the cursor —
519    // but it documents intent and surfaces any close errors before COMMIT.
520    guard.client_mut().batch_execute("CLOSE _rivet")?;
521    guard.commit()?;
522    Ok((total_rows, schema.is_some()))
523}
524
525impl super::Source for PostgresSource {
526    fn export(
527        &mut self,
528        request: &super::ExportRequest<'_>,
529        sink: &mut dyn super::BatchSink,
530    ) -> Result<()> {
531        let built = build_incremental_query(
532            request.query,
533            request.incremental,
534            request.cursor,
535            SourceType::Postgres,
536        );
537        debug_assert!(
538            built.cursor_param.is_none(),
539            "Postgres path inlines cursor values as E'…' literals — binding is unused"
540        );
541        log::debug!(
542            "executing query (connection={}): {}",
543            if self.transaction_pooler {
544                "transaction-pooler"
545            } else {
546                "direct"
547            },
548            built.sql
549        );
550
551        let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, request.query);
552
553        // PgTxnGuard inside pg_run_export rolls the txn back automatically on
554        // any error or panic, so no explicit ROLLBACK is needed here.
555        let (total_rows, had_schema) = pg_run_export(
556            &mut self.client,
557            &built.sql,
558            request.tuning,
559            request.column_overrides,
560            sink,
561            numeric_hints.as_ref(),
562        )?;
563
564        if !had_schema {
565            sink.on_schema(Arc::new(Schema::empty()))?;
566        }
567
568        log::info!("total: {} rows", total_rows);
569        Ok(())
570    }
571
572    fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
573        let rows = self.client.query(sql, &[])?;
574        if rows.is_empty() {
575            return Ok(None);
576        }
577        let row = &rows[0];
578        if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
579            return Ok(Some(v.to_string()));
580        }
581        if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
582            return Ok(Some(v.to_string()));
583        }
584        if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
585            return Ok(Some(v.to_string()));
586        }
587        // TIMESTAMP / DATE / TIMESTAMPTZ — required for MIN/MAX on time columns (e.g. chunk_by_days)
588        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
589            return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
590        }
591        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
592            return Ok(Some(v.format("%Y-%m-%d").to_string()));
593        }
594        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
595            return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
596        }
597        if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
598            return Ok(Some(v));
599        }
600        Ok(None)
601    }
602
603    fn type_mappings(
604        &mut self,
605        query: &str,
606        column_overrides: &ColumnOverrides,
607    ) -> Result<Vec<TypeMapping>> {
608        let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
609        let stmt = self.client.prepare(&wrapped)?;
610        let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
611        let mappings = stmt
612            .columns()
613            .iter()
614            .map(|col| {
615                let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
616                let source = SourceColumn::simple(col.name(), col.type_().name(), true);
617                TypeMapping::from_source(&source, rivet)
618            })
619            .collect();
620        Ok(mappings)
621    }
622}
623
624/// When the query is a single-table `SELECT … FROM rel` (no joins, no subquery
625/// in `FROM`), PostgreSQL result metadata does not carry `NUMERIC` typmod, but
626/// `information_schema` / the table DDL does. We resolve the base relation with
627/// a small parser and fetch declared precision/scale so `rivet init`-style
628/// exports work without hand-written `columns:` overrides.
629fn pg_numeric_catalog_hints_opt(
630    client: &mut Client,
631    query: &str,
632) -> Option<HashMap<String, (u8, i8)>> {
633    match pg_fetch_numeric_catalog_hints(client, query) {
634        Ok(m) => m,
635        Err(e) => {
636            // Reaching this arm means the parser identified a single-table query
637            // and we tried catalog lookup, but the lookup itself failed. That is
638            // unexpected (not "this query has a JOIN"), so surface it — otherwise
639            // a downstream NUMERIC mapping failure looks like a config problem
640            // when the real cause is here.
641            log::warn!(
642                "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
643            );
644            None
645        }
646    }
647}
648
649fn pg_fetch_numeric_catalog_hints(
650    client: &mut Client,
651    query: &str,
652) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
653    let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
654        return Ok(None);
655    };
656    let locate_sql = "SELECT n.nspname::text, c.relname::text \
657         FROM pg_catalog.pg_class c \
658         JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
659         WHERE c.oid = ($1::text)::regclass";
660    let row_opt = match client.query_opt(locate_sql, &[&regclass_lit]) {
661        Ok(r) => r,
662        Err(e) => {
663            log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
664            return Ok(None);
665        }
666    };
667    let Some(row) = row_opt else {
668        return Ok(None);
669    };
670    let schema: String = row.get(0);
671    let table: String = row.get(1);
672    let rows = client.query(
673        "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
674             FROM information_schema.columns \
675             WHERE table_schema = $1 AND table_name = $2 \
676             ORDER BY ordinal_position",
677        &[&schema, &table],
678    )?;
679
680    let mut map = HashMap::new();
681    for row in rows {
682        let col: String = row.get(0);
683        let dt: String = row.get(1);
684        if !is_pg_numeric_information_type(&dt) {
685            continue;
686        }
687        let p: Option<i32> = row.get(2);
688        let s: Option<i32> = row.get(3);
689        if let (Some(p), Some(s)) = (p, s)
690            && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
691        {
692            map.insert(col, pair);
693        }
694    }
695
696    if map.is_empty() {
697        Ok(None)
698    } else {
699        log::debug!(
700            "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
701            map.len(),
702        );
703        Ok(Some(map))
704    }
705}
706
707fn is_pg_numeric_information_type(dt: &str) -> bool {
708    let d = dt.trim().to_ascii_lowercase();
709    matches!(d.as_str(), "numeric" | "decimal")
710        || d.starts_with("numeric(")
711        || d.starts_with("decimal(")
712}
713
714/// Match Rivet YAML `decimal(p,s)` / Arrow limits (same bound as overrides).
715fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
716    if precision <= 0 || precision > 76 {
717        return None;
718    }
719    let precision_u = precision as u8;
720    if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
721        return None;
722    }
723    let scale_i = scale as i8;
724    if scale_i > precision as i8 {
725        return None;
726    }
727    Some((precision_u, scale_i))
728}
729
730#[cfg(test)]
731mod tests {
732    use super::catalog_numeric_to_decimal_params;
733
734    // FROM-clause parser tests live in `from_parse.rs` alongside the parser.
735
736    #[test]
737    fn catalog_decimal_bounds() {
738        assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
739        assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
740        assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
741        assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
742    }
743
744    #[test]
745    fn parse_work_mem_handles_pg_units() {
746        use super::parse_work_mem;
747        // Postgres SHOW work_mem normally returns "<N>kB", "<N>MB", "<N>GB".
748        // A bare integer is interpreted as kB (matches postgresql.conf parsing).
749        assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
750        assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
751        assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
752        assert_eq!(parse_work_mem("  4MB  "), Some(4 * 1024 * 1024));
753        assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
754        assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
755        assert_eq!(parse_work_mem(""), None);
756        assert_eq!(parse_work_mem("garbage"), None);
757        // We don't accept seconds / units PG would never emit for work_mem.
758        assert_eq!(parse_work_mem("4s"), None);
759    }
760}