1mod arrow_convert;
23mod from_parse;
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use postgres::types::Type;
30use postgres::{Client, NoTls};
31
32use crate::config::{SourceType, TlsConfig};
33use crate::error::Result;
34use crate::source::query::build_export_query;
35use crate::source::tls::build_native_tls;
36use crate::tuning::{ADAPTIVE_SAMPLE_INTERVAL, SourceTuning, next_adaptive_batch_size};
37use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
38
39use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
40use from_parse::try_parse_pg_simple_from_regclass_literal;
41
42pub struct PostgresSource {
43 client: Client,
44 transaction_pooler: bool,
47}
48
49fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
57 let pid1: Option<i32> = client
58 .query_one("SELECT pg_backend_pid()", &[])
59 .ok()
60 .and_then(|r| r.try_get(0).ok());
61 let pid2: Option<i32> = client
62 .query_one("SELECT pg_backend_pid()", &[])
63 .ok()
64 .and_then(|r| r.try_get(0).ok());
65 matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
66}
67
68impl PostgresSource {
69 pub fn connect(url: &str) -> Result<Self> {
72 let mut client = Client::connect(url, NoTls)?;
73 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
74 if transaction_pooler {
75 log::warn!(
76 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
77 SET LOCAL tuning is transaction-scoped; \
78 LISTEN/NOTIFY and advisory locks are unavailable"
79 );
80 }
81 Ok(Self {
82 client,
83 transaction_pooler,
84 })
85 }
86
87 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
90 match tls {
91 Some(cfg) if cfg.mode.is_enforced() => {
92 let connector = build_native_tls(cfg)?;
93 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
94 let mut client = Client::connect(url, make_tls)?;
95 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
96 if transaction_pooler {
97 log::warn!(
98 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
99 SET LOCAL tuning is transaction-scoped; \
100 LISTEN/NOTIFY and advisory locks are unavailable"
101 );
102 }
103 Ok(Self {
104 client,
105 transaction_pooler,
106 })
107 }
108 _ => Self::connect(url),
109 }
110 }
111}
112
113struct PgTxnGuard<'a> {
121 client: &'a mut Client,
122 committed: bool,
123}
124
125impl<'a> PgTxnGuard<'a> {
126 fn begin(client: &'a mut Client) -> Result<Self> {
127 client.batch_execute("BEGIN")?;
128 Ok(Self {
129 client,
130 committed: false,
131 })
132 }
133
134 fn client_mut(&mut self) -> &mut Client {
135 self.client
136 }
137
138 fn commit(mut self) -> Result<()> {
139 self.client.batch_execute("COMMIT")?;
140 self.committed = true;
141 Ok(())
142 }
143}
144
145impl Drop for PgTxnGuard<'_> {
146 fn drop(&mut self) {
147 if !self.committed
148 && let Err(e) = self.client.batch_execute("ROLLBACK")
149 {
150 log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
153 }
154 }
155}
156
157pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
169 let mut client = connect_client(url, tls).ok()?;
170 client
171 .query_one(
172 "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
173 &[],
174 )
175 .ok()
176 .and_then(|r| r.try_get::<_, i64>(0).ok())
177}
178
179fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
190 let raw: Option<String> = client
191 .query_one("SHOW work_mem", &[])
192 .ok()
193 .and_then(|r| r.try_get::<_, String>(0).ok());
194 raw.as_deref().and_then(parse_work_mem)
195}
196
197fn parse_work_mem(raw: &str) -> Option<i64> {
201 let s = raw.trim();
202 let mut split = 0;
204 for (i, ch) in s.char_indices() {
205 if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
206 split = i;
207 break;
208 }
209 split = i + ch.len_utf8();
210 }
211 if split == 0 {
212 return None;
213 }
214 let (num_str, unit) = s.split_at(split);
215 let num: f64 = num_str.parse().ok()?;
216 let unit = unit.trim().to_ascii_lowercase();
217 let multiplier: f64 = match unit.as_str() {
218 "" | "kb" => 1024.0,
221 "mb" => 1024.0 * 1024.0,
222 "gb" => 1024.0 * 1024.0 * 1024.0,
223 "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
224 _ => return None,
225 };
226 let bytes = (num * multiplier) as i64;
227 (bytes > 0).then_some(bytes)
228}
229
230fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
236 let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
237 client
238 .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
239 .ok()
240 .and_then(|r| r.try_get::<_, i64>(0).ok())
241}
242
243pub(crate) fn introspect_pg_table_for_chunking(
255 url: &str,
256 tls: Option<&TlsConfig>,
257 qualified_table: &str,
258) -> Result<crate::source::TableIntrospection> {
259 let (schema, table) = match qualified_table.split_once('.') {
260 Some((s, t)) => (s.to_string(), t.to_string()),
261 None => ("public".to_string(), qualified_table.to_string()),
262 };
263 let mut client = connect_client(url, tls)?;
264
265 let (row_estimate, rel_size_bytes) = match client.query_opt(
267 "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
268 FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
269 WHERE n.nspname = $1::text AND c.relname = $2::text",
270 &[&schema, &table],
271 )? {
272 Some(row) => {
273 let rt: i64 = row.try_get(0).unwrap_or(0);
274 let sz: i64 = row.try_get(1).unwrap_or(0);
275 (rt.max(0), sz.max(0))
276 }
277 None => (0, 0),
278 };
279 let avg_row_bytes = if row_estimate > 0 {
280 Some(rel_size_bytes / row_estimate)
281 } else {
282 None
283 };
284
285 let pk_rows = client.query(
287 "SELECT a.attname::text, t.typname::text \
288 FROM pg_index i \
289 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
290 JOIN pg_type t ON t.oid = a.atttypid \
291 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
292 AND i.indisprimary",
293 &[&schema, &table],
294 )?;
295 let single_int_pk = if pk_rows.len() == 1 {
296 let col: String = pk_rows[0].get(0);
297 let pg_type: String = pk_rows[0].get(1);
298 if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
302 Some(col)
303 } else {
304 log::debug!(
305 "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
306 );
307 None
308 }
309 } else {
310 None
311 };
312
313 let keyset_rows = client.query(
319 "SELECT a.attname::text, i.indisprimary \
320 FROM pg_index i \
321 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = i.indkey[0] \
322 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
323 AND i.indisunique AND i.indnkeyatts = 1 AND a.attnotnull",
324 &[&schema, &table],
325 )?;
326 let mut keyset_keys: Vec<String> = Vec::new();
327 for primary in [true, false] {
328 for row in &keyset_rows {
329 let col: String = row.get(0);
330 let is_primary: bool = row.get(1);
331 if is_primary == primary && !keyset_keys.contains(&col) {
332 keyset_keys.push(col);
333 }
334 }
335 }
336
337 Ok(crate::source::TableIntrospection {
338 single_int_pk,
339 keyset_keys,
340 row_estimate,
341 avg_row_bytes,
342 })
343}
344
345pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
352 match tls {
353 Some(cfg) if cfg.mode.is_enforced() => {
354 let connector = build_native_tls(cfg)?;
355 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
356 Ok(Client::connect(url, make_tls)?)
357 }
358 _ => Ok(Client::connect(url, NoTls)?),
359 }
360}
361
362fn pg_run_export(
371 client: &mut Client,
372 built_sql: &str,
373 tuning: &SourceTuning,
374 column_overrides: &ColumnOverrides,
375 sink: &mut dyn super::BatchSink,
376 numeric_hints: Option<&HashMap<String, (u8, i8)>>,
377) -> Result<(usize, bool)> {
378 let mut guard = PgTxnGuard::begin(client)?;
382 if tuning.statement_timeout_s > 0 {
383 guard.client_mut().batch_execute(&format!(
384 "SET LOCAL statement_timeout = '{}s'",
385 tuning.statement_timeout_s
386 ))?;
387 }
388 if tuning.lock_timeout_s > 0 {
389 guard.client_mut().batch_execute(&format!(
390 "SET LOCAL lock_timeout = '{}s'",
391 tuning.lock_timeout_s
392 ))?;
393 }
394 let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
401
402 guard
403 .client_mut()
404 .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
405
406 const PROBE_FETCH_SIZE: usize = 500;
411 let configured_batch_size = tuning.batch_size;
412 let mut fetch_size = configured_batch_size.min(PROBE_FETCH_SIZE);
413 let mut fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
414 let mut work_mem_cap_applied = false;
415 let mut schema: Option<SchemaRef> = None;
416 let mut columns_cache: Option<Vec<(String, Type)>> = None;
417 let mut total_rows: usize = 0;
418 let mut base_fetch_size = fetch_size;
419 let mut adaptive_last_ckpt: Option<i64> = if tuning.adaptive {
420 pg_sample_checkpoints_req(guard.client_mut())
421 } else {
422 None
423 };
424 let mut batch_count: usize = 0;
425
426 loop {
427 let requested_this_iter = fetch_size;
431 let rows = guard.client_mut().query(&fetch_sql, &[])?;
432 if rows.is_empty() {
433 break;
434 }
435
436 if schema.is_none() {
437 let stmt_cols: Vec<(String, Type)> = rows[0]
438 .columns()
439 .iter()
440 .map(|c| (c.name().to_string(), c.type_().clone()))
441 .collect();
442 let s = Arc::new(pg_columns_to_schema(
443 rows[0].columns(),
444 column_overrides,
445 numeric_hints,
446 )?);
447 sink.on_schema(s.clone())?;
448 schema = Some(s.clone());
449 columns_cache = Some(stmt_cols);
450
451 let effective = tuning.effective_batch_size(Some(&s));
454 if effective != fetch_size && work_mem_bytes.is_none() {
455 fetch_size = effective.max(fetch_size);
456 fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
457 }
458 base_fetch_size = fetch_size;
459 }
460
461 let row_count = rows.len();
462 total_rows += row_count;
463
464 let s = schema.as_ref().expect("schema set on first iteration");
465 let cols = columns_cache
466 .as_ref()
467 .expect("columns set on first iteration");
468 let batch = rows_to_record_batch_typed(s, cols, &rows)?;
469 drop(rows);
470
471 if !work_mem_cap_applied
478 && let Some(wm) = work_mem_bytes
479 && row_count > 0
480 {
481 let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
482 let arrow_per_row = (arrow_bytes / row_count).max(1);
483 let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
484 let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
485 let mut target = safe.min(configured_batch_size);
486 if let Some(mem_mb) = tuning.batch_size_memory_mb {
487 let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
488 target = target.min(arrow_target.max(100));
489 }
490 if target != fetch_size {
491 log::info!(
492 "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N {} → {} (configured={})",
493 wm,
494 arrow_per_row,
495 pg_per_row,
496 fetch_size,
497 target,
498 configured_batch_size,
499 );
500 fetch_size = target;
501 fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
502 base_fetch_size = fetch_size;
503 }
504 work_mem_cap_applied = true;
505 }
506
507 sink.on_batch(&batch)?;
508
509 batch_count += 1;
510 if tuning.adaptive
511 && batch_count.is_multiple_of(ADAPTIVE_SAMPLE_INTERVAL)
512 && let Some(cur) = pg_sample_checkpoints_req(guard.client_mut())
513 {
514 let under_pressure = adaptive_last_ckpt.is_some_and(|prev| cur > prev);
515 adaptive_last_ckpt = Some(cur);
516 let next = next_adaptive_batch_size(fetch_size, base_fetch_size, under_pressure);
517 if next != fetch_size {
518 fetch_size = next;
519 fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
520 log::info!(
521 "adaptive batch size → {} ({})",
522 fetch_size,
523 if under_pressure {
524 "pressure"
525 } else {
526 "recovery"
527 }
528 );
529 }
530 }
531
532 log::info!("fetched {} rows so far...", total_rows);
533
534 if row_count < requested_this_iter {
535 break;
536 }
537
538 if tuning.throttle_ms > 0 {
539 std::thread::sleep(std::time::Duration::from_millis(tuning.throttle_ms));
540 }
541 }
542
543 guard.client_mut().batch_execute("CLOSE _rivet")?;
546 guard.commit()?;
547 Ok((total_rows, schema.is_some()))
548}
549
550impl super::Source for PostgresSource {
551 fn export(
552 &mut self,
553 request: &super::ExportRequest<'_>,
554 sink: &mut dyn super::BatchSink,
555 ) -> Result<()> {
556 let built = build_export_query(request, SourceType::Postgres);
557 debug_assert!(
558 built.cursor_param.is_none(),
559 "Postgres path inlines cursor values as E'…' literals — binding is unused"
560 );
561 log::debug!(
562 "executing query (connection={}): {}",
563 if self.transaction_pooler {
564 "transaction-pooler"
565 } else {
566 "direct"
567 },
568 built.sql
569 );
570
571 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
575 let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, hint_query);
576
577 let (total_rows, had_schema) = pg_run_export(
580 &mut self.client,
581 &built.sql,
582 request.tuning,
583 request.column_overrides,
584 sink,
585 numeric_hints.as_ref(),
586 )?;
587
588 if !had_schema {
589 sink.on_schema(Arc::new(Schema::empty()))?;
590 }
591
592 log::info!("total: {} rows", total_rows);
593 Ok(())
594 }
595
596 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
597 let rows = self.client.query(sql, &[])?;
598 if rows.is_empty() {
599 return Ok(None);
600 }
601 let row = &rows[0];
602 if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
603 return Ok(Some(v.to_string()));
604 }
605 if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
606 return Ok(Some(v.to_string()));
607 }
608 if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
609 return Ok(Some(v.to_string()));
610 }
611 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
613 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
614 }
615 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
616 return Ok(Some(v.format("%Y-%m-%d").to_string()));
617 }
618 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
619 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
620 }
621 if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
622 return Ok(Some(v));
623 }
624 Ok(None)
625 }
626
627 fn type_mappings(
628 &mut self,
629 query: &str,
630 column_overrides: &ColumnOverrides,
631 ) -> Result<Vec<TypeMapping>> {
632 let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
633 let stmt = self.client.prepare(&wrapped)?;
634 let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
635 let mappings = stmt
636 .columns()
637 .iter()
638 .map(|col| {
639 let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
640 let source = SourceColumn::simple(col.name(), col.type_().name(), true);
641 TypeMapping::from_source(&source, rivet)
642 })
643 .collect();
644 Ok(mappings)
645 }
646
647 fn sample_pressure(&mut self) -> Option<u64> {
651 pg_sample_checkpoints_req(&mut self.client).map(|v| v.max(0) as u64)
652 }
653}
654
655fn pg_numeric_catalog_hints_opt(
661 client: &mut Client,
662 query: &str,
663) -> Option<HashMap<String, (u8, i8)>> {
664 match pg_fetch_numeric_catalog_hints(client, query) {
665 Ok(m) => m,
666 Err(e) => {
667 log::warn!(
673 "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
674 );
675 None
676 }
677 }
678}
679
680fn pg_fetch_numeric_catalog_hints(
681 client: &mut Client,
682 query: &str,
683) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
684 let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
685 return Ok(None);
686 };
687 let locate_sql = "SELECT n.nspname::text, c.relname::text \
688 FROM pg_catalog.pg_class c \
689 JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
690 WHERE c.oid = ($1::text)::regclass";
691 let row_opt = match client.query_opt(locate_sql, &[®class_lit]) {
692 Ok(r) => r,
693 Err(e) => {
694 log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
695 return Ok(None);
696 }
697 };
698 let Some(row) = row_opt else {
699 return Ok(None);
700 };
701 let schema: String = row.get(0);
702 let table: String = row.get(1);
703 let rows = client.query(
704 "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
705 FROM information_schema.columns \
706 WHERE table_schema = $1 AND table_name = $2 \
707 ORDER BY ordinal_position",
708 &[&schema, &table],
709 )?;
710
711 let mut map = HashMap::new();
712 for row in rows {
713 let col: String = row.get(0);
714 let dt: String = row.get(1);
715 if !is_pg_numeric_information_type(&dt) {
716 continue;
717 }
718 let p: Option<i32> = row.get(2);
719 let s: Option<i32> = row.get(3);
720 if let (Some(p), Some(s)) = (p, s)
721 && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
722 {
723 map.insert(col, pair);
724 }
725 }
726
727 if map.is_empty() {
728 Ok(None)
729 } else {
730 log::debug!(
731 "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
732 map.len(),
733 );
734 Ok(Some(map))
735 }
736}
737
738fn is_pg_numeric_information_type(dt: &str) -> bool {
739 let d = dt.trim().to_ascii_lowercase();
740 matches!(d.as_str(), "numeric" | "decimal")
741 || d.starts_with("numeric(")
742 || d.starts_with("decimal(")
743}
744
745fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
747 if precision <= 0 || precision > 76 {
748 return None;
749 }
750 let precision_u = precision as u8;
751 if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
752 return None;
753 }
754 let scale_i = scale as i8;
755 if scale_i > precision as i8 {
756 return None;
757 }
758 Some((precision_u, scale_i))
759}
760
761#[cfg(test)]
762mod tests {
763 use super::catalog_numeric_to_decimal_params;
764
765 #[test]
768 fn catalog_decimal_bounds() {
769 assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
770 assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
771 assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
772 assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
773 }
774
775 #[test]
776 fn parse_work_mem_handles_pg_units() {
777 use super::parse_work_mem;
778 assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
781 assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
782 assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
783 assert_eq!(parse_work_mem(" 4MB "), Some(4 * 1024 * 1024));
784 assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
785 assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
786 assert_eq!(parse_work_mem(""), None);
787 assert_eq!(parse_work_mem("garbage"), None);
788 assert_eq!(parse_work_mem("4s"), None);
790 }
791}