1mod arrow_convert;
23mod from_parse;
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use postgres::types::Type;
30use postgres::{Client, NoTls};
31
32use crate::config::{SourceType, TlsConfig};
33use crate::error::Result;
34use crate::source::batch_controller::AdaptiveBatchController;
35use crate::source::query::build_export_query;
36use crate::source::tls::build_native_tls;
37use crate::tuning::SourceTuning;
38use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
39
40use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
41use from_parse::try_parse_pg_simple_from_regclass_literal;
42
43pub struct PostgresSource {
44 client: Client,
45 transaction_pooler: bool,
48}
49
50fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
58 let pid1: Option<i32> = client
59 .query_one("SELECT pg_backend_pid()", &[])
60 .ok()
61 .and_then(|r| r.try_get(0).ok());
62 let pid2: Option<i32> = client
63 .query_one("SELECT pg_backend_pid()", &[])
64 .ok()
65 .and_then(|r| r.try_get(0).ok());
66 matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
67}
68
69impl PostgresSource {
70 pub fn connect(url: &str) -> Result<Self> {
73 let mut client = Client::connect(url, NoTls)?;
74 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
75 if transaction_pooler {
76 log::warn!(
77 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
78 SET LOCAL tuning is transaction-scoped; \
79 LISTEN/NOTIFY and advisory locks are unavailable"
80 );
81 }
82 Ok(Self {
83 client,
84 transaction_pooler,
85 })
86 }
87
88 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
91 crate::source::require_tls_or_loopback(url, tls)?;
93 match tls {
94 Some(cfg) if cfg.mode.is_enforced() => {
95 let connector = build_native_tls(cfg)?;
96 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
97 let mut client = Client::connect(url, make_tls)?;
98 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
99 if transaction_pooler {
100 log::warn!(
101 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
102 SET LOCAL tuning is transaction-scoped; \
103 LISTEN/NOTIFY and advisory locks are unavailable"
104 );
105 }
106 Ok(Self {
107 client,
108 transaction_pooler,
109 })
110 }
111 _ => Self::connect(url),
112 }
113 }
114}
115
116struct PgTxnGuard<'a> {
124 client: &'a mut Client,
125 committed: bool,
126}
127
128impl<'a> PgTxnGuard<'a> {
129 fn begin(client: &'a mut Client) -> Result<Self> {
130 client.batch_execute("BEGIN")?;
131 Ok(Self {
132 client,
133 committed: false,
134 })
135 }
136
137 fn client_mut(&mut self) -> &mut Client {
138 self.client
139 }
140
141 fn commit(mut self) -> Result<()> {
142 self.client.batch_execute("COMMIT")?;
143 self.committed = true;
144 Ok(())
145 }
146}
147
148impl Drop for PgTxnGuard<'_> {
149 fn drop(&mut self) {
150 if !self.committed
151 && let Err(e) = self.client.batch_execute("ROLLBACK")
152 {
153 log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
156 }
157 }
158}
159
160pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
172 let mut client = connect_client(url, tls).ok()?;
173 client
174 .query_one(
175 "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
176 &[],
177 )
178 .ok()
179 .and_then(|r| r.try_get::<_, i64>(0).ok())
180}
181
182fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
193 let raw: Option<String> = client
194 .query_one("SHOW work_mem", &[])
195 .ok()
196 .and_then(|r| r.try_get::<_, String>(0).ok());
197 raw.as_deref().and_then(parse_work_mem)
198}
199
200fn parse_work_mem(raw: &str) -> Option<i64> {
204 let s = raw.trim();
205 let mut split = 0;
207 for (i, ch) in s.char_indices() {
208 if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
209 split = i;
210 break;
211 }
212 split = i + ch.len_utf8();
213 }
214 if split == 0 {
215 return None;
216 }
217 let (num_str, unit) = s.split_at(split);
218 let num: f64 = num_str.parse().ok()?;
219 let unit = unit.trim().to_ascii_lowercase();
220 let multiplier: f64 = match unit.as_str() {
221 "" | "kb" => 1024.0,
224 "mb" => 1024.0 * 1024.0,
225 "gb" => 1024.0 * 1024.0 * 1024.0,
226 "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
227 _ => return None,
228 };
229 let bytes = (num * multiplier) as i64;
230 (bytes > 0).then_some(bytes)
231}
232
233fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
239 let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
240 client
241 .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
242 .ok()
243 .and_then(|r| r.try_get::<_, i64>(0).ok())
244}
245
246pub(crate) fn introspect_pg_table_for_chunking(
258 url: &str,
259 tls: Option<&TlsConfig>,
260 qualified_table: &str,
261) -> Result<crate::source::TableIntrospection> {
262 let (schema, table) = match qualified_table.split_once('.') {
263 Some((s, t)) => (s.to_string(), t.to_string()),
264 None => ("public".to_string(), qualified_table.to_string()),
265 };
266 let mut client = connect_client(url, tls)?;
267
268 let (row_estimate, rel_size_bytes) = match client.query_opt(
270 "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
271 FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
272 WHERE n.nspname = $1::text AND c.relname = $2::text",
273 &[&schema, &table],
274 )? {
275 Some(row) => {
276 let rt: i64 = row.try_get(0).unwrap_or(0);
277 let sz: i64 = row.try_get(1).unwrap_or(0);
278 (rt.max(0), sz.max(0))
279 }
280 None => (0, 0),
281 };
282 let avg_row_bytes = if row_estimate > 0 {
283 Some(rel_size_bytes / row_estimate)
284 } else {
285 None
286 };
287
288 let pk_rows = client.query(
290 "SELECT a.attname::text, t.typname::text \
291 FROM pg_index i \
292 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
293 JOIN pg_type t ON t.oid = a.atttypid \
294 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
295 AND i.indisprimary",
296 &[&schema, &table],
297 )?;
298 let single_int_pk = if pk_rows.len() == 1 {
299 let col: String = pk_rows[0].get(0);
300 let pg_type: String = pk_rows[0].get(1);
301 if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
305 Some(col)
306 } else {
307 log::debug!(
308 "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
309 );
310 None
311 }
312 } else {
313 None
314 };
315
316 let keyset_rows = client.query(
322 "SELECT a.attname::text, i.indisprimary \
323 FROM pg_index i \
324 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = i.indkey[0] \
325 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
326 AND i.indisunique AND i.indnkeyatts = 1 AND a.attnotnull",
327 &[&schema, &table],
328 )?;
329 let mut keyset_keys: Vec<String> = Vec::new();
330 for primary in [true, false] {
331 for row in &keyset_rows {
332 let col: String = row.get(0);
333 let is_primary: bool = row.get(1);
334 if is_primary == primary && !keyset_keys.contains(&col) {
335 keyset_keys.push(col);
336 }
337 }
338 }
339
340 Ok(crate::source::TableIntrospection {
341 single_int_pk,
342 keyset_keys,
343 row_estimate,
344 avg_row_bytes,
345 })
346}
347
348pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
358 crate::source::require_tls_or_loopback(url, tls)?;
360 match tls {
361 Some(cfg) if cfg.mode.is_enforced() => {
362 let connector = build_native_tls(cfg)?;
363 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
364 Ok(Client::connect(url, make_tls)?)
365 }
366 _ => Ok(Client::connect(url, NoTls)?),
367 }
368}
369
370fn pg_run_export(
379 client: &mut Client,
380 built_sql: &str,
381 tuning: &SourceTuning,
382 column_overrides: &ColumnOverrides,
383 sink: &mut dyn super::BatchSink,
384 numeric_hints: Option<&HashMap<String, (u8, i8)>>,
385) -> Result<(usize, bool)> {
386 let mut guard = PgTxnGuard::begin(client)?;
390 if tuning.statement_timeout_s > 0 {
391 guard.client_mut().batch_execute(&format!(
392 "SET LOCAL statement_timeout = '{}s'",
393 tuning.statement_timeout_s
394 ))?;
395 }
396 if tuning.lock_timeout_s > 0 {
397 guard.client_mut().batch_execute(&format!(
398 "SET LOCAL lock_timeout = '{}s'",
399 tuning.lock_timeout_s
400 ))?;
401 }
402 let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
409
410 guard
411 .client_mut()
412 .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
413
414 let configured_batch_size = tuning.batch_size;
419 let mut ctl = AdaptiveBatchController::new(tuning, configured_batch_size);
422 ctl.seed_pressure(if tuning.adaptive {
423 pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64)
424 } else {
425 None
426 });
427 let mut schema: Option<SchemaRef> = None;
428 let mut columns_cache: Option<Vec<(String, Type)>> = None;
429 let mut total_rows: usize = 0;
430 let mut cap_applied = false;
431 let max_value_bytes = tuning.max_value_bytes();
435
436 loop {
437 let requested = ctl.target();
438 let fetch_sql = format!("FETCH {} FROM _rivet", requested);
439 let rows = guard.client_mut().query(&fetch_sql, &[])?;
440 if rows.is_empty() {
441 break;
442 }
443
444 if schema.is_none() {
445 let stmt_cols: Vec<(String, Type)> = rows[0]
446 .columns()
447 .iter()
448 .map(|c| (c.name().to_string(), c.type_().clone()))
449 .collect();
450 let s = Arc::new(pg_columns_to_schema(
451 rows[0].columns(),
452 column_overrides,
453 numeric_hints,
454 )?);
455 sink.on_schema(s.clone())?;
456 if work_mem_bytes.is_none() {
459 let effective = tuning.effective_batch_size(Some(&s));
460 ctl.apply_memory_cap(effective.max(requested));
461 cap_applied = true;
462 }
463 schema = Some(s);
464 columns_cache = Some(stmt_cols);
465 }
466
467 let row_count = rows.len();
468 total_rows += row_count;
469
470 let s = schema.as_ref().expect("schema set on first iteration");
471 let cols = columns_cache
472 .as_ref()
473 .expect("columns set on first iteration");
474 let batch = rows_to_record_batch_typed(s, cols, &rows, max_value_bytes)?;
475 drop(rows);
476
477 if !cap_applied
482 && let Some(wm) = work_mem_bytes
483 && row_count > 0
484 {
485 let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
486 let arrow_per_row = (arrow_bytes / row_count).max(1);
487 let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
488 let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
489 let mut target = safe;
490 if let Some(mem_mb) = tuning.batch_size_memory_mb {
491 let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
492 target = target.min(arrow_target.max(100));
493 }
494 if let Some(new) = ctl.apply_memory_cap(target) {
495 log::info!(
496 "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N → {} (configured={})",
497 wm,
498 arrow_per_row,
499 pg_per_row,
500 new,
501 configured_batch_size,
502 );
503 }
504 cap_applied = true;
505 }
506
507 sink.on_batch(&batch)?;
508
509 if let Some((new, under_pressure)) =
510 ctl.after_batch(|| pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64))
511 {
512 log::info!(
513 "adaptive batch size → {} ({})",
514 new,
515 if under_pressure {
516 "pressure"
517 } else {
518 "recovery"
519 }
520 );
521 }
522
523 log::info!("fetched {} rows so far...", total_rows);
524
525 if row_count < requested {
526 break;
527 }
528 ctl.throttle();
529 }
530
531 guard.client_mut().batch_execute("CLOSE _rivet")?;
534 guard.commit()?;
535 Ok((total_rows, schema.is_some()))
536}
537
538impl super::Source for PostgresSource {
539 fn export(
540 &mut self,
541 request: &super::ExportRequest<'_>,
542 sink: &mut dyn super::BatchSink,
543 ) -> Result<()> {
544 let built = build_export_query(request, SourceType::Postgres);
545 debug_assert!(
546 built.cursor_param.is_none(),
547 "Postgres path inlines cursor values as E'…' literals — binding is unused"
548 );
549 log::debug!(
550 "executing query (connection={}): {}",
551 if self.transaction_pooler {
552 "transaction-pooler"
553 } else {
554 "direct"
555 },
556 built.sql
557 );
558
559 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
563 let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, hint_query);
564
565 let (total_rows, had_schema) = pg_run_export(
568 &mut self.client,
569 &built.sql,
570 request.tuning,
571 request.column_overrides,
572 sink,
573 numeric_hints.as_ref(),
574 )?;
575
576 if !had_schema {
577 sink.on_schema(Arc::new(Schema::empty()))?;
578 }
579
580 log::info!("total: {} rows", total_rows);
581 Ok(())
582 }
583
584 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
585 let rows = self.client.query(sql, &[])?;
586 if rows.is_empty() {
587 return Ok(None);
588 }
589 let row = &rows[0];
590 if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
591 return Ok(Some(v.to_string()));
592 }
593 if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
594 return Ok(Some(v.to_string()));
595 }
596 if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
597 return Ok(Some(v.to_string()));
598 }
599 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
601 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
602 }
603 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
604 return Ok(Some(v.format("%Y-%m-%d").to_string()));
605 }
606 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
607 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
608 }
609 if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
610 return Ok(Some(v));
611 }
612 Ok(None)
613 }
614
615 fn type_mappings(
616 &mut self,
617 query: &str,
618 column_overrides: &ColumnOverrides,
619 ) -> Result<Vec<TypeMapping>> {
620 let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
621 let stmt = self.client.prepare(&wrapped)?;
622 let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
623 let mappings = stmt
624 .columns()
625 .iter()
626 .map(|col| {
627 let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
628 let source = SourceColumn::simple(col.name(), col.type_().name(), true);
629 TypeMapping::from_source(&source, rivet)
630 })
631 .collect();
632 Ok(mappings)
633 }
634
635 fn sample_pressure(&mut self) -> Option<u64> {
639 pg_sample_checkpoints_req(&mut self.client).map(|v| v.max(0) as u64)
640 }
641}
642
643fn pg_numeric_catalog_hints_opt(
649 client: &mut Client,
650 query: &str,
651) -> Option<HashMap<String, (u8, i8)>> {
652 match pg_fetch_numeric_catalog_hints(client, query) {
653 Ok(m) => m,
654 Err(e) => {
655 log::warn!(
661 "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
662 );
663 None
664 }
665 }
666}
667
668fn pg_fetch_numeric_catalog_hints(
669 client: &mut Client,
670 query: &str,
671) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
672 let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
673 return Ok(None);
674 };
675 let locate_sql = "SELECT n.nspname::text, c.relname::text \
676 FROM pg_catalog.pg_class c \
677 JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
678 WHERE c.oid = ($1::text)::regclass";
679 let row_opt = match client.query_opt(locate_sql, &[®class_lit]) {
680 Ok(r) => r,
681 Err(e) => {
682 log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
683 return Ok(None);
684 }
685 };
686 let Some(row) = row_opt else {
687 return Ok(None);
688 };
689 let schema: String = row.get(0);
690 let table: String = row.get(1);
691 let rows = client.query(
692 "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
693 FROM information_schema.columns \
694 WHERE table_schema = $1 AND table_name = $2 \
695 ORDER BY ordinal_position",
696 &[&schema, &table],
697 )?;
698
699 let mut map = HashMap::new();
700 for row in rows {
701 let col: String = row.get(0);
702 let dt: String = row.get(1);
703 if !is_pg_numeric_information_type(&dt) {
704 continue;
705 }
706 let p: Option<i32> = row.get(2);
707 let s: Option<i32> = row.get(3);
708 if let (Some(p), Some(s)) = (p, s)
709 && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
710 {
711 map.insert(col, pair);
712 }
713 }
714
715 if map.is_empty() {
716 Ok(None)
717 } else {
718 log::debug!(
719 "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
720 map.len(),
721 );
722 Ok(Some(map))
723 }
724}
725
726fn is_pg_numeric_information_type(dt: &str) -> bool {
727 let d = dt.trim().to_ascii_lowercase();
728 matches!(d.as_str(), "numeric" | "decimal")
729 || d.starts_with("numeric(")
730 || d.starts_with("decimal(")
731}
732
733fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
735 if precision <= 0 || precision > 76 {
736 return None;
737 }
738 let precision_u = precision as u8;
739 if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
740 return None;
741 }
742 let scale_i = scale as i8;
743 if scale_i > precision as i8 {
744 return None;
745 }
746 Some((precision_u, scale_i))
747}
748
749#[cfg(test)]
750mod tests {
751 use super::catalog_numeric_to_decimal_params;
752
753 #[test]
756 fn catalog_decimal_bounds() {
757 assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
758 assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
759 assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
760 assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
761 }
762
763 #[test]
764 fn parse_work_mem_handles_pg_units() {
765 use super::parse_work_mem;
766 assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
769 assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
770 assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
771 assert_eq!(parse_work_mem(" 4MB "), Some(4 * 1024 * 1024));
772 assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
773 assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
774 assert_eq!(parse_work_mem(""), None);
775 assert_eq!(parse_work_mem("garbage"), None);
776 assert_eq!(parse_work_mem("4s"), None);
778 }
779}