use super::ExportDiagnostic;
use super::analysis::*;
use super::cursor_expr::incremental_key_expr;
use super::schema_error::PreflightSchemaError;
use crate::config::{ExportConfig, ExportMode, SourceType, TlsConfig};
use crate::error::Result;
pub(super) fn check_postgres(
url: &str,
tls: Option<&TlsConfig>,
exports: &[&ExportConfig],
silent: bool,
) -> Result<()> {
let mut client = crate::source::postgres::connect_client(url, tls)?;
let db_max_connections = fetch_max_connections_pg(&mut client);
for export in exports {
let diag = diagnose_pg(&mut client, export, db_max_connections)?;
if !silent {
super::print_diagnostic(&diag);
}
}
Ok(())
}
pub(super) fn diagnose_export_pg(
url: &str,
tls: Option<&TlsConfig>,
export: &ExportConfig,
) -> Result<super::ExportDiagnostic> {
let mut client = crate::source::postgres::connect_client(url, tls)?;
let db_max_connections = fetch_max_connections_pg(&mut client);
diagnose_pg(&mut client, export, db_max_connections)
}
fn fetch_max_connections_pg(client: &mut postgres::Client) -> Option<u32> {
let rows = match client.query("SELECT current_setting('max_connections')::int", &[]) {
Ok(r) => r,
Err(e) => {
log::debug!("preflight: max_connections probe failed: {e}");
return None;
}
};
rows.first()?.get::<_, i32>(0).try_into().ok()
}
fn diagnose_pg(
client: &mut postgres::Client,
export: &ExportConfig,
db_max_connections: Option<u32>,
) -> Result<ExportDiagnostic> {
let mode_str = diagnose_mode_str(export);
let base_query = resolve_preflight_base_query(export);
let base_query = base_query.as_str();
let range_col = export
.chunk_column
.as_deref()
.or(export.cursor_column.as_deref());
let effective_query = if let Some(order) = incremental_key_expr(export, SourceType::Postgres) {
format!(
"SELECT * FROM ({}) AS _rivet ORDER BY {}",
base_query, order
)
} else {
base_query.to_string()
};
let (row_estimate, avg_row_bytes) = estimate_rows_pg(client, &effective_query)?;
let (range_min, range_max) = if export.mode == ExportMode::Incremental {
if let Some(expr) = incremental_key_expr(export, SourceType::Postgres) {
let range_query = match crate::pipeline::chunked::strip_select_star_from(base_query) {
Some(tbl) => format!("SELECT min({expr})::text, max({expr})::text FROM {tbl}"),
None => format!(
"SELECT min({expr})::text, max({expr})::text FROM ({base_query}) AS _rivet"
),
};
match client.query(&range_query, &[]) {
Ok(rows) if !rows.is_empty() => {
let min_val: Option<String> = rows[0].get(0);
let max_val: Option<String> = rows[0].get(1);
(min_val, max_val)
}
Ok(_) => (None, None),
Err(e) => {
log::debug!(
"preflight: incremental key range probe failed for export '{}': {e}",
export.name
);
(None, None)
}
}
} else {
(None, None)
}
} else if let Some(col) = range_col {
get_cursor_range_pg(client, base_query, col)
} else {
(None, None)
};
let (scan_type, plan_uses_index) = analyze_plan_pg(client, &effective_query);
let uses_index = if matches!(export.mode, ExportMode::Chunked | ExportMode::Incremental)
&& let Some(col) = range_col
&& let Some(table) = export
.table
.as_deref()
.or_else(|| table_from_simple_query(base_query))
{
match column_has_btree_pg(client, table, col) {
Some(true) => true,
Some(false) => plan_uses_index,
None => plan_uses_index,
}
} else {
plan_uses_index
};
let strategy = derive_strategy(export);
let verdict = compute_verdict(row_estimate, uses_index, export.cursor_column.is_some());
let recommended_profile = recommend_profile(row_estimate, uses_index, export);
let recommended_parallel = recommend_parallelism(export, row_estimate, uses_index);
let warnings = collect_warnings(
export,
row_estimate,
avg_row_bytes,
range_min.as_deref(),
range_max.as_deref(),
db_max_connections,
);
let suggestion = build_suggestion(&verdict, row_estimate, uses_index, export);
Ok(ExportDiagnostic {
export_name: export.name.clone(),
strategy,
mode: mode_str,
cursor_column: export.cursor_column.clone(),
row_estimate,
avg_row_bytes,
cursor_min: range_min,
cursor_max: range_max,
scan_type,
uses_index,
verdict,
recommended_profile,
recommended_parallel,
warnings,
suggestion,
})
}
fn estimate_rows_pg(
client: &mut postgres::Client,
query: &str,
) -> Result<(Option<i64>, Option<i64>)> {
let explain = format!("EXPLAIN {}", query);
let rows = match client.query(&explain, &[]) {
Ok(r) => r,
Err(e) => {
if let Some(fail) = schema_fail_pg(&e) {
return Err(fail);
}
log::debug!("preflight: EXPLAIN for row-estimate failed: {e}");
return Ok((None, None));
}
};
let lines: Vec<String> = rows.iter().map(|r| r.get::<_, String>(0)).collect();
let plan_text = lines.join("\n");
Ok((
parse_pg_row_estimate(&plan_text),
parse_pg_row_width(&plan_text),
))
}
fn schema_fail_pg(e: &postgres::Error) -> Option<anyhow::Error> {
let code = e.code()?;
code.code().starts_with("42").then(|| {
let detail = e
.as_db_error()
.map(|db| db.message())
.unwrap_or("schema/query error");
PreflightSchemaError::new(detail, format!("SQLSTATE {}", code.code())).into_error()
})
}
pub(crate) fn parse_pg_row_estimate(plan: &str) -> Option<i64> {
for line in plan.lines() {
if let Some(idx) = line.find("rows=") {
let after = &line[idx + 5..];
let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
if let Ok(n) = num_str.parse::<i64>() {
return Some(n);
}
}
}
None
}
pub(crate) fn parse_pg_row_width(plan: &str) -> Option<i64> {
for line in plan.lines() {
if let Some(idx) = line.find("width=") {
let after = &line[idx + 6..];
let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
if let Ok(n) = num_str.parse::<i64>() {
return Some(n);
}
}
}
None
}
fn get_cursor_range_pg(
client: &mut postgres::Client,
base_query: &str,
cursor_col: &str,
) -> (Option<String>, Option<String>) {
let expr = crate::sql::quote_ident(SourceType::Postgres, cursor_col);
let range_query = match crate::pipeline::chunked::strip_select_star_from(base_query) {
Some(tbl) => format!("SELECT min({expr})::text, max({expr})::text FROM {tbl}"),
None => {
format!("SELECT min({expr})::text, max({expr})::text FROM ({base_query}) AS _rivet")
}
};
match client.query(&range_query, &[]) {
Ok(rows) if !rows.is_empty() => {
let min_val: Option<String> = rows[0].get(0);
let max_val: Option<String> = rows[0].get(1);
(min_val, max_val)
}
Ok(_) => (None, None),
Err(e) => {
log::debug!("preflight: cursor range probe on '{cursor_col}' failed: {e}");
(None, None)
}
}
}
pub(crate) fn table_from_simple_query(query: &str) -> Option<&str> {
let mut depth = 0u32;
let mut chars = query.char_indices().peekable();
while let Some((idx, ch)) = chars.next() {
match ch {
'(' => depth += 1,
')' => depth = depth.saturating_sub(1),
_ if depth == 0 => {
let rest = &query[idx..];
let head_ok = idx == 0
|| matches!(
query.as_bytes()[idx - 1],
b' ' | b'\t' | b'\n' | b'\r' | b')'
);
if head_ok && rest.len() >= 5 && rest[..4].eq_ignore_ascii_case("from") {
let after = rest[4..].chars().next();
if matches!(after, Some(c) if c.is_whitespace() || c == '(') {
let mut j = idx + 4;
let bytes = query.as_bytes();
while j < bytes.len() && bytes[j].is_ascii_whitespace() {
j += 1;
}
let id_start = j;
while j < bytes.len() {
let b = bytes[j];
let id_char = b.is_ascii_alphanumeric() || b == b'_' || b == b'.';
if id_char {
j += 1;
} else {
break;
}
}
if j == id_start {
return None;
}
let token = &query[id_start..j];
let mut k = j;
while k < bytes.len() && bytes[k].is_ascii_whitespace() {
k += 1;
}
if k < bytes.len() {
let rest = &query[k..];
if rest.starts_with(',') {
return None;
}
let next_word: String = rest
.chars()
.take_while(|c| c.is_ascii_alphabetic())
.collect::<String>()
.to_ascii_lowercase();
if matches!(
next_word.as_str(),
"join"
| "inner"
| "left"
| "right"
| "outer"
| "full"
| "cross"
| "natural"
) {
return None;
}
}
return Some(token);
}
}
}
_ => {}
}
let _ = chars.peek();
}
None
}
pub(crate) fn column_has_btree_pg(
client: &mut postgres::Client,
qualified_table: &str,
column: &str,
) -> Option<bool> {
let (schema, table) = match qualified_table.split_once('.') {
Some((s, t)) => (s, t),
None => ("public", qualified_table),
};
let sql = "SELECT 1 \
FROM pg_index i \
JOIN pg_class c ON c.oid = i.indrelid \
JOIN pg_namespace n ON n.oid = c.relnamespace \
JOIN pg_class ic ON ic.oid = i.indexrelid \
JOIN pg_am am ON am.oid = ic.relam \
JOIN pg_attribute a ON a.attrelid = i.indrelid \
AND a.attnum = i.indkey[0] \
WHERE n.nspname = $1::text \
AND c.relname = $2::text \
AND a.attname = $3::text \
AND am.amname = 'btree' \
AND i.indisvalid AND i.indisready \
LIMIT 1";
match client.query(sql, &[&schema, &table, &column]) {
Ok(rows) => Some(!rows.is_empty()),
Err(e) => {
log::debug!("preflight: btree index probe failed for {schema}.{table}.{column}: {e}");
None
}
}
}
fn analyze_plan_pg(client: &mut postgres::Client, query: &str) -> (Option<String>, bool) {
let explain = format!("EXPLAIN {}", query);
match client.query(&explain, &[]) {
Ok(rows) => {
let lines: Vec<String> = rows.iter().map(|r| r.get::<_, String>(0)).collect();
let plan_text = lines.join("\n");
let uses_index = plan_text.contains("Index Scan")
|| plan_text.contains("Index Only Scan")
|| plan_text.contains("Bitmap Index Scan");
let scan_type = extract_scan_type(&plan_text);
(Some(scan_type), uses_index)
}
Err(e) => {
log::debug!("preflight: EXPLAIN for plan analysis failed: {e}");
(None, false)
}
}
}
pub(crate) fn extract_scan_type(plan: &str) -> String {
for line in plan.lines() {
let trimmed = line.trim().trim_start_matches("-> ");
if trimmed.contains("Scan") || trimmed.contains("scan") {
return trimmed.trim_start_matches("-> ").to_string();
}
}
plan.lines().next().unwrap_or("unknown").trim().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_pg_row_estimate_typical_seq_scan() {
let plan = "Seq Scan on orders (cost=0.00..1250.00 rows=5000 width=120)";
assert_eq!(parse_pg_row_estimate(plan), Some(5000));
}
#[test]
fn parse_pg_row_estimate_nested_plan_first_rows_wins() {
let plan = "Aggregate (cost=0.00..50.00 rows=1 width=8)\n -> Seq Scan on t (cost=0.00..100.00 rows=10000 width=4)";
assert_eq!(parse_pg_row_estimate(plan), Some(1));
}
#[test]
fn parse_pg_row_estimate_large_row_count() {
let plan =
"Index Scan using idx on orders (cost=0.00..1234567.00 rows=123456789 width=50)";
assert_eq!(parse_pg_row_estimate(plan), Some(123_456_789));
}
#[test]
fn parse_pg_row_estimate_no_rows_keyword_returns_none() {
assert!(parse_pg_row_estimate("Seq Scan on t (cost=0.00..100.00 width=8)").is_none());
}
#[test]
fn parse_pg_row_estimate_empty_plan_returns_none() {
assert!(parse_pg_row_estimate("").is_none());
}
#[test]
fn parse_pg_row_width_typical_seq_scan() {
let plan = "Seq Scan on orders (cost=0.00..1250.00 rows=5000 width=120)";
assert_eq!(parse_pg_row_width(plan), Some(120));
}
#[test]
fn parse_pg_row_width_top_node_wins() {
let plan = "Sort (cost=0.00..50.00 rows=1000 width=64)\n -> Seq Scan on t (cost=0.00..100.00 rows=1000 width=200)";
assert_eq!(parse_pg_row_width(plan), Some(64));
}
#[test]
fn parse_pg_row_width_absent_returns_none() {
assert!(parse_pg_row_width("Result (cost=0.00..0.01 rows=1)").is_none());
assert!(parse_pg_row_width("").is_none());
}
#[test]
fn extract_scan_type_seq_scan() {
let plan = "Seq Scan on orders (cost=0.00..1000.00 rows=50000 width=8)";
let result = extract_scan_type(plan);
assert!(result.contains("Seq Scan"), "got: {result}");
}
#[test]
fn extract_scan_type_index_scan() {
let plan = " -> Index Scan using orders_pkey on orders (cost=0.43..8.45 rows=1 width=8)";
let result = extract_scan_type(plan);
assert!(result.contains("Index Scan"), "got: {result}");
}
#[test]
fn extract_scan_type_bitmap_heap_scan() {
let plan = " -> Bitmap Heap Scan on orders (cost=4.35..16.20 rows=4 width=8)";
let result = extract_scan_type(plan);
assert!(result.contains("Bitmap Heap Scan"), "got: {result}");
}
#[test]
fn extract_scan_type_no_scan_returns_first_line() {
let plan = "Aggregate (cost=10.00..10.01 rows=1 width=8)\n -> Sort (...)";
let result = extract_scan_type(plan);
assert!(result.starts_with("Aggregate"), "got: {result}");
}
#[test]
fn extract_scan_type_empty_plan_returns_unknown() {
assert_eq!(extract_scan_type(""), "unknown");
}
#[test]
fn table_from_simple_query_bare_select() {
assert_eq!(
table_from_simple_query("SELECT id, name FROM users"),
Some("users")
);
}
#[test]
fn table_from_simple_query_schema_qualified() {
assert_eq!(
table_from_simple_query("SELECT * FROM public.orders"),
Some("public.orders")
);
}
#[test]
fn table_from_simple_query_multiline_rivet_init_shape() {
let q =
"\nSELECT id, name, email, age, balance,\n is_active, bio, created_at\nFROM users\n";
assert_eq!(table_from_simple_query(q), Some("users"));
}
#[test]
fn table_from_simple_query_case_insensitive_keyword() {
assert_eq!(
table_from_simple_query("select * from Users"),
Some("Users")
);
assert_eq!(
table_from_simple_query("Select * From users"),
Some("users")
);
}
#[test]
fn table_from_simple_query_rejects_join() {
assert_eq!(
table_from_simple_query("SELECT * FROM users JOIN orders USING (id)"),
None
);
assert_eq!(table_from_simple_query("SELECT * FROM users, orders"), None);
}
#[test]
fn table_from_simple_query_accepts_aliased_table() {
assert_eq!(
table_from_simple_query("SELECT * FROM users u"),
Some("users")
);
assert_eq!(
table_from_simple_query("SELECT * FROM users AS u"),
Some("users")
);
}
#[test]
fn table_from_simple_query_accepts_trailing_clauses() {
assert_eq!(
table_from_simple_query("SELECT * FROM users WHERE id > 0"),
Some("users")
);
assert_eq!(
table_from_simple_query("SELECT * FROM users ORDER BY id"),
Some("users")
);
assert_eq!(
table_from_simple_query("SELECT * FROM users LIMIT 100"),
Some("users")
);
}
#[test]
fn table_from_simple_query_rejects_all_join_flavors() {
for kw in [
"JOIN",
"INNER JOIN",
"LEFT JOIN",
"LEFT OUTER JOIN",
"RIGHT JOIN",
"FULL OUTER JOIN",
"CROSS JOIN",
"NATURAL JOIN",
] {
let q = format!("SELECT * FROM users {kw} orders ON users.id = orders.user_id");
assert_eq!(
table_from_simple_query(&q),
None,
"{kw}: should reject multi-relation"
);
}
}
#[test]
fn table_from_simple_query_skips_subquery_from() {
assert_eq!(
table_from_simple_query("SELECT (SELECT max(x) FROM events) FROM users"),
Some("users")
);
}
#[test]
fn table_from_simple_query_subquery_only_returns_none() {
assert_eq!(
table_from_simple_query("SELECT (SELECT max(x) FROM events)"),
None
);
}
#[test]
fn table_from_simple_query_handles_no_from_clause() {
assert_eq!(table_from_simple_query("SELECT 1"), None);
assert_eq!(table_from_simple_query(""), None);
}
#[test]
fn table_shortcut_resolves_to_real_table_not_select_one() {
let mut export = crate::config::sample_export("orders");
export.query = None;
export.table = Some("orders".into());
let base = export
.resolve_query(std::path::Path::new(""), None)
.expect("table shortcut resolves");
assert_eq!(base, "SELECT * FROM orders");
assert_ne!(base, "SELECT 1");
assert_eq!(
crate::pipeline::chunked::strip_select_star_from(&base),
Some("orders")
);
assert_eq!(table_from_simple_query(&base), Some("orders"));
}
#[test]
fn schema_qualified_table_shortcut_resolves_to_real_table() {
let mut export = crate::config::sample_export("orders");
export.query = None;
export.table = Some("public.orders".into());
let base = export
.resolve_query(std::path::Path::new(""), None)
.expect("schema-qualified table shortcut resolves");
assert_eq!(base, "SELECT * FROM public.orders");
assert_eq!(
crate::pipeline::chunked::strip_select_star_from(&base),
Some("public.orders")
);
}
#[test]
fn inline_query_form_is_left_untouched() {
let mut export = crate::config::sample_export("custom");
export.table = None;
export.query = Some("SELECT id, total FROM orders WHERE total > 0".into());
let base = export
.resolve_query(std::path::Path::new(""), None)
.expect("inline query resolves");
assert_eq!(base, "SELECT id, total FROM orders WHERE total > 0");
}
#[test]
fn table_from_simple_query_rejects_quoted_identifier() {
assert_eq!(
table_from_simple_query("SELECT * FROM \"User Table\""),
None
);
}
}