1mod arrow_convert;
23mod from_parse;
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use postgres::types::Type;
30use postgres::{Client, NoTls};
31
32use crate::config::{SourceType, TlsConfig};
33use crate::error::Result;
34use crate::source::batch_controller::AdaptiveBatchController;
35use crate::source::query::build_export_query;
36use crate::source::tls::build_native_tls;
37use crate::tuning::SourceTuning;
38use crate::types::{ColumnOverrides, SourceColumn, TypeMapping};
39
40use arrow_convert::{pg_columns_to_schema, rivet_type_for_pg_column, rows_to_record_batch_typed};
41use from_parse::try_parse_pg_simple_from_regclass_literal;
42
43pub struct PostgresSource {
44 client: Client,
45 transaction_pooler: bool,
48}
49
50fn detect_pg_transaction_pooler(client: &mut Client) -> bool {
58 let pid1: Option<i32> = client
59 .query_one("SELECT pg_backend_pid()", &[])
60 .ok()
61 .and_then(|r| r.try_get(0).ok());
62 let pid2: Option<i32> = client
63 .query_one("SELECT pg_backend_pid()", &[])
64 .ok()
65 .and_then(|r| r.try_get(0).ok());
66 matches!((pid1, pid2), (Some(a), Some(b)) if a != b)
67}
68
69impl PostgresSource {
70 pub fn connect(url: &str) -> Result<Self> {
73 let mut client = Client::connect(url, NoTls)?;
74 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
75 if transaction_pooler {
76 log::warn!(
77 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
78 SET LOCAL tuning is transaction-scoped; \
79 LISTEN/NOTIFY and advisory locks are unavailable"
80 );
81 }
82 Ok(Self {
83 client,
84 transaction_pooler,
85 })
86 }
87
88 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
91 match tls {
92 Some(cfg) if cfg.mode.is_enforced() => {
93 let connector = build_native_tls(cfg)?;
94 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
95 let mut client = Client::connect(url, make_tls)?;
96 let transaction_pooler = detect_pg_transaction_pooler(&mut client);
97 if transaction_pooler {
98 log::warn!(
99 "transaction-mode connection pooler detected (pgBouncer/Odyssey) — \
100 SET LOCAL tuning is transaction-scoped; \
101 LISTEN/NOTIFY and advisory locks are unavailable"
102 );
103 }
104 Ok(Self {
105 client,
106 transaction_pooler,
107 })
108 }
109 _ => Self::connect(url),
110 }
111 }
112}
113
114struct PgTxnGuard<'a> {
122 client: &'a mut Client,
123 committed: bool,
124}
125
126impl<'a> PgTxnGuard<'a> {
127 fn begin(client: &'a mut Client) -> Result<Self> {
128 client.batch_execute("BEGIN")?;
129 Ok(Self {
130 client,
131 committed: false,
132 })
133 }
134
135 fn client_mut(&mut self) -> &mut Client {
136 self.client
137 }
138
139 fn commit(mut self) -> Result<()> {
140 self.client.batch_execute("COMMIT")?;
141 self.committed = true;
142 Ok(())
143 }
144}
145
146impl Drop for PgTxnGuard<'_> {
147 fn drop(&mut self) {
148 if !self.committed
149 && let Err(e) = self.client.batch_execute("ROLLBACK")
150 {
151 log::warn!("PgTxnGuard: ROLLBACK during drop failed: {e:#}");
154 }
155 }
156}
157
158pub(crate) fn sample_temp_bytes(url: &str, tls: Option<&TlsConfig>) -> Option<i64> {
170 let mut client = connect_client(url, tls).ok()?;
171 client
172 .query_one(
173 "SELECT temp_bytes::bigint FROM pg_stat_database WHERE datname = current_database()",
174 &[],
175 )
176 .ok()
177 .and_then(|r| r.try_get::<_, i64>(0).ok())
178}
179
180fn pg_fetch_work_mem_bytes(client: &mut Client) -> Option<i64> {
191 let raw: Option<String> = client
192 .query_one("SHOW work_mem", &[])
193 .ok()
194 .and_then(|r| r.try_get::<_, String>(0).ok());
195 raw.as_deref().and_then(parse_work_mem)
196}
197
198fn parse_work_mem(raw: &str) -> Option<i64> {
202 let s = raw.trim();
203 let mut split = 0;
205 for (i, ch) in s.char_indices() {
206 if !ch.is_ascii_digit() && ch != '.' && ch != '-' {
207 split = i;
208 break;
209 }
210 split = i + ch.len_utf8();
211 }
212 if split == 0 {
213 return None;
214 }
215 let (num_str, unit) = s.split_at(split);
216 let num: f64 = num_str.parse().ok()?;
217 let unit = unit.trim().to_ascii_lowercase();
218 let multiplier: f64 = match unit.as_str() {
219 "" | "kb" => 1024.0,
222 "mb" => 1024.0 * 1024.0,
223 "gb" => 1024.0 * 1024.0 * 1024.0,
224 "tb" => 1024.0 * 1024.0 * 1024.0 * 1024.0,
225 _ => return None,
226 };
227 let bytes = (num * multiplier) as i64;
228 (bytes > 0).then_some(bytes)
229}
230
231fn pg_sample_checkpoints_req(client: &mut Client) -> Option<i64> {
237 let _ = client.execute("SELECT pg_stat_clear_snapshot()", &[]);
238 client
239 .query_one("SELECT checkpoints_req FROM pg_stat_bgwriter", &[])
240 .ok()
241 .and_then(|r| r.try_get::<_, i64>(0).ok())
242}
243
244pub(crate) fn introspect_pg_table_for_chunking(
256 url: &str,
257 tls: Option<&TlsConfig>,
258 qualified_table: &str,
259) -> Result<crate::source::TableIntrospection> {
260 let (schema, table) = match qualified_table.split_once('.') {
261 Some((s, t)) => (s.to_string(), t.to_string()),
262 None => ("public".to_string(), qualified_table.to_string()),
263 };
264 let mut client = connect_client(url, tls)?;
265
266 let (row_estimate, rel_size_bytes) = match client.query_opt(
268 "SELECT c.reltuples::bigint, pg_relation_size(c.oid)::bigint \
269 FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace \
270 WHERE n.nspname = $1::text AND c.relname = $2::text",
271 &[&schema, &table],
272 )? {
273 Some(row) => {
274 let rt: i64 = row.try_get(0).unwrap_or(0);
275 let sz: i64 = row.try_get(1).unwrap_or(0);
276 (rt.max(0), sz.max(0))
277 }
278 None => (0, 0),
279 };
280 let avg_row_bytes = if row_estimate > 0 {
281 Some(rel_size_bytes / row_estimate)
282 } else {
283 None
284 };
285
286 let pk_rows = client.query(
288 "SELECT a.attname::text, t.typname::text \
289 FROM pg_index i \
290 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) \
291 JOIN pg_type t ON t.oid = a.atttypid \
292 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
293 AND i.indisprimary",
294 &[&schema, &table],
295 )?;
296 let single_int_pk = if pk_rows.len() == 1 {
297 let col: String = pk_rows[0].get(0);
298 let pg_type: String = pk_rows[0].get(1);
299 if matches!(pg_type.as_str(), "int2" | "int4" | "int8") {
303 Some(col)
304 } else {
305 log::debug!(
306 "introspect_pg_table: PK '{col}' on {schema}.{table} has non-int type '{pg_type}' — skipping auto-resolve"
307 );
308 None
309 }
310 } else {
311 None
312 };
313
314 let keyset_rows = client.query(
320 "SELECT a.attname::text, i.indisprimary \
321 FROM pg_index i \
322 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = i.indkey[0] \
323 WHERE i.indrelid = (($1::text || '.' || $2::text)::regclass) \
324 AND i.indisunique AND i.indnkeyatts = 1 AND a.attnotnull",
325 &[&schema, &table],
326 )?;
327 let mut keyset_keys: Vec<String> = Vec::new();
328 for primary in [true, false] {
329 for row in &keyset_rows {
330 let col: String = row.get(0);
331 let is_primary: bool = row.get(1);
332 if is_primary == primary && !keyset_keys.contains(&col) {
333 keyset_keys.push(col);
334 }
335 }
336 }
337
338 Ok(crate::source::TableIntrospection {
339 single_int_pk,
340 keyset_keys,
341 row_estimate,
342 avg_row_bytes,
343 })
344}
345
346pub(crate) fn connect_client(url: &str, tls: Option<&TlsConfig>) -> Result<Client> {
353 match tls {
354 Some(cfg) if cfg.mode.is_enforced() => {
355 let connector = build_native_tls(cfg)?;
356 let make_tls = postgres_native_tls::MakeTlsConnector::new(connector);
357 Ok(Client::connect(url, make_tls)?)
358 }
359 _ => Ok(Client::connect(url, NoTls)?),
360 }
361}
362
363fn pg_run_export(
372 client: &mut Client,
373 built_sql: &str,
374 tuning: &SourceTuning,
375 column_overrides: &ColumnOverrides,
376 sink: &mut dyn super::BatchSink,
377 numeric_hints: Option<&HashMap<String, (u8, i8)>>,
378) -> Result<(usize, bool)> {
379 let mut guard = PgTxnGuard::begin(client)?;
383 if tuning.statement_timeout_s > 0 {
384 guard.client_mut().batch_execute(&format!(
385 "SET LOCAL statement_timeout = '{}s'",
386 tuning.statement_timeout_s
387 ))?;
388 }
389 if tuning.lock_timeout_s > 0 {
390 guard.client_mut().batch_execute(&format!(
391 "SET LOCAL lock_timeout = '{}s'",
392 tuning.lock_timeout_s
393 ))?;
394 }
395 let work_mem_bytes = pg_fetch_work_mem_bytes(guard.client_mut());
402
403 guard
404 .client_mut()
405 .batch_execute(&format!("DECLARE _rivet NO SCROLL CURSOR FOR {built_sql}"))?;
406
407 let configured_batch_size = tuning.batch_size;
412 let mut ctl = AdaptiveBatchController::new(tuning, configured_batch_size);
415 ctl.seed_pressure(if tuning.adaptive {
416 pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64)
417 } else {
418 None
419 });
420 let mut schema: Option<SchemaRef> = None;
421 let mut columns_cache: Option<Vec<(String, Type)>> = None;
422 let mut total_rows: usize = 0;
423 let mut cap_applied = false;
424
425 loop {
426 let requested = ctl.target();
427 let fetch_sql = format!("FETCH {} FROM _rivet", requested);
428 let rows = guard.client_mut().query(&fetch_sql, &[])?;
429 if rows.is_empty() {
430 break;
431 }
432
433 if schema.is_none() {
434 let stmt_cols: Vec<(String, Type)> = rows[0]
435 .columns()
436 .iter()
437 .map(|c| (c.name().to_string(), c.type_().clone()))
438 .collect();
439 let s = Arc::new(pg_columns_to_schema(
440 rows[0].columns(),
441 column_overrides,
442 numeric_hints,
443 )?);
444 sink.on_schema(s.clone())?;
445 if work_mem_bytes.is_none() {
448 let effective = tuning.effective_batch_size(Some(&s));
449 ctl.apply_memory_cap(effective.max(requested));
450 cap_applied = true;
451 }
452 schema = Some(s);
453 columns_cache = Some(stmt_cols);
454 }
455
456 let row_count = rows.len();
457 total_rows += row_count;
458
459 let s = schema.as_ref().expect("schema set on first iteration");
460 let cols = columns_cache
461 .as_ref()
462 .expect("columns set on first iteration");
463 let batch = rows_to_record_batch_typed(s, cols, &rows)?;
464 drop(rows);
465
466 if !cap_applied
471 && let Some(wm) = work_mem_bytes
472 && row_count > 0
473 {
474 let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
475 let arrow_per_row = (arrow_bytes / row_count).max(1);
476 let pg_per_row = ((arrow_per_row * 12) / 10).max(64);
477 let safe = (((wm as f64) * 0.7) as usize / pg_per_row).max(100);
478 let mut target = safe;
479 if let Some(mem_mb) = tuning.batch_size_memory_mb {
480 let arrow_target = (mem_mb * 1024 * 1024) / arrow_per_row;
481 target = target.min(arrow_target.max(100));
482 }
483 if let Some(new) = ctl.apply_memory_cap(target) {
484 log::info!(
485 "PG work_mem={} B, observed row={} B (arrow), pg≈{} B → FETCH N → {} (configured={})",
486 wm,
487 arrow_per_row,
488 pg_per_row,
489 new,
490 configured_batch_size,
491 );
492 }
493 cap_applied = true;
494 }
495
496 sink.on_batch(&batch)?;
497
498 if let Some((new, under_pressure)) =
499 ctl.after_batch(|| pg_sample_checkpoints_req(guard.client_mut()).map(|v| v as u64))
500 {
501 log::info!(
502 "adaptive batch size → {} ({})",
503 new,
504 if under_pressure {
505 "pressure"
506 } else {
507 "recovery"
508 }
509 );
510 }
511
512 log::info!("fetched {} rows so far...", total_rows);
513
514 if row_count < requested {
515 break;
516 }
517 ctl.throttle();
518 }
519
520 guard.client_mut().batch_execute("CLOSE _rivet")?;
523 guard.commit()?;
524 Ok((total_rows, schema.is_some()))
525}
526
527impl super::Source for PostgresSource {
528 fn export(
529 &mut self,
530 request: &super::ExportRequest<'_>,
531 sink: &mut dyn super::BatchSink,
532 ) -> Result<()> {
533 let built = build_export_query(request, SourceType::Postgres);
534 debug_assert!(
535 built.cursor_param.is_none(),
536 "Postgres path inlines cursor values as E'…' literals — binding is unused"
537 );
538 log::debug!(
539 "executing query (connection={}): {}",
540 if self.transaction_pooler {
541 "transaction-pooler"
542 } else {
543 "direct"
544 },
545 built.sql
546 );
547
548 let hint_query = request.catalog_hint_query.unwrap_or(request.query);
552 let numeric_hints = pg_numeric_catalog_hints_opt(&mut self.client, hint_query);
553
554 let (total_rows, had_schema) = pg_run_export(
557 &mut self.client,
558 &built.sql,
559 request.tuning,
560 request.column_overrides,
561 sink,
562 numeric_hints.as_ref(),
563 )?;
564
565 if !had_schema {
566 sink.on_schema(Arc::new(Schema::empty()))?;
567 }
568
569 log::info!("total: {} rows", total_rows);
570 Ok(())
571 }
572
573 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
574 let rows = self.client.query(sql, &[])?;
575 if rows.is_empty() {
576 return Ok(None);
577 }
578 let row = &rows[0];
579 if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
580 return Ok(Some(v.to_string()));
581 }
582 if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
583 return Ok(Some(v.to_string()));
584 }
585 if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
586 return Ok(Some(v.to_string()));
587 }
588 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDateTime>>(0) {
590 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
591 }
592 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::NaiveDate>>(0) {
593 return Ok(Some(v.format("%Y-%m-%d").to_string()));
594 }
595 if let Ok(Some(v)) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(0) {
596 return Ok(Some(v.format("%Y-%m-%d %H:%M:%S").to_string()));
597 }
598 if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
599 return Ok(Some(v));
600 }
601 Ok(None)
602 }
603
604 fn type_mappings(
605 &mut self,
606 query: &str,
607 column_overrides: &ColumnOverrides,
608 ) -> Result<Vec<TypeMapping>> {
609 let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
610 let stmt = self.client.prepare(&wrapped)?;
611 let hints = pg_numeric_catalog_hints_opt(&mut self.client, query);
612 let mappings = stmt
613 .columns()
614 .iter()
615 .map(|col| {
616 let rivet = rivet_type_for_pg_column(col, column_overrides, hints.as_ref());
617 let source = SourceColumn::simple(col.name(), col.type_().name(), true);
618 TypeMapping::from_source(&source, rivet)
619 })
620 .collect();
621 Ok(mappings)
622 }
623
624 fn sample_pressure(&mut self) -> Option<u64> {
628 pg_sample_checkpoints_req(&mut self.client).map(|v| v.max(0) as u64)
629 }
630}
631
632fn pg_numeric_catalog_hints_opt(
638 client: &mut Client,
639 query: &str,
640) -> Option<HashMap<String, (u8, i8)>> {
641 match pg_fetch_numeric_catalog_hints(client, query) {
642 Ok(m) => m,
643 Err(e) => {
644 log::warn!(
650 "PG numeric catalog lookup failed — NUMERIC columns will require explicit `columns:` overrides: {e}"
651 );
652 None
653 }
654 }
655}
656
657fn pg_fetch_numeric_catalog_hints(
658 client: &mut Client,
659 query: &str,
660) -> crate::error::Result<Option<HashMap<String, (u8, i8)>>> {
661 let Some(regclass_lit) = try_parse_pg_simple_from_regclass_literal(query) else {
662 return Ok(None);
663 };
664 let locate_sql = "SELECT n.nspname::text, c.relname::text \
665 FROM pg_catalog.pg_class c \
666 JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
667 WHERE c.oid = ($1::text)::regclass";
668 let row_opt = match client.query_opt(locate_sql, &[®class_lit]) {
669 Ok(r) => r,
670 Err(e) => {
671 log::warn!("PG numeric catalog: '{regclass_lit}' regclass lookup failed: {e}");
672 return Ok(None);
673 }
674 };
675 let Some(row) = row_opt else {
676 return Ok(None);
677 };
678 let schema: String = row.get(0);
679 let table: String = row.get(1);
680 let rows = client.query(
681 "SELECT column_name::text, data_type::text, numeric_precision, numeric_scale \
682 FROM information_schema.columns \
683 WHERE table_schema = $1 AND table_name = $2 \
684 ORDER BY ordinal_position",
685 &[&schema, &table],
686 )?;
687
688 let mut map = HashMap::new();
689 for row in rows {
690 let col: String = row.get(0);
691 let dt: String = row.get(1);
692 if !is_pg_numeric_information_type(&dt) {
693 continue;
694 }
695 let p: Option<i32> = row.get(2);
696 let s: Option<i32> = row.get(3);
697 if let (Some(p), Some(s)) = (p, s)
698 && let Some(pair) = catalog_numeric_to_decimal_params(p, s)
699 {
700 map.insert(col, pair);
701 }
702 }
703
704 if map.is_empty() {
705 Ok(None)
706 } else {
707 log::debug!(
708 "PG numeric catalog: resolved {} DECIMAL/NUMERIC column(s) for relation {regclass_lit}",
709 map.len(),
710 );
711 Ok(Some(map))
712 }
713}
714
715fn is_pg_numeric_information_type(dt: &str) -> bool {
716 let d = dt.trim().to_ascii_lowercase();
717 matches!(d.as_str(), "numeric" | "decimal")
718 || d.starts_with("numeric(")
719 || d.starts_with("decimal(")
720}
721
722fn catalog_numeric_to_decimal_params(precision: i32, scale: i32) -> Option<(u8, i8)> {
724 if precision <= 0 || precision > 76 {
725 return None;
726 }
727 let precision_u = precision as u8;
728 if scale < i32::from(i8::MIN) || scale > i32::from(i8::MAX) {
729 return None;
730 }
731 let scale_i = scale as i8;
732 if scale_i > precision as i8 {
733 return None;
734 }
735 Some((precision_u, scale_i))
736}
737
738#[cfg(test)]
739mod tests {
740 use super::catalog_numeric_to_decimal_params;
741
742 #[test]
745 fn catalog_decimal_bounds() {
746 assert_eq!(catalog_numeric_to_decimal_params(18, 2), Some((18, 2)));
747 assert!(catalog_numeric_to_decimal_params(0, 2).is_none());
748 assert!(catalog_numeric_to_decimal_params(77, 0).is_none());
749 assert!(catalog_numeric_to_decimal_params(18, 19).is_none());
750 }
751
752 #[test]
753 fn parse_work_mem_handles_pg_units() {
754 use super::parse_work_mem;
755 assert_eq!(parse_work_mem("4MB"), Some(4 * 1024 * 1024));
758 assert_eq!(parse_work_mem("16384kB"), Some(16384 * 1024));
759 assert_eq!(parse_work_mem("1GB"), Some(1024 * 1024 * 1024));
760 assert_eq!(parse_work_mem(" 4MB "), Some(4 * 1024 * 1024));
761 assert_eq!(parse_work_mem("4mb"), Some(4 * 1024 * 1024));
762 assert_eq!(parse_work_mem("65536"), Some(65536 * 1024));
763 assert_eq!(parse_work_mem(""), None);
764 assert_eq!(parse_work_mem("garbage"), None);
765 assert_eq!(parse_work_mem("4s"), None);
767 }
768}