nirv_engine/protocol/
sqlserver_protocol.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use tokio::net::TcpStream;
4
5use crate::protocol::{ProtocolAdapter, ProtocolType, Connection, Credentials, ProtocolQuery, ProtocolResponse};
6use crate::utils::{NirvResult, ProtocolError, QueryResult, ColumnMetadata, Row, Value, DataType};
7
8/// SQL Server TDS (Tabular Data Stream) protocol version
9const TDS_VERSION: u32 = 0x74000004; // TDS 7.4
10
11/// TDS packet types
12#[derive(Debug, Clone, PartialEq)]
13pub enum TdsPacketType {
14    SqlBatch = 0x01,
15    PreTds7Login = 0x02,
16    Rpc = 0x03,
17    TabularResult = 0x04,
18    AttentionSignal = 0x06,
19    BulkLoadData = 0x07,
20    FederatedAuthToken = 0x08,
21    TransactionManagerRequest = 0x0E,
22    Tds7Login = 0x10,
23    Sspi = 0x11,
24    PreLogin = 0x12,
25}
26
27/// TDS token types for responses
28#[derive(Debug, Clone, PartialEq)]
29pub enum TdsTokenType {
30    ColMetadata = 0x81,
31    Row = 0xD1,
32    Done = 0xFD,
33    DoneInProc = 0xFF,
34    DoneProc = 0xFE,
35    Error = 0xAA,
36    Info = 0xAB,
37    LoginAck = 0xAD,
38    EnvChange = 0xE3,
39}
40
41/// SQL Server data types (TDS type codes)
42#[derive(Debug, Clone, PartialEq)]
43pub enum TdsDataType {
44    Null = 0x1F,
45    Int1 = 0x30,
46    Bit = 0x32,
47    Int2 = 0x34,
48    Int4 = 0x38,
49    DatetimeN = 0x6F,
50    Float8 = 0x3E,
51    Money = 0x3C,
52    DateTime = 0x3D,
53    Float4 = 0x3B,
54    Money4 = 0x7A,
55    Int8 = 0x7F,
56    BitN = 0x68,
57    IntN = 0x26,
58    FloatN = 0x6D,
59    NVarChar = 0xE7,
60    VarChar = 0xA7,
61    Binary = 0xAD,
62    VarBinary = 0xA5,
63}
64
65/// SQL Server protocol adapter implementation
66#[derive(Debug)]
67pub struct SqlServerProtocol {
68    // Configuration and state can be added here
69}
70
71impl SqlServerProtocol {
72    /// Create a new SQL Server protocol adapter
73    pub fn new() -> Self {
74        Self {}
75    }
76    
77    /// Parse a TDS login packet
78    pub fn parse_login_packet(&self, data: &[u8]) -> NirvResult<HashMap<String, String>> {
79        if data.len() < 8 {
80            return Err(ProtocolError::InvalidMessageFormat("TDS packet too short".to_string()).into());
81        }
82        
83        // Parse TDS header
84        let packet_type = data[0];
85        let _status = data[1];
86        let length = u16::from_be_bytes([data[2], data[3]]) as usize;
87        
88        if packet_type != TdsPacketType::Tds7Login as u8 {
89            return Err(ProtocolError::InvalidMessageFormat(
90                format!("Expected login packet, got type {}", packet_type)
91            ).into());
92        }
93        
94        if data.len() < length {
95            return Err(ProtocolError::InvalidMessageFormat("Incomplete TDS packet".to_string()).into());
96        }
97        
98        // For simplicity, return a mock parsed login
99        let mut params = HashMap::new();
100        params.insert("username".to_string(), "sa".to_string());
101        params.insert("database".to_string(), "master".to_string());
102        params.insert("application".to_string(), "NIRV Engine".to_string());
103        
104        Ok(params)
105    }
106    
107    /// Parse a SQL batch packet
108    pub fn parse_sql_batch(&self, data: &[u8]) -> NirvResult<String> {
109        if data.is_empty() {
110            return Err(ProtocolError::InvalidMessageFormat("Empty SQL batch".to_string()).into());
111        }
112        
113        // SQL Server sends SQL text as UTF-16LE
114        if data.len() % 2 != 0 {
115            return Err(ProtocolError::InvalidMessageFormat("Invalid UTF-16 data length".to_string()).into());
116        }
117        
118        let mut utf16_chars = Vec::new();
119        for chunk in data.chunks_exact(2) {
120            let char_code = u16::from_le_bytes([chunk[0], chunk[1]]);
121            utf16_chars.push(char_code);
122        }
123        
124        String::from_utf16(&utf16_chars)
125            .map_err(|e| ProtocolError::InvalidMessageFormat(format!("Invalid UTF-16: {}", e)).into())
126    }
127    
128    /// Create a TDS header
129    fn create_tds_header(&self, packet_type: TdsPacketType, length: u16) -> Vec<u8> {
130        let mut header = Vec::with_capacity(8);
131        header.push(packet_type as u8);
132        header.push(0x01); // Status: End of message
133        header.extend_from_slice(&length.to_be_bytes());
134        header.extend_from_slice(&0u16.to_be_bytes()); // SPID
135        header.push(0x01); // Packet ID
136        header.push(0x00); // Window
137        header
138    }
139    
140    /// Create a login acknowledgment response
141    fn create_login_ack(&self) -> Vec<u8> {
142        let mut response = Vec::new();
143        
144        // LoginAck token
145        response.push(TdsTokenType::LoginAck as u8);
146        
147        // Token length (placeholder, will be updated)
148        let length_pos = response.len();
149        response.extend_from_slice(&0u16.to_le_bytes());
150        
151        // Interface (1 byte) - SQL Server
152        response.push(0x01);
153        
154        // TDS version (4 bytes)
155        response.extend_from_slice(&TDS_VERSION.to_le_bytes());
156        
157        // Program name (variable length)
158        let program_name = "Microsoft SQL Server";
159        response.push(program_name.len() as u8);
160        response.extend_from_slice(program_name.as_bytes());
161        
162        // Program version (4 bytes)
163        response.extend_from_slice(&0x10000000u32.to_le_bytes());
164        
165        // Update token length
166        let token_length = (response.len() - length_pos - 2) as u16;
167        response[length_pos..length_pos + 2].copy_from_slice(&token_length.to_le_bytes());
168        
169        response
170    }
171    
172    /// Create an environment change token
173    fn create_env_change(&self, change_type: u8, new_value: &str, old_value: &str) -> Vec<u8> {
174        let mut token = Vec::new();
175        
176        // EnvChange token
177        token.push(TdsTokenType::EnvChange as u8);
178        
179        // Token length (placeholder)
180        let length_pos = token.len();
181        token.extend_from_slice(&0u16.to_le_bytes());
182        
183        // Change type
184        token.push(change_type);
185        
186        // New value
187        token.push(new_value.len() as u8);
188        token.extend_from_slice(new_value.as_bytes());
189        
190        // Old value
191        token.push(old_value.len() as u8);
192        token.extend_from_slice(old_value.as_bytes());
193        
194        // Update token length
195        let token_length = (token.len() - length_pos - 2) as u16;
196        token[length_pos..length_pos + 2].copy_from_slice(&token_length.to_le_bytes());
197        
198        token
199    }
200    
201    /// Create column metadata token
202    pub fn create_colmetadata(&self, columns: &[ColumnMetadata]) -> Vec<u8> {
203        let mut token = Vec::new();
204        
205        // ColMetadata token
206        token.push(TdsTokenType::ColMetadata as u8);
207        
208        // Column count
209        token.extend_from_slice(&(columns.len() as u16).to_le_bytes());
210        
211        for column in columns {
212            // Column metadata
213            let tds_type = self.datatype_to_tds_type(&column.data_type);
214            token.push(tds_type);
215            
216            // Type-specific metadata
217            match column.data_type {
218                DataType::Text => {
219                    token.extend_from_slice(&0xFFFFu16.to_le_bytes()); // Max length
220                    token.extend_from_slice(&0u32.to_le_bytes()); // Collation
221                    token.push(0); // Collation flags
222                }
223                DataType::Integer => {
224                    token.push(4); // Length
225                }
226                DataType::Float => {
227                    token.push(8); // Length
228                }
229                DataType::Boolean => {
230                    token.push(1); // Length
231                }
232                _ => {
233                    token.push(0); // Default length
234                }
235            }
236            
237            // Column name
238            let name_utf16: Vec<u16> = column.name.encode_utf16().collect();
239            token.push(name_utf16.len() as u8);
240            for ch in name_utf16 {
241                token.extend_from_slice(&ch.to_le_bytes());
242            }
243        }
244        
245        token
246    }
247    
248    /// Create a data row token
249    pub fn create_row(&self, row: &Row, columns: &[ColumnMetadata]) -> Vec<u8> {
250        let mut token = Vec::new();
251        
252        // Row token
253        token.push(TdsTokenType::Row as u8);
254        
255        for (i, value) in row.values.iter().enumerate() {
256            let _column_type = if i < columns.len() {
257                &columns[i].data_type
258            } else {
259                &DataType::Text
260            };
261            
262            match value {
263                Value::Null => {
264                    token.push(0); // NULL indicator
265                }
266                Value::Integer(val) => {
267                    token.push(4); // Length
268                    token.extend_from_slice(&(*val as i32).to_le_bytes());
269                }
270                Value::Float(val) => {
271                    token.push(8); // Length
272                    token.extend_from_slice(&val.to_le_bytes());
273                }
274                Value::Boolean(val) => {
275                    token.push(1); // Length
276                    token.push(if *val { 1 } else { 0 });
277                }
278                Value::Text(val) => {
279                    let utf16: Vec<u16> = val.encode_utf16().collect();
280                    let byte_len = utf16.len() * 2;
281                    token.extend_from_slice(&(byte_len as u16).to_le_bytes());
282                    for ch in utf16 {
283                        token.extend_from_slice(&ch.to_le_bytes());
284                    }
285                }
286                _ => {
287                    // Convert other types to string
288                    let str_val = format!("{:?}", value);
289                    let utf16: Vec<u16> = str_val.encode_utf16().collect();
290                    let byte_len = utf16.len() * 2;
291                    token.extend_from_slice(&(byte_len as u16).to_le_bytes());
292                    for ch in utf16 {
293                        token.extend_from_slice(&ch.to_le_bytes());
294                    }
295                }
296            }
297        }
298        
299        token
300    }
301    
302    /// Create a DONE token
303    pub fn create_done(&self, status: u16, cur_cmd: u16, row_count: u64) -> Vec<u8> {
304        let mut token = Vec::new();
305        
306        // Done token
307        token.push(TdsTokenType::Done as u8);
308        
309        // Status
310        token.extend_from_slice(&status.to_le_bytes());
311        
312        // Current command
313        token.extend_from_slice(&cur_cmd.to_le_bytes());
314        
315        // Row count
316        token.extend_from_slice(&row_count.to_le_bytes());
317        
318        token
319    }
320    
321    /// Create an error response
322    pub fn create_error_response(&self, error_number: u32, message: &str, severity: u8) -> Vec<u8> {
323        let mut response = Vec::new();
324        
325        // TDS header
326        let header = self.create_tds_header(TdsPacketType::TabularResult, 0);
327        response.extend_from_slice(&header);
328        
329        // Error token
330        response.push(TdsTokenType::Error as u8);
331        
332        // Token length (placeholder)
333        let length_pos = response.len();
334        response.extend_from_slice(&0u16.to_le_bytes());
335        
336        // Error number
337        response.extend_from_slice(&error_number.to_le_bytes());
338        
339        // State
340        response.push(1);
341        
342        // Severity
343        response.push(severity);
344        
345        // Message length and text
346        response.extend_from_slice(&(message.len() as u16).to_le_bytes());
347        response.extend_from_slice(message.as_bytes());
348        
349        // Server name (empty)
350        response.push(0);
351        
352        // Procedure name (empty)
353        response.push(0);
354        
355        // Line number
356        response.extend_from_slice(&0u32.to_le_bytes());
357        
358        // Update token length
359        let token_length = (response.len() - length_pos - 2) as u16;
360        response[length_pos..length_pos + 2].copy_from_slice(&token_length.to_le_bytes());
361        
362        // Update TDS header length
363        let total_length = response.len() as u16;
364        response[2..4].copy_from_slice(&total_length.to_be_bytes());
365        
366        response
367    }
368    
369    /// Convert internal DataType to TDS type code
370    fn datatype_to_tds_type(&self, data_type: &DataType) -> u8 {
371        match data_type {
372            DataType::Text => TdsDataType::NVarChar as u8,
373            DataType::Integer => TdsDataType::IntN as u8,
374            DataType::Float => TdsDataType::FloatN as u8,
375            DataType::Boolean => TdsDataType::BitN as u8,
376            DataType::Date => TdsDataType::DatetimeN as u8,
377            DataType::DateTime => TdsDataType::DatetimeN as u8,
378            DataType::Binary => TdsDataType::VarBinary as u8,
379            DataType::Json => TdsDataType::NVarChar as u8,
380        }
381    }
382    
383    /// Convert Value to TDS type code
384    pub fn value_to_tds_type(&self, value: &Value) -> u8 {
385        match value {
386            Value::Null => TdsDataType::Null as u8,
387            Value::Integer(_) => TdsDataType::IntN as u8,
388            Value::Float(_) => TdsDataType::FloatN as u8,
389            Value::Boolean(_) => TdsDataType::BitN as u8,
390            Value::Text(_) => TdsDataType::NVarChar as u8,
391            Value::Date(_) => TdsDataType::DatetimeN as u8,
392            Value::DateTime(_) => TdsDataType::DatetimeN as u8,
393            Value::Binary(_) => TdsDataType::VarBinary as u8,
394            Value::Json(_) => TdsDataType::NVarChar as u8,
395        }
396    }
397    
398
399}
400
401impl Default for SqlServerProtocol {
402    fn default() -> Self {
403        Self::new()
404    }
405}
406
407#[async_trait]
408impl ProtocolAdapter for SqlServerProtocol {
409    async fn accept_connection(&self, stream: TcpStream) -> NirvResult<Connection> {
410        let connection = Connection::new(stream, ProtocolType::SqlServer);
411        
412        // SQL Server connection setup would happen here
413        // For now, just return the connection
414        
415        Ok(connection)
416    }
417    
418    async fn authenticate(&self, conn: &mut Connection, credentials: Credentials) -> NirvResult<()> {
419        // In a real implementation, this would validate credentials
420        // For testing, we'll just mark as authenticated
421        
422        conn.authenticated = true;
423        conn.database = credentials.database;
424        conn.parameters.insert("username".to_string(), credentials.username);
425        
426        if let Some(password) = credentials.password {
427            conn.parameters.insert("password".to_string(), password);
428        }
429        
430        // Merge additional parameters
431        for (key, value) in credentials.parameters {
432            conn.parameters.insert(key, value);
433        }
434        
435        // Send login acknowledgment
436        let login_ack = self.create_login_ack();
437        let env_change = self.create_env_change(1, &conn.database, "");
438        
439        let mut response = Vec::new();
440        let header = self.create_tds_header(
441            TdsPacketType::TabularResult, 
442            (login_ack.len() + env_change.len()) as u16 + 8
443        );
444        response.extend_from_slice(&header);
445        response.extend_from_slice(&login_ack);
446        response.extend_from_slice(&env_change);
447        
448        // In a real implementation, we would write this to the stream
449        // conn.stream.write_all(&response).await?;
450        
451        Ok(())
452    }
453    
454    async fn handle_query(&self, conn: &Connection, _query: ProtocolQuery) -> NirvResult<ProtocolResponse> {
455        if !conn.authenticated {
456            return Err(ProtocolError::AuthenticationFailed("Connection not authenticated".to_string()).into());
457        }
458        
459        // For testing, return a mock result
460        let mock_result = QueryResult {
461            columns: vec![
462                ColumnMetadata {
463                    name: "id".to_string(),
464                    data_type: DataType::Integer,
465                    nullable: false,
466                },
467                ColumnMetadata {
468                    name: "name".to_string(),
469                    data_type: DataType::Text,
470                    nullable: true,
471                },
472            ],
473            rows: vec![
474                Row::new(vec![Value::Integer(1), Value::Text("Test User".to_string())]),
475            ],
476            affected_rows: Some(1),
477            execution_time: std::time::Duration::from_millis(5),
478        };
479        
480        Ok(ProtocolResponse::new(mock_result, ProtocolType::SqlServer))
481    }
482    
483    fn get_protocol_type(&self) -> ProtocolType {
484        ProtocolType::SqlServer
485    }
486    
487    async fn parse_message(&self, _conn: &Connection, data: &[u8]) -> NirvResult<ProtocolQuery> {
488        if data.len() < 8 {
489            return Err(ProtocolError::InvalidMessageFormat("TDS packet too short".to_string()).into());
490        }
491        
492        let packet_type = data[0];
493        
494        match packet_type {
495            x if x == TdsPacketType::SqlBatch as u8 => {
496                let sql_text = self.parse_sql_batch(&data[8..])?;
497                Ok(ProtocolQuery::new(sql_text, ProtocolType::SqlServer))
498            }
499            x if x == TdsPacketType::Tds7Login as u8 => {
500                // Return a dummy query for login packets
501                Ok(ProtocolQuery::new("LOGIN".to_string(), ProtocolType::SqlServer))
502            }
503            _ => {
504                Err(ProtocolError::UnsupportedFeature(
505                    format!("Unsupported TDS packet type: {}", packet_type)
506                ).into())
507            }
508        }
509    }
510    
511    async fn format_response(&self, _conn: &Connection, result: QueryResult) -> NirvResult<Vec<u8>> {
512        let mut response = Vec::new();
513        
514        // Create column metadata
515        let colmetadata = self.create_colmetadata(&result.columns);
516        
517        // Create data rows
518        let mut rows_data = Vec::new();
519        for row in &result.rows {
520            let row_data = self.create_row(row, &result.columns);
521            rows_data.extend_from_slice(&row_data);
522        }
523        
524        // Create DONE token
525        let done = self.create_done(0x0010, 0xC1, result.rows.len() as u64); // DONE_COUNT
526        
527        // Combine all tokens
528        let mut tokens = Vec::new();
529        tokens.extend_from_slice(&colmetadata);
530        tokens.extend_from_slice(&rows_data);
531        tokens.extend_from_slice(&done);
532        
533        // Create TDS header
534        let header = self.create_tds_header(TdsPacketType::TabularResult, (tokens.len() + 8) as u16);
535        
536        response.extend_from_slice(&header);
537        response.extend_from_slice(&tokens);
538        
539        Ok(response)
540    }
541    
542    async fn terminate_connection(&self, conn: &mut Connection) -> NirvResult<()> {
543        conn.authenticated = false;
544        conn.database.clear();
545        conn.parameters.clear();
546        
547        // In a real implementation, we would close the stream gracefully
548        // conn.stream.shutdown().await?;
549        
550        Ok(())
551    }
552}