Skip to main content

tl_data/
pg.rs

1use datafusion::arrow::array::RecordBatch;
2use datafusion::arrow::array::*;
3use datafusion::arrow::datatypes::{DataType, Field, Schema};
4use postgres::{Client, NoTls};
5use std::sync::Arc;
6
7use crate::engine::DataEngine;
8
9/// Map PostgreSQL type names to Arrow DataTypes.
10fn pg_type_to_arrow(pg_type: &postgres::types::Type) -> DataType {
11    match *pg_type {
12        postgres::types::Type::BOOL => DataType::Boolean,
13        postgres::types::Type::INT2 => DataType::Int16,
14        postgres::types::Type::INT4 => DataType::Int32,
15        postgres::types::Type::INT8 => DataType::Int64,
16        postgres::types::Type::FLOAT4 => DataType::Float32,
17        postgres::types::Type::FLOAT8 => DataType::Float64,
18        postgres::types::Type::TEXT
19        | postgres::types::Type::VARCHAR
20        | postgres::types::Type::BPCHAR => DataType::Utf8,
21        _ => DataType::Utf8, // fallback: convert to string
22    }
23}
24
25impl DataEngine {
26    /// Read a PostgreSQL table into a DataFusion DataFrame.
27    pub fn read_postgres(
28        &self,
29        conn_str: &str,
30        table_name: &str,
31    ) -> Result<datafusion::prelude::DataFrame, String> {
32        let mut client = Client::connect(conn_str, NoTls)
33            .map_err(|e| format!("PostgreSQL connection error: {e}"))?;
34
35        let query = format!("SELECT * FROM {table_name}");
36        let rows = client
37            .query(&query, &[])
38            .map_err(|e| format!("PostgreSQL query error: {e}"))?;
39
40        if rows.is_empty() {
41            return Err(format!("Table '{table_name}' is empty or does not exist"));
42        }
43
44        let columns = rows[0].columns();
45        let fields: Vec<Field> = columns
46            .iter()
47            .map(|col| Field::new(col.name(), pg_type_to_arrow(col.type_()), true))
48            .collect();
49        let schema = Arc::new(Schema::new(fields));
50
51        let mut arrays: Vec<Arc<dyn Array>> = Vec::new();
52        for (col_idx, col) in columns.iter().enumerate() {
53            let arrow_type = pg_type_to_arrow(col.type_());
54            let array: Arc<dyn Array> = match arrow_type {
55                DataType::Boolean => {
56                    let values: Vec<Option<bool>> = rows.iter().map(|r| r.get(col_idx)).collect();
57                    Arc::new(BooleanArray::from(values))
58                }
59                DataType::Int16 => {
60                    let values: Vec<Option<i16>> = rows.iter().map(|r| r.get(col_idx)).collect();
61                    Arc::new(Int16Array::from(values))
62                }
63                DataType::Int32 => {
64                    let values: Vec<Option<i32>> = rows.iter().map(|r| r.get(col_idx)).collect();
65                    Arc::new(Int32Array::from(values))
66                }
67                DataType::Int64 => {
68                    let values: Vec<Option<i64>> = rows.iter().map(|r| r.get(col_idx)).collect();
69                    Arc::new(Int64Array::from(values))
70                }
71                DataType::Float32 => {
72                    let values: Vec<Option<f32>> = rows.iter().map(|r| r.get(col_idx)).collect();
73                    Arc::new(Float32Array::from(values))
74                }
75                DataType::Float64 => {
76                    let values: Vec<Option<f64>> = rows.iter().map(|r| r.get(col_idx)).collect();
77                    Arc::new(Float64Array::from(values))
78                }
79                _ => {
80                    let values: Vec<Option<String>> = rows
81                        .iter()
82                        .map(|r| r.try_get::<_, String>(col_idx).ok())
83                        .collect();
84                    Arc::new(StringArray::from(values))
85                }
86            };
87            arrays.push(array);
88        }
89
90        let batch = RecordBatch::try_new(schema, arrays)
91            .map_err(|e| format!("Arrow RecordBatch creation error: {e}"))?;
92
93        self.register_batch(table_name, batch)?;
94
95        self.rt
96            .block_on(self.ctx.table(table_name))
97            .map_err(|e| format!("Table reference error: {e}"))
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    #[ignore] // Requires a running PostgreSQL instance
107    fn test_read_postgres() {
108        let engine = DataEngine::new();
109        let df = engine
110            .read_postgres(
111                "host=localhost user=postgres password=postgres dbname=testdb",
112                "users",
113            )
114            .unwrap();
115        let batches = engine.collect(df).unwrap();
116        assert!(!batches.is_empty());
117    }
118}