1use async_trait::async_trait;
2use std::collections::HashMap;
3use tokio::net::TcpStream;
4
5use crate::protocol::{ProtocolAdapter, ProtocolType, Connection, Credentials, ProtocolQuery, ProtocolResponse};
6use crate::utils::{NirvResult, ProtocolError, QueryResult, ColumnMetadata, Row, Value, DataType};
7
8const TDS_VERSION: u32 = 0x74000004; #[derive(Debug, Clone, PartialEq)]
13pub enum TdsPacketType {
14 SqlBatch = 0x01,
15 PreTds7Login = 0x02,
16 Rpc = 0x03,
17 TabularResult = 0x04,
18 AttentionSignal = 0x06,
19 BulkLoadData = 0x07,
20 FederatedAuthToken = 0x08,
21 TransactionManagerRequest = 0x0E,
22 Tds7Login = 0x10,
23 Sspi = 0x11,
24 PreLogin = 0x12,
25}
26
27#[derive(Debug, Clone, PartialEq)]
29pub enum TdsTokenType {
30 ColMetadata = 0x81,
31 Row = 0xD1,
32 Done = 0xFD,
33 DoneInProc = 0xFF,
34 DoneProc = 0xFE,
35 Error = 0xAA,
36 Info = 0xAB,
37 LoginAck = 0xAD,
38 EnvChange = 0xE3,
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub enum TdsDataType {
44 Null = 0x1F,
45 Int1 = 0x30,
46 Bit = 0x32,
47 Int2 = 0x34,
48 Int4 = 0x38,
49 DatetimeN = 0x6F,
50 Float8 = 0x3E,
51 Money = 0x3C,
52 DateTime = 0x3D,
53 Float4 = 0x3B,
54 Money4 = 0x7A,
55 Int8 = 0x7F,
56 BitN = 0x68,
57 IntN = 0x26,
58 FloatN = 0x6D,
59 NVarChar = 0xE7,
60 VarChar = 0xA7,
61 Binary = 0xAD,
62 VarBinary = 0xA5,
63}
64
65#[derive(Debug)]
67pub struct SqlServerProtocol {
68 }
70
71impl SqlServerProtocol {
72 pub fn new() -> Self {
74 Self {}
75 }
76
77 pub fn parse_login_packet(&self, data: &[u8]) -> NirvResult<HashMap<String, String>> {
79 if data.len() < 8 {
80 return Err(ProtocolError::InvalidMessageFormat("TDS packet too short".to_string()).into());
81 }
82
83 let packet_type = data[0];
85 let _status = data[1];
86 let length = u16::from_be_bytes([data[2], data[3]]) as usize;
87
88 if packet_type != TdsPacketType::Tds7Login as u8 {
89 return Err(ProtocolError::InvalidMessageFormat(
90 format!("Expected login packet, got type {}", packet_type)
91 ).into());
92 }
93
94 if data.len() < length {
95 return Err(ProtocolError::InvalidMessageFormat("Incomplete TDS packet".to_string()).into());
96 }
97
98 let mut params = HashMap::new();
100 params.insert("username".to_string(), "sa".to_string());
101 params.insert("database".to_string(), "master".to_string());
102 params.insert("application".to_string(), "NIRV Engine".to_string());
103
104 Ok(params)
105 }
106
107 pub fn parse_sql_batch(&self, data: &[u8]) -> NirvResult<String> {
109 if data.is_empty() {
110 return Err(ProtocolError::InvalidMessageFormat("Empty SQL batch".to_string()).into());
111 }
112
113 if data.len() % 2 != 0 {
115 return Err(ProtocolError::InvalidMessageFormat("Invalid UTF-16 data length".to_string()).into());
116 }
117
118 let mut utf16_chars = Vec::new();
119 for chunk in data.chunks_exact(2) {
120 let char_code = u16::from_le_bytes([chunk[0], chunk[1]]);
121 utf16_chars.push(char_code);
122 }
123
124 String::from_utf16(&utf16_chars)
125 .map_err(|e| ProtocolError::InvalidMessageFormat(format!("Invalid UTF-16: {}", e)).into())
126 }
127
128 fn create_tds_header(&self, packet_type: TdsPacketType, length: u16) -> Vec<u8> {
130 let mut header = Vec::with_capacity(8);
131 header.push(packet_type as u8);
132 header.push(0x01); header.extend_from_slice(&length.to_be_bytes());
134 header.extend_from_slice(&0u16.to_be_bytes()); header.push(0x01); header.push(0x00); header
138 }
139
140 fn create_login_ack(&self) -> Vec<u8> {
142 let mut response = Vec::new();
143
144 response.push(TdsTokenType::LoginAck as u8);
146
147 let length_pos = response.len();
149 response.extend_from_slice(&0u16.to_le_bytes());
150
151 response.push(0x01);
153
154 response.extend_from_slice(&TDS_VERSION.to_le_bytes());
156
157 let program_name = "Microsoft SQL Server";
159 response.push(program_name.len() as u8);
160 response.extend_from_slice(program_name.as_bytes());
161
162 response.extend_from_slice(&0x10000000u32.to_le_bytes());
164
165 let token_length = (response.len() - length_pos - 2) as u16;
167 response[length_pos..length_pos + 2].copy_from_slice(&token_length.to_le_bytes());
168
169 response
170 }
171
172 fn create_env_change(&self, change_type: u8, new_value: &str, old_value: &str) -> Vec<u8> {
174 let mut token = Vec::new();
175
176 token.push(TdsTokenType::EnvChange as u8);
178
179 let length_pos = token.len();
181 token.extend_from_slice(&0u16.to_le_bytes());
182
183 token.push(change_type);
185
186 token.push(new_value.len() as u8);
188 token.extend_from_slice(new_value.as_bytes());
189
190 token.push(old_value.len() as u8);
192 token.extend_from_slice(old_value.as_bytes());
193
194 let token_length = (token.len() - length_pos - 2) as u16;
196 token[length_pos..length_pos + 2].copy_from_slice(&token_length.to_le_bytes());
197
198 token
199 }
200
201 pub fn create_colmetadata(&self, columns: &[ColumnMetadata]) -> Vec<u8> {
203 let mut token = Vec::new();
204
205 token.push(TdsTokenType::ColMetadata as u8);
207
208 token.extend_from_slice(&(columns.len() as u16).to_le_bytes());
210
211 for column in columns {
212 let tds_type = self.datatype_to_tds_type(&column.data_type);
214 token.push(tds_type);
215
216 match column.data_type {
218 DataType::Text => {
219 token.extend_from_slice(&0xFFFFu16.to_le_bytes()); token.extend_from_slice(&0u32.to_le_bytes()); token.push(0); }
223 DataType::Integer => {
224 token.push(4); }
226 DataType::Float => {
227 token.push(8); }
229 DataType::Boolean => {
230 token.push(1); }
232 _ => {
233 token.push(0); }
235 }
236
237 let name_utf16: Vec<u16> = column.name.encode_utf16().collect();
239 token.push(name_utf16.len() as u8);
240 for ch in name_utf16 {
241 token.extend_from_slice(&ch.to_le_bytes());
242 }
243 }
244
245 token
246 }
247
248 pub fn create_row(&self, row: &Row, columns: &[ColumnMetadata]) -> Vec<u8> {
250 let mut token = Vec::new();
251
252 token.push(TdsTokenType::Row as u8);
254
255 for (i, value) in row.values.iter().enumerate() {
256 let _column_type = if i < columns.len() {
257 &columns[i].data_type
258 } else {
259 &DataType::Text
260 };
261
262 match value {
263 Value::Null => {
264 token.push(0); }
266 Value::Integer(val) => {
267 token.push(4); token.extend_from_slice(&(*val as i32).to_le_bytes());
269 }
270 Value::Float(val) => {
271 token.push(8); token.extend_from_slice(&val.to_le_bytes());
273 }
274 Value::Boolean(val) => {
275 token.push(1); token.push(if *val { 1 } else { 0 });
277 }
278 Value::Text(val) => {
279 let utf16: Vec<u16> = val.encode_utf16().collect();
280 let byte_len = utf16.len() * 2;
281 token.extend_from_slice(&(byte_len as u16).to_le_bytes());
282 for ch in utf16 {
283 token.extend_from_slice(&ch.to_le_bytes());
284 }
285 }
286 _ => {
287 let str_val = format!("{:?}", value);
289 let utf16: Vec<u16> = str_val.encode_utf16().collect();
290 let byte_len = utf16.len() * 2;
291 token.extend_from_slice(&(byte_len as u16).to_le_bytes());
292 for ch in utf16 {
293 token.extend_from_slice(&ch.to_le_bytes());
294 }
295 }
296 }
297 }
298
299 token
300 }
301
302 pub fn create_done(&self, status: u16, cur_cmd: u16, row_count: u64) -> Vec<u8> {
304 let mut token = Vec::new();
305
306 token.push(TdsTokenType::Done as u8);
308
309 token.extend_from_slice(&status.to_le_bytes());
311
312 token.extend_from_slice(&cur_cmd.to_le_bytes());
314
315 token.extend_from_slice(&row_count.to_le_bytes());
317
318 token
319 }
320
321 pub fn create_error_response(&self, error_number: u32, message: &str, severity: u8) -> Vec<u8> {
323 let mut response = Vec::new();
324
325 let header = self.create_tds_header(TdsPacketType::TabularResult, 0);
327 response.extend_from_slice(&header);
328
329 response.push(TdsTokenType::Error as u8);
331
332 let length_pos = response.len();
334 response.extend_from_slice(&0u16.to_le_bytes());
335
336 response.extend_from_slice(&error_number.to_le_bytes());
338
339 response.push(1);
341
342 response.push(severity);
344
345 response.extend_from_slice(&(message.len() as u16).to_le_bytes());
347 response.extend_from_slice(message.as_bytes());
348
349 response.push(0);
351
352 response.push(0);
354
355 response.extend_from_slice(&0u32.to_le_bytes());
357
358 let token_length = (response.len() - length_pos - 2) as u16;
360 response[length_pos..length_pos + 2].copy_from_slice(&token_length.to_le_bytes());
361
362 let total_length = response.len() as u16;
364 response[2..4].copy_from_slice(&total_length.to_be_bytes());
365
366 response
367 }
368
369 fn datatype_to_tds_type(&self, data_type: &DataType) -> u8 {
371 match data_type {
372 DataType::Text => TdsDataType::NVarChar as u8,
373 DataType::Integer => TdsDataType::IntN as u8,
374 DataType::Float => TdsDataType::FloatN as u8,
375 DataType::Boolean => TdsDataType::BitN as u8,
376 DataType::Date => TdsDataType::DatetimeN as u8,
377 DataType::DateTime => TdsDataType::DatetimeN as u8,
378 DataType::Binary => TdsDataType::VarBinary as u8,
379 DataType::Json => TdsDataType::NVarChar as u8,
380 }
381 }
382
383 pub fn value_to_tds_type(&self, value: &Value) -> u8 {
385 match value {
386 Value::Null => TdsDataType::Null as u8,
387 Value::Integer(_) => TdsDataType::IntN as u8,
388 Value::Float(_) => TdsDataType::FloatN as u8,
389 Value::Boolean(_) => TdsDataType::BitN as u8,
390 Value::Text(_) => TdsDataType::NVarChar as u8,
391 Value::Date(_) => TdsDataType::DatetimeN as u8,
392 Value::DateTime(_) => TdsDataType::DatetimeN as u8,
393 Value::Binary(_) => TdsDataType::VarBinary as u8,
394 Value::Json(_) => TdsDataType::NVarChar as u8,
395 }
396 }
397
398
399}
400
401impl Default for SqlServerProtocol {
402 fn default() -> Self {
403 Self::new()
404 }
405}
406
407#[async_trait]
408impl ProtocolAdapter for SqlServerProtocol {
409 async fn accept_connection(&self, stream: TcpStream) -> NirvResult<Connection> {
410 let connection = Connection::new(stream, ProtocolType::SqlServer);
411
412 Ok(connection)
416 }
417
418 async fn authenticate(&self, conn: &mut Connection, credentials: Credentials) -> NirvResult<()> {
419 conn.authenticated = true;
423 conn.database = credentials.database;
424 conn.parameters.insert("username".to_string(), credentials.username);
425
426 if let Some(password) = credentials.password {
427 conn.parameters.insert("password".to_string(), password);
428 }
429
430 for (key, value) in credentials.parameters {
432 conn.parameters.insert(key, value);
433 }
434
435 let login_ack = self.create_login_ack();
437 let env_change = self.create_env_change(1, &conn.database, "");
438
439 let mut response = Vec::new();
440 let header = self.create_tds_header(
441 TdsPacketType::TabularResult,
442 (login_ack.len() + env_change.len()) as u16 + 8
443 );
444 response.extend_from_slice(&header);
445 response.extend_from_slice(&login_ack);
446 response.extend_from_slice(&env_change);
447
448 Ok(())
452 }
453
454 async fn handle_query(&self, conn: &Connection, _query: ProtocolQuery) -> NirvResult<ProtocolResponse> {
455 if !conn.authenticated {
456 return Err(ProtocolError::AuthenticationFailed("Connection not authenticated".to_string()).into());
457 }
458
459 let mock_result = QueryResult {
461 columns: vec![
462 ColumnMetadata {
463 name: "id".to_string(),
464 data_type: DataType::Integer,
465 nullable: false,
466 },
467 ColumnMetadata {
468 name: "name".to_string(),
469 data_type: DataType::Text,
470 nullable: true,
471 },
472 ],
473 rows: vec![
474 Row::new(vec![Value::Integer(1), Value::Text("Test User".to_string())]),
475 ],
476 affected_rows: Some(1),
477 execution_time: std::time::Duration::from_millis(5),
478 };
479
480 Ok(ProtocolResponse::new(mock_result, ProtocolType::SqlServer))
481 }
482
483 fn get_protocol_type(&self) -> ProtocolType {
484 ProtocolType::SqlServer
485 }
486
487 async fn parse_message(&self, _conn: &Connection, data: &[u8]) -> NirvResult<ProtocolQuery> {
488 if data.len() < 8 {
489 return Err(ProtocolError::InvalidMessageFormat("TDS packet too short".to_string()).into());
490 }
491
492 let packet_type = data[0];
493
494 match packet_type {
495 x if x == TdsPacketType::SqlBatch as u8 => {
496 let sql_text = self.parse_sql_batch(&data[8..])?;
497 Ok(ProtocolQuery::new(sql_text, ProtocolType::SqlServer))
498 }
499 x if x == TdsPacketType::Tds7Login as u8 => {
500 Ok(ProtocolQuery::new("LOGIN".to_string(), ProtocolType::SqlServer))
502 }
503 _ => {
504 Err(ProtocolError::UnsupportedFeature(
505 format!("Unsupported TDS packet type: {}", packet_type)
506 ).into())
507 }
508 }
509 }
510
511 async fn format_response(&self, _conn: &Connection, result: QueryResult) -> NirvResult<Vec<u8>> {
512 let mut response = Vec::new();
513
514 let colmetadata = self.create_colmetadata(&result.columns);
516
517 let mut rows_data = Vec::new();
519 for row in &result.rows {
520 let row_data = self.create_row(row, &result.columns);
521 rows_data.extend_from_slice(&row_data);
522 }
523
524 let done = self.create_done(0x0010, 0xC1, result.rows.len() as u64); let mut tokens = Vec::new();
529 tokens.extend_from_slice(&colmetadata);
530 tokens.extend_from_slice(&rows_data);
531 tokens.extend_from_slice(&done);
532
533 let header = self.create_tds_header(TdsPacketType::TabularResult, (tokens.len() + 8) as u16);
535
536 response.extend_from_slice(&header);
537 response.extend_from_slice(&tokens);
538
539 Ok(response)
540 }
541
542 async fn terminate_connection(&self, conn: &mut Connection) -> NirvResult<()> {
543 conn.authenticated = false;
544 conn.database.clear();
545 conn.parameters.clear();
546
547 Ok(())
551 }
552}