use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
use polars::prelude::*;
use postgres::types::Type;
use postgres::{Client, NoTls, Row};
use rust_decimal::prelude::ToPrimitive;
use rust_decimal::Decimal;
use std::collections::HashSet;
use taxa_core::error::{Error, Result};
use taxa_core::source::Source;
pub struct SqlSource {
df: DataFrame,
dsn: String,
query: String,
}
impl SqlSource {
pub fn connect(dsn: &str, query: &str) -> Result<Self> {
let df = Self::ingest(dsn, query)?;
Ok(SqlSource {
df,
dsn: dsn.to_string(),
query: query.to_string(),
})
}
pub fn reload(&mut self) -> Result<()> {
self.df = Self::ingest(&self.dsn, &self.query)?;
Ok(())
}
pub fn height(&self) -> usize {
self.df.height()
}
fn ingest(dsn: &str, query: &str) -> Result<DataFrame> {
let mut client = Client::connect(dsn, NoTls)
.map_err(|e| Error::Engine(format!("postgres connect failed ({dsn}): {e}")))?;
let stmt = client
.prepare(query)
.map_err(|e| Error::Engine(format!("postgres prepare failed: {e}")))?;
let schema: Vec<(String, Type)> = stmt
.columns()
.iter()
.map(|c| (c.name().to_string(), c.type_().clone()))
.collect();
let rows = client
.query(&stmt, &[])
.map_err(|e| Error::Engine(format!("postgres query failed: {e}")))?;
build_dataframe(&schema, &rows)
}
}
impl Source for SqlSource {
fn frame(&self) -> Result<LazyFrame> {
Ok(self.df.clone().lazy())
}
fn columns(&self) -> Result<HashSet<String>> {
Ok(self
.df
.get_column_names()
.into_iter()
.map(|s| s.to_string())
.collect())
}
fn schema(&self) -> Result<Vec<(String, String)>> {
Ok(self
.df
.iter()
.map(|s| (s.name().to_string(), format!("{}", s.dtype())))
.collect())
}
}
fn build_dataframe(schema: &[(String, Type)], rows: &[Row]) -> Result<DataFrame> {
let mut series: Vec<Column> = Vec::with_capacity(schema.len());
for (idx, (name, ty)) in schema.iter().enumerate() {
series.push(decode_column(rows, idx, name, ty)?.into_column());
}
DataFrame::new(series).map_err(|e| Error::Engine(format!("dataframe build failed: {e}")))
}
fn decode_column(rows: &[Row], idx: usize, name: &str, ty: &Type) -> Result<Series> {
let pl_name: PlSmallStr = name.into();
let n = rows.len();
macro_rules! collect {
($t:ty) => {{
let mut v: Vec<Option<$t>> = Vec::with_capacity(n);
for row in rows {
v.push(row.get::<usize, Option<$t>>(idx));
}
v
}};
}
let series = match *ty {
Type::BOOL => Series::new(pl_name, collect!(bool)),
Type::INT2 => Series::new(pl_name, collect!(i16)),
Type::INT4 => Series::new(pl_name, collect!(i32)),
Type::INT8 => Series::new(pl_name, collect!(i64)),
Type::FLOAT4 => Series::new(pl_name, collect!(f32)),
Type::FLOAT8 => Series::new(pl_name, collect!(f64)),
Type::NUMERIC => {
let mut v: Vec<Option<f64>> = Vec::with_capacity(n);
for row in rows {
let dec: Option<Decimal> = row.get(idx);
v.push(dec.and_then(|d| d.to_f64()));
}
Series::new(pl_name, v)
}
Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::NAME => {
Series::new(pl_name, collect!(String))
}
Type::DATE => {
let mut v: Vec<Option<String>> = Vec::with_capacity(n);
for row in rows {
let d: Option<NaiveDate> = row.get(idx);
v.push(d.map(|d| d.format("%Y-%m-%d").to_string()));
}
Series::new(pl_name, v)
}
Type::TIMESTAMP => {
let mut v: Vec<Option<String>> = Vec::with_capacity(n);
for row in rows {
let t: Option<NaiveDateTime> = row.get(idx);
v.push(t.map(|t| t.format("%Y-%m-%dT%H:%M:%S%.f").to_string()));
}
Series::new(pl_name, v)
}
Type::TIMESTAMPTZ => {
let mut v: Vec<Option<String>> = Vec::with_capacity(n);
for row in rows {
let t: Option<DateTime<Utc>> = row.get(idx);
v.push(t.map(|t| t.to_rfc3339()));
}
Series::new(pl_name, v)
}
ref other => {
return Err(Error::Engine(format!(
"taxa-sql: column '{name}' has unsupported Postgres type '{other}'. \
Cast it in the SQL query (e.g. `{name}::text` or `{name}::float8`)."
)));
}
};
Ok(series)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_rows_preserve_schema() {
let schema = vec![
("symbol".to_string(), Type::TEXT),
("dt".to_string(), Type::DATE),
("mcap".to_string(), Type::FLOAT8),
("n".to_string(), Type::INT8),
];
let rows: Vec<Row> = Vec::new();
let df = build_dataframe(&schema, &rows).expect("typed empty frame");
assert_eq!(df.shape(), (0, 4), "0 rows but all 4 columns preserved");
assert_eq!(
df.get_column_names()
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
["symbol", "dt", "mcap", "n"]
);
assert_eq!(df.column("symbol").unwrap().dtype(), &DataType::String);
assert_eq!(df.column("dt").unwrap().dtype(), &DataType::String);
assert_eq!(df.column("mcap").unwrap().dtype(), &DataType::Float64);
assert_eq!(df.column("n").unwrap().dtype(), &DataType::Int64);
}
#[test]
fn postgres_to_polars_dtype_contract() {
let cases: &[(Type, DataType)] = &[
(Type::BOOL, DataType::Boolean),
(Type::INT2, DataType::Int16),
(Type::INT4, DataType::Int32),
(Type::INT8, DataType::Int64),
(Type::FLOAT4, DataType::Float32),
(Type::FLOAT8, DataType::Float64),
(Type::NUMERIC, DataType::Float64), (Type::TEXT, DataType::String),
(Type::VARCHAR, DataType::String),
(Type::BPCHAR, DataType::String),
(Type::NAME, DataType::String),
(Type::DATE, DataType::String), (Type::TIMESTAMP, DataType::String),
(Type::TIMESTAMPTZ, DataType::String),
];
let rows: Vec<Row> = Vec::new();
for (pg, expected) in cases {
let schema = vec![("c".to_string(), pg.clone())];
let df = build_dataframe(&schema, &rows).expect("typed empty frame");
assert_eq!(
df.column("c").unwrap().dtype(),
expected,
"Postgres {pg:?} must map to {expected:?}"
);
}
let schema = vec![("c".to_string(), Type::JSON)];
let err = build_dataframe(&schema, &rows).unwrap_err().to_string();
assert!(
err.contains("unsupported"),
"unsupported type errors clearly: {err}"
);
}
}