1use datafusion::arrow::array::RecordBatch;
2use datafusion::arrow::array::*;
3use datafusion::arrow::datatypes::{DataType, Field, Schema};
4use postgres::{Client, NoTls};
5use std::sync::Arc;
6
7use crate::engine::DataEngine;
8
9fn pg_error_detail(e: &postgres::Error) -> String {
13 if let Some(db_err) = e.as_db_error() {
14 let mut msg = format!("{}: {}", db_err.severity(), db_err.message());
15 if let Some(detail) = db_err.detail() {
16 msg.push_str(&format!(" DETAIL: {detail}"));
17 }
18 if let Some(hint) = db_err.hint() {
19 msg.push_str(&format!(" HINT: {hint}"));
20 }
21 if let Some(where_) = db_err.where_() {
22 msg.push_str(&format!(" WHERE: {where_}"));
23 }
24 msg.push_str(&format!(" (SQLSTATE {})", db_err.code().code()));
25 msg
26 } else {
27 format!("{e}")
28 }
29}
30
31fn pg_connect(conn_str: &str) -> Result<Client, String> {
33 if let Ok(tls_connector) = native_tls::TlsConnector::builder()
35 .danger_accept_invalid_certs(true)
36 .build()
37 {
38 let connector = postgres_native_tls::MakeTlsConnector::new(tls_connector);
39 if let Ok(client) = Client::connect(conn_str, connector) {
40 return Ok(client);
41 }
42 }
43 Client::connect(conn_str, NoTls)
45 .map_err(|e| format!("PostgreSQL connection error: {}", pg_error_detail(&e)))
46}
47
48const PG_CURSOR_FETCH_SIZE: usize = 1_000_000;
51
52const PG_BATCH_SIZE: usize = 100_000;
54
55fn pg_type_to_arrow(pg_type: &postgres::types::Type) -> DataType {
57 match *pg_type {
58 postgres::types::Type::BOOL => DataType::Boolean,
59 postgres::types::Type::INT2 => DataType::Int16,
60 postgres::types::Type::INT4 => DataType::Int32,
61 postgres::types::Type::INT8 => DataType::Int64,
62 postgres::types::Type::FLOAT4 => DataType::Float32,
63 postgres::types::Type::FLOAT8 => DataType::Float64,
64 postgres::types::Type::TEXT
65 | postgres::types::Type::VARCHAR
66 | postgres::types::Type::BPCHAR => DataType::Utf8,
67 _ => DataType::Utf8, }
69}
70
71fn build_record_batch(rows: &[postgres::Row], schema: &Arc<Schema>) -> Result<RecordBatch, String> {
73 let mut arrays: Vec<Arc<dyn Array>> = Vec::new();
74 for (col_idx, field) in schema.fields().iter().enumerate() {
75 let array: Arc<dyn Array> = match field.data_type() {
76 DataType::Boolean => {
77 let values: Vec<Option<bool>> = rows.iter().map(|r| r.get(col_idx)).collect();
78 Arc::new(BooleanArray::from(values))
79 }
80 DataType::Int16 => {
81 let values: Vec<Option<i16>> = rows.iter().map(|r| r.get(col_idx)).collect();
82 Arc::new(Int16Array::from(values))
83 }
84 DataType::Int32 => {
85 let values: Vec<Option<i32>> = rows.iter().map(|r| r.get(col_idx)).collect();
86 Arc::new(Int32Array::from(values))
87 }
88 DataType::Int64 => {
89 let values: Vec<Option<i64>> = rows.iter().map(|r| r.get(col_idx)).collect();
90 Arc::new(Int64Array::from(values))
91 }
92 DataType::Float32 => {
93 let values: Vec<Option<f32>> = rows.iter().map(|r| r.get(col_idx)).collect();
94 Arc::new(Float32Array::from(values))
95 }
96 DataType::Float64 => {
97 let values: Vec<Option<f64>> = rows.iter().map(|r| r.get(col_idx)).collect();
98 Arc::new(Float64Array::from(values))
99 }
100 _ => {
101 let values: Vec<Option<String>> = rows
102 .iter()
103 .map(|r| r.try_get::<_, String>(col_idx).ok())
104 .collect();
105 Arc::new(StringArray::from(values))
106 }
107 };
108 arrays.push(array);
109 }
110
111 RecordBatch::try_new(schema.clone(), arrays)
112 .map_err(|e| format!("Arrow RecordBatch creation error: {e}"))
113}
114
115impl DataEngine {
116 pub fn read_postgres(
118 &self,
119 conn_str: &str,
120 table_name: &str,
121 ) -> Result<datafusion::prelude::DataFrame, String> {
122 let query = format!("SELECT * FROM \"{}\"", table_name.replace('"', "\"\""));
123 self.query_postgres(conn_str, &query, table_name)
124 }
125
126 pub fn query_postgres(
131 &self,
132 conn_str: &str,
133 query: &str,
134 register_as: &str,
135 ) -> Result<datafusion::prelude::DataFrame, String> {
136 let mut client = pg_connect(conn_str)?;
137
138 let result = self.query_postgres_cursor(&mut client, query);
140 let (batches, final_schema) = match result {
141 Ok(v) => v,
142 Err(_) => {
143 let mut client2 = pg_connect(conn_str)?;
145 self.query_postgres_simple(&mut client2, query)?
146 }
147 };
148
149 if batches.is_empty() {
150 return Err(format!("Query returned no rows: {query}"));
151 }
152
153 let _ = self.ctx.deregister_table(register_as);
154 self.register_batches(register_as, final_schema, batches)?;
155
156 self.rt
157 .block_on(self.ctx.table(register_as))
158 .map_err(|e| format!("Table reference error: {e}"))
159 }
160
161 fn query_postgres_cursor(
164 &self,
165 client: &mut Client,
166 query: &str,
167 ) -> Result<(Vec<RecordBatch>, Arc<Schema>), String> {
168 let mut txn = client
169 .transaction()
170 .map_err(|e| format!("PostgreSQL transaction error: {}", pg_error_detail(&e)))?;
171
172 txn.execute(
173 &format!("DECLARE _tl_cursor NO SCROLL CURSOR FOR {query}"),
174 &[],
175 )
176 .map_err(|e| format!("PostgreSQL DECLARE CURSOR error: {}", pg_error_detail(&e)))?;
177
178 let mut batches: Vec<RecordBatch> = Vec::new();
179 let mut schema: Option<Arc<Schema>> = None;
180
181 loop {
182 let rows = txn
183 .query(
184 &format!("FETCH {PG_CURSOR_FETCH_SIZE} FROM _tl_cursor"),
185 &[],
186 )
187 .map_err(|e| format!("PostgreSQL FETCH error: {}", pg_error_detail(&e)))?;
188
189 if rows.is_empty() {
190 break;
191 }
192
193 if schema.is_none() {
194 let columns = rows[0].columns();
195 let fields: Vec<Field> = columns
196 .iter()
197 .map(|col| Field::new(col.name(), pg_type_to_arrow(col.type_()), true))
198 .collect();
199 schema = Some(Arc::new(Schema::new(fields)));
200 }
201
202 let s = schema.as_ref().unwrap();
204 for chunk in rows.chunks(PG_BATCH_SIZE) {
205 let batch = build_record_batch(chunk, s)?;
206 batches.push(batch);
207 }
208 }
209
210 txn.execute("CLOSE _tl_cursor", &[]).ok();
211 txn.commit()
212 .map_err(|e| format!("PostgreSQL COMMIT error: {}", pg_error_detail(&e)))?;
213
214 let final_schema = schema.ok_or_else(|| "No schema from cursor query".to_string())?;
215 Ok((batches, final_schema))
216 }
217
218 fn query_postgres_simple(
220 &self,
221 client: &mut Client,
222 query: &str,
223 ) -> Result<(Vec<RecordBatch>, Arc<Schema>), String> {
224 let rows = client
225 .query(query, &[])
226 .map_err(|e| format!("PostgreSQL query error: {}", pg_error_detail(&e)))?;
227
228 if rows.is_empty() {
229 return Err("Query returned no rows".to_string());
230 }
231
232 let columns = rows[0].columns();
233 let fields: Vec<Field> = columns
234 .iter()
235 .map(|col| Field::new(col.name(), pg_type_to_arrow(col.type_()), true))
236 .collect();
237 let schema = Arc::new(Schema::new(fields));
238
239 let mut batches: Vec<RecordBatch> = Vec::new();
240 for chunk in rows.chunks(PG_BATCH_SIZE) {
241 let batch = build_record_batch(chunk, &schema)?;
242 batches.push(batch);
243 }
244
245 Ok((batches, schema))
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 #[ignore] fn test_read_postgres() {
256 let engine = DataEngine::new();
257 let df = engine
258 .read_postgres(
259 "host=localhost user=postgres password=postgres dbname=testdb",
260 "users",
261 )
262 .unwrap();
263 let batches = engine.collect(df).unwrap();
264 assert!(!batches.is_empty());
265 }
266
267 #[test]
268 #[ignore] fn test_query_postgres_custom_sql() {
270 let engine = DataEngine::new();
271 let df = engine
272 .query_postgres(
273 "host=localhost user=postgres password=postgres dbname=testdb",
274 "SELECT * FROM users LIMIT 10",
275 "limited_users",
276 )
277 .unwrap();
278 let batches = engine.collect(df).unwrap();
279 assert!(!batches.is_empty());
280 let total: usize = batches.iter().map(|b| b.num_rows()).sum();
281 assert!(total <= 10);
282 }
283}