Skip to main content

rivet/source/postgres/
mod.rs

1//! PostgreSQL `Source` implementation.
2//!
3//! Module layout:
4//!
5//! - `mod.rs` (this file) — `PostgresSource` struct + connect/TLS path, the
6//!   transaction-pooler detector, `PgTxnGuard`, sampling helpers
7//!   (`sample_temp_bytes`, `pg_sample_checkpoints_req`, `pg_fetch_work_mem_bytes`),
8//!   `introspect_pg_table_for_chunking`, the cursor + FETCH export loop
9//!   (`pg_run_export`), the `Source` trait impl, and the catalog-hint
10//!   resolver that bridges parsed FROM clauses to `pg_catalog`.
11//! - [`arrow_convert`] — the entire row → Arrow `RecordBatch` pipeline: type
12//!   mapping (`pg_columns_to_schema`, `rivet_type_for_pg_column`), per-cell
13//!   decoders (INTERVAL, UUID, enum, NUMERIC), and the array builders. Kept
14//!   in a sibling because it is the largest single-purpose cluster in this
15//!   driver (~620 LoC) and has zero reverse dependency back into the
16//!   connection / cursor layer.
17//! - [`from_parse`] — pure `&str`/`&[u8]` parser that extracts the simple
18//!   `<schema>.<table>` literal from a user query so the catalog-hint path
19//!   can cast it to `regclass`.  Zero postgres-crate dependency, fully
20//!   unit-tested in isolation.
21
22mod arrow_convert;
23mod from_parse;
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use postgres::types::Type;
30use postgres::{Client, NoTls};
31
32use crate::config::{SourceType, TlsConfig};
33use crate::error::Result;
34use crate::source::batch_controller::AdaptiveBatchController;
35use crate::source::query::build_export_query;
36use crate::source::tls::build_native_tls;
37use crate::tuning::SourceTuning;
38use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
39
40use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
41use from_parse::try_parse_pg_simple_from_regclass_literal;
42
43pub struct PostgresSource {
44    client: Client,
45    /// True when two consecutive pg_backend_pid() calls returned different values,
46    /// indicating a transaction-mode connection pooler (pgBouncer, Odyssey, etc.).
47    transaction_pooler: bool,
48}
49
50/// Detect whether the connection is going through a transaction-mode pooler
51/// (pgBouncer, Odyssey, etc.) by comparing backend PIDs across two implicit
52/// transactions. Returns true when PIDs differ — impossible on a direct
53/// connection or session-mode pooler where the same physical backend is kept.
54///
55/// False negatives are possible when pool_size = 1 (the same backend is always
56/// reused), so this is a best-effort warning rather than a hard guarantee.
57fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
58    let pid1: Option<i32> = client
59        .query_one("SELECT pg_backend_pid()", &[])
60        .ok()
61        .and_then(|r| r.try_get(0).ok());
62    let pid2: Option<i32> = client
63        .query_one("SELECT pg_backend_pid()", &[])
64        .ok()
65        .and_then(|r| r.try_get(0).ok());
66    matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
67}
68
69impl PostgresSource {
70    /// Connect with no transport security (legacy path). Prefer [`Self::connect_with_tls`]
71    /// for production workloads so credentials and result sets are not visible on the wire.
72    pub fn connect(url: &str) -> Result<Self> {
73        let mut client = Client::connect(url, NoTls)?;
74        let transaction_pooler = detect_pg_transaction_pooler(&mut client);
75        if transaction_pooler {
76            log::warn!(
77                "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
78                 SET LOCAL tuning is transaction-scoped; \
79                 LISTEN/NOTIFY and advisory locks are unavailable"
80            );
81        }
82        Ok(Self {
83            client,
84            transaction_pooler,
85        })
86    }
87
88    /// Connect honoring the user's [`TlsConfig`]. When `tls.mode` is
89    /// [`TlsMode::Disable`] this falls back to [`Self::connect`].
90    pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
91        // Refuse remote plaintext (no `tls:` block) before any dial (CWE-319).
92        crate::source::require_tls_or_loopback(url, tls)?;
93        match tls {
94            Some(cfg) if cfg.mode.is_enforced() => {
95                let connector = build_native_tls(cfg)?;
96                let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
97                let mut client = Client::connect(url, make_tls)?;
98                let transaction_pooler = detect_pg_transaction_pooler(&mut client);
99                if transaction_pooler {
100                    log::warn!(
101                        "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
102                         SET LOCAL tuning is transaction-scoped; \
103                         LISTEN/NOTIFY and advisory locks are unavailable"
104                    );
105                }
106                Ok(Self {
107                    client,
108                    transaction_pooler,
109                })
110            }
111            _ => Self::connect(url),
112        }
113    }
114}
115
116/// RAII guard for an open `BEGIN ... COMMIT` block.
117///
118/// `commit()` runs `COMMIT` and marks the txn done; if the guard is dropped
119/// before `commit()` (early return, `?`-bubbled error, or panic-driven unwind),
120/// `Drop` issues a best-effort `ROLLBACK`. Postgres releases any open cursors
121/// as part of ROLLBACK, so the cursor declared inside the txn is also cleaned
122/// up. Closes the **G1** gap from the DBA audit (cursor leak on panic).
123struct PgTxnGuard<'a> {
124    client: &'a mut Client,
125    committed: bool,
126}
127
128impl<'a> PgTxnGuard<'a> {
129    fn begin(client: &'a mut Client) -> Result<Self> {
130        client.batch_execute("BEGIN")?;
131        Ok(Self {
132            client,
133            committed: false,
134        })
135    }
136
137    fn client_mut(&mut self) -> &mut Client {
138        self.client
139    }
140
141    fn commit(mut self) -> Result<()> {
142        self.client.batch_execute("COMMIT")?;
143        self.committed = true;
144        Ok(())
145    }
146}
147
148impl Drop for PgTxnGuard<'_> {
149    fn drop(&mut self) {
150        if !self.committed
151            && let Err(e) = self.client.batch_execute("ROLLBACK")
152        {
153            // Drop must not panic. Worst case the connection is poisoned and
154            // the pool recycles it; log so operators see it.
155            log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
156        }
157    }
158}
159
160/// Snapshot `pg_stat_database.temp_bytes` for the current database.
161///
162/// Used by the pipeline job to compute per-run cursor / sort spill: we capture
163/// the cluster-wide counter immediately before and after each export and
164/// surface the delta on the run summary card. Failures (connect, query) return
165/// `None` — the metric is informational, not a correctness signal.
166///
167/// Note this is a cluster-level counter: concurrent activity from other
168/// connections during the run inflates the delta. For a single-tenant test
169/// box (the common pilot setup) it is accurate; for shared hosts it is a
170/// noisy upper bound, useful as a "your workload was loud" signal.
171pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
172    let mut client = connect_client(url, tls).ok()?;
173    client
174        .query_one(
175            "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
176            &[],
177        )
178        .ok()
179        .and_then(|r| r.try_get::<_, i64>(0).ok())
180}
181
182/// Snapshot the broader source-harm counters from `pg_stat_database` for the
183/// current database — a superset of [`sample_temp_bytes`] (which the run summary
184/// tracks on its own). Returns `(metric, cumulative_value)` pairs; the pipeline
185/// captures these before and after the export and stores the per-metric delta in
186/// `export_harm`.
187///
188/// All counters live in `pg_stat_database` and are readable by **any** role — no
189/// `pg_monitor` membership or superuser needed (unlike `pg_stat_activity`'s view
190/// of other sessions). These are cluster-level cumulative counters, so concurrent
191/// activity inflates the delta; on a single-tenant pilot box it is the run's own
192/// footprint. `None` on connect/query failure — informational, never blocks the
193/// export.
194pub(crate) fn sample_harm_counters(
195    url: &str,
196    tls: Option<&TlsConfig>,
197) -> Option<Vec<(String, i64)>> {
198    let mut client = connect_client(url, tls).ok()?;
199    // `tup_returned` (rows the engine had to scan) is the read-amplification
200    // signal; `blks_read`/`blks_hit` the I/O vs cache split; `temp_files` the
201    // spill count; `deadlocks` contention. temp_bytes is intentionally omitted —
202    // it's already on the run summary (export_metrics.pg_temp_bytes_delta).
203    let row = client
204        .query_one(
205            "SELECT blks_read::bigint, blks_hit::bigint, tup_returned::bigint, \
206             tup_fetched::bigint, temp_files::bigint, deadlocks::bigint \
207             FROM pg_stat_database WHERE datname = current_database()",
208            &[],
209        )
210        .ok()?;
211    let names = [
212        "pg_blks_read",
213        "pg_blks_hit",
214        "pg_tup_returned",
215        "pg_tup_fetched",
216        "pg_temp_files",
217        "pg_deadlocks",
218    ];
219    let mut out = Vec::with_capacity(names.len());
220    for (i, name) in names.iter().enumerate() {
221        if let Ok(v) = row.try_get::<_, i64>(i) {
222            out.push(((*name).to_string(), v));
223        }
224    }
225    Some(out)
226}
227
228/// Probe `SHOW work_mem` and return the value in bytes.
229///
230/// PostgreSQL spills FETCH-cursor output to `pgsql_tmp/` once the in-flight
231/// row set exceeds `work_mem` — on wide rows with the default 4 MB the spill
232/// fires on every chunk and dominates `pg_stat_database.temp_bytes`. Knowing
233/// the value lets the cursor loop cap FETCH N below `work_mem × 0.7`, keeping
234/// the result set in memory.
235///
236/// Returns None on any parse / query failure — the cursor loop falls back to
237/// the configured static batch_size in that case.
238fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
239    let raw: Option<String> = client
240        .query_one("SHOW work_mem", &[])
241        .ok()
242        .and_then(|r| r.try_get::<_, String>(0).ok());
243    raw.as_deref().and_then(parse_work_mem)
244}
245
246/// Parse a `SHOW work_mem` value like `"4MB"`, `"16384kB"`, `"1GB"`, or a bare
247/// number-of-kB string (the older PG default unit) into a byte count. Returns
248/// `None` for anything else so callers can decide whether to fall back.
249fn parse_work_mem(raw: &str) -> Option<i64> {
250    let s = raw.trim();
251    // Split numeric prefix from optional unit.
252    let mut split = 0;
253    for (i, ch) in s.char_indices() {
254        if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
255            split = i;
256            break;
257        }
258        split = i + ch.len_utf8();
259    }
260    if split == 0 {
261        return None;
262    }
263    let (num_str, unit) = s.split_at(split);
264    let num: f64 = num_str.parse().ok()?;
265    let unit = unit.trim().to_ascii_lowercase();
266    let multiplier: f64 = match unit.as_str() {
267        // Postgres always uses 1024-based units, matching the syntax it
268        // accepts in postgresql.conf.
269        "" | "kb" => 1024.0,
270        "mb" => 1024.0 * 1024.0,
271        "gb" => 1024.0 * 1024.0 * 1024.0,
272        "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
273        _ => return None,
274    };
275    let bytes = (num * multiplier) as i64;
276    (bytes > 0).then_some(bytes)
277}
278
279/// Sample `checkpoints_req` from `pg_stat_bgwriter`.
280///
281/// PostgreSQL caches the statistics snapshot at the start of each transaction.
282/// We call `pg_stat_clear_snapshot()` first to discard that cache so every
283/// adaptive sample sees fresh counters rather than the frozen value from BEGIN.
284fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
285    let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
286    client
287        .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
288        .ok()
289        .and_then(|r| r.try_get::<_, i64>(0).ok())
290}
291
292/// Probe `pg_class` and `pg_index` for the stats chunked-mode planning needs.
293///
294/// Returns a [`crate::source::TableIntrospection`] populated from one connection
295/// (two round-trips total: one stats query, one PK query). Failure to connect
296/// or to query bubbles up as `Err`; missing rows or unanalyzed tables are
297/// represented as zero/None in the result so callers can decide policy.
298///
299/// The `qualified_table` argument is `<schema>.<table>` (e.g. `public.users`)
300/// or bare `<table>` (resolved under `public`). It is split internally with
301/// the same strict rules as the `table:` YAML shortcut — anything more
302/// elaborate must use the explicit-column path.
303pub(crate) fn introspect_pg_table_for_chunking(
304    url: &str,
305    tls: Option<&TlsConfig>,
306    qualified_table: &str,
307) -> Result<crate::source::TableIntrospection> {
308    let (schema, table) = match qualified_table.split_once('.') {
309        Some((s, t)) => (s.to_string(), t.to_string()),
310        None => ("public".to_string(), qualified_table.to_string()),
311    };
312    let mut client = connect_client(url, tls)?;
313
314    // ── reltuples + heap size, in one shot ──────────────────────────────
315    let (row_estimate, rel_size_bytes) = match client.query_opt(
316        "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
317         FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
318         WHERE n.nspname = $1::text AND c.relname = $2::text",
319        &[&schema, &table],
320    )? {
321        Some(row) => {
322            let rt: i64 = row.try_get(0).unwrap_or(0);
323            let sz: i64 = row.try_get(1).unwrap_or(0);
324            (rt.max(0), sz.max(0))
325        }
326        None => (0, 0),
327    };
328    let avg_row_bytes = if row_estimate > 0 {
329        Some(rel_size_bytes / row_estimate)
330    } else {
331        None
332    };
333
334    // ── single int PK probe ─────────────────────────────────────────────
335    let pk_rows = client.query(
336        "SELECT a.attname::text, t.typname::text \
337         FROM pg_index i \
338         JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
339         JOIN pg_type t ON t.oid = a.atttypid \
340         WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
341           AND i.indisprimary",
342        &[&schema, &table],
343    )?;
344    let single_int_pk = if pk_rows.len() == 1 {
345        let col: String = pk_rows[0].get(0);
346        let pg_type: String = pk_rows[0].get(1);
347        // Only integer-family types are safe for range chunking via min/max →
348        // BETWEEN slicing. Text/UUID/decimal would need different splitting
349        // logic and are excluded from auto-resolution.
350        if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
351            Some(col)
352        } else {
353            log::debug!(
354                "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
355            );
356            None
357        }
358    } else {
359        None
360    };
361
362    // ── keyset keys (OPT-4): single-column, NOT NULL, UNIQUE indexes ────
363    // `indnkeyatts = 1` keeps single-column indexes; `indkey[0] = a.attnum`
364    // binds to a real column (not an expression index); `attnotnull` removes
365    // NULL-ordering ambiguity. Index-backed + unique ⇒ keyset's `ORDER BY key
366    // LIMIT n` is a range scan and `WHERE key > last` never skips dup keys.
367    let keyset_rows = client.query(
368        "SELECT a.attname::text, i.indisprimary \
369         FROM pg_index i \
370         JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = i.indkey[0] \
371         WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
372           AND i.indisunique AND i.indnkeyatts = 1 AND a.attnotnull",
373        &[&schema, &table],
374    )?;
375    let mut keyset_keys: Vec<String> = Vec::new();
376    for primary in [true, false] {
377        for row in &keyset_rows {
378            let col: String = row.get(0);
379            let is_primary: bool = row.get(1);
380            if is_primary == primary && !keyset_keys.contains(&col) {
381                keyset_keys.push(col);
382            }
383        }
384    }
385
386    Ok(crate::source::TableIntrospection {
387        single_int_pk,
388        keyset_keys,
389        row_estimate,
390        avg_row_bytes,
391    })
392}
393
394/// Open a bare `postgres::Client` honoring the configured TLS policy.
395///
396/// Shared by preflight, doctor, and `rivet init` so every code path that
397/// connects to Postgres applies the same transport-security rules. Preflight
398/// and doctor pass the YAML `tls:` block; init runs before any YAML exists,
399/// so it derives a `TlsConfig` from the URL's `sslmode` parameter (see
400/// `crate::init::postgres::connect`). `tls = None` or `mode: disable` falls
401/// back to the insecure `NoTls` transport — a warning is logged from
402/// `create_source` so operators know TLS is off.
403pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
404    // Refuse remote plaintext (no `tls:` block) before any dial (CWE-319).
405    crate::source::require_tls_or_loopback(url, tls)?;
406    match tls {
407        Some(cfg) if cfg.mode.is_enforced() => {
408            let connector = build_native_tls(cfg)?;
409            let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
410            Ok(Client::connect(url, make_tls)?)
411        }
412        _ => Ok(Client::connect(url, NoTls)?),
413    }
414}
415
416/// Run the full export transaction against an open Postgres client.
417///
418/// All session-mutating SET commands use SET LOCAL so they are scoped to
419/// the transaction and reset automatically on COMMIT or ROLLBACK. The caller
420/// is responsible for issuing ROLLBACK if this function returns Err.
421///
422/// Returns (total_rows, had_schema). had_schema is false only when the query
423/// returned zero rows; the caller must emit an empty schema in that case.
424fn pg_run_export(
425    client: &mut Client,
426    built_sql: &str,
427    tuning: &SourceTuning,
428    column_overrides: &ColumnOverrides,
429    sink: &mut dyn super::BatchSink,
430    numeric_hints: Option<&HashMap<String, (u8, i8)>>,
431) -> Result<(usize, bool)> {
432    // Open the txn under guard *first* — if SET LOCAL or DECLARE fails below,
433    // Drop will roll back. Without the guard, a failure between BEGIN and the
434    // explicit ROLLBACK in the caller would leak a half-set-up txn into the pool.
435    let mut guard = PgTxnGuard::begin(client)?;
436    if tuning.statement_timeout_s > 0 {
437        guard.client_mut().batch_execute(&format!(
438            "SET LOCAL statement_timeout = '{}s'",
439            tuning.statement_timeout_s
440        ))?;
441    }
442    if tuning.lock_timeout_s > 0 {
443        guard.client_mut().batch_execute(&format!(
444            "SET LOCAL lock_timeout = '{}s'",
445            tuning.lock_timeout_s
446        ))?;
447    }
448    // Cap FETCH N under `work_mem × 0.7` so the cursor never spills to
449    // `pgsql_tmp/`. Without this, a wide-row chunk with the default
450    // `batch_size: 50000` × ~4 KB/row = ~200 MB easily exceeds the typical
451    // `work_mem: 4 MB` and writes the entire chunk to disk before the first
452    // FETCH returns. Measured cost on the content_items bench: ~3.2 GB of
453    // temp_bytes per export, dominating the DB-side signal report.
454    let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
455
456    guard
457        .client_mut()
458        .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
459
460    // The first FETCH is intentionally a small `PROBE_BATCH_SIZE` row-width
461    // probe (the controller starts there): without it we can't know
462    // `arrow_bytes/row` before the cursor runs, and a single FETCH of
463    // `tuning.batch_size` × wide rows already triggers a `pgsql_tmp/` spill.
464    let configured_batch_size = tuning.batch_size;
465    // Shared batch-size state machine; PG provides the FETCH N row source, the
466    // work_mem (or schema-derived) cap target, and the checkpoint pressure proxy.
467    let mut ctl = AdaptiveBatchController::new(tuning, configured_batch_size);
468    ctl.seed_pressure(if tuning.adaptive {
469        pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64)
470    } else {
471        None
472    });
473    let mut schema: Option<SchemaRef> = None;
474    let mut columns_cache: Option<Vec<(String, Type)>> = None;
475    let mut total_rows: usize = 0;
476    let mut cap_applied = false;
477    // Per-value ceiling (MB→bytes; `0`/None disables), enforced pre-allocation
478    // inside the batch builder so an oversized cell bails before Arrow reserves
479    // the buffer. Same source of truth as the sink's backstop guard.
480    let max_value_bytes = tuning.max_value_bytes();
481
482    loop {
483        let requested = ctl.target();
484        let fetch_sql = format!("FETCH {} FROM _rivet", requested);
485        let rows = guard.client_mut().query(&fetch_sql, &[])?;
486        if rows.is_empty() {
487            break;
488        }
489
490        if schema.is_none() {
491            let stmt_cols: Vec<(String, Type)> = rows[0]
492                .columns()
493                .iter()
494                .map(|c| (c.name().to_string(), c.type_().clone()))
495                .collect();
496            let s = Arc::new(pg_columns_to_schema(
497                rows[0].columns(),
498                column_overrides,
499                numeric_hints,
500            )?);
501            sink.on_schema(s.clone())?;
502            // When work_mem can't be read, fall back to the schema-derived
503            // effective batch size as the cap target (controller clamps it).
504            if work_mem_bytes.is_none() {
505                let effective = tuning.effective_batch_size(Some(&s));
506                ctl.apply_memory_cap(effective.max(requested));
507                cap_applied = true;
508            }
509            schema = Some(s);
510            columns_cache = Some(stmt_cols);
511        }
512
513        let row_count = rows.len();
514        total_rows += row_count;
515
516        let s = schema.as_ref().expect("schema set on first iteration");
517        let cols = columns_cache
518            .as_ref()
519            .expect("columns set on first iteration");
520        let batch = rows_to_record_batch_typed(s, cols, &rows, max_value_bytes)?;
521        drop(rows);
522
523        // After the first (probe) batch we know the actual row width. Cap the
524        // FETCH N below `work_mem × 0.7` so the cursor never spills:
525        //   pg_row_bytes ≈ arrow_per_row × 1.2 ; safe = work_mem×0.7 / pg_row_bytes
526        // The controller clamps it to the configured `batch_size`.
527        if !cap_applied
528            && let Some(wm) = work_mem_bytes
529            && row_count > 0
530        {
531            let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
532            let arrow_per_row = (arrow_bytes / row_count).max(1);
533            let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
534            let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
535            let mut target = safe;
536            if let Some(mem_mb) = tuning.batch_size_memory_mb {
537                let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
538                target = target.min(arrow_target.max(100));
539            }
540            if let Some(new) = ctl.apply_memory_cap(target) {
541                log::info!(
542                    "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N → {} (configured={})",
543                    wm,
544                    arrow_per_row,
545                    pg_per_row,
546                    new,
547                    configured_batch_size,
548                );
549            }
550            cap_applied = true;
551        }
552
553        sink.on_batch(&batch)?;
554
555        if let Some((new, under_pressure)) =
556            ctl.after_batch(|| pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64))
557        {
558            log::info!(
559                "adaptive batch size → {} ({})",
560                new,
561                if under_pressure {
562                    "pressure"
563                } else {
564                    "recovery"
565                }
566            );
567        }
568
569        log::info!("fetched {} rows so far...", total_rows);
570
571        if row_count < requested {
572            break;
573        }
574        ctl.throttle();
575    }
576
577    // Explicit CLOSE is technically redundant — COMMIT releases the cursor —
578    // but it documents intent and surfaces any close errors before COMMIT.
579    guard.client_mut().batch_execute("CLOSE _rivet")?;
580    guard.commit()?;
581    Ok((total_rows, schema.is_some()))
582}
583
584impl super::Source for PostgresSource {
585    fn export(
586        &mut self,
587        request: &super::ExportRequest<'_>,
588        sink: &mut dyn super::BatchSink,
589    ) -> Result<()> {
590        let built = build_export_query(request, SourceType::Postgres);
591        debug_assert!(
592            built.cursor_param.is_none(),
593            "Postgres path inlines cursor values as E'…' literals — binding is unused"
594        );
595        log::debug!(
596            "executing query (connection={}): {}",
597            if self.transaction_pooler {
598                "transaction-pooler"
599            } else {
600                "direct"
601            },
602            built.sql
603        );
604
605        // Resolve NUMERIC precision from the *unwrapped* base query when the
606        // caller wrapped `query` in a chunk/keyset subquery (which hides the
607        // source table from the catalog parser). Falls back to `query`.
608        let hint_query = request.catalog_hint_query.unwrap_or(request.query);
609        let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, hint_query);
610
611        // PgTxnGuard inside pg_run_export rolls the txn back automatically on
612        // any error or panic, so no explicit ROLLBACK is needed here.
613        let (total_rows, had_schema) = pg_run_export(
614            &mut self.client,
615            &built.sql,
616            request.tuning,
617            request.column_overrides,
618            sink,
619            numeric_hints.as_ref(),
620        )?;
621
622        if !had_schema {
623            sink.on_schema(Arc::new(Schema::empty()))?;
624        }
625
626        log::info!("total: {} rows", total_rows);
627        Ok(())
628    }
629
630    fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
631        let rows = self.client.query(sql, &[])?;
632        if rows.is_empty() {
633            return Ok(None);
634        }
635        let row = &rows[0];
636        if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
637            return Ok(Some(v.to_string()));
638        }
639        if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
640            return Ok(Some(v.to_string()));
641        }
642        if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
643            return Ok(Some(v.to_string()));
644        }
645        // TIMESTAMP / DATE / TIMESTAMPTZ — required for MIN/MAX on time columns (e.g. chunk_by_days)
646        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
647            return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
648        }
649        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
650            return Ok(Some(v.format("%Y-%m-%d").to_string()));
651        }
652        if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
653            return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
654        }
655        if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
656            return Ok(Some(v));
657        }
658        Ok(None)
659    }
660
661    fn type_mappings(
662        &mut self,
663        query: &str,
664        column_overrides: &ColumnOverrides,
665    ) -> Result<Vec<TypeMapping>> {
666        let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
667        let stmt = self.client.prepare(&wrapped)?;
668        let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
669        let mappings = stmt
670            .columns()
671            .iter()
672            .map(|col| {
673                let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
674                let source = SourceColumn::simple(col.name(), col.type_().name(), true);
675                TypeMapping::from_source(&source, rivet)
676            })
677            .collect();
678        Ok(mappings)
679    }
680
681    /// Governor pressure proxy: `pg_stat_bgwriter.checkpoints_req` — the same
682    /// monotonic counter the adaptive batch loop samples. Rising between samples
683    /// means the source is checkpointing harder under write pressure.
684    fn sample_pressure(&mut self) -> Option<u64> {
685        pg_sample_checkpoints_req(&mut self.client).map(|v| v.max(0) as u64)
686    }
687}
688
689/// When the query is a single-table `SELECT … FROM rel` (no joins, no subquery
690/// in `FROM`), PostgreSQL result metadata does not carry `NUMERIC` typmod, but
691/// `information_schema` / the table DDL does. We resolve the base relation with
692/// a small parser and fetch declared precision/scale so `rivet init`-style
693/// exports work without hand-written `columns:` overrides.
694fn pg_numeric_catalog_hints_opt(
695    client: &mut Client,
696    query: &str,
697) -> Option<HashMap<String, (u8, i8)>> {
698    match pg_fetch_numeric_catalog_hints(client, query) {
699        Ok(m) => m,
700        Err(e) => {
701            // Reaching this arm means the parser identified a single-table query
702            // and we tried catalog lookup, but the lookup itself failed. That is
703            // unexpected (not "this query has a JOIN"), so surface it — otherwise
704            // a downstream NUMERIC mapping failure looks like a config problem
705            // when the real cause is here.
706            log::warn!(
707                "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
708            );
709            None
710        }
711    }
712}
713
714fn pg_fetch_numeric_catalog_hints(
715    client: &mut Client,
716    query: &str,
717) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
718    let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
719        return Ok(None);
720    };
721    let locate_sql = "SELECT n.nspname::text, c.relname::text \
722         FROM pg_catalog.pg_class c \
723         JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
724         WHERE c.oid = ($1::text)::regclass";
725    let row_opt = match client.query_opt(locate_sql, &[&regclass_lit]) {
726        Ok(r) => r,
727        Err(e) => {
728            log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
729            return Ok(None);
730        }
731    };
732    let Some(row) = row_opt else {
733        return Ok(None);
734    };
735    let schema: String = row.get(0);
736    let table: String = row.get(1);
737    let rows = client.query(
738        "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
739             FROM information_schema.columns \
740             WHERE table_schema = $1 AND table_name = $2 \
741             ORDER BY ordinal_position",
742        &[&schema, &table],
743    )?;
744
745    let mut map = HashMap::new();
746    for row in rows {
747        let col: String = row.get(0);
748        let dt: String = row.get(1);
749        if !is_pg_numeric_information_type(&dt) {
750            continue;
751        }
752        let p: Option<i32> = row.get(2);
753        let s: Option<i32> = row.get(3);
754        if let (Some(p), Some(s)) = (p, s)
755            && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
756        {
757            map.insert(col, pair);
758        }
759    }
760
761    if map.is_empty() {
762        Ok(None)
763    } else {
764        log::debug!(
765            "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
766            map.len(),
767        );
768        Ok(Some(map))
769    }
770}
771
772fn is_pg_numeric_information_type(dt: &str) -> bool {
773    let d = dt.trim().to_ascii_lowercase();
774    matches!(d.as_str(), "numeric" | "decimal")
775        || d.starts_with("numeric(")
776        || d.starts_with("decimal(")
777}
778
779/// Match Rivet YAML `decimal(p,s)` / Arrow limits (same bound as overrides).
780fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
781    if precision <= 0 || precision > 76 {
782        return None;
783    }
784    let precision_u = precision as u8;
785    if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
786        return None;
787    }
788    let scale_i = scale as i8;
789    if scale_i > precision as i8 {
790        return None;
791    }
792    Some((precision_u, scale_i))
793}
794
795#[cfg(test)]
796mod tests {
797    use super::catalog_numeric_to_decimal_params;
798
799    // FROM-clause parser tests live in `from_parse.rs` alongside the parser.
800
801    #[test]
802    fn catalog_decimal_bounds() {
803        assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
804        assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
805        assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
806        assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
807    }
808
809    #[test]
810    fn parse_work_mem_handles_pg_units() {
811        use super::parse_work_mem;
812        // Postgres SHOW work_mem normally returns "<N>kB", "<N>MB", "<N>GB".
813        // A bare integer is interpreted as kB (matches postgresql.conf parsing).
814        assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
815        assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
816        assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
817        assert_eq!(parse_work_mem("  4MB  "), Some(4 * 1024 * 1024));
818        assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
819        assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
820        assert_eq!(parse_work_mem(""), None);
821        assert_eq!(parse_work_mem("garbage"), None);
822        // We don't accept seconds / units PG would never emit for work_mem.
823        assert_eq!(parse_work_mem("4s"), None);
824    }
825}