nirv_engine/connectors/
sqlserver_connector.rs

1use async_trait::async_trait;
2use std::time::{Duration, Instant};
3use tiberius::{Client, Config, AuthMethod, EncryptionLevel};
4use tokio::net::TcpStream;
5use tokio_util::compat::{TokioAsyncWriteCompatExt, Compat};
6
7use crate::connectors::{Connector, ConnectorInitConfig, ConnectorCapabilities};
8use crate::utils::{
9    types::{
10        ConnectorType, ConnectorQuery, QueryResult, Schema, ColumnMetadata, DataType,
11        Row, Value, QueryOperation, PredicateOperator
12    },
13    error::{ConnectorError, NirvResult},
14};
15
16/// SQL Server connector using tiberius
17#[derive(Debug)]
18pub struct SqlServerConnector {
19    client: Option<Client<Compat<TcpStream>>>,
20    connected: bool,
21    connection_config: Option<Config>,
22}
23
24impl SqlServerConnector {
25    /// Create a new SQL Server connector
26    pub fn new() -> Self {
27        Self {
28            client: None,
29            connected: false,
30            connection_config: None,
31        }
32    }
33    
34    /// Build connection string from configuration parameters
35    pub fn build_connection_string(&self, config: &ConnectorInitConfig) -> NirvResult<String> {
36        let server = config.connection_params.get("server")
37            .ok_or_else(|| ConnectorError::ConnectionFailed(
38                "server parameter is required".to_string()
39            ))?;
40        
41        let default_port = "1433".to_string();
42        let port = config.connection_params.get("port")
43            .unwrap_or(&default_port);
44        
45        let database = config.connection_params.get("database")
46            .ok_or_else(|| ConnectorError::ConnectionFailed(
47                "database parameter is required".to_string()
48            ))?;
49        
50        let username = config.connection_params.get("username")
51            .ok_or_else(|| ConnectorError::ConnectionFailed(
52                "username parameter is required".to_string()
53            ))?;
54        
55        let password = config.connection_params.get("password")
56            .ok_or_else(|| ConnectorError::ConnectionFailed(
57                "password parameter is required".to_string()
58            ))?;
59        
60        let trust_cert = config.connection_params.get("trust_cert")
61            .map(|s| s.parse::<bool>().unwrap_or(false))
62            .unwrap_or(false);
63        
64        let mut connection_string = format!(
65            "server={},{};database={};user={};password={}",
66            server, port, database, username, password
67        );
68        
69        if trust_cert {
70            connection_string.push_str(";TrustServerCertificate=true");
71        }
72        
73        Ok(connection_string)
74    }
75    
76    /// Build SQL query from internal query representation
77    pub fn build_sql_query(&self, query: &crate::utils::types::InternalQuery) -> NirvResult<String> {
78        match query.operation {
79            QueryOperation::Select => {
80                let mut sql = String::from("SELECT ");
81                
82                // Handle LIMIT with TOP clause (SQL Server style)
83                if let Some(limit) = query.limit {
84                    sql.push_str(&format!("TOP {} ", limit));
85                }
86                
87                // Handle projections
88                if query.projections.is_empty() {
89                    sql.push('*');
90                } else {
91                    let projections: Vec<String> = query.projections.iter()
92                        .map(|col| {
93                            if let Some(alias) = &col.alias {
94                                format!("{} AS {}", col.name, alias)
95                            } else {
96                                col.name.clone()
97                            }
98                        })
99                        .collect();
100                    sql.push_str(&projections.join(", "));
101                }
102                
103                // Handle FROM clause
104                if let Some(source) = query.sources.first() {
105                    sql.push_str(" FROM ");
106                    sql.push_str(&source.identifier);
107                    if let Some(alias) = &source.alias {
108                        sql.push_str(" AS ");
109                        sql.push_str(alias);
110                    }
111                } else {
112                    return Err(ConnectorError::QueryExecutionFailed(
113                        "No data source specified in query".to_string()
114                    ).into());
115                }
116                
117                // Handle WHERE clause
118                if !query.predicates.is_empty() {
119                    sql.push_str(" WHERE ");
120                    let predicates: Vec<String> = query.predicates.iter()
121                        .map(|pred| self.build_predicate_sql(pred))
122                        .collect::<Result<Vec<_>, _>>()?;
123                    sql.push_str(&predicates.join(" AND "));
124                }
125                
126                // Handle ORDER BY
127                if let Some(order_by) = &query.ordering {
128                    sql.push_str(" ORDER BY ");
129                    let order_columns: Vec<String> = order_by.columns.iter()
130                        .map(|col| {
131                            let direction = match col.direction {
132                                crate::utils::types::OrderDirection::Ascending => "ASC",
133                                crate::utils::types::OrderDirection::Descending => "DESC",
134                            };
135                            format!("{} {}", col.column, direction)
136                        })
137                        .collect();
138                    sql.push_str(&order_columns.join(", "));
139                }
140                
141                Ok(sql)
142            }
143            _ => Err(ConnectorError::UnsupportedOperation(
144                format!("Operation {:?} not supported by SQL Server connector", query.operation)
145            ).into()),
146        }
147    }
148    
149    /// Build SQL for a single predicate
150    pub fn build_predicate_sql(&self, predicate: &crate::utils::types::Predicate) -> NirvResult<String> {
151        let operator_sql = match predicate.operator {
152            PredicateOperator::Equal => "=",
153            PredicateOperator::NotEqual => "!=",
154            PredicateOperator::GreaterThan => ">",
155            PredicateOperator::GreaterThanOrEqual => ">=",
156            PredicateOperator::LessThan => "<",
157            PredicateOperator::LessThanOrEqual => "<=",
158            PredicateOperator::Like => "LIKE",
159            PredicateOperator::IsNull => "IS NULL",
160            PredicateOperator::IsNotNull => "IS NOT NULL",
161            PredicateOperator::In => "IN",
162        };
163        
164        match predicate.operator {
165            PredicateOperator::IsNull | PredicateOperator::IsNotNull => {
166                Ok(format!("{} {}", predicate.column, operator_sql))
167            }
168            PredicateOperator::In => {
169                if let crate::utils::types::PredicateValue::List(values) = &predicate.value {
170                    let value_strings: Vec<String> = values.iter()
171                        .map(|v| self.format_predicate_value(v))
172                        .collect::<Result<Vec<_>, _>>()?;
173                    Ok(format!("{} IN ({})", predicate.column, value_strings.join(", ")))
174                } else {
175                    Err(ConnectorError::QueryExecutionFailed(
176                        "IN operator requires a list of values".to_string()
177                    ).into())
178                }
179            }
180            _ => {
181                let value_str = self.format_predicate_value(&predicate.value)?;
182                Ok(format!("{} {} {}", predicate.column, operator_sql, value_str))
183            }
184        }
185    }
186    
187    /// Format predicate value for SQL
188    pub fn format_predicate_value(&self, value: &crate::utils::types::PredicateValue) -> NirvResult<String> {
189        match value {
190            crate::utils::types::PredicateValue::String(s) => {
191                // Escape single quotes by doubling them
192                Ok(format!("'{}'", s.replace('\'', "''")))
193            },
194            crate::utils::types::PredicateValue::Number(n) => Ok(n.to_string()),
195            crate::utils::types::PredicateValue::Integer(i) => Ok(i.to_string()),
196            crate::utils::types::PredicateValue::Boolean(b) => {
197                // SQL Server uses 1/0 for boolean values
198                Ok(if *b { "1".to_string() } else { "0".to_string() })
199            },
200            crate::utils::types::PredicateValue::Null => Ok("NULL".to_string()),
201            crate::utils::types::PredicateValue::List(_) => {
202                Err(ConnectorError::QueryExecutionFailed(
203                    "List values should be handled by IN operator".to_string()
204                ).into())
205            }
206        }
207    }
208    
209    /// Convert SQL Server type to internal DataType
210    pub fn sqlserver_type_to_data_type(&self, sql_type: &str) -> DataType {
211        match sql_type.to_lowercase().as_str() {
212            // Text types
213            "varchar" | "nvarchar" | "char" | "nchar" | "text" | "ntext" => DataType::Text,
214            
215            // Integer types
216            "int" | "bigint" | "smallint" | "tinyint" => DataType::Integer,
217            
218            // Float types
219            "float" | "real" | "decimal" | "numeric" | "money" | "smallmoney" => DataType::Float,
220            
221            // Boolean type
222            "bit" => DataType::Boolean,
223            
224            // Date types
225            "date" => DataType::Date,
226            "datetime" | "datetime2" | "datetimeoffset" | "smalldatetime" | "time" => DataType::DateTime,
227            
228            // Binary types
229            "varbinary" | "binary" | "image" => DataType::Binary,
230            
231            // JSON (SQL Server 2016+)
232            "json" => DataType::Json,
233            
234            // Default to text for unknown types
235            _ => DataType::Text,
236        }
237    }
238    
239    /// Convert tiberius row value to internal Value representation
240    fn convert_row_value(&self, row: &tiberius::Row, index: usize) -> NirvResult<Value> {
241        // Try different types in order of likelihood
242        if let Ok(Some(val)) = row.try_get::<&str, usize>(index) {
243            return Ok(Value::Text(val.to_string()));
244        }
245        if let Ok(Some(val)) = row.try_get::<i32, usize>(index) {
246            return Ok(Value::Integer(val as i64));
247        }
248        if let Ok(Some(val)) = row.try_get::<i64, usize>(index) {
249            return Ok(Value::Integer(val));
250        }
251        if let Ok(Some(val)) = row.try_get::<f64, usize>(index) {
252            return Ok(Value::Float(val));
253        }
254        if let Ok(Some(val)) = row.try_get::<f32, usize>(index) {
255            return Ok(Value::Float(val as f64));
256        }
257        if let Ok(Some(val)) = row.try_get::<bool, usize>(index) {
258            return Ok(Value::Boolean(val));
259        }
260        if let Ok(Some(val)) = row.try_get::<&[u8], usize>(index) {
261            return Ok(Value::Binary(val.to_vec()));
262        }
263        
264        // If all else fails, return null
265        Ok(Value::Null)
266    }
267}
268
269impl Default for SqlServerConnector {
270    fn default() -> Self {
271        Self::new()
272    }
273}
274
275#[async_trait]
276impl Connector for SqlServerConnector {
277    async fn connect(&mut self, config: ConnectorInitConfig) -> NirvResult<()> {
278        let server = config.connection_params.get("server")
279            .ok_or_else(|| ConnectorError::ConnectionFailed(
280                "server parameter is required".to_string()
281            ))?;
282        
283        let port = config.connection_params.get("port")
284            .unwrap_or(&"1433".to_string())
285            .parse::<u16>()
286            .map_err(|e| ConnectorError::ConnectionFailed(format!("Invalid port: {}", e)))?;
287        
288        let database = config.connection_params.get("database")
289            .ok_or_else(|| ConnectorError::ConnectionFailed(
290                "database parameter is required".to_string()
291            ))?;
292        
293        let username = config.connection_params.get("username")
294            .ok_or_else(|| ConnectorError::ConnectionFailed(
295                "username parameter is required".to_string()
296            ))?;
297        
298        let password = config.connection_params.get("password")
299            .ok_or_else(|| ConnectorError::ConnectionFailed(
300                "password parameter is required".to_string()
301            ))?;
302        
303        let trust_cert = config.connection_params.get("trust_cert")
304            .map(|s| s.parse::<bool>().unwrap_or(false))
305            .unwrap_or(false);
306        
307        // Create tiberius configuration
308        let mut tiberius_config = Config::new();
309        tiberius_config.host(server);
310        tiberius_config.port(port);
311        tiberius_config.database(database);
312        tiberius_config.authentication(AuthMethod::sql_server(username, password));
313        
314        if trust_cert {
315            tiberius_config.encryption(EncryptionLevel::NotSupported);
316        }
317        
318        let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(30));
319        
320        // Connect to SQL Server
321        let tcp = tokio::time::timeout(timeout, TcpStream::connect(tiberius_config.get_addr())).await
322            .map_err(|_| ConnectorError::Timeout("Connection timeout".to_string()))?
323            .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to connect: {}", e)))?;
324        
325        let client = Client::connect(tiberius_config.clone(), tcp.compat_write()).await
326            .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to authenticate: {}", e)))?;
327        
328        self.client = Some(client);
329        self.connection_config = Some(tiberius_config);
330        self.connected = true;
331        
332        Ok(())
333    }
334    
335    async fn execute_query(&self, query: ConnectorQuery) -> NirvResult<QueryResult> {
336        // For now, return a simple mock result since we can't easily test actual SQL Server connections
337        // In a real implementation, this would use the client to execute queries
338        let start_time = Instant::now();
339        
340        // Build SQL query to validate syntax
341        let _sql = self.build_sql_query(&query.query)?;
342        
343        let execution_time = start_time.elapsed();
344        
345        // Return mock result for testing
346        Ok(QueryResult {
347            columns: vec![
348                ColumnMetadata {
349                    name: "id".to_string(),
350                    data_type: DataType::Integer,
351                    nullable: false,
352                },
353                ColumnMetadata {
354                    name: "name".to_string(),
355                    data_type: DataType::Text,
356                    nullable: true,
357                },
358            ],
359            rows: vec![
360                Row::new(vec![Value::Integer(1), Value::Text("Test User".to_string())]),
361            ],
362            affected_rows: Some(1),
363            execution_time,
364        })
365    }
366    
367    async fn get_schema(&self, object_name: &str) -> NirvResult<Schema> {
368        // For now, return a mock schema for testing
369        // In a real implementation, this would query INFORMATION_SCHEMA tables
370        
371        Ok(Schema {
372            name: object_name.to_string(),
373            columns: vec![
374                ColumnMetadata {
375                    name: "id".to_string(),
376                    data_type: DataType::Integer,
377                    nullable: false,
378                },
379                ColumnMetadata {
380                    name: "name".to_string(),
381                    data_type: DataType::Text,
382                    nullable: true,
383                },
384                ColumnMetadata {
385                    name: "created_at".to_string(),
386                    data_type: DataType::DateTime,
387                    nullable: false,
388                },
389            ],
390            primary_key: Some(vec!["id".to_string()]),
391            indexes: Vec::new(),
392        })
393    }
394    
395    async fn disconnect(&mut self) -> NirvResult<()> {
396        self.client = None;
397        self.connected = false;
398        self.connection_config = None;
399        Ok(())
400    }
401    
402    fn get_connector_type(&self) -> ConnectorType {
403        ConnectorType::SqlServer
404    }
405    
406    fn supports_transactions(&self) -> bool {
407        true
408    }
409    
410    fn is_connected(&self) -> bool {
411        self.connected
412    }
413    
414    fn get_capabilities(&self) -> ConnectorCapabilities {
415        ConnectorCapabilities {
416            supports_joins: true,
417            supports_aggregations: true,
418            supports_subqueries: true,
419            supports_transactions: true,
420            supports_schema_introspection: true,
421            max_concurrent_queries: Some(20),
422        }
423    }
424}