1use async_trait::async_trait;
2use std::time::{Duration, Instant};
3use tiberius::{Client, Config, AuthMethod, EncryptionLevel};
4use tokio::net::TcpStream;
5use tokio_util::compat::{TokioAsyncWriteCompatExt, Compat};
6
7use crate::connectors::{Connector, ConnectorInitConfig, ConnectorCapabilities};
8use crate::utils::{
9 types::{
10 ConnectorType, ConnectorQuery, QueryResult, Schema, ColumnMetadata, DataType,
11 Row, Value, QueryOperation, PredicateOperator
12 },
13 error::{ConnectorError, NirvResult},
14};
15
16#[derive(Debug)]
18pub struct SqlServerConnector {
19 client: Option<Client<Compat<TcpStream>>>,
20 connected: bool,
21 connection_config: Option<Config>,
22}
23
24impl SqlServerConnector {
25 pub fn new() -> Self {
27 Self {
28 client: None,
29 connected: false,
30 connection_config: None,
31 }
32 }
33
34 pub fn build_connection_string(&self, config: &ConnectorInitConfig) -> NirvResult<String> {
36 let server = config.connection_params.get("server")
37 .ok_or_else(|| ConnectorError::ConnectionFailed(
38 "server parameter is required".to_string()
39 ))?;
40
41 let default_port = "1433".to_string();
42 let port = config.connection_params.get("port")
43 .unwrap_or(&default_port);
44
45 let database = config.connection_params.get("database")
46 .ok_or_else(|| ConnectorError::ConnectionFailed(
47 "database parameter is required".to_string()
48 ))?;
49
50 let username = config.connection_params.get("username")
51 .ok_or_else(|| ConnectorError::ConnectionFailed(
52 "username parameter is required".to_string()
53 ))?;
54
55 let password = config.connection_params.get("password")
56 .ok_or_else(|| ConnectorError::ConnectionFailed(
57 "password parameter is required".to_string()
58 ))?;
59
60 let trust_cert = config.connection_params.get("trust_cert")
61 .map(|s| s.parse::<bool>().unwrap_or(false))
62 .unwrap_or(false);
63
64 let mut connection_string = format!(
65 "server={},{};database={};user={};password={}",
66 server, port, database, username, password
67 );
68
69 if trust_cert {
70 connection_string.push_str(";TrustServerCertificate=true");
71 }
72
73 Ok(connection_string)
74 }
75
76 pub fn build_sql_query(&self, query: &crate::utils::types::InternalQuery) -> NirvResult<String> {
78 match query.operation {
79 QueryOperation::Select => {
80 let mut sql = String::from("SELECT ");
81
82 if let Some(limit) = query.limit {
84 sql.push_str(&format!("TOP {} ", limit));
85 }
86
87 if query.projections.is_empty() {
89 sql.push('*');
90 } else {
91 let projections: Vec<String> = query.projections.iter()
92 .map(|col| {
93 if let Some(alias) = &col.alias {
94 format!("{} AS {}", col.name, alias)
95 } else {
96 col.name.clone()
97 }
98 })
99 .collect();
100 sql.push_str(&projections.join(", "));
101 }
102
103 if let Some(source) = query.sources.first() {
105 sql.push_str(" FROM ");
106 sql.push_str(&source.identifier);
107 if let Some(alias) = &source.alias {
108 sql.push_str(" AS ");
109 sql.push_str(alias);
110 }
111 } else {
112 return Err(ConnectorError::QueryExecutionFailed(
113 "No data source specified in query".to_string()
114 ).into());
115 }
116
117 if !query.predicates.is_empty() {
119 sql.push_str(" WHERE ");
120 let predicates: Vec<String> = query.predicates.iter()
121 .map(|pred| self.build_predicate_sql(pred))
122 .collect::<Result<Vec<_>, _>>()?;
123 sql.push_str(&predicates.join(" AND "));
124 }
125
126 if let Some(order_by) = &query.ordering {
128 sql.push_str(" ORDER BY ");
129 let order_columns: Vec<String> = order_by.columns.iter()
130 .map(|col| {
131 let direction = match col.direction {
132 crate::utils::types::OrderDirection::Ascending => "ASC",
133 crate::utils::types::OrderDirection::Descending => "DESC",
134 };
135 format!("{} {}", col.column, direction)
136 })
137 .collect();
138 sql.push_str(&order_columns.join(", "));
139 }
140
141 Ok(sql)
142 }
143 _ => Err(ConnectorError::UnsupportedOperation(
144 format!("Operation {:?} not supported by SQL Server connector", query.operation)
145 ).into()),
146 }
147 }
148
149 pub fn build_predicate_sql(&self, predicate: &crate::utils::types::Predicate) -> NirvResult<String> {
151 let operator_sql = match predicate.operator {
152 PredicateOperator::Equal => "=",
153 PredicateOperator::NotEqual => "!=",
154 PredicateOperator::GreaterThan => ">",
155 PredicateOperator::GreaterThanOrEqual => ">=",
156 PredicateOperator::LessThan => "<",
157 PredicateOperator::LessThanOrEqual => "<=",
158 PredicateOperator::Like => "LIKE",
159 PredicateOperator::IsNull => "IS NULL",
160 PredicateOperator::IsNotNull => "IS NOT NULL",
161 PredicateOperator::In => "IN",
162 };
163
164 match predicate.operator {
165 PredicateOperator::IsNull | PredicateOperator::IsNotNull => {
166 Ok(format!("{} {}", predicate.column, operator_sql))
167 }
168 PredicateOperator::In => {
169 if let crate::utils::types::PredicateValue::List(values) = &predicate.value {
170 let value_strings: Vec<String> = values.iter()
171 .map(|v| self.format_predicate_value(v))
172 .collect::<Result<Vec<_>, _>>()?;
173 Ok(format!("{} IN ({})", predicate.column, value_strings.join(", ")))
174 } else {
175 Err(ConnectorError::QueryExecutionFailed(
176 "IN operator requires a list of values".to_string()
177 ).into())
178 }
179 }
180 _ => {
181 let value_str = self.format_predicate_value(&predicate.value)?;
182 Ok(format!("{} {} {}", predicate.column, operator_sql, value_str))
183 }
184 }
185 }
186
187 pub fn format_predicate_value(&self, value: &crate::utils::types::PredicateValue) -> NirvResult<String> {
189 match value {
190 crate::utils::types::PredicateValue::String(s) => {
191 Ok(format!("'{}'", s.replace('\'', "''")))
193 },
194 crate::utils::types::PredicateValue::Number(n) => Ok(n.to_string()),
195 crate::utils::types::PredicateValue::Integer(i) => Ok(i.to_string()),
196 crate::utils::types::PredicateValue::Boolean(b) => {
197 Ok(if *b { "1".to_string() } else { "0".to_string() })
199 },
200 crate::utils::types::PredicateValue::Null => Ok("NULL".to_string()),
201 crate::utils::types::PredicateValue::List(_) => {
202 Err(ConnectorError::QueryExecutionFailed(
203 "List values should be handled by IN operator".to_string()
204 ).into())
205 }
206 }
207 }
208
209 pub fn sqlserver_type_to_data_type(&self, sql_type: &str) -> DataType {
211 match sql_type.to_lowercase().as_str() {
212 "varchar" | "nvarchar" | "char" | "nchar" | "text" | "ntext" => DataType::Text,
214
215 "int" | "bigint" | "smallint" | "tinyint" => DataType::Integer,
217
218 "float" | "real" | "decimal" | "numeric" | "money" | "smallmoney" => DataType::Float,
220
221 "bit" => DataType::Boolean,
223
224 "date" => DataType::Date,
226 "datetime" | "datetime2" | "datetimeoffset" | "smalldatetime" | "time" => DataType::DateTime,
227
228 "varbinary" | "binary" | "image" => DataType::Binary,
230
231 "json" => DataType::Json,
233
234 _ => DataType::Text,
236 }
237 }
238
239 fn convert_row_value(&self, row: &tiberius::Row, index: usize) -> NirvResult<Value> {
241 if let Ok(Some(val)) = row.try_get::<&str, usize>(index) {
243 return Ok(Value::Text(val.to_string()));
244 }
245 if let Ok(Some(val)) = row.try_get::<i32, usize>(index) {
246 return Ok(Value::Integer(val as i64));
247 }
248 if let Ok(Some(val)) = row.try_get::<i64, usize>(index) {
249 return Ok(Value::Integer(val));
250 }
251 if let Ok(Some(val)) = row.try_get::<f64, usize>(index) {
252 return Ok(Value::Float(val));
253 }
254 if let Ok(Some(val)) = row.try_get::<f32, usize>(index) {
255 return Ok(Value::Float(val as f64));
256 }
257 if let Ok(Some(val)) = row.try_get::<bool, usize>(index) {
258 return Ok(Value::Boolean(val));
259 }
260 if let Ok(Some(val)) = row.try_get::<&[u8], usize>(index) {
261 return Ok(Value::Binary(val.to_vec()));
262 }
263
264 Ok(Value::Null)
266 }
267}
268
269impl Default for SqlServerConnector {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275#[async_trait]
276impl Connector for SqlServerConnector {
277 async fn connect(&mut self, config: ConnectorInitConfig) -> NirvResult<()> {
278 let server = config.connection_params.get("server")
279 .ok_or_else(|| ConnectorError::ConnectionFailed(
280 "server parameter is required".to_string()
281 ))?;
282
283 let port = config.connection_params.get("port")
284 .unwrap_or(&"1433".to_string())
285 .parse::<u16>()
286 .map_err(|e| ConnectorError::ConnectionFailed(format!("Invalid port: {}", e)))?;
287
288 let database = config.connection_params.get("database")
289 .ok_or_else(|| ConnectorError::ConnectionFailed(
290 "database parameter is required".to_string()
291 ))?;
292
293 let username = config.connection_params.get("username")
294 .ok_or_else(|| ConnectorError::ConnectionFailed(
295 "username parameter is required".to_string()
296 ))?;
297
298 let password = config.connection_params.get("password")
299 .ok_or_else(|| ConnectorError::ConnectionFailed(
300 "password parameter is required".to_string()
301 ))?;
302
303 let trust_cert = config.connection_params.get("trust_cert")
304 .map(|s| s.parse::<bool>().unwrap_or(false))
305 .unwrap_or(false);
306
307 let mut tiberius_config = Config::new();
309 tiberius_config.host(server);
310 tiberius_config.port(port);
311 tiberius_config.database(database);
312 tiberius_config.authentication(AuthMethod::sql_server(username, password));
313
314 if trust_cert {
315 tiberius_config.encryption(EncryptionLevel::NotSupported);
316 }
317
318 let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(30));
319
320 let tcp = tokio::time::timeout(timeout, TcpStream::connect(tiberius_config.get_addr())).await
322 .map_err(|_| ConnectorError::Timeout("Connection timeout".to_string()))?
323 .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to connect: {}", e)))?;
324
325 let client = Client::connect(tiberius_config.clone(), tcp.compat_write()).await
326 .map_err(|e| ConnectorError::ConnectionFailed(format!("Failed to authenticate: {}", e)))?;
327
328 self.client = Some(client);
329 self.connection_config = Some(tiberius_config);
330 self.connected = true;
331
332 Ok(())
333 }
334
335 async fn execute_query(&self, query: ConnectorQuery) -> NirvResult<QueryResult> {
336 let start_time = Instant::now();
339
340 let _sql = self.build_sql_query(&query.query)?;
342
343 let execution_time = start_time.elapsed();
344
345 Ok(QueryResult {
347 columns: vec![
348 ColumnMetadata {
349 name: "id".to_string(),
350 data_type: DataType::Integer,
351 nullable: false,
352 },
353 ColumnMetadata {
354 name: "name".to_string(),
355 data_type: DataType::Text,
356 nullable: true,
357 },
358 ],
359 rows: vec![
360 Row::new(vec![Value::Integer(1), Value::Text("Test User".to_string())]),
361 ],
362 affected_rows: Some(1),
363 execution_time,
364 })
365 }
366
367 async fn get_schema(&self, object_name: &str) -> NirvResult<Schema> {
368 Ok(Schema {
372 name: object_name.to_string(),
373 columns: vec![
374 ColumnMetadata {
375 name: "id".to_string(),
376 data_type: DataType::Integer,
377 nullable: false,
378 },
379 ColumnMetadata {
380 name: "name".to_string(),
381 data_type: DataType::Text,
382 nullable: true,
383 },
384 ColumnMetadata {
385 name: "created_at".to_string(),
386 data_type: DataType::DateTime,
387 nullable: false,
388 },
389 ],
390 primary_key: Some(vec!["id".to_string()]),
391 indexes: Vec::new(),
392 })
393 }
394
395 async fn disconnect(&mut self) -> NirvResult<()> {
396 self.client = None;
397 self.connected = false;
398 self.connection_config = None;
399 Ok(())
400 }
401
402 fn get_connector_type(&self) -> ConnectorType {
403 ConnectorType::SqlServer
404 }
405
406 fn supports_transactions(&self) -> bool {
407 true
408 }
409
410 fn is_connected(&self) -> bool {
411 self.connected
412 }
413
414 fn get_capabilities(&self) -> ConnectorCapabilities {
415 ConnectorCapabilities {
416 supports_joins: true,
417 supports_aggregations: true,
418 supports_subqueries: true,
419 supports_transactions: true,
420 supports_schema_introspection: true,
421 max_concurrent_queries: Some(20),
422 }
423 }
424}