1mod arrow_convert;
23pub(crate) mod cdc;
24mod from_parse;
25
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use arrow::datatypes::{Schema, SchemaRef};
30use postgres::types::Type;
31use postgres::{Client, NoTls};
32
33use crate::config::{SourceType, TlsConfig};
34use crate::error::Result;
35use crate::source::batch_controller::AdaptiveBatchController;
36use crate::source::query::build_export_query;
37use crate::source::tls::build_native_tls;
38use crate::tuning::SourceTuning;
39use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
40
41use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
42use from_parse::try_parse_pg_simple_from_regclass_literal;
43
44pub struct PostgresSource {
45 client: Client,
46 transaction_pooler: bool,
49}
50
51fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
59 let pid1: Option<i32> = client
60 .query_one("SELECT pg_backend_pid()", &[])
61 .ok()
62 .and_then(|r| r.try_get(0).ok());
63 let pid2: Option<i32> = client
64 .query_one("SELECT pg_backend_pid()", &[])
65 .ok()
66 .and_then(|r| r.try_get(0).ok());
67 matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
68}
69
70impl PostgresSource {
71 pub fn connect(url: &str) -> Result<Self> {
74 let mut client = Client::connect(url, NoTls)?;
75 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
76 if transaction_pooler {
77 log::warn!(
78 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
79 SET LOCAL tuning is transaction-scoped; \
80 LISTEN/NOTIFY and advisory locks are unavailable"
81 );
82 }
83 Ok(Self {
84 client,
85 transaction_pooler,
86 })
87 }
88
89 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
92 crate::source::require_tls_or_loopback(url, tls)?;
94 match tls {
95 Some(cfg) if cfg.mode.is_enforced() => {
96 let connector = build_native_tls(cfg)?;
97 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
98 let mut client = Client::connect(url, make_tls)?;
99 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
100 if transaction_pooler {
101 log::warn!(
102 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
103 SET LOCAL tuning is transaction-scoped; \
104 LISTEN/NOTIFY and advisory locks are unavailable"
105 );
106 }
107 Ok(Self {
108 client,
109 transaction_pooler,
110 })
111 }
112 _ => Self::connect(url),
113 }
114 }
115}
116
117struct PgTxnGuard<'a> {
125 client: &'a mut Client,
126 committed: bool,
127}
128
129impl<'a> PgTxnGuard<'a> {
130 fn begin(client: &'a mut Client) -> Result<Self> {
131 client.batch_execute("BEGIN")?;
132 Ok(Self {
133 client,
134 committed: false,
135 })
136 }
137
138 fn client_mut(&mut self) -> &mut Client {
139 self.client
140 }
141
142 fn commit(mut self) -> Result<()> {
143 self.client.batch_execute("COMMIT")?;
144 self.committed = true;
145 Ok(())
146 }
147}
148
149impl Drop for PgTxnGuard<'_> {
150 fn drop(&mut self) {
151 if !self.committed
152 && let Err(e) = self.client.batch_execute("ROLLBACK")
153 {
154 log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
157 }
158 }
159}
160
161pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
173 let mut client = connect_client(url, tls).ok()?;
174 client
175 .query_one(
176 "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
177 &[],
178 )
179 .ok()
180 .and_then(|r| r.try_get::<_, i64>(0).ok())
181}
182
183pub(crate) fn sample_harm_counters(
196 url: &str,
197 tls: Option<&TlsConfig>,
198) -> Option<Vec<(String, i64)>> {
199 let mut client = connect_client(url, tls).ok()?;
200 let row = client
205 .query_one(
206 "SELECT blks_read::bigint, blks_hit::bigint, tup_returned::bigint, \
207 tup_fetched::bigint, temp_files::bigint, deadlocks::bigint \
208 FROM pg_stat_database WHERE datname = current_database()",
209 &[],
210 )
211 .ok()?;
212 let names = [
213 "pg_blks_read",
214 "pg_blks_hit",
215 "pg_tup_returned",
216 "pg_tup_fetched",
217 "pg_temp_files",
218 "pg_deadlocks",
219 ];
220 let mut out = Vec::with_capacity(names.len());
221 for (i, name) in names.iter().enumerate() {
222 if let Ok(v) = row.try_get::<_, i64>(i) {
223 out.push(((*name).to_string(), v));
224 }
225 }
226 Some(out)
227}
228
229fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
240 let raw: Option<String> = client
241 .query_one("SHOW work_mem", &[])
242 .ok()
243 .and_then(|r| r.try_get::<_, String>(0).ok());
244 raw.as_deref().and_then(parse_work_mem)
245}
246
247fn parse_work_mem(raw: &str) -> Option<i64> {
251 let s = raw.trim();
252 let mut split = 0;
254 for (i, ch) in s.char_indices() {
255 if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
256 split = i;
257 break;
258 }
259 split = i + ch.len_utf8();
260 }
261 if split == 0 {
262 return None;
263 }
264 let (num_str, unit) = s.split_at(split);
265 let num: f64 = num_str.parse().ok()?;
266 let unit = unit.trim().to_ascii_lowercase();
267 let multiplier: f64 = match unit.as_str() {
268 "" | "kb" => 1024.0,
271 "mb" => 1024.0 * 1024.0,
272 "gb" => 1024.0 * 1024.0 * 1024.0,
273 "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
274 _ => return None,
275 };
276 let bytes = (num * multiplier) as i64;
277 (bytes > 0).then_some(bytes)
278}
279
280fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
286 let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
287 client
288 .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
289 .ok()
290 .and_then(|r| r.try_get::<_, i64>(0).ok())
291}
292
293pub(crate) fn introspect_pg_table_for_chunking(
305 url: &str,
306 tls: Option<&TlsConfig>,
307 qualified_table: &str,
308) -> Result<crate::source::TableIntrospection> {
309 let (schema, table) = match qualified_table.split_once('.') {
310 Some((s, t)) => (s.to_string(), t.to_string()),
311 None => ("public".to_string(), qualified_table.to_string()),
312 };
313 let mut client = connect_client(url, tls)?;
314
315 let (row_estimate, rel_size_bytes) = match client.query_opt(
317 "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
318 FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
319 WHERE n.nspname = $1::text AND c.relname = $2::text",
320 &[&schema, &table],
321 )? {
322 Some(row) => {
323 let rt: i64 = row.try_get(0).unwrap_or(0);
324 let sz: i64 = row.try_get(1).unwrap_or(0);
325 (rt.max(0), sz.max(0))
326 }
327 None => (0, 0),
328 };
329 let avg_row_bytes = if row_estimate > 0 {
330 Some(rel_size_bytes / row_estimate)
331 } else {
332 None
333 };
334
335 let pk_rows = client.query(
337 "SELECT a.attname::text, t.typname::text \
338 FROM pg_index i \
339 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
340 JOIN pg_type t ON t.oid = a.atttypid \
341 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
342 AND i.indisprimary",
343 &[&schema, &table],
344 )?;
345 let single_int_pk = if pk_rows.len() == 1 {
346 let col: String = pk_rows[0].get(0);
347 let pg_type: String = pk_rows[0].get(1);
348 if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
352 Some(col)
353 } else {
354 log::debug!(
355 "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
356 );
357 None
358 }
359 } else {
360 None
361 };
362
363 let keyset_rows = client.query(
369 "SELECT a.attname::text, i.indisprimary \
370 FROM pg_index i \
371 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = i.indkey[0] \
372 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
373 AND i.indisunique AND i.indnkeyatts = 1 AND a.attnotnull",
374 &[&schema, &table],
375 )?;
376 let mut keyset_keys: Vec<String> = Vec::new();
377 for primary in [true, false] {
378 for row in &keyset_rows {
379 let col: String = row.get(0);
380 let is_primary: bool = row.get(1);
381 if is_primary == primary && !keyset_keys.contains(&col) {
382 keyset_keys.push(col);
383 }
384 }
385 }
386
387 Ok(crate::source::TableIntrospection {
388 single_int_pk,
389 keyset_keys,
390 row_estimate,
391 avg_row_bytes,
392 })
393}
394
395pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
405 crate::source::require_tls_or_loopback(url, tls)?;
407 match tls {
408 Some(cfg) if cfg.mode.is_enforced() => {
409 let connector = build_native_tls(cfg)?;
410 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
411 Ok(Client::connect(url, make_tls)?)
412 }
413 _ => Ok(Client::connect(url, NoTls)?),
414 }
415}
416
417fn pg_run_export(
426 client: &mut Client,
427 built_sql: &str,
428 tuning: &SourceTuning,
429 column_overrides: &ColumnOverrides,
430 sink: &mut dyn super::BatchSink,
431 numeric_hints: Option<&HashMap<String, (u8, i8)>>,
432) -> Result<(usize, bool)> {
433 let mut guard = PgTxnGuard::begin(client)?;
437 if tuning.statement_timeout_s > 0 {
438 guard.client_mut().batch_execute(&format!(
439 "SET LOCAL statement_timeout = '{}s'",
440 tuning.statement_timeout_s
441 ))?;
442 }
443 if tuning.lock_timeout_s > 0 {
444 guard.client_mut().batch_execute(&format!(
445 "SET LOCAL lock_timeout = '{}s'",
446 tuning.lock_timeout_s
447 ))?;
448 }
449 let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
456
457 guard
458 .client_mut()
459 .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
460
461 let configured_batch_size = tuning.batch_size;
466 let mut ctl = AdaptiveBatchController::new(tuning, configured_batch_size);
469 ctl.seed_pressure(if tuning.adaptive {
470 pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64)
471 } else {
472 None
473 });
474 let mut schema: Option<SchemaRef> = None;
475 let mut columns_cache: Option<Vec<(String, Type)>> = None;
476 let mut total_rows: usize = 0;
477 let mut cap_applied = false;
478 let max_value_bytes = tuning.max_value_bytes();
482
483 loop {
484 let requested = ctl.target();
485 let fetch_sql = format!("FETCH {} FROM _rivet", requested);
486 let rows = guard.client_mut().query(&fetch_sql, &[])?;
487 if rows.is_empty() {
488 break;
489 }
490
491 if schema.is_none() {
492 let stmt_cols: Vec<(String, Type)> = rows[0]
493 .columns()
494 .iter()
495 .map(|c| (c.name().to_string(), c.type_().clone()))
496 .collect();
497 let s = Arc::new(pg_columns_to_schema(
498 rows[0].columns(),
499 column_overrides,
500 numeric_hints,
501 )?);
502 sink.on_schema(s.clone())?;
503 if work_mem_bytes.is_none() {
506 let effective = tuning.effective_batch_size(Some(&s));
507 ctl.apply_memory_cap(effective.max(requested));
508 cap_applied = true;
509 }
510 schema = Some(s);
511 columns_cache = Some(stmt_cols);
512 }
513
514 let row_count = rows.len();
515 total_rows += row_count;
516
517 let s = schema.as_ref().expect("schema set on first iteration");
518 let cols = columns_cache
519 .as_ref()
520 .expect("columns set on first iteration");
521 let batch = rows_to_record_batch_typed(s, cols, &rows, max_value_bytes)?;
522 drop(rows);
523
524 if !cap_applied
529 && let Some(wm) = work_mem_bytes
530 && row_count > 0
531 {
532 let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
533 let arrow_per_row = (arrow_bytes / row_count).max(1);
534 let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
535 let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
536 let mut target = safe;
537 if let Some(mem_mb) = tuning.batch_size_memory_mb {
538 let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
539 target = target.min(arrow_target.max(100));
540 }
541 if let Some(new) = ctl.apply_memory_cap(target) {
542 log::info!(
543 "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N → {} (configured={})",
544 wm,
545 arrow_per_row,
546 pg_per_row,
547 new,
548 configured_batch_size,
549 );
550 }
551 cap_applied = true;
552 }
553
554 sink.on_batch(&batch)?;
555
556 if let Some((new, under_pressure)) =
557 ctl.after_batch(|| pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64))
558 {
559 log::info!(
560 "adaptive batch size → {} ({})",
561 new,
562 if under_pressure {
563 "pressure"
564 } else {
565 "recovery"
566 }
567 );
568 }
569
570 log::info!("fetched {} rows so far...", total_rows);
571
572 if row_count < requested {
573 break;
574 }
575 ctl.throttle(row_count);
576 }
577
578 guard.client_mut().batch_execute("CLOSE _rivet")?;
581 guard.commit()?;
582 Ok((total_rows, schema.is_some()))
583}
584
585impl super::Source for PostgresSource {
586 fn export(
587 &mut self,
588 request: &super::ExportRequest<'_>,
589 sink: &mut dyn super::BatchSink,
590 ) -> Result<()> {
591 let built = build_export_query(request, SourceType::Postgres);
592 debug_assert!(
593 built.cursor_param.is_none(),
594 "Postgres path inlines cursor values as E'…' literals — binding is unused"
595 );
596 log::debug!(
597 "executing query (connection={}): {}",
598 if self.transaction_pooler {
599 "transaction-pooler"
600 } else {
601 "direct"
602 },
603 built.sql
604 );
605
606 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
610 let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, hint_query);
611
612 let (total_rows, had_schema) = pg_run_export(
615 &mut self.client,
616 &built.sql,
617 request.tuning,
618 request.column_overrides,
619 sink,
620 numeric_hints.as_ref(),
621 )?;
622
623 if !had_schema {
624 sink.on_schema(Arc::new(Schema::empty()))?;
625 }
626
627 log::info!("total: {} rows", total_rows);
628 Ok(())
629 }
630
631 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
632 let rows = self.client.query(sql, &[])?;
633 if rows.is_empty() {
634 return Ok(None);
635 }
636 let row = &rows[0];
637 if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
638 return Ok(Some(v.to_string()));
639 }
640 if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
641 return Ok(Some(v.to_string()));
642 }
643 if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
644 return Ok(Some(v.to_string()));
645 }
646 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
648 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
649 }
650 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
651 return Ok(Some(v.format("%Y-%m-%d").to_string()));
652 }
653 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
654 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
655 }
656 if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
657 return Ok(Some(v));
658 }
659 Ok(None)
660 }
661
662 fn type_mappings(
663 &mut self,
664 query: &str,
665 column_overrides: &ColumnOverrides,
666 ) -> Result<Vec<TypeMapping>> {
667 let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
668 let stmt = self.client.prepare(&wrapped)?;
669 let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
670 let mappings = stmt
671 .columns()
672 .iter()
673 .map(|col| {
674 let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
675 let source = SourceColumn::simple(col.name(), col.type_().name(), true);
676 TypeMapping::from_source(&source, rivet)
677 })
678 .collect();
679 Ok(mappings)
680 }
681
682 fn sample_pressure(&mut self) -> Option<u64> {
686 pg_sample_checkpoints_req(&mut self.client).map(|v| v.max(0) as u64)
687 }
688}
689
690fn pg_numeric_catalog_hints_opt(
696 client: &mut Client,
697 query: &str,
698) -> Option<HashMap<String, (u8, i8)>> {
699 match pg_fetch_numeric_catalog_hints(client, query) {
700 Ok(m) => m,
701 Err(e) => {
702 log::warn!(
708 "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
709 );
710 None
711 }
712 }
713}
714
715fn pg_fetch_numeric_catalog_hints(
716 client: &mut Client,
717 query: &str,
718) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
719 let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
720 return Ok(None);
721 };
722 let locate_sql = "SELECT n.nspname::text, c.relname::text \
723 FROM pg_catalog.pg_class c \
724 JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
725 WHERE c.oid = ($1::text)::regclass";
726 let row_opt = match client.query_opt(locate_sql, &[®class_lit]) {
727 Ok(r) => r,
728 Err(e) => {
729 log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
730 return Ok(None);
731 }
732 };
733 let Some(row) = row_opt else {
734 return Ok(None);
735 };
736 let schema: String = row.get(0);
737 let table: String = row.get(1);
738 let rows = client.query(
739 "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
740 FROM information_schema.columns \
741 WHERE table_schema = $1 AND table_name = $2 \
742 ORDER BY ordinal_position",
743 &[&schema, &table],
744 )?;
745
746 let mut map = HashMap::new();
747 for row in rows {
748 let col: String = row.get(0);
749 let dt: String = row.get(1);
750 if !is_pg_numeric_information_type(&dt) {
751 continue;
752 }
753 let p: Option<i32> = row.get(2);
754 let s: Option<i32> = row.get(3);
755 if let (Some(p), Some(s)) = (p, s)
756 && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
757 {
758 map.insert(col, pair);
759 }
760 }
761
762 if map.is_empty() {
763 Ok(None)
764 } else {
765 log::debug!(
766 "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
767 map.len(),
768 );
769 Ok(Some(map))
770 }
771}
772
773fn is_pg_numeric_information_type(dt: &str) -> bool {
774 let d = dt.trim().to_ascii_lowercase();
775 matches!(d.as_str(), "numeric" | "decimal")
776 || d.starts_with("numeric(")
777 || d.starts_with("decimal(")
778}
779
780fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
782 if precision <= 0 || precision > 76 {
783 return None;
784 }
785 let precision_u = precision as u8;
786 if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
787 return None;
788 }
789 let scale_i = scale as i8;
790 if scale_i > precision as i8 {
791 return None;
792 }
793 Some((precision_u, scale_i))
794}
795
796#[cfg(test)]
797mod tests {
798 use super::catalog_numeric_to_decimal_params;
799
800 #[test]
803 fn catalog_decimal_bounds() {
804 assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
805 assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
806 assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
807 assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
808 }
809
810 #[test]
811 fn parse_work_mem_handles_pg_units() {
812 use super::parse_work_mem;
813 assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
816 assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
817 assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
818 assert_eq!(parse_work_mem(" 4MB "), Some(4 * 1024 * 1024));
819 assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
820 assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
821 assert_eq!(parse_work_mem(""), None);
822 assert_eq!(parse_work_mem("garbage"), None);
823 assert_eq!(parse_work_mem("4s"), None);
825 }
826}