use ciborium::Value as CborValue;
use sqlx::postgres::PgRow;
use sqlx::{Column, Row, TypeInfo};
use vantage_types::Record;
use super::types::{AnyPostgresType, PostgresTypeVariants};
pub(crate) fn bind_postgres_value<'q>(
query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
value: &'q AnyPostgresType,
) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
let cbor = value.value();
match value.type_variant() {
Some(PostgresTypeVariants::Null) => query.bind(None::<String>),
None => bind_by_cbor(query, cbor),
Some(PostgresTypeVariants::Bool) => match cbor {
CborValue::Null => query.bind(None::<bool>),
CborValue::Bool(b) => query.bind(*b),
CborValue::Integer(i) => match i64::try_from(*i) {
Ok(n) => query.bind(n != 0),
Err(_) => query.bind(None::<bool>),
},
_ => query.bind(None::<bool>),
},
Some(PostgresTypeVariants::Int2) => match cbor {
CborValue::Null => query.bind(None::<i16>),
CborValue::Integer(i) => {
query.bind(i64::try_from(*i).ok().and_then(|n| i16::try_from(n).ok()))
}
_ => query.bind(None::<i16>),
},
Some(PostgresTypeVariants::Int4) => match cbor {
CborValue::Null => query.bind(None::<i32>),
CborValue::Integer(i) => {
query.bind(i64::try_from(*i).ok().and_then(|n| i32::try_from(n).ok()))
}
_ => query.bind(None::<i32>),
},
Some(PostgresTypeVariants::Int8) => match cbor {
CborValue::Null => query.bind(None::<i64>),
CborValue::Integer(i) => query.bind(i64::try_from(*i).ok()),
_ => query.bind(None::<i64>),
},
Some(PostgresTypeVariants::Float4) => match cbor {
CborValue::Null => query.bind(None::<f32>),
CborValue::Float(f) => query.bind(*f as f32),
CborValue::Integer(i) => query.bind(i64::try_from(*i).ok().map(|n| n as f32)),
_ => query.bind(None::<f32>),
},
Some(PostgresTypeVariants::Float8) => match cbor {
CborValue::Null => query.bind(None::<f64>),
CborValue::Float(f) => query.bind(*f),
CborValue::Integer(i) => query.bind(i64::try_from(*i).ok().map(|n| n as f64)),
_ => query.bind(None::<f64>),
},
Some(PostgresTypeVariants::Text) => match cbor {
CborValue::Null => query.bind(None::<String>),
CborValue::Text(s) => query.bind(s.as_str()),
CborValue::Tag(_, inner) => {
if let CborValue::Text(s) = inner.as_ref() {
query.bind(s.as_str())
} else {
query.bind(None::<String>)
}
}
_ => query.bind(None::<String>),
},
Some(PostgresTypeVariants::Decimal) => {
let s = match cbor {
CborValue::Null => return query.bind(None::<rust_decimal::Decimal>),
CborValue::Tag(10, inner) => match inner.as_ref() {
CborValue::Text(s) => s.as_str(),
_ => return query.bind(None::<rust_decimal::Decimal>),
},
CborValue::Text(s) => s.as_str(),
_ => return query.bind(None::<rust_decimal::Decimal>),
};
match s.parse::<rust_decimal::Decimal>() {
Ok(d) => query.bind(d),
Err(_) => query.bind(None::<rust_decimal::Decimal>),
}
}
Some(PostgresTypeVariants::DateTime) => {
let s = match cbor {
CborValue::Null => return query.bind(None::<chrono::NaiveDateTime>),
CborValue::Tag(0, inner) => match inner.as_ref() {
CborValue::Text(s) => s.clone(),
_ => return query.bind(None::<chrono::NaiveDateTime>),
},
CborValue::Text(s) => s.clone(),
_ => return query.bind(None::<chrono::NaiveDateTime>),
};
if let Ok(dt) = chrono::DateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%#z") {
query.bind(dt.with_timezone(&chrono::Utc))
} else if let Ok(dt) = chrono::DateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f%#z") {
query.bind(dt.with_timezone(&chrono::Utc))
} else if let Ok(dt) = s.parse::<chrono::DateTime<chrono::Utc>>() {
query.bind(dt)
} else if let Ok(ndt) = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S")
.or_else(|_| chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f"))
.or_else(|_| chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S"))
.or_else(|_| chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S%.f"))
{
query.bind(ndt)
} else {
query.bind(None::<chrono::NaiveDateTime>)
}
}
Some(PostgresTypeVariants::Date) => {
let s = match cbor {
CborValue::Null => return query.bind(None::<chrono::NaiveDate>),
CborValue::Tag(100, inner) => match inner.as_ref() {
CborValue::Text(s) => s.clone(),
_ => return query.bind(None::<chrono::NaiveDate>),
},
CborValue::Text(s) => s.clone(),
_ => return query.bind(None::<chrono::NaiveDate>),
};
match chrono::NaiveDate::parse_from_str(&s, "%Y-%m-%d") {
Ok(d) => query.bind(d),
Err(_) => query.bind(None::<chrono::NaiveDate>),
}
}
Some(PostgresTypeVariants::Time) => {
let s = match cbor {
CborValue::Null => return query.bind(None::<chrono::NaiveTime>),
CborValue::Tag(101, inner) => match inner.as_ref() {
CborValue::Text(s) => s.clone(),
_ => return query.bind(None::<chrono::NaiveTime>),
},
CborValue::Text(s) => s.clone(),
_ => return query.bind(None::<chrono::NaiveTime>),
};
match chrono::NaiveTime::parse_from_str(&s, "%H:%M:%S")
.or_else(|_| chrono::NaiveTime::parse_from_str(&s, "%H:%M:%S%.f"))
{
Ok(t) => query.bind(t),
Err(_) => query.bind(None::<chrono::NaiveTime>),
}
}
Some(PostgresTypeVariants::Uuid) => match cbor {
CborValue::Null => query.bind(None::<String>),
CborValue::Tag(9, inner) => {
if let CborValue::Text(s) = inner.as_ref() {
query.bind(s.as_str())
} else {
query.bind(None::<String>)
}
}
CborValue::Text(s) => query.bind(s.as_str()),
_ => query.bind(None::<String>),
},
Some(PostgresTypeVariants::Blob) => match cbor {
CborValue::Null => query.bind(None::<Vec<u8>>),
CborValue::Bytes(b) => query.bind(b.as_slice()),
CborValue::Text(s) => query.bind(s.as_bytes()),
_ => query.bind(None::<Vec<u8>>),
},
}
}
fn bind_by_cbor<'q>(
query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
cbor: &'q CborValue,
) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
match cbor {
CborValue::Null => query.bind(None::<String>),
CborValue::Bool(b) => query.bind(*b),
CborValue::Integer(i) => {
if let Ok(n) = i64::try_from(*i) {
query.bind(n)
} else {
query.bind(i128::from(*i).to_string())
}
}
CborValue::Float(f) => query.bind(*f),
CborValue::Text(s) => query.bind(s.as_str()),
CborValue::Bytes(b) => query.bind(b.as_slice()),
CborValue::Tag(10, inner) => {
if let CborValue::Text(s) = inner.as_ref() {
query.bind(s.as_str())
} else {
query.bind(None::<String>)
}
}
CborValue::Tag(0 | 100 | 101, inner) => {
if let CborValue::Text(s) = inner.as_ref() {
query.bind(s.as_str())
} else {
query.bind(None::<String>)
}
}
CborValue::Tag(9, inner) => {
if let CborValue::Text(s) = inner.as_ref() {
query.bind(s.as_str())
} else {
query.bind(None::<String>)
}
}
_ => query.bind(None::<String>),
}
}
pub(crate) fn row_to_record(row: &PgRow) -> Record<AnyPostgresType> {
let mut record = Record::new();
for col in row.columns() {
let name = col.name().to_string();
let type_name = col.type_info().name();
let (cbor, variant) = pg_column_to_cbor(row, col.ordinal(), type_name);
let value = match variant {
Some(v) => AnyPostgresType::with_variant(cbor, v),
None => AnyPostgresType::untyped(cbor),
};
record.insert(name, value);
}
record
}
fn pg_column_to_cbor(
row: &PgRow,
ordinal: usize,
type_name: &str,
) -> (CborValue, Option<PostgresTypeVariants>) {
use sqlx::ValueRef;
if row
.try_get_raw(ordinal)
.map(|v| v.is_null())
.unwrap_or(true)
{
return (CborValue::Null, None);
}
match type_name {
"BOOL" => {
if let Ok(v) = row.try_get::<bool, _>(ordinal) {
return (CborValue::Bool(v), Some(PostgresTypeVariants::Bool));
}
}
"INT2" | "SMALLINT" | "SMALLSERIAL" => {
if let Ok(v) = row.try_get::<i16, _>(ordinal) {
return (
CborValue::Integer((v as i64).into()),
Some(PostgresTypeVariants::Int2),
);
}
}
"INT4" | "INT" | "INTEGER" | "SERIAL" => {
if let Ok(v) = row.try_get::<i32, _>(ordinal) {
return (
CborValue::Integer((v as i64).into()),
Some(PostgresTypeVariants::Int4),
);
}
}
"INT8" | "BIGINT" | "BIGSERIAL" => {
if let Ok(v) = row.try_get::<i64, _>(ordinal) {
return (
CborValue::Integer(v.into()),
Some(PostgresTypeVariants::Int8),
);
}
}
"FLOAT4" | "REAL" => {
if let Ok(v) = row.try_get::<f32, _>(ordinal) {
return (
CborValue::Float(v as f64),
Some(PostgresTypeVariants::Float4),
);
}
}
"FLOAT8" | "DOUBLE PRECISION" => {
if let Ok(v) = row.try_get::<f64, _>(ordinal) {
return (CborValue::Float(v), Some(PostgresTypeVariants::Float8));
}
}
"NUMERIC" | "DECIMAL" => {
if let Ok(v) = row.try_get::<rust_decimal::Decimal, _>(ordinal) {
return (
CborValue::Tag(10, Box::new(CborValue::Text(v.to_string()))),
Some(PostgresTypeVariants::Decimal),
);
}
}
"_TEXT" | "TEXT[]" => {
if let Ok(v) = row.try_get::<Vec<String>, _>(ordinal) {
return (
CborValue::Array(v.into_iter().map(CborValue::Text).collect()),
Some(PostgresTypeVariants::Text),
);
}
}
"_INT4" | "INT4[]" | "INTEGER[]" => {
if let Ok(v) = row.try_get::<Vec<i32>, _>(ordinal) {
return (
CborValue::Array(
v.into_iter()
.map(|i| CborValue::Integer((i as i64).into()))
.collect(),
),
Some(PostgresTypeVariants::Int4),
);
}
}
"UUID" => {
if let Ok(v) = row.try_get::<uuid::Uuid, _>(ordinal) {
return (
CborValue::Tag(9, Box::new(CborValue::Text(v.to_string()))),
Some(PostgresTypeVariants::Uuid),
);
}
}
"TIME" | "TIME WITHOUT TIME ZONE" => {
if let Ok(v) = row.try_get::<chrono::NaiveTime, _>(ordinal) {
return (
CborValue::Tag(
101,
Box::new(CborValue::Text(v.format("%H:%M:%S%.f").to_string())),
),
Some(PostgresTypeVariants::Time),
);
}
}
"DATE" => {
if let Ok(v) = row.try_get::<chrono::NaiveDate, _>(ordinal) {
return (
CborValue::Tag(
100,
Box::new(CborValue::Text(v.format("%Y-%m-%d").to_string())),
),
Some(PostgresTypeVariants::Date),
);
}
}
"TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => {
if let Ok(v) = row.try_get::<chrono::DateTime<chrono::Utc>, _>(ordinal) {
return (
CborValue::Tag(
0,
Box::new(CborValue::Text(
v.format("%Y-%m-%d %H:%M:%S%.f+00").to_string(),
)),
),
Some(PostgresTypeVariants::DateTime),
);
}
}
"TIMESTAMP" | "TIMESTAMP WITHOUT TIME ZONE" => {
if let Ok(v) = row.try_get::<chrono::NaiveDateTime, _>(ordinal) {
return (
CborValue::Tag(
0,
Box::new(CborValue::Text(
v.format("%Y-%m-%d %H:%M:%S%.f").to_string(),
)),
),
Some(PostgresTypeVariants::DateTime),
);
}
}
"JSONB" | "JSON" => {
if let Ok(v) = row.try_get::<serde_json::Value, _>(ordinal) {
let cbor = crate::types::json_to_cbor(v);
return (cbor, None);
}
}
"BYTEA" => {
if let Ok(v) = row.try_get::<Vec<u8>, _>(ordinal) {
return (CborValue::Bytes(v), Some(PostgresTypeVariants::Blob));
}
}
_ => {}
}
if let Ok(v) = row.try_get::<bool, _>(ordinal) {
return (CborValue::Bool(v), Some(PostgresTypeVariants::Bool));
}
if let Ok(v) = row.try_get::<i64, _>(ordinal) {
return (
CborValue::Integer(v.into()),
Some(PostgresTypeVariants::Int8),
);
}
if let Ok(v) = row.try_get::<i32, _>(ordinal) {
return (
CborValue::Integer((v as i64).into()),
Some(PostgresTypeVariants::Int4),
);
}
if let Ok(v) = row.try_get::<f64, _>(ordinal) {
return (CborValue::Float(v), Some(PostgresTypeVariants::Float8));
}
if let Ok(v) = row.try_get::<String, _>(ordinal) {
return (CborValue::Text(v), None);
}
eprintln!(
"vantage: failed to decode PostgreSQL column '{}' (type '{}') — returning NULL",
row.columns()[ordinal].name(),
type_name,
);
(CborValue::Null, None)
}