1mod arrow_convert;
23mod from_parse;
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use postgres::types::Type;
30use postgres::{Client, NoTls};
31
32use crate::config::{SourceType, TlsConfig};
33use crate::error::Result;
34use crate::source::batch_controller::AdaptiveBatchController;
35use crate::source::query::build_export_query;
36use crate::source::tls::build_native_tls;
37use crate::tuning::SourceTuning;
38use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
39
40use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
41use from_parse::try_parse_pg_simple_from_regclass_literal;
42
43pub struct PostgresSource {
44 client: Client,
45 transaction_pooler: bool,
48}
49
50fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
58 let pid1: Option<i32> = client
59 .query_one("SELECT pg_backend_pid()", &[])
60 .ok()
61 .and_then(|r| r.try_get(0).ok());
62 let pid2: Option<i32> = client
63 .query_one("SELECT pg_backend_pid()", &[])
64 .ok()
65 .and_then(|r| r.try_get(0).ok());
66 matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
67}
68
69impl PostgresSource {
70 pub fn connect(url: &str) -> Result<Self> {
73 let mut client = Client::connect(url, NoTls)?;
74 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
75 if transaction_pooler {
76 log::warn!(
77 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
78 SET LOCAL tuning is transaction-scoped; \
79 LISTEN/NOTIFY and advisory locks are unavailable"
80 );
81 }
82 Ok(Self {
83 client,
84 transaction_pooler,
85 })
86 }
87
88 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
91 crate::source::require_tls_or_loopback(url, tls)?;
93 match tls {
94 Some(cfg) if cfg.mode.is_enforced() => {
95 let connector = build_native_tls(cfg)?;
96 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
97 let mut client = Client::connect(url, make_tls)?;
98 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
99 if transaction_pooler {
100 log::warn!(
101 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
102 SET LOCAL tuning is transaction-scoped; \
103 LISTEN/NOTIFY and advisory locks are unavailable"
104 );
105 }
106 Ok(Self {
107 client,
108 transaction_pooler,
109 })
110 }
111 _ => Self::connect(url),
112 }
113 }
114}
115
116struct PgTxnGuard<'a> {
124 client: &'a mut Client,
125 committed: bool,
126}
127
128impl<'a> PgTxnGuard<'a> {
129 fn begin(client: &'a mut Client) -> Result<Self> {
130 client.batch_execute("BEGIN")?;
131 Ok(Self {
132 client,
133 committed: false,
134 })
135 }
136
137 fn client_mut(&mut self) -> &mut Client {
138 self.client
139 }
140
141 fn commit(mut self) -> Result<()> {
142 self.client.batch_execute("COMMIT")?;
143 self.committed = true;
144 Ok(())
145 }
146}
147
148impl Drop for PgTxnGuard<'_> {
149 fn drop(&mut self) {
150 if !self.committed
151 && let Err(e) = self.client.batch_execute("ROLLBACK")
152 {
153 log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
156 }
157 }
158}
159
160pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
172 let mut client = connect_client(url, tls).ok()?;
173 client
174 .query_one(
175 "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
176 &[],
177 )
178 .ok()
179 .and_then(|r| r.try_get::<_, i64>(0).ok())
180}
181
182pub(crate) fn sample_harm_counters(
195 url: &str,
196 tls: Option<&TlsConfig>,
197) -> Option<Vec<(String, i64)>> {
198 let mut client = connect_client(url, tls).ok()?;
199 let row = client
204 .query_one(
205 "SELECT blks_read::bigint, blks_hit::bigint, tup_returned::bigint, \
206 tup_fetched::bigint, temp_files::bigint, deadlocks::bigint \
207 FROM pg_stat_database WHERE datname = current_database()",
208 &[],
209 )
210 .ok()?;
211 let names = [
212 "pg_blks_read",
213 "pg_blks_hit",
214 "pg_tup_returned",
215 "pg_tup_fetched",
216 "pg_temp_files",
217 "pg_deadlocks",
218 ];
219 let mut out = Vec::with_capacity(names.len());
220 for (i, name) in names.iter().enumerate() {
221 if let Ok(v) = row.try_get::<_, i64>(i) {
222 out.push(((*name).to_string(), v));
223 }
224 }
225 Some(out)
226}
227
228fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
239 let raw: Option<String> = client
240 .query_one("SHOW work_mem", &[])
241 .ok()
242 .and_then(|r| r.try_get::<_, String>(0).ok());
243 raw.as_deref().and_then(parse_work_mem)
244}
245
246fn parse_work_mem(raw: &str) -> Option<i64> {
250 let s = raw.trim();
251 let mut split = 0;
253 for (i, ch) in s.char_indices() {
254 if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
255 split = i;
256 break;
257 }
258 split = i + ch.len_utf8();
259 }
260 if split == 0 {
261 return None;
262 }
263 let (num_str, unit) = s.split_at(split);
264 let num: f64 = num_str.parse().ok()?;
265 let unit = unit.trim().to_ascii_lowercase();
266 let multiplier: f64 = match unit.as_str() {
267 "" | "kb" => 1024.0,
270 "mb" => 1024.0 * 1024.0,
271 "gb" => 1024.0 * 1024.0 * 1024.0,
272 "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
273 _ => return None,
274 };
275 let bytes = (num * multiplier) as i64;
276 (bytes > 0).then_some(bytes)
277}
278
279fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
285 let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
286 client
287 .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
288 .ok()
289 .and_then(|r| r.try_get::<_, i64>(0).ok())
290}
291
292pub(crate) fn introspect_pg_table_for_chunking(
304 url: &str,
305 tls: Option<&TlsConfig>,
306 qualified_table: &str,
307) -> Result<crate::source::TableIntrospection> {
308 let (schema, table) = match qualified_table.split_once('.') {
309 Some((s, t)) => (s.to_string(), t.to_string()),
310 None => ("public".to_string(), qualified_table.to_string()),
311 };
312 let mut client = connect_client(url, tls)?;
313
314 let (row_estimate, rel_size_bytes) = match client.query_opt(
316 "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
317 FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
318 WHERE n.nspname = $1::text AND c.relname = $2::text",
319 &[&schema, &table],
320 )? {
321 Some(row) => {
322 let rt: i64 = row.try_get(0).unwrap_or(0);
323 let sz: i64 = row.try_get(1).unwrap_or(0);
324 (rt.max(0), sz.max(0))
325 }
326 None => (0, 0),
327 };
328 let avg_row_bytes = if row_estimate > 0 {
329 Some(rel_size_bytes / row_estimate)
330 } else {
331 None
332 };
333
334 let pk_rows = client.query(
336 "SELECT a.attname::text, t.typname::text \
337 FROM pg_index i \
338 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
339 JOIN pg_type t ON t.oid = a.atttypid \
340 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
341 AND i.indisprimary",
342 &[&schema, &table],
343 )?;
344 let single_int_pk = if pk_rows.len() == 1 {
345 let col: String = pk_rows[0].get(0);
346 let pg_type: String = pk_rows[0].get(1);
347 if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
351 Some(col)
352 } else {
353 log::debug!(
354 "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
355 );
356 None
357 }
358 } else {
359 None
360 };
361
362 let keyset_rows = client.query(
368 "SELECT a.attname::text, i.indisprimary \
369 FROM pg_index i \
370 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = i.indkey[0] \
371 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
372 AND i.indisunique AND i.indnkeyatts = 1 AND a.attnotnull",
373 &[&schema, &table],
374 )?;
375 let mut keyset_keys: Vec<String> = Vec::new();
376 for primary in [true, false] {
377 for row in &keyset_rows {
378 let col: String = row.get(0);
379 let is_primary: bool = row.get(1);
380 if is_primary == primary && !keyset_keys.contains(&col) {
381 keyset_keys.push(col);
382 }
383 }
384 }
385
386 Ok(crate::source::TableIntrospection {
387 single_int_pk,
388 keyset_keys,
389 row_estimate,
390 avg_row_bytes,
391 })
392}
393
394pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
404 crate::source::require_tls_or_loopback(url, tls)?;
406 match tls {
407 Some(cfg) if cfg.mode.is_enforced() => {
408 let connector = build_native_tls(cfg)?;
409 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
410 Ok(Client::connect(url, make_tls)?)
411 }
412 _ => Ok(Client::connect(url, NoTls)?),
413 }
414}
415
416fn pg_run_export(
425 client: &mut Client,
426 built_sql: &str,
427 tuning: &SourceTuning,
428 column_overrides: &ColumnOverrides,
429 sink: &mut dyn super::BatchSink,
430 numeric_hints: Option<&HashMap<String, (u8, i8)>>,
431) -> Result<(usize, bool)> {
432 let mut guard = PgTxnGuard::begin(client)?;
436 if tuning.statement_timeout_s > 0 {
437 guard.client_mut().batch_execute(&format!(
438 "SET LOCAL statement_timeout = '{}s'",
439 tuning.statement_timeout_s
440 ))?;
441 }
442 if tuning.lock_timeout_s > 0 {
443 guard.client_mut().batch_execute(&format!(
444 "SET LOCAL lock_timeout = '{}s'",
445 tuning.lock_timeout_s
446 ))?;
447 }
448 let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
455
456 guard
457 .client_mut()
458 .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
459
460 let configured_batch_size = tuning.batch_size;
465 let mut ctl = AdaptiveBatchController::new(tuning, configured_batch_size);
468 ctl.seed_pressure(if tuning.adaptive {
469 pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64)
470 } else {
471 None
472 });
473 let mut schema: Option<SchemaRef> = None;
474 let mut columns_cache: Option<Vec<(String, Type)>> = None;
475 let mut total_rows: usize = 0;
476 let mut cap_applied = false;
477 let max_value_bytes = tuning.max_value_bytes();
481
482 loop {
483 let requested = ctl.target();
484 let fetch_sql = format!("FETCH {} FROM _rivet", requested);
485 let rows = guard.client_mut().query(&fetch_sql, &[])?;
486 if rows.is_empty() {
487 break;
488 }
489
490 if schema.is_none() {
491 let stmt_cols: Vec<(String, Type)> = rows[0]
492 .columns()
493 .iter()
494 .map(|c| (c.name().to_string(), c.type_().clone()))
495 .collect();
496 let s = Arc::new(pg_columns_to_schema(
497 rows[0].columns(),
498 column_overrides,
499 numeric_hints,
500 )?);
501 sink.on_schema(s.clone())?;
502 if work_mem_bytes.is_none() {
505 let effective = tuning.effective_batch_size(Some(&s));
506 ctl.apply_memory_cap(effective.max(requested));
507 cap_applied = true;
508 }
509 schema = Some(s);
510 columns_cache = Some(stmt_cols);
511 }
512
513 let row_count = rows.len();
514 total_rows += row_count;
515
516 let s = schema.as_ref().expect("schema set on first iteration");
517 let cols = columns_cache
518 .as_ref()
519 .expect("columns set on first iteration");
520 let batch = rows_to_record_batch_typed(s, cols, &rows, max_value_bytes)?;
521 drop(rows);
522
523 if !cap_applied
528 && let Some(wm) = work_mem_bytes
529 && row_count > 0
530 {
531 let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
532 let arrow_per_row = (arrow_bytes / row_count).max(1);
533 let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
534 let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
535 let mut target = safe;
536 if let Some(mem_mb) = tuning.batch_size_memory_mb {
537 let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
538 target = target.min(arrow_target.max(100));
539 }
540 if let Some(new) = ctl.apply_memory_cap(target) {
541 log::info!(
542 "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N → {} (configured={})",
543 wm,
544 arrow_per_row,
545 pg_per_row,
546 new,
547 configured_batch_size,
548 );
549 }
550 cap_applied = true;
551 }
552
553 sink.on_batch(&batch)?;
554
555 if let Some((new, under_pressure)) =
556 ctl.after_batch(|| pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64))
557 {
558 log::info!(
559 "adaptive batch size → {} ({})",
560 new,
561 if under_pressure {
562 "pressure"
563 } else {
564 "recovery"
565 }
566 );
567 }
568
569 log::info!("fetched {} rows so far...", total_rows);
570
571 if row_count < requested {
572 break;
573 }
574 ctl.throttle();
575 }
576
577 guard.client_mut().batch_execute("CLOSE _rivet")?;
580 guard.commit()?;
581 Ok((total_rows, schema.is_some()))
582}
583
584impl super::Source for PostgresSource {
585 fn export(
586 &mut self,
587 request: &super::ExportRequest<'_>,
588 sink: &mut dyn super::BatchSink,
589 ) -> Result<()> {
590 let built = build_export_query(request, SourceType::Postgres);
591 debug_assert!(
592 built.cursor_param.is_none(),
593 "Postgres path inlines cursor values as E'…' literals — binding is unused"
594 );
595 log::debug!(
596 "executing query (connection={}): {}",
597 if self.transaction_pooler {
598 "transaction-pooler"
599 } else {
600 "direct"
601 },
602 built.sql
603 );
604
605 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
609 let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, hint_query);
610
611 let (total_rows, had_schema) = pg_run_export(
614 &mut self.client,
615 &built.sql,
616 request.tuning,
617 request.column_overrides,
618 sink,
619 numeric_hints.as_ref(),
620 )?;
621
622 if !had_schema {
623 sink.on_schema(Arc::new(Schema::empty()))?;
624 }
625
626 log::info!("total: {} rows", total_rows);
627 Ok(())
628 }
629
630 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
631 let rows = self.client.query(sql, &[])?;
632 if rows.is_empty() {
633 return Ok(None);
634 }
635 let row = &rows[0];
636 if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
637 return Ok(Some(v.to_string()));
638 }
639 if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
640 return Ok(Some(v.to_string()));
641 }
642 if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
643 return Ok(Some(v.to_string()));
644 }
645 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
647 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
648 }
649 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
650 return Ok(Some(v.format("%Y-%m-%d").to_string()));
651 }
652 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
653 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
654 }
655 if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
656 return Ok(Some(v));
657 }
658 Ok(None)
659 }
660
661 fn type_mappings(
662 &mut self,
663 query: &str,
664 column_overrides: &ColumnOverrides,
665 ) -> Result<Vec<TypeMapping>> {
666 let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
667 let stmt = self.client.prepare(&wrapped)?;
668 let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
669 let mappings = stmt
670 .columns()
671 .iter()
672 .map(|col| {
673 let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
674 let source = SourceColumn::simple(col.name(), col.type_().name(), true);
675 TypeMapping::from_source(&source, rivet)
676 })
677 .collect();
678 Ok(mappings)
679 }
680
681 fn sample_pressure(&mut self) -> Option<u64> {
685 pg_sample_checkpoints_req(&mut self.client).map(|v| v.max(0) as u64)
686 }
687}
688
689fn pg_numeric_catalog_hints_opt(
695 client: &mut Client,
696 query: &str,
697) -> Option<HashMap<String, (u8, i8)>> {
698 match pg_fetch_numeric_catalog_hints(client, query) {
699 Ok(m) => m,
700 Err(e) => {
701 log::warn!(
707 "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
708 );
709 None
710 }
711 }
712}
713
714fn pg_fetch_numeric_catalog_hints(
715 client: &mut Client,
716 query: &str,
717) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
718 let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
719 return Ok(None);
720 };
721 let locate_sql = "SELECT n.nspname::text, c.relname::text \
722 FROM pg_catalog.pg_class c \
723 JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
724 WHERE c.oid = ($1::text)::regclass";
725 let row_opt = match client.query_opt(locate_sql, &[®class_lit]) {
726 Ok(r) => r,
727 Err(e) => {
728 log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
729 return Ok(None);
730 }
731 };
732 let Some(row) = row_opt else {
733 return Ok(None);
734 };
735 let schema: String = row.get(0);
736 let table: String = row.get(1);
737 let rows = client.query(
738 "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
739 FROM information_schema.columns \
740 WHERE table_schema = $1 AND table_name = $2 \
741 ORDER BY ordinal_position",
742 &[&schema, &table],
743 )?;
744
745 let mut map = HashMap::new();
746 for row in rows {
747 let col: String = row.get(0);
748 let dt: String = row.get(1);
749 if !is_pg_numeric_information_type(&dt) {
750 continue;
751 }
752 let p: Option<i32> = row.get(2);
753 let s: Option<i32> = row.get(3);
754 if let (Some(p), Some(s)) = (p, s)
755 && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
756 {
757 map.insert(col, pair);
758 }
759 }
760
761 if map.is_empty() {
762 Ok(None)
763 } else {
764 log::debug!(
765 "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
766 map.len(),
767 );
768 Ok(Some(map))
769 }
770}
771
772fn is_pg_numeric_information_type(dt: &str) -> bool {
773 let d = dt.trim().to_ascii_lowercase();
774 matches!(d.as_str(), "numeric" | "decimal")
775 || d.starts_with("numeric(")
776 || d.starts_with("decimal(")
777}
778
779fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
781 if precision <= 0 || precision > 76 {
782 return None;
783 }
784 let precision_u = precision as u8;
785 if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
786 return None;
787 }
788 let scale_i = scale as i8;
789 if scale_i > precision as i8 {
790 return None;
791 }
792 Some((precision_u, scale_i))
793}
794
795#[cfg(test)]
796mod tests {
797 use super::catalog_numeric_to_decimal_params;
798
799 #[test]
802 fn catalog_decimal_bounds() {
803 assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
804 assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
805 assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
806 assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
807 }
808
809 #[test]
810 fn parse_work_mem_handles_pg_units() {
811 use super::parse_work_mem;
812 assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
815 assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
816 assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
817 assert_eq!(parse_work_mem(" 4MB "), Some(4 * 1024 * 1024));
818 assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
819 assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
820 assert_eq!(parse_work_mem(""), None);
821 assert_eq!(parse_work_mem("garbage"), None);
822 assert_eq!(parse_work_mem("4s"), None);
824 }
825}