1use async_trait::async_trait;
2use std::collections::HashMap;
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4use tokio::net::TcpStream;
5
6use crate::protocol::{ProtocolAdapter, ProtocolType, Connection, Credentials, ProtocolQuery, ProtocolResponse, ResponseFormat};
7use crate::utils::{NirvResult, ProtocolError, QueryResult, ColumnMetadata, Row, Value, DataType};
8
9const SQLITE_OPEN_READONLY: u32 = 0x00000001;
11const SQLITE_OPEN_READWRITE: u32 = 0x00000002;
12const SQLITE_OPEN_CREATE: u32 = 0x00000004;
13const SQLITE_OPEN_URI: u32 = 0x00000040;
14const SQLITE_OPEN_MEMORY: u32 = 0x00000080;
15
16const SQLITE_OK: u32 = 0;
18const SQLITE_ERROR: u32 = 1;
19const SQLITE_BUSY: u32 = 5;
20const SQLITE_NOMEM: u32 = 7;
21const SQLITE_READONLY: u32 = 8;
22const SQLITE_MISUSE: u32 = 21;
23
24#[derive(Debug, Clone, PartialEq)]
26pub enum SQLiteDataType {
27 Null = 0,
28 Integer = 1,
29 Real = 2,
30 Text = 3,
31 Blob = 4,
32}
33
34#[derive(Debug, Clone, PartialEq)]
36pub enum SQLiteCommand {
37 Connect,
38 Query,
39 Prepare,
40 Execute,
41 Close,
42}
43
44#[derive(Debug)]
50pub struct SQLiteProtocolAdapter {
51 database_path: String,
52 connection_flags: u32,
53 prepared_statements: HashMap<u32, String>,
54 next_statement_id: u32,
55}
56
57impl SQLiteProtocolAdapter {
58 pub fn new() -> Self {
60 Self {
61 database_path: ":memory:".to_string(),
62 connection_flags: SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
63 prepared_statements: HashMap::new(),
64 next_statement_id: 1,
65 }
66 }
67
68 pub fn with_database_path(database_path: String) -> Self {
70 let flags = if database_path == ":memory:" || database_path.is_empty() {
71 SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_MEMORY
72 } else {
73 SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE
74 };
75
76 Self {
77 database_path,
78 connection_flags: flags,
79 prepared_statements: HashMap::new(),
80 next_statement_id: 1,
81 }
82 }
83
84 fn parse_connection_request(&self, data: &[u8]) -> NirvResult<(String, u32)> {
86 if data.len() < 8 {
87 return Err(ProtocolError::InvalidMessageFormat("Connection request too short".to_string()).into());
88 }
89
90 let flags = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
92
93 let path_start = 4;
95 let path_end = data[path_start..].iter().position(|&b| b == 0)
96 .map(|pos| path_start + pos)
97 .unwrap_or(data.len());
98
99 let database_path = String::from_utf8_lossy(&data[path_start..path_end]).to_string();
100
101 Ok((database_path, flags))
102 }
103
104 fn create_ok_response(&self, changes: u32, last_insert_rowid: i64) -> Vec<u8> {
106 let mut response = Vec::new();
107
108 response.push(0);
110
111 response.extend_from_slice(&SQLITE_OK.to_le_bytes());
113
114 response.extend_from_slice(&changes.to_le_bytes());
116
117 response.extend_from_slice(&last_insert_rowid.to_le_bytes());
119
120 response
121 }
122
123 fn create_error_response(&self, error_code: u32, message: &str) -> Vec<u8> {
125 let mut response = Vec::new();
126
127 response.push(1);
129
130 response.extend_from_slice(&error_code.to_le_bytes());
132
133 response.extend_from_slice(&(message.len() as u32).to_le_bytes());
135
136 response.extend_from_slice(message.as_bytes());
138
139 response
140 }
141
142 fn create_row_response(&self, columns: &[ColumnMetadata], rows: &[Row]) -> Vec<u8> {
144 let mut response = Vec::new();
145
146 response.push(2);
148
149 response.extend_from_slice(&(columns.len() as u32).to_le_bytes());
151
152 for column in columns {
154 response.extend_from_slice(&(column.name.len() as u32).to_le_bytes());
156
157 response.extend_from_slice(column.name.as_bytes());
159
160 let sqlite_type = self.nirv_type_to_sqlite_type(&column.data_type);
162 response.push(sqlite_type as u8);
163
164 response.push(if column.nullable { 1 } else { 0 });
166 }
167
168 response.extend_from_slice(&(rows.len() as u32).to_le_bytes());
170
171 for row in rows {
173 for value in &row.values {
174 match value {
175 Value::Null => {
176 response.push(SQLiteDataType::Null as u8);
177 response.extend_from_slice(&0u32.to_le_bytes()); }
179 Value::Integer(i) => {
180 response.push(SQLiteDataType::Integer as u8);
181 response.extend_from_slice(&8u32.to_le_bytes()); response.extend_from_slice(&i.to_le_bytes());
183 }
184 Value::Float(f) => {
185 response.push(SQLiteDataType::Real as u8);
186 response.extend_from_slice(&8u32.to_le_bytes()); response.extend_from_slice(&f.to_le_bytes());
188 }
189 Value::Text(s) => {
190 response.push(SQLiteDataType::Text as u8);
191 response.extend_from_slice(&(s.len() as u32).to_le_bytes());
192 response.extend_from_slice(s.as_bytes());
193 }
194 Value::Binary(b) => {
195 response.push(SQLiteDataType::Blob as u8);
196 response.extend_from_slice(&(b.len() as u32).to_le_bytes());
197 response.extend_from_slice(b);
198 }
199 Value::Boolean(b) => {
200 response.push(SQLiteDataType::Integer as u8);
201 response.extend_from_slice(&8u32.to_le_bytes());
202 let int_val = if *b { 1i64 } else { 0i64 };
203 response.extend_from_slice(&int_val.to_le_bytes());
204 }
205 Value::Date(d) | Value::DateTime(d) => {
206 response.push(SQLiteDataType::Text as u8);
207 response.extend_from_slice(&(d.len() as u32).to_le_bytes());
208 response.extend_from_slice(d.as_bytes());
209 }
210 Value::Json(j) => {
211 response.push(SQLiteDataType::Text as u8);
212 response.extend_from_slice(&(j.len() as u32).to_le_bytes());
213 response.extend_from_slice(j.as_bytes());
214 }
215 }
216 }
217 }
218
219 response
220 }
221
222 fn nirv_type_to_sqlite_type(&self, data_type: &DataType) -> SQLiteDataType {
224 match data_type {
225 DataType::Text => SQLiteDataType::Text,
226 DataType::Integer => SQLiteDataType::Integer,
227 DataType::Float => SQLiteDataType::Real,
228 DataType::Boolean => SQLiteDataType::Integer,
229 DataType::Date => SQLiteDataType::Text,
230 DataType::DateTime => SQLiteDataType::Text,
231 DataType::Json => SQLiteDataType::Text,
232 DataType::Binary => SQLiteDataType::Blob,
233 }
234 }
235
236 fn parse_command(&self, data: &[u8]) -> NirvResult<(SQLiteCommand, Vec<u8>)> {
238 if data.is_empty() {
239 return Err(ProtocolError::InvalidMessageFormat("Empty command".to_string()).into());
240 }
241
242 let command_byte = data[0];
243 let command_data = if data.len() > 1 { &data[1..] } else { &[] };
244
245 let command = match command_byte {
246 0 => SQLiteCommand::Connect,
247 1 => SQLiteCommand::Query,
248 2 => SQLiteCommand::Prepare,
249 3 => SQLiteCommand::Execute,
250 4 => SQLiteCommand::Close,
251 _ => return Err(ProtocolError::UnsupportedFeature(format!("Unknown SQLite command: {}", command_byte)).into()),
252 };
253
254 Ok((command, command_data.to_vec()))
255 }
256
257 fn process_sqlite_sql(&self, sql: &str) -> String {
259 let mut processed_sql = sql.to_string();
260
261 processed_sql = processed_sql.replace("datetime('now')", "CURRENT_TIMESTAMP");
266 processed_sql = processed_sql.replace("date('now')", "CURRENT_DATE");
267 processed_sql = processed_sql.replace("time('now')", "CURRENT_TIME");
268
269 processed_sql
271 }
272
273 fn validate_connection_flags(&self, flags: u32) -> NirvResult<()> {
275 if (flags & SQLITE_OPEN_READONLY) != 0 && (flags & SQLITE_OPEN_READWRITE) != 0 {
277 return Err(ProtocolError::InvalidMessageFormat("Cannot specify both READONLY and READWRITE flags".to_string()).into());
278 }
279
280 if (flags & (SQLITE_OPEN_READONLY | SQLITE_OPEN_READWRITE)) == 0 {
282 return Err(ProtocolError::InvalidMessageFormat("Must specify either READONLY or READWRITE flag".to_string()).into());
283 }
284
285 Ok(())
286 }
287}
288
289impl Default for SQLiteProtocolAdapter {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295#[async_trait]
296impl ProtocolAdapter for SQLiteProtocolAdapter {
297 async fn accept_connection(&self, stream: TcpStream) -> NirvResult<Connection> {
298 let connection = Connection::new(stream, ProtocolType::SQLite);
299 Ok(connection)
300 }
301
302 async fn authenticate(&self, conn: &mut Connection, credentials: Credentials) -> NirvResult<()> {
303 let mut buffer = vec![0u8; 1024];
308 let bytes_read = match conn.stream.read(&mut buffer).await {
309 Ok(n) => n,
310 Err(_) => {
311 conn.authenticated = true;
313 conn.database = credentials.database.clone();
314 return Ok(());
315 }
316 };
317
318 if bytes_read > 0 {
319 let (database_path, flags) = self.parse_connection_request(&buffer[..bytes_read])?;
321
322 self.validate_connection_flags(flags)?;
324
325 conn.database = if database_path.is_empty() {
327 credentials.database
328 } else {
329 database_path
330 };
331
332 conn.parameters.insert("flags".to_string(), flags.to_string());
333
334 let ok_response = self.create_ok_response(0, 0);
336 conn.stream.write_all(&ok_response).await
337 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send OK response: {}", e)))?;
338 }
339
340 conn.authenticated = true;
341 Ok(())
342 }
343
344 async fn handle_query(&self, _conn: &Connection, _query: ProtocolQuery) -> NirvResult<ProtocolResponse> {
345 let columns = vec![
348 ColumnMetadata {
349 name: "id".to_string(),
350 data_type: DataType::Integer,
351 nullable: false,
352 },
353 ColumnMetadata {
354 name: "name".to_string(),
355 data_type: DataType::Text,
356 nullable: true,
357 },
358 ];
359
360 let rows = vec![
361 Row::new(vec![Value::Integer(1), Value::Text("SQLite Test User".to_string())]),
362 Row::new(vec![Value::Integer(2), Value::Text("Another SQLite User".to_string())]),
363 ];
364
365 let result = QueryResult {
366 columns,
367 rows,
368 affected_rows: Some(2),
369 execution_time: std::time::Duration::from_millis(5),
370 };
371
372 Ok(ProtocolResponse::new(result, ProtocolType::SQLite))
373 }
374
375 fn get_protocol_type(&self) -> ProtocolType {
376 ProtocolType::SQLite
377 }
378
379 async fn parse_message(&self, _conn: &Connection, data: &[u8]) -> NirvResult<ProtocolQuery> {
380 let (command, command_data) = self.parse_command(data)?;
381
382 match command {
383 SQLiteCommand::Connect => {
384 Ok(ProtocolQuery::new("CONNECT".to_string(), ProtocolType::SQLite))
385 }
386 SQLiteCommand::Query => {
387 let sql = String::from_utf8_lossy(&command_data).to_string();
388 let processed_sql = self.process_sqlite_sql(&sql);
389 Ok(ProtocolQuery::new(processed_sql, ProtocolType::SQLite))
390 }
391 SQLiteCommand::Prepare => {
392 let sql = String::from_utf8_lossy(&command_data).to_string();
393 let processed_sql = self.process_sqlite_sql(&sql);
394 Ok(ProtocolQuery::new(format!("PREPARE {}", processed_sql), ProtocolType::SQLite))
395 }
396 SQLiteCommand::Execute => {
397 if command_data.len() < 4 {
399 return Err(ProtocolError::InvalidMessageFormat("Execute command missing statement ID".to_string()).into());
400 }
401
402 let statement_id = u32::from_le_bytes([command_data[0], command_data[1], command_data[2], command_data[3]]);
403 Ok(ProtocolQuery::new(format!("EXECUTE {}", statement_id), ProtocolType::SQLite))
404 }
405 SQLiteCommand::Close => {
406 Ok(ProtocolQuery::new("CLOSE".to_string(), ProtocolType::SQLite))
407 }
408 }
409 }
410
411 async fn format_response(&self, _conn: &Connection, result: QueryResult) -> NirvResult<Vec<u8>> {
412 if result.columns.is_empty() {
413 let ok_response = self.create_ok_response(result.affected_rows.unwrap_or(0) as u32, 0);
415 Ok(ok_response)
416 } else {
417 let row_response = self.create_row_response(&result.columns, &result.rows);
419 Ok(row_response)
420 }
421 }
422
423 async fn terminate_connection(&self, conn: &mut Connection) -> NirvResult<()> {
424 let close_response = self.create_ok_response(0, 0);
426 let _ = conn.stream.write_all(&close_response).await;
427
428 conn.stream.shutdown().await
429 .map_err(|_e| ProtocolError::ConnectionClosed)?;
430 Ok(())
431 }
432}