use crate::connection::{
AsyncConnection, BulkInsert, ConnectOptions, ExecutionSummary, ForeignKey, QueryResult,
SchemaInfo, StatementResult,
};
use crate::error::SqlError;
use crate::stream::BoxRowStream;
use crate::url::DatabaseUrl;
use crate::value::{ColumnInfo, Row, TypeHint, Value};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::sink::SinkExt;
use secrecy::ExposeSecret;
use std::sync::Arc;
use tokio_postgres::types::Type;
pub struct PostgresConnection {
client: tokio_postgres::Client,
}
#[async_trait]
impl AsyncConnection for PostgresConnection {
async fn execute(&mut self, sql: &str) -> Result<ExecutionSummary, SqlError> {
let rows_affected = self
.client
.execute(sql, &[])
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
Ok(ExecutionSummary {
rows_affected: Some(rows_affected),
command_tag: None,
})
}
async fn query(&mut self, sql: &str) -> Result<QueryResult, SqlError> {
let rows = self
.client
.query(sql, &[])
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
if rows.is_empty() {
return Ok(QueryResult {
columns: Vec::new(),
rows: Vec::new(),
});
}
let first = &rows[0];
let columns: Vec<ColumnInfo> = first
.columns()
.iter()
.map(|c| ColumnInfo {
name: c.name().to_string(),
type_hint: pg_type_to_hint(c.type_()),
nullable: true,
})
.collect();
let data_rows: Vec<Row> = rows
.iter()
.map(|row| {
(0..columns.len())
.map(|i| pg_to_value(row, i, row.columns()[i].type_()))
.collect()
})
.collect();
Ok(QueryResult {
columns,
rows: data_rows,
})
}
async fn query_stream(
&mut self,
sql: &str,
) -> Result<(Vec<ColumnInfo>, BoxRowStream<'_>), SqlError> {
use futures_util::stream::TryStreamExt;
let statement = self
.client
.prepare(sql)
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
let columns: Vec<ColumnInfo> = statement
.columns()
.iter()
.map(|c| ColumnInfo {
name: c.name().to_string(),
type_hint: pg_type_to_hint(c.type_()),
nullable: true,
})
.collect();
let ncols = columns.len();
let params: [&(dyn tokio_postgres::types::ToSql + Sync); 0] = [];
let row_stream = self
.client
.query_raw(&statement, params)
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
let mapped = row_stream
.map_ok(move |row| {
(0..ncols)
.map(|i| pg_to_value(&row, i, row.columns()[i].type_()))
.collect::<Row>()
})
.map_err(|e| SqlError::QueryFailed(e.to_string()));
Ok((columns, Box::pin(mapped)))
}
async fn execute_multi(&mut self, sql: &str) -> Result<Vec<StatementResult>, SqlError> {
let msgs = self
.client
.simple_query(sql)
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
let mut results = Vec::new();
let mut current_columns: Vec<ColumnInfo> = Vec::new();
let mut current_rows: Vec<Row> = Vec::new();
for msg in msgs {
use tokio_postgres::SimpleQueryMessage;
match msg {
SimpleQueryMessage::Row(row) => {
if current_columns.is_empty() {
current_columns = row
.columns()
.iter()
.map(|c| ColumnInfo {
name: c.name().to_string(),
type_hint: TypeHint::Other,
nullable: true,
})
.collect();
}
let values: Vec<Value> = (0..row.len())
.map(|i| match row.get(i) {
Some(s) => Value::String(s.to_string()),
None => Value::Null,
})
.collect();
current_rows.push(values);
}
SimpleQueryMessage::CommandComplete(n) => {
if !current_columns.is_empty() {
results.push(StatementResult::Query(QueryResult {
columns: std::mem::take(&mut current_columns),
rows: std::mem::take(&mut current_rows),
}));
} else {
results.push(StatementResult::Summary(ExecutionSummary {
rows_affected: Some(n),
command_tag: None,
}));
}
}
_ => {}
}
}
if !current_columns.is_empty() {
results.push(StatementResult::Query(QueryResult {
columns: std::mem::take(&mut current_columns),
rows: std::mem::take(&mut current_rows),
}));
}
Ok(results)
}
async fn ping(&mut self) -> Result<(), SqlError> {
self.client
.execute("SELECT 1", &[])
.await
.map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
Ok(())
}
async fn list_tables(&mut self, schema: Option<&str>) -> Result<Vec<String>, SqlError> {
let schema = schema.unwrap_or("public");
let rows = self
.client
.query(
"SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_type = 'BASE TABLE' ORDER BY table_name",
&[&schema,
],
)
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
let names = rows
.into_iter()
.map(|row| row.get::<_, String>(0))
.collect();
Ok(names)
}
async fn list_schemas(&mut self) -> Result<Vec<SchemaInfo>, SqlError> {
let rows = self
.client
.query(
"SELECT schema_name, schema_name = current_schema() AS is_default FROM information_schema.schemata ORDER BY schema_name",
&[],
)
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
let schemas = rows
.into_iter()
.map(|row| SchemaInfo {
name: row.get::<_, String>(0),
is_default: row.try_get::<_, bool>(1).unwrap_or(false),
})
.collect();
Ok(schemas)
}
async fn describe_table(
&mut self,
schema: Option<&str>,
table: &str,
) -> Result<QueryResult, SqlError> {
let schema = schema.unwrap_or("public");
let rows = self
.client
.query(
"SELECT column_name, data_type, is_nullable, column_default, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position",
&[&schema,
&table,
],
)
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
let columns = vec![
ColumnInfo {
name: "column_name".to_string(),
type_hint: TypeHint::String,
nullable: true,
},
ColumnInfo {
name: "data_type".to_string(),
type_hint: TypeHint::String,
nullable: true,
},
ColumnInfo {
name: "is_nullable".to_string(),
type_hint: TypeHint::String,
nullable: true,
},
ColumnInfo {
name: "column_default".to_string(),
type_hint: TypeHint::String,
nullable: true,
},
ColumnInfo {
name: "numeric_precision".to_string(),
type_hint: TypeHint::Int64,
nullable: true,
},
ColumnInfo {
name: "numeric_scale".to_string(),
type_hint: TypeHint::Int64,
nullable: true,
},
];
let data_rows: Vec<Row> = rows
.iter()
.map(|row| {
vec![
row.try_get::<_, Option<String>>("column_name")
.unwrap_or(None)
.map(Value::String)
.unwrap_or(Value::Null),
row.try_get::<_, Option<String>>("data_type")
.unwrap_or(None)
.map(Value::String)
.unwrap_or(Value::Null),
row.try_get::<_, Option<String>>("is_nullable")
.unwrap_or(None)
.map(Value::String)
.unwrap_or(Value::Null),
row.try_get::<_, Option<String>>("column_default")
.unwrap_or(None)
.map(Value::String)
.unwrap_or(Value::Null),
row.try_get::<_, Option<i32>>("numeric_precision")
.unwrap_or(None)
.map(|v| Value::Int64(i64::from(v)))
.unwrap_or(Value::Null),
row.try_get::<_, Option<i32>>("numeric_scale")
.unwrap_or(None)
.map(|v| Value::Int64(i64::from(v)))
.unwrap_or(Value::Null),
]
})
.collect();
Ok(QueryResult {
columns,
rows: data_rows,
})
}
async fn primary_key(
&mut self,
schema: Option<&str>,
table: &str,
) -> Result<Vec<String>, SqlError> {
let schema = schema.unwrap_or("public");
let sql = "SELECT a.attname \
FROM pg_index i \
JOIN pg_class c ON c.oid = i.indrelid \
JOIN pg_namespace n ON n.oid = c.relnamespace \
JOIN unnest(i.indkey) WITH ORDINALITY AS k(attnum, ord) ON true \
JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = k.attnum \
WHERE i.indisprimary AND n.nspname = $1 AND c.relname = $2 \
ORDER BY k.ord";
let rows = self
.client
.query(sql, &[&schema, &table])
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
Ok(rows.into_iter().map(|r| r.get::<_, String>(0)).collect())
}
async fn list_foreign_keys(
&mut self,
schema: Option<&str>,
) -> Result<Vec<ForeignKey>, SqlError> {
let schema = schema.unwrap_or("public");
let sql = "SELECT c.conname, \
cl_child.relname AS child_table, \
a_child.attname AS child_col, \
cl_parent.relname AS parent_table, \
a_parent.attname AS parent_col, \
c.confdeltype, \
k.ord \
FROM pg_constraint c \
JOIN pg_class cl_child ON cl_child.oid = c.conrelid \
JOIN pg_namespace n_child ON n_child.oid = cl_child.relnamespace \
JOIN pg_class cl_parent ON cl_parent.oid = c.confrelid \
JOIN pg_namespace n_parent ON n_parent.oid = cl_parent.relnamespace \
JOIN unnest(c.conkey) WITH ORDINALITY AS k(attnum, ord) ON true \
JOIN pg_attribute a_child ON a_child.attrelid = cl_child.oid AND a_child.attnum = k.attnum \
JOIN unnest(c.confkey) WITH ORDINALITY AS kp(attnum, ord) ON kp.ord = k.ord \
JOIN pg_attribute a_parent ON a_parent.attrelid = cl_parent.oid AND a_parent.attnum = kp.attnum \
WHERE c.contype = 'f' AND n_child.nspname = $1 \
ORDER BY c.conname, k.ord";
let rows = self
.client
.query(sql, &[&schema])
.await
.map_err(|e| SqlError::QueryFailed(e.to_string()))?;
let mut map: indexmap::IndexMap<String, ForeignKey> = indexmap::IndexMap::new();
for row in rows {
let conname: String = row.get(0);
let child_table: String = row.get(1);
let child_col: String = row.get(2);
let parent_table: String = row.get(3);
let parent_col: String = row.get(4);
let confdeltype: i8 = row.get(5);
let on_delete = pg_confdeltype(confdeltype);
let entry = map.entry(conname).or_insert_with(|| ForeignKey {
child_table: child_table.clone(),
child_columns: Vec::new(),
parent_table: parent_table.clone(),
parent_columns: Vec::new(),
on_delete,
});
entry.child_columns.push(child_col);
entry.parent_columns.push(parent_col);
}
Ok(map.into_values().collect())
}
async fn bulk_insert_rows(&mut self, target: BulkInsert<'_>) -> Result<usize, SqlError> {
if target.rows.is_empty() {
return Ok(0);
}
let table = crate::copy::quote_identifier(target.table, crate::backend::Backend::Postgres);
let cols = target
.columns
.iter()
.map(|c| crate::copy::quote_identifier(&c.name, crate::backend::Backend::Postgres))
.collect::<Vec<_>>()
.join(", ");
match target.copy_format {
crate::copy::CopyFormat::Text => {
let stmt = format!("COPY {table} ({cols}) FROM STDIN WITH (FORMAT TEXT)");
let sink = self
.client
.copy_in::<_, Bytes>(stmt.as_str())
.await
.map_err(|e| pg_text_copy::classify_copy_error(&e))?;
tokio::pin!(sink);
let hints: Vec<TypeHint> = target.columns.iter().map(|c| c.type_hint).collect();
for row in target.rows {
let buf = pg_text_copy::encode_row(row, &hints)?;
sink.send(buf)
.await
.map_err(|e| SqlError::QueryFailed(format!("COPY send: {e}")))?;
}
let rows = sink
.as_mut()
.finish()
.await
.map_err(|e| SqlError::QueryFailed(format!("COPY finish: {e}")))?;
Ok(rows as usize)
}
crate::copy::CopyFormat::Binary => {
pg_binary_copy::run(&mut self.client, &table, &cols, &target).await
}
}
}
}
mod pg_text_copy {
use crate::error::SqlError;
use crate::value::{TypeHint, Value};
use bytes::Bytes;
pub fn encode_row(row: &[Value], hints: &[TypeHint]) -> Result<Bytes, SqlError> {
let mut buf = String::with_capacity(row.len() * 12 + 1);
for (i, value) in row.iter().enumerate() {
if i > 0 {
buf.push('\t');
}
let hint = hints.get(i).copied().unwrap_or(TypeHint::Other);
encode_value(&mut buf, value, hint)?;
}
buf.push('\n');
Ok(Bytes::from(buf.into_bytes()))
}
fn encode_value(out: &mut String, v: &Value, hint: TypeHint) -> Result<(), SqlError> {
match v {
Value::Null => out.push_str("\\N"),
Value::Bool(b) => out.push(if *b { 't' } else { 'f' }),
Value::Int64(n) => {
use std::fmt::Write;
let _ = write!(out, "{n}");
}
Value::Float64(f) => {
if f.is_nan() {
out.push_str("NaN");
} else if f.is_infinite() {
out.push_str(if *f > 0.0 { "Infinity" } else { "-Infinity" });
} else {
use std::fmt::Write;
let _ = write!(out, "{f}");
}
}
Value::Decimal(s) => push_escaped(out, s),
Value::String(s) => push_escaped(out, s),
Value::Bytes(b) => {
out.push_str("\\\\x");
use std::fmt::Write;
for byte in b {
let _ = write!(out, "{byte:02x}");
}
}
Value::Date(d) => {
use std::fmt::Write;
let _ = write!(out, "{d}");
}
Value::Time(t) => {
use std::fmt::Write;
let _ = write!(out, "{t}");
}
Value::DateTime(dt) => {
use std::fmt::Write;
let _ = write!(out, "{dt}");
}
Value::DateTimeTz(dt) => {
out.push_str(&dt.to_rfc3339());
}
Value::Json(j) => {
let rendered = serde_json::to_string(j)
.map_err(|e| SqlError::QueryFailed(format!("PG bulk: JSON serialize: {e}")))?;
push_escaped(out, &rendered);
}
Value::Uuid(s) => push_escaped(out, s),
Value::Array(a) => {
let _ = hint; let rendered = serde_json::to_string(a)
.map_err(|e| SqlError::QueryFailed(format!("PG bulk: array serialize: {e}")))?;
push_escaped(out, &rendered);
}
}
Ok(())
}
fn push_escaped(out: &mut String, s: &str) {
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'\t' => out.push_str("\\t"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\0' => {
out.push_str("\\x00");
}
other => out.push(other),
}
}
}
pub fn classify_copy_error(e: &tokio_postgres::Error) -> SqlError {
use tokio_postgres::error::SqlState;
if let Some(code) = e.code()
&& *code == SqlState::WRONG_OBJECT_TYPE
{
return SqlError::BulkUnavailable(format!("PG rejected COPY: {e}"));
}
SqlError::QueryFailed(format!("COPY setup: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
fn enc1(v: Value, hint: TypeHint) -> String {
let bytes = encode_row(&[v], &[hint]).expect("encode_row");
let s = std::str::from_utf8(&bytes).unwrap().to_string();
assert!(s.ends_with('\n'));
s.trim_end_matches('\n').to_string()
}
#[test]
fn encode_null_is_backslash_n() {
assert_eq!(enc1(Value::Null, TypeHint::Null), "\\N");
}
#[test]
fn encode_bool_is_t_or_f() {
assert_eq!(enc1(Value::Bool(true), TypeHint::Bool), "t");
assert_eq!(enc1(Value::Bool(false), TypeHint::Bool), "f");
}
#[test]
fn encode_int_and_float() {
assert_eq!(enc1(Value::Int64(42), TypeHint::Int64), "42");
assert_eq!(enc1(Value::Int64(-7), TypeHint::Int64), "-7");
assert_eq!(enc1(Value::Float64(1.5), TypeHint::Float64), "1.5");
}
#[test]
fn encode_float_nan_and_inf() {
assert_eq!(enc1(Value::Float64(f64::NAN), TypeHint::Float64), "NaN");
assert_eq!(
enc1(Value::Float64(f64::INFINITY), TypeHint::Float64),
"Infinity"
);
assert_eq!(
enc1(Value::Float64(f64::NEG_INFINITY), TypeHint::Float64),
"-Infinity"
);
}
#[test]
fn encode_string_escapes_backslash_first() {
assert_eq!(
enc1(Value::String("\\.\n".into()), TypeHint::String),
"\\\\.\\n"
);
}
#[test]
fn encode_string_escapes_tab_cr_lf() {
assert_eq!(
enc1(Value::String("a\tb\nc\rd".into()), TypeHint::String),
"a\\tb\\nc\\rd"
);
}
#[test]
fn encode_string_passes_through_normal_chars() {
assert_eq!(
enc1(Value::String("héllo, world 🐈".into()), TypeHint::String),
"héllo, world 🐈"
);
}
#[test]
fn encode_string_replaces_nul() {
assert_eq!(
enc1(Value::String("a\0b".into()), TypeHint::String),
"a\\x00b"
);
}
#[test]
fn encode_bytes_is_hex_with_double_backslash_x() {
assert_eq!(
enc1(Value::Bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]), TypeHint::Bytes),
"\\\\xdeadbeef"
);
}
#[test]
fn encode_date_time_datetime() {
let d = NaiveDate::from_ymd_opt(2026, 5, 14).unwrap();
let t = NaiveTime::from_hms_opt(12, 34, 56).unwrap();
let dt = NaiveDateTime::new(d, t);
assert_eq!(enc1(Value::Date(d), TypeHint::Date), "2026-05-14");
assert_eq!(enc1(Value::Time(t), TypeHint::Time), "12:34:56");
assert_eq!(
enc1(Value::DateTime(dt), TypeHint::DateTime),
"2026-05-14 12:34:56"
);
}
#[test]
fn encode_datetimetz_is_rfc3339() {
let dt = Utc.with_ymd_and_hms(2026, 5, 14, 12, 34, 56).unwrap();
assert_eq!(
enc1(Value::DateTimeTz(dt), TypeHint::DateTimeTz),
"2026-05-14T12:34:56+00:00"
);
}
#[test]
fn encode_json_is_compact_with_escapes() {
let j = serde_json::json!({"role": "admin", "active": true});
let encoded = enc1(Value::Json(j), TypeHint::Json);
assert!(encoded.contains("\"role\":\"admin\""));
assert!(encoded.contains("\"active\":true"));
}
#[test]
fn encode_uuid_passes_through() {
assert_eq!(
enc1(
Value::Uuid("550e8400-e29b-41d4-a716-446655440000".into()),
TypeHint::Uuid
),
"550e8400-e29b-41d4-a716-446655440000"
);
}
#[test]
fn encode_array_is_compact_json() {
let a = Value::Array(vec![Value::Int64(1), Value::Int64(2), Value::Int64(3)]);
assert_eq!(enc1(a, TypeHint::Array), "[1,2,3]");
}
#[test]
fn encode_decimal_passes_through_with_escapes() {
assert_eq!(
enc1(Value::Decimal("99.5".into()), TypeHint::Decimal),
"99.5"
);
}
#[test]
fn encode_row_with_multiple_cells_uses_tab_separator() {
let row = vec![
Value::Int64(1),
Value::String("Alice".into()),
Value::Null,
Value::Bool(true),
];
let hints = vec![
TypeHint::Int64,
TypeHint::String,
TypeHint::Null,
TypeHint::Bool,
];
let bytes = encode_row(&row, &hints).unwrap();
assert_eq!(std::str::from_utf8(&bytes).unwrap(), "1\tAlice\t\\N\tt\n");
}
#[test]
fn encode_row_empty_row_is_just_newline() {
let bytes = encode_row(&[], &[]).unwrap();
assert_eq!(std::str::from_utf8(&bytes).unwrap(), "\n");
}
}
}
mod pg_binary_copy {
use super::pg_text_copy;
use crate::connection::BulkInsert;
use crate::error::SqlError;
use crate::value::{TypeHint, Value};
use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use rust_decimal::Decimal;
use std::str::FromStr;
use tokio_postgres::Client;
use tokio_postgres::binary_copy::BinaryCopyInWriter;
use tokio_postgres::types::{ToSql, Type};
use uuid::Uuid;
pub async fn run(
client: &mut Client,
table: &str,
cols: &str,
target: &BulkInsert<'_>,
) -> Result<usize, SqlError> {
let types: Vec<Type> = target
.columns
.iter()
.map(|c| pg_type_for_hint(c.type_hint))
.collect::<Result<_, _>>()?;
let stmt = format!("COPY {table} ({cols}) FROM STDIN WITH (FORMAT BINARY)");
let sink = client
.copy_in::<_, bytes::Bytes>(stmt.as_str())
.await
.map_err(|e| pg_text_copy::classify_copy_error(&e))?;
let writer = BinaryCopyInWriter::new(sink, &types);
tokio::pin!(writer);
let hints: Vec<TypeHint> = target.columns.iter().map(|c| c.type_hint).collect();
for row in target.rows {
let cells: Vec<PgBinaryBind> = row
.iter()
.zip(hints.iter())
.map(|(v, h)| value_to_pg_binary_bind(v, *h))
.collect::<Result<_, _>>()?;
let refs: Vec<&(dyn ToSql + Sync)> =
cells.iter().map(PgBinaryBind::as_to_sql).collect();
writer
.as_mut()
.write(&refs)
.await
.map_err(|e| SqlError::QueryFailed(format!("BINARY COPY write: {e}")))?;
}
let rows = writer
.as_mut()
.finish()
.await
.map_err(|e| SqlError::QueryFailed(format!("BINARY COPY finish: {e}")))?;
Ok(rows as usize)
}
pub(super) fn pg_type_for_hint(hint: TypeHint) -> Result<Type, SqlError> {
Ok(match hint {
TypeHint::Bool => Type::BOOL,
TypeHint::Int64 => Type::INT8,
TypeHint::Float64 => Type::FLOAT8,
TypeHint::Decimal => Type::NUMERIC,
TypeHint::String => Type::TEXT,
TypeHint::Bytes => Type::BYTEA,
TypeHint::Date => Type::DATE,
TypeHint::Time => Type::TIME,
TypeHint::DateTime => Type::TIMESTAMP,
TypeHint::DateTimeTz => Type::TIMESTAMPTZ,
TypeHint::Json => Type::JSONB,
TypeHint::Uuid => Type::UUID,
TypeHint::Array => Type::JSONB,
TypeHint::Null | TypeHint::Other => {
return Err(SqlError::BulkUnavailable(format!(
"PG binary COPY: cannot bind a column with TypeHint::{hint:?} \
(no concrete PG type to declare). Re-run with \
--copy-format text or --bulk-native off."
)));
}
})
}
#[derive(Debug)]
pub(super) enum PgBinaryBind {
Bool(Option<bool>),
Int8(Option<i64>),
Float8(Option<f64>),
Numeric(Option<Decimal>),
Text(Option<String>),
Bytea(Option<Vec<u8>>),
Date(Option<NaiveDate>),
Time(Option<NaiveTime>),
Timestamp(Option<NaiveDateTime>),
TimestampTz(Option<DateTime<Utc>>),
Json(Option<serde_json::Value>),
Uuid(Option<Uuid>),
}
impl PgBinaryBind {
pub(super) fn as_to_sql(&self) -> &(dyn ToSql + Sync) {
match self {
Self::Bool(v) => v,
Self::Int8(v) => v,
Self::Float8(v) => v,
Self::Numeric(v) => v,
Self::Text(v) => v,
Self::Bytea(v) => v,
Self::Date(v) => v,
Self::Time(v) => v,
Self::Timestamp(v) => v,
Self::TimestampTz(v) => v,
Self::Json(v) => v,
Self::Uuid(v) => v,
}
}
}
pub(super) fn value_to_pg_binary_bind(
v: &Value,
hint: TypeHint,
) -> Result<PgBinaryBind, SqlError> {
Ok(match (v, hint) {
(Value::Null, _) => null_bind_for_hint(hint)?,
(Value::Bool(b), _) => PgBinaryBind::Bool(Some(*b)),
(Value::Int64(n), _) => PgBinaryBind::Int8(Some(*n)),
(Value::Float64(f), _) => PgBinaryBind::Float8(Some(*f)),
(Value::Decimal(s), _) => PgBinaryBind::Numeric(Some(parse_decimal(s)?)),
(Value::String(s), TypeHint::Uuid) => {
PgBinaryBind::Uuid(Some(Uuid::parse_str(s).map_err(|e| {
SqlError::QueryFailed(format!("PG binary COPY: bad UUID '{s}': {e}"))
})?))
}
(Value::String(s), _) => PgBinaryBind::Text(Some(s.clone())),
(Value::Bytes(b), _) => PgBinaryBind::Bytea(Some(b.clone())),
(Value::Date(d), _) => PgBinaryBind::Date(Some(*d)),
(Value::Time(t), _) => PgBinaryBind::Time(Some(*t)),
(Value::DateTime(dt), _) => PgBinaryBind::Timestamp(Some(*dt)),
(Value::DateTimeTz(dt), _) => PgBinaryBind::TimestampTz(Some(*dt)),
(Value::Json(j), _) => PgBinaryBind::Json(Some(j.clone())),
(Value::Array(arr), _) => {
let json = serde_json::to_value(arr).map_err(|e| {
SqlError::QueryFailed(format!("PG binary COPY: array serialize: {e}"))
})?;
PgBinaryBind::Json(Some(json))
}
(Value::Uuid(s), _) => PgBinaryBind::Uuid(Some(Uuid::parse_str(s).map_err(|e| {
SqlError::QueryFailed(format!("PG binary COPY: bad UUID '{s}': {e}"))
})?)),
})
}
fn null_bind_for_hint(hint: TypeHint) -> Result<PgBinaryBind, SqlError> {
Ok(match hint {
TypeHint::Bool => PgBinaryBind::Bool(None),
TypeHint::Int64 => PgBinaryBind::Int8(None),
TypeHint::Float64 => PgBinaryBind::Float8(None),
TypeHint::Decimal => PgBinaryBind::Numeric(None),
TypeHint::String => PgBinaryBind::Text(None),
TypeHint::Bytes => PgBinaryBind::Bytea(None),
TypeHint::Date => PgBinaryBind::Date(None),
TypeHint::Time => PgBinaryBind::Time(None),
TypeHint::DateTime => PgBinaryBind::Timestamp(None),
TypeHint::DateTimeTz => PgBinaryBind::TimestampTz(None),
TypeHint::Json | TypeHint::Array => PgBinaryBind::Json(None),
TypeHint::Uuid => PgBinaryBind::Uuid(None),
TypeHint::Null | TypeHint::Other => {
return Err(SqlError::BulkUnavailable(format!(
"PG binary COPY: cannot type-encode NULL for TypeHint::{hint:?}"
)));
}
})
}
fn parse_decimal(s: &str) -> Result<Decimal, SqlError> {
Decimal::from_str(s).map_err(|e| {
SqlError::QueryFailed(format!("PG binary COPY: invalid NUMERIC '{s}': {e}"))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pg_type_for_hint_maps_canonical_dest_types() {
assert_eq!(pg_type_for_hint(TypeHint::Bool).unwrap(), Type::BOOL);
assert_eq!(pg_type_for_hint(TypeHint::Int64).unwrap(), Type::INT8);
assert_eq!(pg_type_for_hint(TypeHint::Float64).unwrap(), Type::FLOAT8);
assert_eq!(pg_type_for_hint(TypeHint::Decimal).unwrap(), Type::NUMERIC);
assert_eq!(pg_type_for_hint(TypeHint::String).unwrap(), Type::TEXT);
assert_eq!(pg_type_for_hint(TypeHint::Bytes).unwrap(), Type::BYTEA);
assert_eq!(pg_type_for_hint(TypeHint::Date).unwrap(), Type::DATE);
assert_eq!(pg_type_for_hint(TypeHint::Time).unwrap(), Type::TIME);
assert_eq!(
pg_type_for_hint(TypeHint::DateTime).unwrap(),
Type::TIMESTAMP
);
assert_eq!(
pg_type_for_hint(TypeHint::DateTimeTz).unwrap(),
Type::TIMESTAMPTZ
);
assert_eq!(pg_type_for_hint(TypeHint::Json).unwrap(), Type::JSONB);
assert_eq!(pg_type_for_hint(TypeHint::Uuid).unwrap(), Type::UUID);
assert_eq!(pg_type_for_hint(TypeHint::Array).unwrap(), Type::JSONB);
}
#[test]
fn pg_type_for_hint_other_falls_back_via_bulk_unavailable() {
let err = pg_type_for_hint(TypeHint::Other).unwrap_err();
assert!(matches!(err, SqlError::BulkUnavailable(_)));
let err = pg_type_for_hint(TypeHint::Null).unwrap_err();
assert!(matches!(err, SqlError::BulkUnavailable(_)));
}
#[test]
fn null_bind_picks_typed_none_per_hint() {
assert!(matches!(
null_bind_for_hint(TypeHint::Bool).unwrap(),
PgBinaryBind::Bool(None)
));
assert!(matches!(
null_bind_for_hint(TypeHint::Int64).unwrap(),
PgBinaryBind::Int8(None)
));
assert!(matches!(
null_bind_for_hint(TypeHint::Json).unwrap(),
PgBinaryBind::Json(None)
));
assert!(matches!(
null_bind_for_hint(TypeHint::Uuid).unwrap(),
PgBinaryBind::Uuid(None)
));
}
#[test]
fn null_bind_array_collapses_to_json_none() {
assert!(matches!(
null_bind_for_hint(TypeHint::Array).unwrap(),
PgBinaryBind::Json(None)
));
}
#[test]
fn value_to_bind_routes_int_bool_string_null() {
assert!(matches!(
value_to_pg_binary_bind(&Value::Int64(42), TypeHint::Int64).unwrap(),
PgBinaryBind::Int8(Some(42))
));
assert!(matches!(
value_to_pg_binary_bind(&Value::Bool(true), TypeHint::Bool).unwrap(),
PgBinaryBind::Bool(Some(true))
));
match value_to_pg_binary_bind(&Value::String("hi".into()), TypeHint::String).unwrap() {
PgBinaryBind::Text(Some(s)) => assert_eq!(s, "hi"),
_ => panic!("expected Text"),
}
assert!(matches!(
value_to_pg_binary_bind(&Value::Null, TypeHint::Int64).unwrap(),
PgBinaryBind::Int8(None)
));
}
#[test]
fn value_to_bind_decimal_roundtrips_through_rust_decimal() {
match value_to_pg_binary_bind(&Value::Decimal("99.5".into()), TypeHint::Decimal)
.unwrap()
{
PgBinaryBind::Numeric(Some(d)) => assert_eq!(d.to_string(), "99.5"),
_ => panic!("expected Numeric"),
}
let err =
value_to_pg_binary_bind(&Value::Decimal("not-a-number".into()), TypeHint::Decimal)
.unwrap_err();
assert!(matches!(err, SqlError::QueryFailed(_)));
}
#[test]
fn value_to_bind_string_to_uuid_when_dest_is_uuid() {
let bind = value_to_pg_binary_bind(
&Value::String("00112233-4455-6677-8899-aabbccddeeff".into()),
TypeHint::Uuid,
)
.unwrap();
match bind {
PgBinaryBind::Uuid(Some(u)) => {
assert_eq!(u.to_string(), "00112233-4455-6677-8899-aabbccddeeff")
}
_ => panic!("expected Uuid"),
}
}
#[test]
fn value_to_bind_array_collapses_to_json() {
let arr = vec![Value::String("a".into()), Value::String("b".into())];
let bind = value_to_pg_binary_bind(&Value::Array(arr), TypeHint::Array).unwrap();
match bind {
PgBinaryBind::Json(Some(v)) => {
assert_eq!(v, serde_json::json!(["a", "b"]));
}
_ => panic!("expected Json"),
}
}
}
}
pub(crate) async fn connect(
url: &DatabaseUrl,
opts: &ConnectOptions,
) -> Result<PostgresConnection, SqlError> {
let mut config = match url.raw().parse::<tokio_postgres::Config>() {
Ok(cfg) => cfg,
Err(_) => build_config_from_url(url)?,
};
if let Some(pwd) = opts.effective_password(url) {
config.password(pwd.expose_secret());
}
let tls_connector = build_tls_connector(opts)
.await
.map_err(SqlError::TlsError)?;
let (client, connection) = config
.connect(tls_connector)
.await
.map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("[ferrule] Postgres background connection error: {}", e);
}
});
Ok(PostgresConnection { client })
}
pub(crate) async fn connect_with_stream<S>(
url: &DatabaseUrl,
opts: &ConnectOptions,
stream: S,
) -> Result<PostgresConnection, SqlError>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
use tokio_postgres::tls::MakeTlsConnect;
let mut config = match url.raw().parse::<tokio_postgres::Config>() {
Ok(cfg) => cfg,
Err(_) => build_config_from_url(url)?,
};
if let Some(pwd) = opts.effective_password(url) {
config.password(pwd.expose_secret());
}
let mut make_tls = build_tls_connector(opts)
.await
.map_err(SqlError::TlsError)?;
let hostname = url.host().unwrap_or("localhost");
let tls = <tokio_postgres_rustls::MakeRustlsConnect as MakeTlsConnect<S>>::make_tls_connect(
&mut make_tls,
hostname,
)
.map_err(|e| SqlError::TlsError(format!("make_tls_connect failed: {e:?}")))?;
let (client, connection) = config
.connect_raw(stream, tls)
.await
.map_err(|e| SqlError::ConnectionFailed(e.to_string()))?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("[ferrule] Postgres background connection error: {}", e);
}
});
Ok(PostgresConnection { client })
}
fn build_config_from_url(url: &DatabaseUrl) -> Result<tokio_postgres::Config, SqlError> {
let mut config = tokio_postgres::Config::new();
if let Some(host) = url.host() {
config.host(host);
} else {
config.host("localhost");
}
config.port(url.port().unwrap_or(5432));
if !url.username().is_empty() {
config.user(url.username());
}
if let Some(pwd) = url.password() {
config.password(pwd.expose_secret());
}
if !url.database().is_empty() {
config.dbname(url.database());
}
Ok(config)
}
async fn build_tls_connector(
opts: &ConnectOptions,
) -> Result<tokio_postgres_rustls::MakeRustlsConnect, String> {
use rustls::{ClientConfig, RootCertStore};
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = if opts.insecure {
let verifier = Arc::new(InsecureVerifier);
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth()
} else {
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
};
Ok(tokio_postgres_rustls::MakeRustlsConnect::new(config))
}
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
fn pg_confdeltype(c: i8) -> Option<String> {
match c as u8 {
b'a' => Some("NO ACTION".into()),
b'r' => Some("RESTRICT".into()),
b'c' => Some("CASCADE".into()),
b'n' => Some("SET NULL".into()),
b'd' => Some("SET DEFAULT".into()),
_ => None,
}
}
fn pg_type_to_hint(ty: &Type) -> TypeHint {
match ty {
&Type::BOOL => TypeHint::Bool,
&Type::INT2 | &Type::INT4 | &Type::INT8 => TypeHint::Int64,
&Type::FLOAT4 | &Type::FLOAT8 => TypeHint::Float64,
&Type::NUMERIC => TypeHint::Decimal,
&Type::TEXT | &Type::VARCHAR | &Type::BPCHAR | &Type::NAME => TypeHint::String,
&Type::BYTEA => TypeHint::Bytes,
&Type::DATE => TypeHint::Date,
&Type::TIME => TypeHint::Time,
&Type::TIMESTAMP => TypeHint::DateTime,
&Type::TIMESTAMPTZ => TypeHint::DateTimeTz,
&Type::JSON | &Type::JSONB => TypeHint::Json,
&Type::UUID => TypeHint::Uuid,
_ if ty.name().starts_with('_') => TypeHint::Array,
_ => TypeHint::Other,
}
}
fn pg_to_value(row: &tokio_postgres::Row, col: usize, pg_type: &Type) -> Value {
use tokio_postgres::types::Type;
match pg_type {
&Type::BOOL => row
.try_get::<_, Option<bool>>(col)
.unwrap_or(None)
.map(Value::Bool)
.unwrap_or(Value::Null),
&Type::INT2 => row
.try_get::<_, Option<i16>>(col)
.unwrap_or(None)
.map(|v| Value::Int64(i64::from(v)))
.unwrap_or(Value::Null),
&Type::INT4 => row
.try_get::<_, Option<i32>>(col)
.unwrap_or(None)
.map(|v| Value::Int64(i64::from(v)))
.unwrap_or(Value::Null),
&Type::INT8 => row
.try_get::<_, Option<i64>>(col)
.unwrap_or(None)
.map(Value::Int64)
.unwrap_or(Value::Null),
&Type::FLOAT4 => row
.try_get::<_, Option<f32>>(col)
.unwrap_or(None)
.map(|v| Value::Float64(f64::from(v)))
.unwrap_or(Value::Null),
&Type::FLOAT8 => row
.try_get::<_, Option<f64>>(col)
.unwrap_or(None)
.map(Value::Float64)
.unwrap_or(Value::Null),
&Type::NUMERIC => row
.try_get::<_, Option<rust_decimal::Decimal>>(col)
.unwrap_or(None)
.map(|d| Value::Decimal(d.to_string()))
.unwrap_or(Value::Null),
&Type::TEXT | &Type::VARCHAR | &Type::BPCHAR | &Type::NAME => row
.try_get::<_, Option<String>>(col)
.unwrap_or(None)
.map(Value::String)
.unwrap_or(Value::Null),
&Type::BYTEA => row
.try_get::<_, Option<Vec<u8>>>(col)
.unwrap_or(None)
.map(Value::Bytes)
.unwrap_or(Value::Null),
&Type::DATE => row
.try_get::<_, Option<chrono::NaiveDate>>(col)
.unwrap_or(None)
.map(Value::Date)
.unwrap_or(Value::Null),
&Type::TIME => row
.try_get::<_, Option<chrono::NaiveTime>>(col)
.unwrap_or(None)
.map(Value::Time)
.unwrap_or(Value::Null),
&Type::TIMESTAMP => row
.try_get::<_, Option<chrono::NaiveDateTime>>(col)
.unwrap_or(None)
.map(Value::DateTime)
.unwrap_or(Value::Null),
&Type::TIMESTAMPTZ => row
.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(col)
.unwrap_or(None)
.map(Value::DateTimeTz)
.unwrap_or(Value::Null),
&Type::JSON | &Type::JSONB => row
.try_get::<_, Option<serde_json::Value>>(col)
.unwrap_or(None)
.map(Value::Json)
.unwrap_or(Value::Null),
&Type::UUID => row
.try_get::<_, Option<uuid::Uuid>>(col)
.unwrap_or(None)
.map(|u| Value::Uuid(u.to_string()))
.unwrap_or(Value::Null),
_ => row
.try_get::<_, Option<String>>(col)
.unwrap_or(None)
.map(Value::String)
.unwrap_or(Value::Null),
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_POSTGRES_URL: &str =
"postgres://ferrule:ferrule@127.0.0.1:15432/ferrule?sslmode=disable";
fn try_connect() -> Option<Box<dyn crate::Connection>> {
let url = DatabaseUrl::parse(TEST_POSTGRES_URL).ok()?;
let conn = crate::connect(&url, &ConnectOptions::default(), None).ok()?;
Some(conn)
}
#[test]
fn test_postgres_ping() {
let Some(mut conn) = try_connect() else {
eprintln!("Postgres test container not available, skipping test_postgres_ping");
return;
};
conn.ping().expect("ping should succeed");
}
#[test]
fn test_postgres_query() {
let Some(mut conn) = try_connect() else {
eprintln!("Postgres test container not available, skipping test_postgres_query");
return;
};
let result = conn
.query("SELECT * FROM test_users")
.expect("query should succeed");
assert!(!result.columns.is_empty(), "should have columns");
assert!(!result.rows.is_empty(), "should have rows");
}
#[test]
fn test_postgres_execute() {
let Some(mut conn) = try_connect() else {
eprintln!("Postgres test container not available, skipping test_postgres_execute");
return;
};
let summary = conn
.execute("INSERT INTO test_users (name, age) VALUES ('TestUser', 99)")
.expect("execute should succeed");
assert!(
summary.rows_affected.is_some_and(|n| n > 0),
"should have affected rows"
);
}
#[test]
fn test_postgres_list_tables() {
let Some(mut conn) = try_connect() else {
eprintln!("Postgres test container not available, skipping test_postgres_list_tables");
return;
};
let tables = conn.list_tables(None).expect("list_tables should succeed");
assert!(
tables.contains(&"test_users".to_string()),
"should contain test_users, got: {tables:?}"
);
}
#[test]
fn test_postgres_list_schemas() {
let Some(mut conn) = try_connect() else {
eprintln!("Postgres test container not available, skipping test_postgres_list_schemas");
return;
};
let schemas = conn.list_schemas().expect("list_schemas should succeed");
assert!(
schemas.iter().any(|s| s.name == "public"),
"should contain public, got: {schemas:?}"
);
let defaults = schemas.iter().filter(|s| s.is_default).count();
assert_eq!(
defaults, 1,
"exactly one schema should be flagged is_default, got: {schemas:?}"
);
}
#[test]
fn test_postgres_describe_table() {
let Some(mut conn) = try_connect() else {
eprintln!(
"Postgres test container not available, skipping test_postgres_describe_table"
);
return;
};
let result = conn
.describe_table(None, "test_users")
.expect("describe_table should succeed");
assert_eq!(result.columns.len(), 6, "should return 6 metadata columns");
let col_names: Vec<String> = result.columns.iter().map(|c| c.name.clone()).collect();
assert_eq!(
col_names,
vec![
"column_name",
"data_type",
"is_nullable",
"column_default",
"numeric_precision",
"numeric_scale",
]
);
assert!(
result.rows.len() >= 6,
"expected at least 6 rows, got {}",
result.rows.len()
);
}
#[test]
fn test_postgres_type_mapping() {
let Some(mut conn) = try_connect() else {
eprintln!("Postgres test container not available, skipping test_postgres_type_mapping");
return;
};
let result = conn
.query(
"SELECT name, age, score, active, meta, uid FROM test_users \
WHERE name = 'Alice'",
)
.expect("query should succeed");
assert_eq!(result.rows.len(), 1, "expected exactly Alice");
let row = &result.rows[0];
assert!(matches!(row[0], Value::String(_)), "name should be String");
assert!(matches!(row[1], Value::Int64(_)), "age should be Int64");
assert!(
matches!(row[2], Value::Decimal(_) | Value::Float64(_)),
"score (NUMERIC) should be Decimal or Float64"
);
assert!(matches!(row[3], Value::Bool(_)), "active should be Bool");
assert!(
matches!(row[4], Value::Json(_)),
"meta (JSONB) should be Json"
);
assert!(matches!(row[5], Value::Uuid(_)), "uid should be Uuid");
}
#[test]
fn test_postgres_timestamptz_mapping() {
let Some(mut conn) = try_connect() else {
eprintln!(
"Postgres test container not available, skipping test_postgres_timestamptz_mapping"
);
return;
};
let result = conn
.query("SELECT created_at FROM test_users WHERE name = 'Alice'")
.expect("query should succeed");
assert_eq!(result.rows.len(), 1);
assert!(
matches!(result.rows[0][0], Value::DateTimeTz(_)),
"created_at (TIMESTAMPTZ) should be DateTimeTz, got {:?}",
result.rows[0][0]
);
}
#[test]
fn test_postgres_bulk_insert_rows_round_trip() {
let Some(mut conn) = try_connect() else {
eprintln!(
"Postgres test container not available, skipping test_postgres_bulk_insert_rows_round_trip"
);
return;
};
let pid = std::process::id();
let table = format!("ferrule_bulk_test_{pid}");
let _ = conn.execute(&format!("DROP TABLE IF EXISTS {table}"));
conn.execute(&format!(
"CREATE TABLE {table} (\
id BIGINT, \
name TEXT, \
active BOOLEAN, \
score DOUBLE PRECISION, \
meta JSONB, \
tricky TEXT\
)"
))
.expect("CREATE TABLE");
let columns = vec![
ColumnInfo {
name: "id".into(),
type_hint: TypeHint::Int64,
nullable: false,
},
ColumnInfo {
name: "name".into(),
type_hint: TypeHint::String,
nullable: true,
},
ColumnInfo {
name: "active".into(),
type_hint: TypeHint::Bool,
nullable: true,
},
ColumnInfo {
name: "score".into(),
type_hint: TypeHint::Float64,
nullable: true,
},
ColumnInfo {
name: "meta".into(),
type_hint: TypeHint::Json,
nullable: true,
},
ColumnInfo {
name: "tricky".into(),
type_hint: TypeHint::String,
nullable: true,
},
];
let rows: Vec<Row> = vec![
vec![
Value::Int64(1),
Value::String("Alice".into()),
Value::Bool(true),
Value::Float64(99.5),
Value::Json(serde_json::json!({"role": "admin"})),
Value::String("plain".into()),
],
vec![
Value::Int64(2),
Value::String("Bob".into()),
Value::Bool(false),
Value::Float64(88.25),
Value::Json(serde_json::json!({"role": "user"})),
Value::String("comma,sep".into()),
],
vec![
Value::Int64(3),
Value::String("Esc\\\t\nape".into()),
Value::Bool(true),
Value::Float64(0.0),
Value::Json(serde_json::Value::Null),
Value::String("\\.".into()),
],
vec![
Value::Int64(4),
Value::Null,
Value::Null,
Value::Null,
Value::Null,
Value::Null,
],
vec![
Value::Int64(5),
Value::String("nan-and-inf".into()),
Value::Bool(true),
Value::Float64(f64::INFINITY),
Value::Json(serde_json::json!([1, 2, 3])),
Value::String("héllo 🐈".into()),
],
];
let n = conn
.bulk_insert_rows(BulkInsert {
table: &table,
columns: &columns,
rows: &rows,
copy_format: crate::copy::CopyFormat::Text,
})
.expect("bulk_insert_rows");
assert_eq!(n, 5, "bulk should return rows-accepted = 5");
let count = conn
.query(&format!("SELECT count(*)::bigint FROM {table}"))
.unwrap();
assert!(matches!(count.rows[0][0], Value::Int64(5)));
let r3 = conn
.query(&format!("SELECT name, tricky FROM {table} WHERE id = 3"))
.unwrap();
assert_eq!(r3.rows.len(), 1);
if let Value::String(name) = &r3.rows[0][0] {
assert_eq!(
name, "Esc\\\t\nape",
"row 3 name should round-trip with raw bytes"
);
} else {
panic!("row 3 name should be String, got {:?}", r3.rows[0][0]);
}
if let Value::String(tricky) = &r3.rows[0][1] {
assert_eq!(
tricky, "\\.",
"row 3 tricky should be literal backslash-dot"
);
} else {
panic!("row 3 tricky should be String, got {:?}", r3.rows[0][1]);
}
let r4 = conn
.query(&format!("SELECT name, active FROM {table} WHERE id = 4"))
.unwrap();
assert!(matches!(r4.rows[0][0], Value::Null));
assert!(matches!(r4.rows[0][1], Value::Null));
conn.execute(&format!("DROP TABLE {table}"))
.expect("DROP TABLE");
}
#[test]
fn test_postgres_primary_key() {
let Some(mut conn) = try_connect() else {
eprintln!("Postgres test container not available, skipping test_postgres_primary_key");
return;
};
let pk = conn.primary_key(None, "test_users").expect("primary_key");
assert_eq!(pk, vec!["id".to_string()]);
}
#[test]
fn test_postgres_list_foreign_keys() {
let Some(mut conn) = try_connect() else {
eprintln!(
"Postgres test container not available, skipping test_postgres_list_foreign_keys"
);
return;
};
let pid = std::process::id();
let child = format!("ferrule_fk_test_orders_{pid}");
let _ = conn.execute(&format!("DROP TABLE IF EXISTS {child}"));
conn.execute(&format!(
"CREATE TABLE {child} (\
id SERIAL PRIMARY KEY, \
user_id INT REFERENCES test_users(id) ON DELETE CASCADE\
)"
))
.expect("CREATE TABLE");
let fks = conn.list_foreign_keys(None).expect("list_foreign_keys");
let matching: Vec<_> = fks.iter().filter(|fk| fk.child_table == child).collect();
assert_eq!(matching.len(), 1, "expected 1 FK from {child}, got {fks:?}");
let fk = matching[0];
assert_eq!(fk.child_columns, vec!["user_id".to_string()]);
assert_eq!(fk.parent_table, "test_users");
assert_eq!(fk.parent_columns, vec!["id".to_string()]);
assert_eq!(fk.on_delete.as_deref(), Some("CASCADE"));
conn.execute(&format!("DROP TABLE {child}"))
.expect("DROP TABLE");
}
#[test]
fn test_postgres_copy_skip_then_upsert() {
use crate::backend::Backend;
use crate::copy::{CopyOptions, CopySource, IfExists, copy_rows};
let (Some(mut src), Some(mut dst)) = (try_connect(), try_connect()) else {
eprintln!(
"Postgres test container not available, skipping test_postgres_copy_skip_then_upsert"
);
return;
};
let pid = std::process::id();
let src_table = format!("ferrule_pg_skip_src_{pid}");
let dst_table = format!("ferrule_pg_skip_dst_{pid}");
let _ = src.execute(&format!("DROP TABLE IF EXISTS {src_table}"));
let _ = dst.execute(&format!("DROP TABLE IF EXISTS {dst_table}"));
src.execute(&format!(
"CREATE TABLE {src_table} (id INT PRIMARY KEY, name TEXT, val INT)"
))
.expect("CREATE src");
dst.execute(&format!(
"CREATE TABLE {dst_table} (id INT PRIMARY KEY, name TEXT, val INT)"
))
.expect("CREATE dst");
src.execute(&format!(
"INSERT INTO {src_table} VALUES (1, 'new-1', 10), (2, 'new-2', 20)"
))
.expect("seed src");
dst.execute(&format!("INSERT INTO {dst_table} VALUES (1, 'old-1', 99)"))
.expect("seed dst");
let opts = CopyOptions {
source: CopySource::Query {
sql: format!("SELECT * FROM {src_table} ORDER BY id"),
into: dst_table.clone(),
},
if_exists: IfExists::Skip,
..Default::default()
};
copy_rows(
&mut src,
Backend::Postgres,
&mut dst,
Backend::Postgres,
&opts,
)
.expect("copy_rows skip");
let out = dst
.query(&format!(
"SELECT id, name, val FROM {dst_table} ORDER BY id"
))
.expect("verify skip");
assert_eq!(out.rows.len(), 2);
assert!(matches!(&out.rows[0][1], Value::String(s) if s == "old-1"));
assert!(matches!(&out.rows[1][1], Value::String(s) if s == "new-2"));
let opts = CopyOptions {
source: CopySource::Query {
sql: format!("SELECT * FROM {src_table} ORDER BY id"),
into: dst_table.clone(),
},
if_exists: IfExists::Upsert,
..Default::default()
};
copy_rows(
&mut src,
Backend::Postgres,
&mut dst,
Backend::Postgres,
&opts,
)
.expect("copy_rows upsert");
let out = dst
.query(&format!(
"SELECT id, name, val FROM {dst_table} ORDER BY id"
))
.expect("verify upsert");
assert_eq!(out.rows.len(), 2);
assert!(matches!(&out.rows[0][1], Value::String(s) if s == "new-1"));
assert!(matches!(&out.rows[0][2], Value::Int64(10)));
assert!(matches!(&out.rows[1][1], Value::String(s) if s == "new-2"));
let _ = src.execute(&format!("DROP TABLE {src_table}"));
let _ = dst.execute(&format!("DROP TABLE {dst_table}"));
}
#[cfg(feature = "sqlite")]
#[test]
fn test_postgres_to_sqlite_all_tables_round_trip() {
use crate::backend::Backend;
use crate::copy::{AllTablesOptions, copy_all_tables};
let Some(mut src) = try_connect() else {
eprintln!(
"Postgres test container not available, skipping test_postgres_to_sqlite_all_tables_round_trip"
);
return;
};
let pid = std::process::id();
let parent = format!("ferrule_all_parent_{pid}");
let child = format!("ferrule_all_child_{pid}");
let _ = src.execute(&format!("DROP TABLE IF EXISTS {child}"));
let _ = src.execute(&format!("DROP TABLE IF EXISTS {parent}"));
src.execute(&format!(
"CREATE TABLE {parent} (id INT PRIMARY KEY, name TEXT)"
))
.expect("CREATE parent");
src.execute(&format!(
"CREATE TABLE {child} (id INT PRIMARY KEY, \
parent_id INT REFERENCES {parent}(id), \
note TEXT)"
))
.expect("CREATE child");
src.execute(&format!(
"INSERT INTO {parent} VALUES (1, 'one'), (2, 'two')"
))
.expect("seed parent");
src.execute(&format!(
"INSERT INTO {child} VALUES (10, 1, 'first'), (11, 2, 'second')"
))
.expect("seed child");
let dst_path = std::env::temp_dir().join(format!("ferrule-pg-all-tables-{pid}.db"));
let _ = std::fs::remove_file(&dst_path);
let dst_url = DatabaseUrl::parse(&format!("sqlite://{}", dst_path.display())).unwrap();
let mut dst =
crate::connect(&dst_url, &ConnectOptions::default(), None).expect("connect sqlite dst");
dst.execute("PRAGMA foreign_keys = ON").unwrap();
let opts = AllTablesOptions {
include: vec![format!("ferrule_all_*_{pid}")],
create_table: true,
..Default::default()
};
let copied = copy_all_tables(
&mut src,
Backend::Postgres,
&mut dst,
Backend::Sqlite,
&opts,
)
.expect("copy_all_tables PG -> SQLite");
assert_eq!(copied, 4, "2 parent rows + 2 child rows expected");
let p = dst
.query(&format!("SELECT count(*) FROM {parent}"))
.expect("verify parent");
let c = dst
.query(&format!("SELECT count(*) FROM {child}"))
.expect("verify child");
assert!(matches!(&p.rows[0][0], Value::Int64(2)));
assert!(matches!(&c.rows[0][0], Value::Int64(2)));
let _ = src.execute(&format!("DROP TABLE {child}"));
let _ = src.execute(&format!("DROP TABLE {parent}"));
let _ = std::fs::remove_file(&dst_path);
}
#[test]
fn test_postgres_binary_copy_round_trip_all_value_variants() {
use crate::backend::Backend;
use crate::copy::{BulkMode, CopyFormat, CopyOptions, CopySource, copy_rows};
let (Some(mut src), Some(mut dst)) = (try_connect(), try_connect()) else {
eprintln!(
"Postgres test container not available, skipping test_postgres_binary_copy_round_trip_all_value_variants"
);
return;
};
let pid = std::process::id();
let src_table = format!("ferrule_pg_bin_src_{pid}");
let dst_table = format!("ferrule_pg_bin_dst_{pid}");
let _ = src.execute(&format!("DROP TABLE IF EXISTS {src_table}"));
let _ = dst.execute(&format!("DROP TABLE IF EXISTS {dst_table}"));
let create = format!(
"CREATE TABLE {src_table} (\
b BOOLEAN, \
i BIGINT, \
f DOUBLE PRECISION, \
n NUMERIC, \
t TEXT, \
by BYTEA, \
d DATE, \
tm TIME, \
dt TIMESTAMP, \
dttz TIMESTAMPTZ, \
j JSONB, \
u UUID\
)"
);
src.execute(&create).expect("CREATE src");
dst.execute(&create.replace(&src_table, &dst_table))
.expect("CREATE dst");
src.execute(&format!(
"INSERT INTO {src_table} VALUES (\
true, 42, 2.5, 99.5, 'hello', '\\xdeadbeef', \
DATE '2024-05-14', TIME '12:34:56', \
TIMESTAMP '2024-05-14 12:34:56', \
TIMESTAMPTZ '2024-05-14 12:34:56+00', \
'{{\"k\":\"v\"}}'::jsonb, \
'00112233-4455-6677-8899-aabbccddeeff'::uuid\
), (\
false, NULL, NULL, NULL, NULL, NULL, \
NULL, NULL, NULL, NULL, NULL, NULL\
)"
))
.expect("seed src");
let opts = CopyOptions {
source: CopySource::Query {
sql: format!("SELECT * FROM {src_table} ORDER BY i NULLS LAST"),
into: dst_table.clone(),
},
bulk_mode: BulkMode::On,
copy_format: CopyFormat::Binary,
..Default::default()
};
let copied = copy_rows(
&mut src,
Backend::Postgres,
&mut dst,
Backend::Postgres,
&opts,
)
.expect("copy_rows binary COPY");
assert_eq!(copied, 2);
let out = dst
.query(&format!(
"SELECT b, i, f, n::text, t, by, d::text, tm::text, dt::text, \
dttz::text, j::text, u::text \
FROM {dst_table} ORDER BY i NULLS LAST"
))
.expect("read back");
assert_eq!(out.rows.len(), 2);
let r0 = &out.rows[0];
assert!(matches!(&r0[0], Value::Bool(true)));
assert!(matches!(&r0[1], Value::Int64(42)));
match &r0[2] {
Value::Float64(f) => assert!((f - 2.5).abs() < 1e-9),
other => panic!("expected Float64(2.5), got {other:?}"),
}
match &r0[3] {
Value::String(s) => assert_eq!(s, "99.5"),
other => panic!("expected NUMERIC text 99.5, got {other:?}"),
}
assert!(matches!(&r0[4], Value::String(s) if s == "hello"));
assert!(matches!(&r0[5], Value::Bytes(b) if b == &vec![0xde, 0xad, 0xbe, 0xef]));
assert!(matches!(&r0[11], Value::String(s) if s == "00112233-4455-6677-8899-aabbccddeeff"));
let r1 = &out.rows[1];
assert!(matches!(&r1[0], Value::Bool(false)));
for col in &r1[1..] {
assert!(matches!(col, Value::Null), "expected NULL, got {col:?}");
}
let _ = src.execute(&format!("DROP TABLE {src_table}"));
let _ = dst.execute(&format!("DROP TABLE {dst_table}"));
}
#[test]
fn test_postgres_cursor_streams_in_bounded_batches() {
let Some(mut conn) = try_connect() else {
eprintln!(
"Postgres test container not available, skipping test_postgres_cursor_streams_in_bounded_batches"
);
return;
};
const TOTAL: i64 = 50_000;
const BATCH: usize = 256;
let sql = format!("SELECT i, i * 2 AS doubled FROM generate_series(1, {TOTAL}) AS g(i)");
let mut cursor = conn.query_cursor(&sql).expect("open pg cursor");
assert_eq!(cursor.columns().len(), 2);
let mut total = 0u64;
let mut batches = 0u64;
loop {
let batch = cursor.next_batch(BATCH).expect("pull pg batch");
if batch.is_empty() {
break;
}
assert!(batch.len() <= BATCH);
total += batch.len() as u64;
batches += 1;
}
assert_eq!(total, TOTAL as u64);
assert_eq!(batches, (TOTAL as u64).div_ceil(BATCH as u64));
}
#[test]
fn test_postgres_write_rows_round_trip() {
let Some(mut conn) = try_connect() else {
eprintln!(
"Postgres test container not available, skipping test_postgres_write_rows_round_trip"
);
return;
};
let _ = conn.execute("DROP TABLE IF EXISTS ferrule_write_test");
conn.execute("CREATE TABLE ferrule_write_test (id INT PRIMARY KEY, name TEXT)")
.expect("create write table");
let columns = vec![
crate::value::ColumnInfo {
name: "id".into(),
type_hint: TypeHint::Int64,
nullable: false,
},
crate::value::ColumnInfo {
name: "name".into(),
type_hint: TypeHint::String,
nullable: true,
},
];
let rows: Vec<crate::value::Row> = (1..=3000)
.map(|i| vec![Value::Int64(i), Value::String(format!("n{i}"))])
.collect();
let opts = crate::write::WriteOptions {
batch_size: 500,
..Default::default()
};
let report = crate::write::write_rows(
&mut *conn,
crate::Backend::Postgres,
"ferrule_write_test",
&columns,
rows,
&opts,
)
.expect("write_rows");
assert_eq!(report.rows_written, 3000);
assert!(report.is_complete());
let back = conn
.query("SELECT COUNT(*) FROM ferrule_write_test")
.expect("count");
assert!(matches!(back.rows[0][0], Value::Int64(3000)));
let _ = conn.execute("DROP TABLE ferrule_write_test");
}
#[test]
fn test_postgres_write_rows_partial_failure() {
let Some(mut conn) = try_connect() else {
eprintln!(
"Postgres test container not available, skipping test_postgres_write_rows_partial_failure"
);
return;
};
let _ = conn.execute("DROP TABLE IF EXISTS ferrule_write_pf");
conn.execute("CREATE TABLE ferrule_write_pf (id INT PRIMARY KEY)")
.expect("create");
conn.execute("INSERT INTO ferrule_write_pf VALUES (5)")
.expect("seed");
let columns = vec![crate::value::ColumnInfo {
name: "id".into(),
type_hint: TypeHint::Int64,
nullable: false,
}];
let rows: Vec<crate::value::Row> = (1..=8).map(|i| vec![Value::Int64(i)]).collect();
let opts = crate::write::WriteOptions {
batch_size: 4,
..Default::default()
};
let report = crate::write::write_rows(
&mut *conn,
crate::Backend::Postgres,
"ferrule_write_pf",
&columns,
rows,
&opts,
)
.expect("write_rows");
assert_eq!(report.rows_written, 4);
assert_eq!(report.rejected_batches.len(), 1);
assert_eq!(report.rejected_batches[0].batch_index, 1);
let _ = conn.execute("DROP TABLE ferrule_write_pf");
}
}