1mod arrow_convert;
23mod from_parse;
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use postgres::types::Type;
30use postgres::{Client, NoTls};
31
32use crate::config::{SourceType, TlsConfig};
33use crate::error::Result;
34use crate::source::query::build_incremental_query;
35use crate::source::tls::build_native_tls;
36use crate::tuning::{ADAPTIVE_SAMPLE_INTERVAL, SourceTuning, next_adaptive_batch_size};
37use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
38
39use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
40use from_parse::try_parse_pg_simple_from_regclass_literal;
41
42pub struct PostgresSource {
43 client: Client,
44 transaction_pooler: bool,
47}
48
49fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
57 let pid1: Option<i32> = client
58 .query_one("SELECT pg_backend_pid()", &[])
59 .ok()
60 .and_then(|r| r.try_get(0).ok());
61 let pid2: Option<i32> = client
62 .query_one("SELECT pg_backend_pid()", &[])
63 .ok()
64 .and_then(|r| r.try_get(0).ok());
65 matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
66}
67
68impl PostgresSource {
69 pub fn connect(url: &str) -> Result<Self> {
72 let mut client = Client::connect(url, NoTls)?;
73 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
74 if transaction_pooler {
75 log::warn!(
76 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
77 SET LOCAL tuning is transaction-scoped; \
78 LISTEN/NOTIFY and advisory locks are unavailable"
79 );
80 }
81 Ok(Self {
82 client,
83 transaction_pooler,
84 })
85 }
86
87 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
90 match tls {
91 Some(cfg) if cfg.mode.is_enforced() => {
92 let connector = build_native_tls(cfg)?;
93 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
94 let mut client = Client::connect(url, make_tls)?;
95 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
96 if transaction_pooler {
97 log::warn!(
98 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
99 SET LOCAL tuning is transaction-scoped; \
100 LISTEN/NOTIFY and advisory locks are unavailable"
101 );
102 }
103 Ok(Self {
104 client,
105 transaction_pooler,
106 })
107 }
108 _ => Self::connect(url),
109 }
110 }
111}
112
113struct PgTxnGuard<'a> {
121 client: &'a mut Client,
122 committed: bool,
123}
124
125impl<'a> PgTxnGuard<'a> {
126 fn begin(client: &'a mut Client) -> Result<Self> {
127 client.batch_execute("BEGIN")?;
128 Ok(Self {
129 client,
130 committed: false,
131 })
132 }
133
134 fn client_mut(&mut self) -> &mut Client {
135 self.client
136 }
137
138 fn commit(mut self) -> Result<()> {
139 self.client.batch_execute("COMMIT")?;
140 self.committed = true;
141 Ok(())
142 }
143}
144
145impl Drop for PgTxnGuard<'_> {
146 fn drop(&mut self) {
147 if !self.committed
148 && let Err(e) = self.client.batch_execute("ROLLBACK")
149 {
150 log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
153 }
154 }
155}
156
157pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
169 let mut client = connect_client(url, tls).ok()?;
170 client
171 .query_one(
172 "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
173 &[],
174 )
175 .ok()
176 .and_then(|r| r.try_get::<_, i64>(0).ok())
177}
178
179fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
190 let raw: Option<String> = client
191 .query_one("SHOW work_mem", &[])
192 .ok()
193 .and_then(|r| r.try_get::<_, String>(0).ok());
194 raw.as_deref().and_then(parse_work_mem)
195}
196
197fn parse_work_mem(raw: &str) -> Option<i64> {
201 let s = raw.trim();
202 let mut split = 0;
204 for (i, ch) in s.char_indices() {
205 if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
206 split = i;
207 break;
208 }
209 split = i + ch.len_utf8();
210 }
211 if split == 0 {
212 return None;
213 }
214 let (num_str, unit) = s.split_at(split);
215 let num: f64 = num_str.parse().ok()?;
216 let unit = unit.trim().to_ascii_lowercase();
217 let multiplier: f64 = match unit.as_str() {
218 "" | "kb" => 1024.0,
221 "mb" => 1024.0 * 1024.0,
222 "gb" => 1024.0 * 1024.0 * 1024.0,
223 "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
224 _ => return None,
225 };
226 let bytes = (num * multiplier) as i64;
227 (bytes > 0).then_some(bytes)
228}
229
230fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
236 let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
237 client
238 .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
239 .ok()
240 .and_then(|r| r.try_get::<_, i64>(0).ok())
241}
242
243pub(crate) fn introspect_pg_table_for_chunking(
255 url: &str,
256 tls: Option<&TlsConfig>,
257 qualified_table: &str,
258) -> Result<crate::source::TableIntrospection> {
259 let (schema, table) = match qualified_table.split_once('.') {
260 Some((s, t)) => (s.to_string(), t.to_string()),
261 None => ("public".to_string(), qualified_table.to_string()),
262 };
263 let mut client = connect_client(url, tls)?;
264
265 let (row_estimate, rel_size_bytes) = match client.query_opt(
267 "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
268 FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
269 WHERE n.nspname = $1::text AND c.relname = $2::text",
270 &[&schema, &table],
271 )? {
272 Some(row) => {
273 let rt: i64 = row.try_get(0).unwrap_or(0);
274 let sz: i64 = row.try_get(1).unwrap_or(0);
275 (rt.max(0), sz.max(0))
276 }
277 None => (0, 0),
278 };
279 let avg_row_bytes = if row_estimate > 0 {
280 Some(rel_size_bytes / row_estimate)
281 } else {
282 None
283 };
284
285 let pk_rows = client.query(
287 "SELECT a.attname::text, t.typname::text \
288 FROM pg_index i \
289 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
290 JOIN pg_type t ON t.oid = a.atttypid \
291 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
292 AND i.indisprimary",
293 &[&schema, &table],
294 )?;
295 let single_int_pk = if pk_rows.len() == 1 {
296 let col: String = pk_rows[0].get(0);
297 let pg_type: String = pk_rows[0].get(1);
298 if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
302 Some(col)
303 } else {
304 log::debug!(
305 "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
306 );
307 None
308 }
309 } else {
310 None
311 };
312
313 Ok(crate::source::TableIntrospection {
314 single_int_pk,
315 row_estimate,
316 avg_row_bytes,
317 })
318}
319
320pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
327 match tls {
328 Some(cfg) if cfg.mode.is_enforced() => {
329 let connector = build_native_tls(cfg)?;
330 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
331 Ok(Client::connect(url, make_tls)?)
332 }
333 _ => Ok(Client::connect(url, NoTls)?),
334 }
335}
336
337fn pg_run_export(
346 client: &mut Client,
347 built_sql: &str,
348 tuning: &SourceTuning,
349 column_overrides: &ColumnOverrides,
350 sink: &mut dyn super::BatchSink,
351 numeric_hints: Option<&HashMap<String, (u8, i8)>>,
352) -> Result<(usize, bool)> {
353 let mut guard = PgTxnGuard::begin(client)?;
357 if tuning.statement_timeout_s > 0 {
358 guard.client_mut().batch_execute(&format!(
359 "SET LOCAL statement_timeout = '{}s'",
360 tuning.statement_timeout_s
361 ))?;
362 }
363 if tuning.lock_timeout_s > 0 {
364 guard.client_mut().batch_execute(&format!(
365 "SET LOCAL lock_timeout = '{}s'",
366 tuning.lock_timeout_s
367 ))?;
368 }
369 let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
376
377 guard
378 .client_mut()
379 .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
380
381 const PROBE_FETCH_SIZE: usize = 500;
386 let configured_batch_size = tuning.batch_size;
387 let mut fetch_size = configured_batch_size.min(PROBE_FETCH_SIZE);
388 let mut fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
389 let mut work_mem_cap_applied = false;
390 let mut schema: Option<SchemaRef> = None;
391 let mut columns_cache: Option<Vec<(String, Type)>> = None;
392 let mut total_rows: usize = 0;
393 let mut base_fetch_size = fetch_size;
394 let mut adaptive_last_ckpt: Option<i64> = if tuning.adaptive {
395 pg_sample_checkpoints_req(guard.client_mut())
396 } else {
397 None
398 };
399 let mut batch_count: usize = 0;
400
401 loop {
402 let requested_this_iter = fetch_size;
406 let rows = guard.client_mut().query(&fetch_sql, &[])?;
407 if rows.is_empty() {
408 break;
409 }
410
411 if schema.is_none() {
412 let stmt_cols: Vec<(String, Type)> = rows[0]
413 .columns()
414 .iter()
415 .map(|c| (c.name().to_string(), c.type_().clone()))
416 .collect();
417 let s = Arc::new(pg_columns_to_schema(
418 rows[0].columns(),
419 column_overrides,
420 numeric_hints,
421 )?);
422 sink.on_schema(s.clone())?;
423 schema = Some(s.clone());
424 columns_cache = Some(stmt_cols);
425
426 let effective = tuning.effective_batch_size(Some(&s));
429 if effective != fetch_size && work_mem_bytes.is_none() {
430 fetch_size = effective.max(fetch_size);
431 fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
432 }
433 base_fetch_size = fetch_size;
434 }
435
436 let row_count = rows.len();
437 total_rows += row_count;
438
439 let s = schema.as_ref().expect("schema set on first iteration");
440 let cols = columns_cache
441 .as_ref()
442 .expect("columns set on first iteration");
443 let batch = rows_to_record_batch_typed(s, cols, &rows)?;
444 drop(rows);
445
446 if !work_mem_cap_applied
453 && let Some(wm) = work_mem_bytes
454 && row_count > 0
455 {
456 let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
457 let arrow_per_row = (arrow_bytes / row_count).max(1);
458 let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
459 let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
460 let mut target = safe.min(configured_batch_size);
461 if let Some(mem_mb) = tuning.batch_size_memory_mb {
462 let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
463 target = target.min(arrow_target.max(100));
464 }
465 if target != fetch_size {
466 log::info!(
467 "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N {} → {} (configured={})",
468 wm,
469 arrow_per_row,
470 pg_per_row,
471 fetch_size,
472 target,
473 configured_batch_size,
474 );
475 fetch_size = target;
476 fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
477 base_fetch_size = fetch_size;
478 }
479 work_mem_cap_applied = true;
480 }
481
482 sink.on_batch(&batch)?;
483
484 batch_count += 1;
485 if tuning.adaptive
486 && batch_count.is_multiple_of(ADAPTIVE_SAMPLE_INTERVAL)
487 && let Some(cur) = pg_sample_checkpoints_req(guard.client_mut())
488 {
489 let under_pressure = adaptive_last_ckpt.is_some_and(|prev| cur > prev);
490 adaptive_last_ckpt = Some(cur);
491 let next = next_adaptive_batch_size(fetch_size, base_fetch_size, under_pressure);
492 if next != fetch_size {
493 fetch_size = next;
494 fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
495 log::info!(
496 "adaptive batch size → {} ({})",
497 fetch_size,
498 if under_pressure {
499 "pressure"
500 } else {
501 "recovery"
502 }
503 );
504 }
505 }
506
507 log::info!("fetched {} rows so far...", total_rows);
508
509 if row_count < requested_this_iter {
510 break;
511 }
512
513 if tuning.throttle_ms > 0 {
514 std::thread::sleep(std::time::Duration::from_millis(tuning.throttle_ms));
515 }
516 }
517
518 guard.client_mut().batch_execute("CLOSE _rivet")?;
521 guard.commit()?;
522 Ok((total_rows, schema.is_some()))
523}
524
525impl super::Source for PostgresSource {
526 fn export(
527 &mut self,
528 request: &super::ExportRequest<'_>,
529 sink: &mut dyn super::BatchSink,
530 ) -> Result<()> {
531 let built = build_incremental_query(
532 request.query,
533 request.incremental,
534 request.cursor,
535 SourceType::Postgres,
536 );
537 debug_assert!(
538 built.cursor_param.is_none(),
539 "Postgres path inlines cursor values as E'…' literals — binding is unused"
540 );
541 log::debug!(
542 "executing query (connection={}): {}",
543 if self.transaction_pooler {
544 "transaction-pooler"
545 } else {
546 "direct"
547 },
548 built.sql
549 );
550
551 let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, request.query);
552
553 let (total_rows, had_schema) = pg_run_export(
556 &mut self.client,
557 &built.sql,
558 request.tuning,
559 request.column_overrides,
560 sink,
561 numeric_hints.as_ref(),
562 )?;
563
564 if !had_schema {
565 sink.on_schema(Arc::new(Schema::empty()))?;
566 }
567
568 log::info!("total: {} rows", total_rows);
569 Ok(())
570 }
571
572 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
573 let rows = self.client.query(sql, &[])?;
574 if rows.is_empty() {
575 return Ok(None);
576 }
577 let row = &rows[0];
578 if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
579 return Ok(Some(v.to_string()));
580 }
581 if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
582 return Ok(Some(v.to_string()));
583 }
584 if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
585 return Ok(Some(v.to_string()));
586 }
587 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
589 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
590 }
591 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
592 return Ok(Some(v.format("%Y-%m-%d").to_string()));
593 }
594 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
595 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
596 }
597 if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
598 return Ok(Some(v));
599 }
600 Ok(None)
601 }
602
603 fn type_mappings(
604 &mut self,
605 query: &str,
606 column_overrides: &ColumnOverrides,
607 ) -> Result<Vec<TypeMapping>> {
608 let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
609 let stmt = self.client.prepare(&wrapped)?;
610 let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
611 let mappings = stmt
612 .columns()
613 .iter()
614 .map(|col| {
615 let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
616 let source = SourceColumn::simple(col.name(), col.type_().name(), true);
617 TypeMapping::from_source(&source, rivet)
618 })
619 .collect();
620 Ok(mappings)
621 }
622}
623
624fn pg_numeric_catalog_hints_opt(
630 client: &mut Client,
631 query: &str,
632) -> Option<HashMap<String, (u8, i8)>> {
633 match pg_fetch_numeric_catalog_hints(client, query) {
634 Ok(m) => m,
635 Err(e) => {
636 log::warn!(
642 "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
643 );
644 None
645 }
646 }
647}
648
649fn pg_fetch_numeric_catalog_hints(
650 client: &mut Client,
651 query: &str,
652) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
653 let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
654 return Ok(None);
655 };
656 let locate_sql = "SELECT n.nspname::text, c.relname::text \
657 FROM pg_catalog.pg_class c \
658 JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
659 WHERE c.oid = ($1::text)::regclass";
660 let row_opt = match client.query_opt(locate_sql, &[®class_lit]) {
661 Ok(r) => r,
662 Err(e) => {
663 log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
664 return Ok(None);
665 }
666 };
667 let Some(row) = row_opt else {
668 return Ok(None);
669 };
670 let schema: String = row.get(0);
671 let table: String = row.get(1);
672 let rows = client.query(
673 "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
674 FROM information_schema.columns \
675 WHERE table_schema = $1 AND table_name = $2 \
676 ORDER BY ordinal_position",
677 &[&schema, &table],
678 )?;
679
680 let mut map = HashMap::new();
681 for row in rows {
682 let col: String = row.get(0);
683 let dt: String = row.get(1);
684 if !is_pg_numeric_information_type(&dt) {
685 continue;
686 }
687 let p: Option<i32> = row.get(2);
688 let s: Option<i32> = row.get(3);
689 if let (Some(p), Some(s)) = (p, s)
690 && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
691 {
692 map.insert(col, pair);
693 }
694 }
695
696 if map.is_empty() {
697 Ok(None)
698 } else {
699 log::debug!(
700 "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
701 map.len(),
702 );
703 Ok(Some(map))
704 }
705}
706
707fn is_pg_numeric_information_type(dt: &str) -> bool {
708 let d = dt.trim().to_ascii_lowercase();
709 matches!(d.as_str(), "numeric" | "decimal")
710 || d.starts_with("numeric(")
711 || d.starts_with("decimal(")
712}
713
714fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
716 if precision <= 0 || precision > 76 {
717 return None;
718 }
719 let precision_u = precision as u8;
720 if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
721 return None;
722 }
723 let scale_i = scale as i8;
724 if scale_i > precision as i8 {
725 return None;
726 }
727 Some((precision_u, scale_i))
728}
729
730#[cfg(test)]
731mod tests {
732 use super::catalog_numeric_to_decimal_params;
733
734 #[test]
737 fn catalog_decimal_bounds() {
738 assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
739 assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
740 assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
741 assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
742 }
743
744 #[test]
745 fn parse_work_mem_handles_pg_units() {
746 use super::parse_work_mem;
747 assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
750 assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
751 assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
752 assert_eq!(parse_work_mem(" 4MB "), Some(4 * 1024 * 1024));
753 assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
754 assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
755 assert_eq!(parse_work_mem(""), None);
756 assert_eq!(parse_work_mem("garbage"), None);
757 assert_eq!(parse_work_mem("4s"), None);
759 }
760}