Skip to main content

tl_data/
pg.rs

1use datafusion::arrow::array::RecordBatch;
2use datafusion::arrow::array::*;
3use datafusion::arrow::datatypes::{DataType, Field, Schema};
4use postgres::{Client, NoTls};
5use std::sync::Arc;
6
7use crate::engine::DataEngine;
8
9/// Extract detailed error message from a postgres::Error.
10/// The Display impl only shows "db error" — this extracts the actual
11/// PostgreSQL error message, detail, hint, and SQLSTATE code.
12fn pg_error_detail(e: &postgres::Error) -> String {
13    if let Some(db_err) = e.as_db_error() {
14        let mut msg = format!("{}: {}", db_err.severity(), db_err.message());
15        if let Some(detail) = db_err.detail() {
16            msg.push_str(&format!(" DETAIL: {detail}"));
17        }
18        if let Some(hint) = db_err.hint() {
19            msg.push_str(&format!(" HINT: {hint}"));
20        }
21        if let Some(where_) = db_err.where_() {
22            msg.push_str(&format!(" WHERE: {where_}"));
23        }
24        msg.push_str(&format!(" (SQLSTATE {})", db_err.code().code()));
25        msg
26    } else {
27        format!("{e}")
28    }
29}
30
31/// Connect to PostgreSQL, trying TLS first then falling back to NoTls.
32fn pg_connect(conn_str: &str) -> Result<Client, String> {
33    // Try TLS first (required by most cloud/remote PostgreSQL)
34    if let Ok(tls_connector) = native_tls::TlsConnector::builder()
35        .danger_accept_invalid_certs(true)
36        .build()
37    {
38        let connector = postgres_native_tls::MakeTlsConnector::new(tls_connector);
39        if let Ok(client) = Client::connect(conn_str, connector) {
40            return Ok(client);
41        }
42    }
43    // Fall back to NoTls (local connections, Unix sockets)
44    Client::connect(conn_str, NoTls)
45        .map_err(|e| format!("PostgreSQL connection error: {}", pg_error_detail(&e)))
46}
47
48/// Cursor FETCH size: rows transferred per network round trip.
49/// 1M rows per fetch = ~50 round trips for a 50M row table.
50const PG_CURSOR_FETCH_SIZE: usize = 1_000_000;
51
52/// RecordBatch size for DataFusion parallelism (local chunking, no network).
53const PG_BATCH_SIZE: usize = 100_000;
54
55/// Map PostgreSQL type names to Arrow DataTypes.
56fn pg_type_to_arrow(pg_type: &postgres::types::Type) -> DataType {
57    match *pg_type {
58        postgres::types::Type::BOOL => DataType::Boolean,
59        postgres::types::Type::INT2 => DataType::Int16,
60        postgres::types::Type::INT4 => DataType::Int32,
61        postgres::types::Type::INT8 => DataType::Int64,
62        postgres::types::Type::FLOAT4 => DataType::Float32,
63        postgres::types::Type::FLOAT8 => DataType::Float64,
64        postgres::types::Type::TEXT
65        | postgres::types::Type::VARCHAR
66        | postgres::types::Type::BPCHAR => DataType::Utf8,
67        _ => DataType::Utf8, // fallback: convert to string
68    }
69}
70
71/// Build a RecordBatch from a slice of postgres Rows using the given schema.
72fn build_record_batch(rows: &[postgres::Row], schema: &Arc<Schema>) -> Result<RecordBatch, String> {
73    let mut arrays: Vec<Arc<dyn Array>> = Vec::new();
74    for (col_idx, field) in schema.fields().iter().enumerate() {
75        let array: Arc<dyn Array> = match field.data_type() {
76            DataType::Boolean => {
77                let values: Vec<Option<bool>> = rows.iter().map(|r| r.get(col_idx)).collect();
78                Arc::new(BooleanArray::from(values))
79            }
80            DataType::Int16 => {
81                let values: Vec<Option<i16>> = rows.iter().map(|r| r.get(col_idx)).collect();
82                Arc::new(Int16Array::from(values))
83            }
84            DataType::Int32 => {
85                let values: Vec<Option<i32>> = rows.iter().map(|r| r.get(col_idx)).collect();
86                Arc::new(Int32Array::from(values))
87            }
88            DataType::Int64 => {
89                let values: Vec<Option<i64>> = rows.iter().map(|r| r.get(col_idx)).collect();
90                Arc::new(Int64Array::from(values))
91            }
92            DataType::Float32 => {
93                let values: Vec<Option<f32>> = rows.iter().map(|r| r.get(col_idx)).collect();
94                Arc::new(Float32Array::from(values))
95            }
96            DataType::Float64 => {
97                let values: Vec<Option<f64>> = rows.iter().map(|r| r.get(col_idx)).collect();
98                Arc::new(Float64Array::from(values))
99            }
100            _ => {
101                let values: Vec<Option<String>> = rows
102                    .iter()
103                    .map(|r| r.try_get::<_, String>(col_idx).ok())
104                    .collect();
105                Arc::new(StringArray::from(values))
106            }
107        };
108        arrays.push(array);
109    }
110
111    RecordBatch::try_new(schema.clone(), arrays)
112        .map_err(|e| format!("Arrow RecordBatch creation error: {e}"))
113}
114
115impl DataEngine {
116    /// Read a PostgreSQL table into a DataFusion DataFrame.
117    pub fn read_postgres(
118        &self,
119        conn_str: &str,
120        table_name: &str,
121    ) -> Result<datafusion::prelude::DataFrame, String> {
122        let query = format!("SELECT * FROM \"{}\"", table_name.replace('"', "\"\""));
123        self.query_postgres(conn_str, &query, table_name)
124    }
125
126    /// Execute a custom SQL query against PostgreSQL and return a DataFusion DataFrame.
127    /// Fetches 1M rows per cursor round trip, then subdivides into 100K-row RecordBatches
128    /// for DataFusion parallelism. Falls back to a simple query if cursors are not
129    /// supported (e.g., PgBouncer, some cloud proxies, or restricted permissions).
130    pub fn query_postgres(
131        &self,
132        conn_str: &str,
133        query: &str,
134        register_as: &str,
135    ) -> Result<datafusion::prelude::DataFrame, String> {
136        let mut client = pg_connect(conn_str)?;
137
138        // Try cursor-based streaming first, fall back to simple query
139        let result = self.query_postgres_cursor(&mut client, query);
140        let (batches, final_schema) = match result {
141            Ok(v) => v,
142            Err(_) => {
143                // Cursor failed — reconnect and use simple query fallback
144                let mut client2 = pg_connect(conn_str)?;
145                self.query_postgres_simple(&mut client2, query)?
146            }
147        };
148
149        if batches.is_empty() {
150            return Err(format!("Query returned no rows: {query}"));
151        }
152
153        let _ = self.ctx.deregister_table(register_as);
154        self.register_batches(register_as, final_schema, batches)?;
155
156        self.rt
157            .block_on(self.ctx.table(register_as))
158            .map_err(|e| format!("Table reference error: {e}"))
159    }
160
161    /// Cursor-based streaming: DECLARE CURSOR + FETCH in large chunks,
162    /// then subdivide locally into smaller RecordBatches for DataFusion parallelism.
163    fn query_postgres_cursor(
164        &self,
165        client: &mut Client,
166        query: &str,
167    ) -> Result<(Vec<RecordBatch>, Arc<Schema>), String> {
168        let mut txn = client
169            .transaction()
170            .map_err(|e| format!("PostgreSQL transaction error: {}", pg_error_detail(&e)))?;
171
172        txn.execute(
173            &format!("DECLARE _tl_cursor NO SCROLL CURSOR FOR {query}"),
174            &[],
175        )
176        .map_err(|e| format!("PostgreSQL DECLARE CURSOR error: {}", pg_error_detail(&e)))?;
177
178        let mut batches: Vec<RecordBatch> = Vec::new();
179        let mut schema: Option<Arc<Schema>> = None;
180
181        loop {
182            let rows = txn
183                .query(
184                    &format!("FETCH {PG_CURSOR_FETCH_SIZE} FROM _tl_cursor"),
185                    &[],
186                )
187                .map_err(|e| format!("PostgreSQL FETCH error: {}", pg_error_detail(&e)))?;
188
189            if rows.is_empty() {
190                break;
191            }
192
193            if schema.is_none() {
194                let columns = rows[0].columns();
195                let fields: Vec<Field> = columns
196                    .iter()
197                    .map(|col| Field::new(col.name(), pg_type_to_arrow(col.type_()), true))
198                    .collect();
199                schema = Some(Arc::new(Schema::new(fields)));
200            }
201
202            // Subdivide the fetched chunk into smaller RecordBatches
203            let s = schema.as_ref().unwrap();
204            for chunk in rows.chunks(PG_BATCH_SIZE) {
205                let batch = build_record_batch(chunk, s)?;
206                batches.push(batch);
207            }
208        }
209
210        txn.execute("CLOSE _tl_cursor", &[]).ok();
211        txn.commit()
212            .map_err(|e| format!("PostgreSQL COMMIT error: {}", pg_error_detail(&e)))?;
213
214        let final_schema = schema.ok_or_else(|| "No schema from cursor query".to_string())?;
215        Ok((batches, final_schema))
216    }
217
218    /// Simple query fallback: runs the full query and batches results locally.
219    fn query_postgres_simple(
220        &self,
221        client: &mut Client,
222        query: &str,
223    ) -> Result<(Vec<RecordBatch>, Arc<Schema>), String> {
224        let rows = client
225            .query(query, &[])
226            .map_err(|e| format!("PostgreSQL query error: {}", pg_error_detail(&e)))?;
227
228        if rows.is_empty() {
229            return Err("Query returned no rows".to_string());
230        }
231
232        let columns = rows[0].columns();
233        let fields: Vec<Field> = columns
234            .iter()
235            .map(|col| Field::new(col.name(), pg_type_to_arrow(col.type_()), true))
236            .collect();
237        let schema = Arc::new(Schema::new(fields));
238
239        let mut batches: Vec<RecordBatch> = Vec::new();
240        for chunk in rows.chunks(PG_BATCH_SIZE) {
241            let batch = build_record_batch(chunk, &schema)?;
242            batches.push(batch);
243        }
244
245        Ok((batches, schema))
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    #[ignore] // Requires a running PostgreSQL instance
255    fn test_read_postgres() {
256        let engine = DataEngine::new();
257        let df = engine
258            .read_postgres(
259                "host=localhost user=postgres password=postgres dbname=testdb",
260                "users",
261            )
262            .unwrap();
263        let batches = engine.collect(df).unwrap();
264        assert!(!batches.is_empty());
265    }
266
267    #[test]
268    #[ignore] // Requires a running PostgreSQL instance
269    fn test_query_postgres_custom_sql() {
270        let engine = DataEngine::new();
271        let df = engine
272            .query_postgres(
273                "host=localhost user=postgres password=postgres dbname=testdb",
274                "SELECT * FROM users LIMIT 10",
275                "limited_users",
276            )
277            .unwrap();
278        let batches = engine.collect(df).unwrap();
279        assert!(!batches.is_empty());
280        let total: usize = batches.iter().map(|b| b.num_rows()).sum();
281        assert!(total <= 10);
282    }
283}