1use async_trait::async_trait;
2use deadpool_postgres::{Config, Pool, Runtime};
3
4use std::time::{Duration, Instant};
5use tokio_postgres::{NoTls, Row as PgRow};
6
7use crate::connectors::connector_trait::{Connector, ConnectorInitConfig, ConnectorCapabilities};
8use crate::utils::{
9 types::{
10 ConnectorType, ConnectorQuery, QueryResult, Schema, ColumnMetadata,
11 DataType, Row, Value, Index, QueryOperation, PredicateOperator
12 },
13 error::{ConnectorError, NirvResult},
14};
15
16#[derive(Debug)]
18pub struct PostgresConnector {
19 pool: Option<Pool>,
20 connected: bool,
21}
22
23impl PostgresConnector {
24 pub fn new() -> Self {
26 Self {
27 pool: None,
28 connected: false,
29 }
30 }
31
32 fn convert_pg_row(&self, pg_row: &PgRow) -> NirvResult<Row> {
34 let mut values = Vec::new();
35
36 for i in 0..pg_row.len() {
37 let value = self.convert_pg_value(pg_row, i)?;
38 values.push(value);
39 }
40
41 Ok(Row::new(values))
42 }
43
44 fn convert_pg_value(&self, row: &PgRow, index: usize) -> NirvResult<Value> {
46 let column = &row.columns()[index];
47 let type_oid = column.type_().oid();
48
49 if row.try_get::<_, Option<String>>(index).unwrap_or(None).is_none() {
51 return Ok(Value::Null);
52 }
53
54 match type_oid {
56 25 | 1043 | 1042 => { let val: String = row.try_get(index)
59 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get text value: {}", e)))?;
60 Ok(Value::Text(val))
61 }
62 23 => { let val: i32 = row.try_get(index)
65 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get int4 value: {}", e)))?;
66 Ok(Value::Integer(val as i64))
67 }
68 20 => { let val: i64 = row.try_get(index)
70 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get int8 value: {}", e)))?;
71 Ok(Value::Integer(val))
72 }
73 21 => { let val: i16 = row.try_get(index)
75 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get int2 value: {}", e)))?;
76 Ok(Value::Integer(val as i64))
77 }
78 700 => { let val: f32 = row.try_get(index)
81 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get float4 value: {}", e)))?;
82 Ok(Value::Float(val as f64))
83 }
84 701 => { let val: f64 = row.try_get(index)
86 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get float8 value: {}", e)))?;
87 Ok(Value::Float(val))
88 }
89 16 => { let val: bool = row.try_get(index)
92 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get bool value: {}", e)))?;
93 Ok(Value::Boolean(val))
94 }
95 114 | 3802 => { let val: String = row.try_get(index)
98 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get json value: {}", e)))?;
99 Ok(Value::Json(val))
100 }
101 1082 => { let val: String = row.try_get(index)
104 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get date value: {}", e)))?;
105 Ok(Value::Date(val))
106 }
107 1114 | 1184 => { let val: String = row.try_get(index)
109 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get timestamp value: {}", e)))?;
110 Ok(Value::DateTime(val))
111 }
112 17 => { let val: Vec<u8> = row.try_get(index)
115 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get bytea value: {}", e)))?;
116 Ok(Value::Binary(val))
117 }
118 _ => {
120 let val: String = row.try_get(index)
121 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get value as string: {}", e)))?;
122 Ok(Value::Text(val))
123 }
124 }
125 }
126
127 fn pg_type_to_data_type(&self, type_oid: u32) -> DataType {
129 match type_oid {
130 25 | 1043 | 1042 => DataType::Text, 23 | 20 | 21 => DataType::Integer, 700 | 701 => DataType::Float, 16 => DataType::Boolean, 114 | 3802 => DataType::Json, 1082 => DataType::Date, 1114 | 1184 => DataType::DateTime, 17 => DataType::Binary, _ => DataType::Text, }
140 }
141
142 fn build_sql_query(&self, query: &crate::utils::types::InternalQuery) -> NirvResult<String> {
144 match query.operation {
145 QueryOperation::Select => {
146 let mut sql = String::from("SELECT ");
147
148 if query.projections.is_empty() {
150 sql.push('*');
151 } else {
152 let projections: Vec<String> = query.projections.iter()
153 .map(|col| {
154 if let Some(alias) = &col.alias {
155 format!("{} AS {}", col.name, alias)
156 } else {
157 col.name.clone()
158 }
159 })
160 .collect();
161 sql.push_str(&projections.join(", "));
162 }
163
164 if let Some(source) = query.sources.first() {
166 sql.push_str(" FROM ");
167 sql.push_str(&source.identifier);
168 if let Some(alias) = &source.alias {
169 sql.push_str(" AS ");
170 sql.push_str(alias);
171 }
172 } else {
173 return Err(ConnectorError::QueryExecutionFailed(
174 "No data source specified in query".to_string()
175 ).into());
176 }
177
178 if !query.predicates.is_empty() {
180 sql.push_str(" WHERE ");
181 let predicates: Vec<String> = query.predicates.iter()
182 .map(|pred| self.build_predicate_sql(pred))
183 .collect::<Result<Vec<_>, _>>()?;
184 sql.push_str(&predicates.join(" AND "));
185 }
186
187 if let Some(order_by) = &query.ordering {
189 sql.push_str(" ORDER BY ");
190 let order_columns: Vec<String> = order_by.columns.iter()
191 .map(|col| {
192 let direction = match col.direction {
193 crate::utils::types::OrderDirection::Ascending => "ASC",
194 crate::utils::types::OrderDirection::Descending => "DESC",
195 };
196 format!("{} {}", col.column, direction)
197 })
198 .collect();
199 sql.push_str(&order_columns.join(", "));
200 }
201
202 if let Some(limit) = query.limit {
204 sql.push_str(&format!(" LIMIT {}", limit));
205 }
206
207 Ok(sql)
208 }
209 _ => Err(ConnectorError::UnsupportedOperation(
210 format!("Operation {:?} not supported by PostgreSQL connector", query.operation)
211 ).into()),
212 }
213 }
214
215 fn build_predicate_sql(&self, predicate: &crate::utils::types::Predicate) -> NirvResult<String> {
217 let operator_sql = match predicate.operator {
218 PredicateOperator::Equal => "=",
219 PredicateOperator::NotEqual => "!=",
220 PredicateOperator::GreaterThan => ">",
221 PredicateOperator::GreaterThanOrEqual => ">=",
222 PredicateOperator::LessThan => "<",
223 PredicateOperator::LessThanOrEqual => "<=",
224 PredicateOperator::Like => "LIKE",
225 PredicateOperator::IsNull => "IS NULL",
226 PredicateOperator::IsNotNull => "IS NOT NULL",
227 PredicateOperator::In => "IN",
228 };
229
230 match predicate.operator {
231 PredicateOperator::IsNull | PredicateOperator::IsNotNull => {
232 Ok(format!("{} {}", predicate.column, operator_sql))
233 }
234 PredicateOperator::In => {
235 if let crate::utils::types::PredicateValue::List(values) = &predicate.value {
236 let value_strings: Vec<String> = values.iter()
237 .map(|v| self.format_predicate_value(v))
238 .collect::<Result<Vec<_>, _>>()?;
239 Ok(format!("{} IN ({})", predicate.column, value_strings.join(", ")))
240 } else {
241 Err(ConnectorError::QueryExecutionFailed(
242 "IN operator requires a list of values".to_string()
243 ).into())
244 }
245 }
246 _ => {
247 let value_str = self.format_predicate_value(&predicate.value)?;
248 Ok(format!("{} {} {}", predicate.column, operator_sql, value_str))
249 }
250 }
251 }
252
253 fn format_predicate_value(&self, value: &crate::utils::types::PredicateValue) -> NirvResult<String> {
255 match value {
256 crate::utils::types::PredicateValue::String(s) => Ok(format!("'{}'", s.replace('\'', "''"))),
257 crate::utils::types::PredicateValue::Number(n) => Ok(n.to_string()),
258 crate::utils::types::PredicateValue::Integer(i) => Ok(i.to_string()),
259 crate::utils::types::PredicateValue::Boolean(b) => Ok(b.to_string()),
260 crate::utils::types::PredicateValue::Null => Ok("NULL".to_string()),
261 crate::utils::types::PredicateValue::List(_) => {
262 Err(ConnectorError::QueryExecutionFailed(
263 "List values should be handled by IN operator".to_string()
264 ).into())
265 }
266 }
267 }
268}
269
270impl Default for PostgresConnector {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276#[async_trait]
277impl Connector for PostgresConnector {
278 async fn connect(&mut self, config: ConnectorInitConfig) -> NirvResult<()> {
279 let host = config.connection_params.get("host")
280 .unwrap_or(&"localhost".to_string()).clone();
281 let port = config.connection_params.get("port")
282 .unwrap_or(&"5432".to_string())
283 .parse::<u16>()
284 .map_err(|e| ConnectorError::ConnectionFailed(format!("Invalid port: {}", e)))?;
285 let user = config.connection_params.get("user")
286 .unwrap_or(&"postgres".to_string()).clone();
287 let password = config.connection_params.get("password")
288 .unwrap_or(&"".to_string()).clone();
289 let dbname = config.connection_params.get("dbname")
290 .unwrap_or(&"postgres".to_string()).clone();
291
292 let max_size = config.max_connections.unwrap_or(10) as usize;
293 let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(30));
294
295 let mut pg_config = Config::new();
297 pg_config.host = Some(host);
298 pg_config.port = Some(port);
299 pg_config.user = Some(user);
300 pg_config.password = Some(password);
301 pg_config.dbname = Some(dbname);
302 pg_config.pool = Some(deadpool_postgres::PoolConfig::new(max_size));
303
304 let pool = pg_config.create_pool(Some(Runtime::Tokio1), NoTls)
306 .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to create pool: {}", e)))?;
307
308 let _client = tokio::time::timeout(timeout, pool.get()).await
310 .map_err(|_| ConnectorError::Timeout("Connection timeout".to_string()))?
311 .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to get connection: {}", e)))?;
312
313 self.pool = Some(pool);
314 self.connected = true;
315
316 Ok(())
317 }
318
319 async fn execute_query(&self, query: ConnectorQuery) -> NirvResult<QueryResult> {
320 if !self.connected {
321 return Err(ConnectorError::ConnectionFailed("Not connected".to_string()).into());
322 }
323
324 let pool = self.pool.as_ref()
325 .ok_or_else(|| ConnectorError::ConnectionFailed("No connection pool available".to_string()))?;
326
327 let start_time = Instant::now();
328
329 let sql = self.build_sql_query(&query.query)?;
331
332 let client = pool.get().await
334 .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to get connection from pool: {}", e)))?;
335
336 let pg_rows = client.query(&sql, &[]).await
338 .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Query execution failed: {}", e)))?;
339
340 let mut columns = Vec::new();
342 let mut rows = Vec::new();
343
344 if let Some(first_row) = pg_rows.first() {
345 for column in first_row.columns() {
347 columns.push(ColumnMetadata {
348 name: column.name().to_string(),
349 data_type: self.pg_type_to_data_type(column.type_().oid()),
350 nullable: true, });
352 }
353 }
354
355 for pg_row in &pg_rows {
357 let row = self.convert_pg_row(pg_row)?;
358 rows.push(row);
359 }
360
361 let execution_time = start_time.elapsed();
362
363 Ok(QueryResult {
364 columns,
365 rows,
366 affected_rows: Some(pg_rows.len() as u64),
367 execution_time,
368 })
369 }
370
371 async fn get_schema(&self, object_name: &str) -> NirvResult<Schema> {
372 if !self.connected {
373 return Err(ConnectorError::ConnectionFailed("Not connected".to_string()).into());
374 }
375
376 let pool = self.pool.as_ref()
377 .ok_or_else(|| ConnectorError::ConnectionFailed("No connection pool available".to_string()))?;
378
379 let client = pool.get().await
380 .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to get connection from pool: {}", e)))?;
381
382 let (schema_name, table_name) = if object_name.contains('.') {
384 let parts: Vec<&str> = object_name.splitn(2, '.').collect();
385 (parts[0].to_string(), parts[1].to_string())
386 } else {
387 ("public".to_string(), object_name.to_string())
388 };
389
390 let column_query = "
392 SELECT
393 column_name,
394 data_type,
395 is_nullable,
396 udt_name,
397 ordinal_position
398 FROM information_schema.columns
399 WHERE table_schema = $1 AND table_name = $2
400 ORDER BY ordinal_position
401 ";
402
403 let column_rows = client.query(column_query, &[&schema_name, &table_name]).await
404 .map_err(|e| ConnectorError::SchemaRetrievalFailed(format!("Failed to retrieve column info: {}", e)))?;
405
406 if column_rows.is_empty() {
407 return Err(ConnectorError::SchemaRetrievalFailed(
408 format!("Table '{}' not found", object_name)
409 ).into());
410 }
411
412 let mut columns = Vec::new();
413 for row in &column_rows {
414 let column_name: String = row.get("column_name");
415 let data_type_str: String = row.get("data_type");
416 let is_nullable: String = row.get("is_nullable");
417
418 let data_type = match data_type_str.as_str() {
419 "character varying" | "text" | "character" => DataType::Text,
420 "integer" | "bigint" | "smallint" => DataType::Integer,
421 "real" | "double precision" | "numeric" => DataType::Float,
422 "boolean" => DataType::Boolean,
423 "date" => DataType::Date,
424 "timestamp without time zone" | "timestamp with time zone" => DataType::DateTime,
425 "json" | "jsonb" => DataType::Json,
426 "bytea" => DataType::Binary,
427 _ => DataType::Text,
428 };
429
430 columns.push(ColumnMetadata {
431 name: column_name,
432 data_type,
433 nullable: is_nullable == "YES",
434 });
435 }
436
437 let pk_query = "
439 SELECT column_name
440 FROM information_schema.key_column_usage
441 WHERE table_schema = $1 AND table_name = $2
442 AND constraint_name IN (
443 SELECT constraint_name
444 FROM information_schema.table_constraints
445 WHERE table_schema = $1 AND table_name = $2
446 AND constraint_type = 'PRIMARY KEY'
447 )
448 ORDER BY ordinal_position
449 ";
450
451 let pk_rows = client.query(pk_query, &[&schema_name, &table_name]).await
452 .map_err(|e| ConnectorError::SchemaRetrievalFailed(format!("Failed to retrieve primary key info: {}", e)))?;
453
454 let primary_key = if pk_rows.is_empty() {
455 None
456 } else {
457 Some(pk_rows.iter().map(|row| row.get::<_, String>("column_name")).collect())
458 };
459
460 let index_query = "
462 SELECT
463 i.indexname,
464 array_agg(a.attname ORDER BY a.attnum) as columns,
465 i.indexdef LIKE '%UNIQUE%' as is_unique
466 FROM pg_indexes i
467 JOIN pg_class c ON c.relname = i.tablename
468 JOIN pg_namespace n ON n.oid = c.relnamespace
469 JOIN pg_index idx ON idx.indexrelid = (
470 SELECT oid FROM pg_class WHERE relname = i.indexname
471 )
472 JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(idx.indkey)
473 WHERE n.nspname = $1 AND i.tablename = $2
474 AND i.indexname NOT LIKE '%_pkey'
475 GROUP BY i.indexname, i.indexdef
476 ";
477
478 let index_rows = client.query(index_query, &[&schema_name, &table_name]).await
479 .unwrap_or_else(|_| Vec::new()); let mut indexes = Vec::new();
482 for row in &index_rows {
483 let index_name: String = row.get("indexname");
484 let columns_array: Vec<String> = row.get("columns");
485 let is_unique: bool = row.get("is_unique");
486
487 indexes.push(Index {
488 name: index_name,
489 columns: columns_array,
490 unique: is_unique,
491 });
492 }
493
494 Ok(Schema {
495 name: object_name.to_string(),
496 columns,
497 primary_key,
498 indexes,
499 })
500 }
501
502 async fn disconnect(&mut self) -> NirvResult<()> {
503 self.pool = None;
504 self.connected = false;
505 Ok(())
506 }
507
508 fn get_connector_type(&self) -> ConnectorType {
509 ConnectorType::PostgreSQL
510 }
511
512 fn supports_transactions(&self) -> bool {
513 true
514 }
515
516 fn is_connected(&self) -> bool {
517 self.connected
518 }
519
520 fn get_capabilities(&self) -> ConnectorCapabilities {
521 ConnectorCapabilities {
522 supports_joins: true,
523 supports_aggregations: true,
524 supports_subqueries: true,
525 supports_transactions: true,
526 supports_schema_introspection: true,
527 max_concurrent_queries: Some(10),
528 }
529 }
530}