nirv_engine/connectors/
postgres_connector.rs

1use async_trait::async_trait;
2use deadpool_postgres::{Config, Pool, Runtime};
3
4use std::time::{Duration, Instant};
5use tokio_postgres::{NoTls, Row as PgRow};
6
7use crate::connectors::connector_trait::{Connector, ConnectorInitConfig, ConnectorCapabilities};
8use crate::utils::{
9    types::{
10        ConnectorType, ConnectorQuery, QueryResult, Schema, ColumnMetadata, 
11        DataType, Row, Value, Index, QueryOperation, PredicateOperator
12    },
13    error::{ConnectorError, NirvResult},
14};
15
16/// PostgreSQL connector using tokio-postgres with connection pooling
17#[derive(Debug)]
18pub struct PostgresConnector {
19    pool: Option<Pool>,
20    connected: bool,
21}
22
23impl PostgresConnector {
24    /// Create a new PostgreSQL connector
25    pub fn new() -> Self {
26        Self {
27            pool: None,
28            connected: false,
29        }
30    }
31    
32    /// Convert PostgreSQL row to internal Row representation
33    fn convert_pg_row(&self, pg_row: &PgRow) -> NirvResult<Row> {
34        let mut values = Vec::new();
35        
36        for i in 0..pg_row.len() {
37            let value = self.convert_pg_value(pg_row, i)?;
38            values.push(value);
39        }
40        
41        Ok(Row::new(values))
42    }
43    
44    /// Convert PostgreSQL value to internal Value representation
45    fn convert_pg_value(&self, row: &PgRow, index: usize) -> NirvResult<Value> {
46        let column = &row.columns()[index];
47        let type_oid = column.type_().oid();
48        
49        // Handle NULL values first
50        if row.try_get::<_, Option<String>>(index).unwrap_or(None).is_none() {
51            return Ok(Value::Null);
52        }
53        
54        // Convert based on PostgreSQL type OID
55        match type_oid {
56            // Text types
57            25 | 1043 | 1042 => { // TEXT, VARCHAR, CHAR
58                let val: String = row.try_get(index)
59                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get text value: {}", e)))?;
60                Ok(Value::Text(val))
61            }
62            // Integer types
63            23 => { // INT4
64                let val: i32 = row.try_get(index)
65                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get int4 value: {}", e)))?;
66                Ok(Value::Integer(val as i64))
67            }
68            20 => { // INT8
69                let val: i64 = row.try_get(index)
70                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get int8 value: {}", e)))?;
71                Ok(Value::Integer(val))
72            }
73            21 => { // INT2
74                let val: i16 = row.try_get(index)
75                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get int2 value: {}", e)))?;
76                Ok(Value::Integer(val as i64))
77            }
78            // Float types
79            700 => { // FLOAT4
80                let val: f32 = row.try_get(index)
81                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get float4 value: {}", e)))?;
82                Ok(Value::Float(val as f64))
83            }
84            701 => { // FLOAT8
85                let val: f64 = row.try_get(index)
86                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get float8 value: {}", e)))?;
87                Ok(Value::Float(val))
88            }
89            // Boolean type
90            16 => { // BOOL
91                let val: bool = row.try_get(index)
92                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get bool value: {}", e)))?;
93                Ok(Value::Boolean(val))
94            }
95            // JSON types
96            114 | 3802 => { // JSON, JSONB
97                let val: String = row.try_get(index)
98                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get json value: {}", e)))?;
99                Ok(Value::Json(val))
100            }
101            // Date/Time types
102            1082 => { // DATE
103                let val: String = row.try_get(index)
104                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get date value: {}", e)))?;
105                Ok(Value::Date(val))
106            }
107            1114 | 1184 => { // TIMESTAMP, TIMESTAMPTZ
108                let val: String = row.try_get(index)
109                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get timestamp value: {}", e)))?;
110                Ok(Value::DateTime(val))
111            }
112            // Binary types
113            17 => { // BYTEA
114                let val: Vec<u8> = row.try_get(index)
115                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get bytea value: {}", e)))?;
116                Ok(Value::Binary(val))
117            }
118            // Default: convert to string
119            _ => {
120                let val: String = row.try_get(index)
121                    .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Failed to get value as string: {}", e)))?;
122                Ok(Value::Text(val))
123            }
124        }
125    }
126    
127    /// Convert PostgreSQL type OID to internal DataType
128    fn pg_type_to_data_type(&self, type_oid: u32) -> DataType {
129        match type_oid {
130            25 | 1043 | 1042 => DataType::Text,     // TEXT, VARCHAR, CHAR
131            23 | 20 | 21 => DataType::Integer,      // INT4, INT8, INT2
132            700 | 701 => DataType::Float,           // FLOAT4, FLOAT8
133            16 => DataType::Boolean,                // BOOL
134            114 | 3802 => DataType::Json,           // JSON, JSONB
135            1082 => DataType::Date,                 // DATE
136            1114 | 1184 => DataType::DateTime,      // TIMESTAMP, TIMESTAMPTZ
137            17 => DataType::Binary,                 // BYTEA
138            _ => DataType::Text,                    // Default to text
139        }
140    }
141    
142    /// Build SQL query from internal query representation
143    fn build_sql_query(&self, query: &crate::utils::types::InternalQuery) -> NirvResult<String> {
144        match query.operation {
145            QueryOperation::Select => {
146                let mut sql = String::from("SELECT ");
147                
148                // Handle projections
149                if query.projections.is_empty() {
150                    sql.push('*');
151                } else {
152                    let projections: Vec<String> = query.projections.iter()
153                        .map(|col| {
154                            if let Some(alias) = &col.alias {
155                                format!("{} AS {}", col.name, alias)
156                            } else {
157                                col.name.clone()
158                            }
159                        })
160                        .collect();
161                    sql.push_str(&projections.join(", "));
162                }
163                
164                // Handle FROM clause
165                if let Some(source) = query.sources.first() {
166                    sql.push_str(" FROM ");
167                    sql.push_str(&source.identifier);
168                    if let Some(alias) = &source.alias {
169                        sql.push_str(" AS ");
170                        sql.push_str(alias);
171                    }
172                } else {
173                    return Err(ConnectorError::QueryExecutionFailed(
174                        "No data source specified in query".to_string()
175                    ).into());
176                }
177                
178                // Handle WHERE clause
179                if !query.predicates.is_empty() {
180                    sql.push_str(" WHERE ");
181                    let predicates: Vec<String> = query.predicates.iter()
182                        .map(|pred| self.build_predicate_sql(pred))
183                        .collect::<Result<Vec<_>, _>>()?;
184                    sql.push_str(&predicates.join(" AND "));
185                }
186                
187                // Handle ORDER BY
188                if let Some(order_by) = &query.ordering {
189                    sql.push_str(" ORDER BY ");
190                    let order_columns: Vec<String> = order_by.columns.iter()
191                        .map(|col| {
192                            let direction = match col.direction {
193                                crate::utils::types::OrderDirection::Ascending => "ASC",
194                                crate::utils::types::OrderDirection::Descending => "DESC",
195                            };
196                            format!("{} {}", col.column, direction)
197                        })
198                        .collect();
199                    sql.push_str(&order_columns.join(", "));
200                }
201                
202                // Handle LIMIT
203                if let Some(limit) = query.limit {
204                    sql.push_str(&format!(" LIMIT {}", limit));
205                }
206                
207                Ok(sql)
208            }
209            _ => Err(ConnectorError::UnsupportedOperation(
210                format!("Operation {:?} not supported by PostgreSQL connector", query.operation)
211            ).into()),
212        }
213    }
214    
215    /// Build SQL for a single predicate
216    fn build_predicate_sql(&self, predicate: &crate::utils::types::Predicate) -> NirvResult<String> {
217        let operator_sql = match predicate.operator {
218            PredicateOperator::Equal => "=",
219            PredicateOperator::NotEqual => "!=",
220            PredicateOperator::GreaterThan => ">",
221            PredicateOperator::GreaterThanOrEqual => ">=",
222            PredicateOperator::LessThan => "<",
223            PredicateOperator::LessThanOrEqual => "<=",
224            PredicateOperator::Like => "LIKE",
225            PredicateOperator::IsNull => "IS NULL",
226            PredicateOperator::IsNotNull => "IS NOT NULL",
227            PredicateOperator::In => "IN",
228        };
229        
230        match predicate.operator {
231            PredicateOperator::IsNull | PredicateOperator::IsNotNull => {
232                Ok(format!("{} {}", predicate.column, operator_sql))
233            }
234            PredicateOperator::In => {
235                if let crate::utils::types::PredicateValue::List(values) = &predicate.value {
236                    let value_strings: Vec<String> = values.iter()
237                        .map(|v| self.format_predicate_value(v))
238                        .collect::<Result<Vec<_>, _>>()?;
239                    Ok(format!("{} IN ({})", predicate.column, value_strings.join(", ")))
240                } else {
241                    Err(ConnectorError::QueryExecutionFailed(
242                        "IN operator requires a list of values".to_string()
243                    ).into())
244                }
245            }
246            _ => {
247                let value_str = self.format_predicate_value(&predicate.value)?;
248                Ok(format!("{} {} {}", predicate.column, operator_sql, value_str))
249            }
250        }
251    }
252    
253    /// Format predicate value for SQL
254    fn format_predicate_value(&self, value: &crate::utils::types::PredicateValue) -> NirvResult<String> {
255        match value {
256            crate::utils::types::PredicateValue::String(s) => Ok(format!("'{}'", s.replace('\'', "''"))),
257            crate::utils::types::PredicateValue::Number(n) => Ok(n.to_string()),
258            crate::utils::types::PredicateValue::Integer(i) => Ok(i.to_string()),
259            crate::utils::types::PredicateValue::Boolean(b) => Ok(b.to_string()),
260            crate::utils::types::PredicateValue::Null => Ok("NULL".to_string()),
261            crate::utils::types::PredicateValue::List(_) => {
262                Err(ConnectorError::QueryExecutionFailed(
263                    "List values should be handled by IN operator".to_string()
264                ).into())
265            }
266        }
267    }
268}
269
270impl Default for PostgresConnector {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276#[async_trait]
277impl Connector for PostgresConnector {
278    async fn connect(&mut self, config: ConnectorInitConfig) -> NirvResult<()> {
279        let host = config.connection_params.get("host")
280            .unwrap_or(&"localhost".to_string()).clone();
281        let port = config.connection_params.get("port")
282            .unwrap_or(&"5432".to_string())
283            .parse::<u16>()
284            .map_err(|e| ConnectorError::ConnectionFailed(format!("Invalid port: {}", e)))?;
285        let user = config.connection_params.get("user")
286            .unwrap_or(&"postgres".to_string()).clone();
287        let password = config.connection_params.get("password")
288            .unwrap_or(&"".to_string()).clone();
289        let dbname = config.connection_params.get("dbname")
290            .unwrap_or(&"postgres".to_string()).clone();
291        
292        let max_size = config.max_connections.unwrap_or(10) as usize;
293        let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(30));
294        
295        // Create deadpool configuration
296        let mut pg_config = Config::new();
297        pg_config.host = Some(host);
298        pg_config.port = Some(port);
299        pg_config.user = Some(user);
300        pg_config.password = Some(password);
301        pg_config.dbname = Some(dbname);
302        pg_config.pool = Some(deadpool_postgres::PoolConfig::new(max_size));
303        
304        // Create connection pool
305        let pool = pg_config.create_pool(Some(Runtime::Tokio1), NoTls)
306            .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to create pool: {}", e)))?;
307        
308        // Test the connection
309        let _client = tokio::time::timeout(timeout, pool.get()).await
310            .map_err(|_| ConnectorError::Timeout("Connection timeout".to_string()))?
311            .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to get connection: {}", e)))?;
312        
313        self.pool = Some(pool);
314        self.connected = true;
315        
316        Ok(())
317    }
318    
319    async fn execute_query(&self, query: ConnectorQuery) -> NirvResult<QueryResult> {
320        if !self.connected {
321            return Err(ConnectorError::ConnectionFailed("Not connected".to_string()).into());
322        }
323        
324        let pool = self.pool.as_ref()
325            .ok_or_else(|| ConnectorError::ConnectionFailed("No connection pool available".to_string()))?;
326        
327        let start_time = Instant::now();
328        
329        // Build SQL query
330        let sql = self.build_sql_query(&query.query)?;
331        
332        // Get connection from pool
333        let client = pool.get().await
334            .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to get connection from pool: {}", e)))?;
335        
336        // Execute query
337        let pg_rows = client.query(&sql, &[]).await
338            .map_err(|e| ConnectorError::QueryExecutionFailed(format!("Query execution failed: {}", e)))?;
339        
340        // Convert results
341        let mut columns = Vec::new();
342        let mut rows = Vec::new();
343        
344        if let Some(first_row) = pg_rows.first() {
345            // Extract column metadata
346            for column in first_row.columns() {
347                columns.push(ColumnMetadata {
348                    name: column.name().to_string(),
349                    data_type: self.pg_type_to_data_type(column.type_().oid()),
350                    nullable: true, // PostgreSQL doesn't provide nullable info in query results
351                });
352            }
353        }
354        
355        // Convert all rows
356        for pg_row in &pg_rows {
357            let row = self.convert_pg_row(pg_row)?;
358            rows.push(row);
359        }
360        
361        let execution_time = start_time.elapsed();
362        
363        Ok(QueryResult {
364            columns,
365            rows,
366            affected_rows: Some(pg_rows.len() as u64),
367            execution_time,
368        })
369    }
370    
371    async fn get_schema(&self, object_name: &str) -> NirvResult<Schema> {
372        if !self.connected {
373            return Err(ConnectorError::ConnectionFailed("Not connected".to_string()).into());
374        }
375        
376        let pool = self.pool.as_ref()
377            .ok_or_else(|| ConnectorError::ConnectionFailed("No connection pool available".to_string()))?;
378        
379        let client = pool.get().await
380            .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to get connection from pool: {}", e)))?;
381        
382        // Parse table name (handle schema.table format)
383        let (schema_name, table_name) = if object_name.contains('.') {
384            let parts: Vec<&str> = object_name.splitn(2, '.').collect();
385            (parts[0].to_string(), parts[1].to_string())
386        } else {
387            ("public".to_string(), object_name.to_string())
388        };
389        
390        // Query column information
391        let column_query = "
392            SELECT 
393                column_name,
394                data_type,
395                is_nullable,
396                udt_name,
397                ordinal_position
398            FROM information_schema.columns 
399            WHERE table_schema = $1 AND table_name = $2
400            ORDER BY ordinal_position
401        ";
402        
403        let column_rows = client.query(column_query, &[&schema_name, &table_name]).await
404            .map_err(|e| ConnectorError::SchemaRetrievalFailed(format!("Failed to retrieve column info: {}", e)))?;
405        
406        if column_rows.is_empty() {
407            return Err(ConnectorError::SchemaRetrievalFailed(
408                format!("Table '{}' not found", object_name)
409            ).into());
410        }
411        
412        let mut columns = Vec::new();
413        for row in &column_rows {
414            let column_name: String = row.get("column_name");
415            let data_type_str: String = row.get("data_type");
416            let is_nullable: String = row.get("is_nullable");
417            
418            let data_type = match data_type_str.as_str() {
419                "character varying" | "text" | "character" => DataType::Text,
420                "integer" | "bigint" | "smallint" => DataType::Integer,
421                "real" | "double precision" | "numeric" => DataType::Float,
422                "boolean" => DataType::Boolean,
423                "date" => DataType::Date,
424                "timestamp without time zone" | "timestamp with time zone" => DataType::DateTime,
425                "json" | "jsonb" => DataType::Json,
426                "bytea" => DataType::Binary,
427                _ => DataType::Text,
428            };
429            
430            columns.push(ColumnMetadata {
431                name: column_name,
432                data_type,
433                nullable: is_nullable == "YES",
434            });
435        }
436        
437        // Query primary key information
438        let pk_query = "
439            SELECT column_name
440            FROM information_schema.key_column_usage
441            WHERE table_schema = $1 AND table_name = $2
442            AND constraint_name IN (
443                SELECT constraint_name
444                FROM information_schema.table_constraints
445                WHERE table_schema = $1 AND table_name = $2
446                AND constraint_type = 'PRIMARY KEY'
447            )
448            ORDER BY ordinal_position
449        ";
450        
451        let pk_rows = client.query(pk_query, &[&schema_name, &table_name]).await
452            .map_err(|e| ConnectorError::SchemaRetrievalFailed(format!("Failed to retrieve primary key info: {}", e)))?;
453        
454        let primary_key = if pk_rows.is_empty() {
455            None
456        } else {
457            Some(pk_rows.iter().map(|row| row.get::<_, String>("column_name")).collect())
458        };
459        
460        // Query index information
461        let index_query = "
462            SELECT 
463                i.indexname,
464                array_agg(a.attname ORDER BY a.attnum) as columns,
465                i.indexdef LIKE '%UNIQUE%' as is_unique
466            FROM pg_indexes i
467            JOIN pg_class c ON c.relname = i.tablename
468            JOIN pg_namespace n ON n.oid = c.relnamespace
469            JOIN pg_index idx ON idx.indexrelid = (
470                SELECT oid FROM pg_class WHERE relname = i.indexname
471            )
472            JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(idx.indkey)
473            WHERE n.nspname = $1 AND i.tablename = $2
474            AND i.indexname NOT LIKE '%_pkey'
475            GROUP BY i.indexname, i.indexdef
476        ";
477        
478        let index_rows = client.query(index_query, &[&schema_name, &table_name]).await
479            .unwrap_or_else(|_| Vec::new()); // Ignore errors for index retrieval
480        
481        let mut indexes = Vec::new();
482        for row in &index_rows {
483            let index_name: String = row.get("indexname");
484            let columns_array: Vec<String> = row.get("columns");
485            let is_unique: bool = row.get("is_unique");
486            
487            indexes.push(Index {
488                name: index_name,
489                columns: columns_array,
490                unique: is_unique,
491            });
492        }
493        
494        Ok(Schema {
495            name: object_name.to_string(),
496            columns,
497            primary_key,
498            indexes,
499        })
500    }
501    
502    async fn disconnect(&mut self) -> NirvResult<()> {
503        self.pool = None;
504        self.connected = false;
505        Ok(())
506    }
507    
508    fn get_connector_type(&self) -> ConnectorType {
509        ConnectorType::PostgreSQL
510    }
511    
512    fn supports_transactions(&self) -> bool {
513        true
514    }
515    
516    fn is_connected(&self) -> bool {
517        self.connected
518    }
519    
520    fn get_capabilities(&self) -> ConnectorCapabilities {
521        ConnectorCapabilities {
522            supports_joins: true,
523            supports_aggregations: true,
524            supports_subqueries: true,
525            supports_transactions: true,
526            supports_schema_introspection: true,
527            max_concurrent_queries: Some(10),
528        }
529    }
530}