nirv_engine/protocol/
sqlite_protocol.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4use tokio::net::TcpStream;
5
6use crate::protocol::{ProtocolAdapter, ProtocolType, Connection, Credentials, ProtocolQuery, ProtocolResponse, ResponseFormat};
7use crate::utils::{NirvResult, ProtocolError, QueryResult, ColumnMetadata, Row, Value, DataType};
8
9/// SQLite connection flags
10const SQLITE_OPEN_READONLY: u32 = 0x00000001;
11const SQLITE_OPEN_READWRITE: u32 = 0x00000002;
12const SQLITE_OPEN_CREATE: u32 = 0x00000004;
13const SQLITE_OPEN_URI: u32 = 0x00000040;
14const SQLITE_OPEN_MEMORY: u32 = 0x00000080;
15
16/// SQLite result codes
17const SQLITE_OK: u32 = 0;
18const SQLITE_ERROR: u32 = 1;
19const SQLITE_BUSY: u32 = 5;
20const SQLITE_NOMEM: u32 = 7;
21const SQLITE_READONLY: u32 = 8;
22const SQLITE_MISUSE: u32 = 21;
23
24/// SQLite data types
25#[derive(Debug, Clone, PartialEq)]
26pub enum SQLiteDataType {
27    Null = 0,
28    Integer = 1,
29    Real = 2,
30    Text = 3,
31    Blob = 4,
32}
33
34/// SQLite command types for the simplified protocol
35#[derive(Debug, Clone, PartialEq)]
36pub enum SQLiteCommand {
37    Connect,
38    Query,
39    Prepare,
40    Execute,
41    Close,
42}
43
44/// SQLite protocol adapter implementation
45/// 
46/// Note: SQLite doesn't have a traditional network protocol like PostgreSQL or MySQL.
47/// This implementation provides a simplified protocol interface that can work with
48/// SQLite clients through file-based connections and basic query execution.
49#[derive(Debug)]
50pub struct SQLiteProtocolAdapter {
51    database_path: String,
52    connection_flags: u32,
53    prepared_statements: HashMap<u32, String>,
54    next_statement_id: u32,
55}
56
57impl SQLiteProtocolAdapter {
58    /// Create a new SQLite protocol adapter
59    pub fn new() -> Self {
60        Self {
61            database_path: ":memory:".to_string(),
62            connection_flags: SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
63            prepared_statements: HashMap::new(),
64            next_statement_id: 1,
65        }
66    }
67    
68    /// Create SQLite protocol adapter with specific database path
69    pub fn with_database_path(database_path: String) -> Self {
70        let flags = if database_path == ":memory:" || database_path.is_empty() {
71            SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_MEMORY
72        } else {
73            SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE
74        };
75        
76        Self {
77            database_path,
78            connection_flags: flags,
79            prepared_statements: HashMap::new(),
80            next_statement_id: 1,
81        }
82    }
83    
84    /// Parse SQLite connection request
85    fn parse_connection_request(&self, data: &[u8]) -> NirvResult<(String, u32)> {
86        if data.len() < 8 {
87            return Err(ProtocolError::InvalidMessageFormat("Connection request too short".to_string()).into());
88        }
89        
90        // Simple protocol: 4 bytes for flags, then null-terminated database path
91        let flags = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
92        
93        // Find null terminator for database path
94        let path_start = 4;
95        let path_end = data[path_start..].iter().position(|&b| b == 0)
96            .map(|pos| path_start + pos)
97            .unwrap_or(data.len());
98        
99        let database_path = String::from_utf8_lossy(&data[path_start..path_end]).to_string();
100        
101        Ok((database_path, flags))
102    }
103    
104    /// Create SQLite OK response
105    fn create_ok_response(&self, changes: u32, last_insert_rowid: i64) -> Vec<u8> {
106        let mut response = Vec::new();
107        
108        // Response type (1 byte): 0 = OK
109        response.push(0);
110        
111        // Result code (4 bytes)
112        response.extend_from_slice(&SQLITE_OK.to_le_bytes());
113        
114        // Changes (4 bytes)
115        response.extend_from_slice(&changes.to_le_bytes());
116        
117        // Last insert rowid (8 bytes)
118        response.extend_from_slice(&last_insert_rowid.to_le_bytes());
119        
120        response
121    }
122    
123    /// Create SQLite error response
124    fn create_error_response(&self, error_code: u32, message: &str) -> Vec<u8> {
125        let mut response = Vec::new();
126        
127        // Response type (1 byte): 1 = Error
128        response.push(1);
129        
130        // Error code (4 bytes)
131        response.extend_from_slice(&error_code.to_le_bytes());
132        
133        // Message length (4 bytes)
134        response.extend_from_slice(&(message.len() as u32).to_le_bytes());
135        
136        // Message
137        response.extend_from_slice(message.as_bytes());
138        
139        response
140    }
141    
142    /// Create SQLite row response
143    fn create_row_response(&self, columns: &[ColumnMetadata], rows: &[Row]) -> Vec<u8> {
144        let mut response = Vec::new();
145        
146        // Response type (1 byte): 2 = Rows
147        response.push(2);
148        
149        // Column count (4 bytes)
150        response.extend_from_slice(&(columns.len() as u32).to_le_bytes());
151        
152        // Column definitions
153        for column in columns {
154            // Column name length (4 bytes)
155            response.extend_from_slice(&(column.name.len() as u32).to_le_bytes());
156            
157            // Column name
158            response.extend_from_slice(column.name.as_bytes());
159            
160            // Column type
161            let sqlite_type = self.nirv_type_to_sqlite_type(&column.data_type);
162            response.push(sqlite_type as u8);
163            
164            // Nullable flag
165            response.push(if column.nullable { 1 } else { 0 });
166        }
167        
168        // Row count (4 bytes)
169        response.extend_from_slice(&(rows.len() as u32).to_le_bytes());
170        
171        // Row data
172        for row in rows {
173            for value in &row.values {
174                match value {
175                    Value::Null => {
176                        response.push(SQLiteDataType::Null as u8);
177                        response.extend_from_slice(&0u32.to_le_bytes()); // No data length
178                    }
179                    Value::Integer(i) => {
180                        response.push(SQLiteDataType::Integer as u8);
181                        response.extend_from_slice(&8u32.to_le_bytes()); // 8 bytes for i64
182                        response.extend_from_slice(&i.to_le_bytes());
183                    }
184                    Value::Float(f) => {
185                        response.push(SQLiteDataType::Real as u8);
186                        response.extend_from_slice(&8u32.to_le_bytes()); // 8 bytes for f64
187                        response.extend_from_slice(&f.to_le_bytes());
188                    }
189                    Value::Text(s) => {
190                        response.push(SQLiteDataType::Text as u8);
191                        response.extend_from_slice(&(s.len() as u32).to_le_bytes());
192                        response.extend_from_slice(s.as_bytes());
193                    }
194                    Value::Binary(b) => {
195                        response.push(SQLiteDataType::Blob as u8);
196                        response.extend_from_slice(&(b.len() as u32).to_le_bytes());
197                        response.extend_from_slice(b);
198                    }
199                    Value::Boolean(b) => {
200                        response.push(SQLiteDataType::Integer as u8);
201                        response.extend_from_slice(&8u32.to_le_bytes());
202                        let int_val = if *b { 1i64 } else { 0i64 };
203                        response.extend_from_slice(&int_val.to_le_bytes());
204                    }
205                    Value::Date(d) | Value::DateTime(d) => {
206                        response.push(SQLiteDataType::Text as u8);
207                        response.extend_from_slice(&(d.len() as u32).to_le_bytes());
208                        response.extend_from_slice(d.as_bytes());
209                    }
210                    Value::Json(j) => {
211                        response.push(SQLiteDataType::Text as u8);
212                        response.extend_from_slice(&(j.len() as u32).to_le_bytes());
213                        response.extend_from_slice(j.as_bytes());
214                    }
215                }
216            }
217        }
218        
219        response
220    }
221    
222    /// Convert NIRV data type to SQLite data type
223    fn nirv_type_to_sqlite_type(&self, data_type: &DataType) -> SQLiteDataType {
224        match data_type {
225            DataType::Text => SQLiteDataType::Text,
226            DataType::Integer => SQLiteDataType::Integer,
227            DataType::Float => SQLiteDataType::Real,
228            DataType::Boolean => SQLiteDataType::Integer,
229            DataType::Date => SQLiteDataType::Text,
230            DataType::DateTime => SQLiteDataType::Text,
231            DataType::Json => SQLiteDataType::Text,
232            DataType::Binary => SQLiteDataType::Blob,
233        }
234    }
235    
236    /// Parse SQLite command from message
237    fn parse_command(&self, data: &[u8]) -> NirvResult<(SQLiteCommand, Vec<u8>)> {
238        if data.is_empty() {
239            return Err(ProtocolError::InvalidMessageFormat("Empty command".to_string()).into());
240        }
241        
242        let command_byte = data[0];
243        let command_data = if data.len() > 1 { &data[1..] } else { &[] };
244        
245        let command = match command_byte {
246            0 => SQLiteCommand::Connect,
247            1 => SQLiteCommand::Query,
248            2 => SQLiteCommand::Prepare,
249            3 => SQLiteCommand::Execute,
250            4 => SQLiteCommand::Close,
251            _ => return Err(ProtocolError::UnsupportedFeature(format!("Unknown SQLite command: {}", command_byte)).into()),
252        };
253        
254        Ok((command, command_data.to_vec()))
255    }
256    
257    /// Handle SQLite-specific SQL functions and syntax
258    fn process_sqlite_sql(&self, sql: &str) -> String {
259        let mut processed_sql = sql.to_string();
260        
261        // Handle SQLite-specific functions that might need translation
262        // For now, we'll pass through most SQL as-is since NIRV handles the source() function
263        
264        // Handle common SQLite functions
265        processed_sql = processed_sql.replace("datetime('now')", "CURRENT_TIMESTAMP");
266        processed_sql = processed_sql.replace("date('now')", "CURRENT_DATE");
267        processed_sql = processed_sql.replace("time('now')", "CURRENT_TIME");
268        
269        // SQLite uses different syntax for some operations, but we'll keep it compatible
270        processed_sql
271    }
272    
273    /// Validate SQLite connection flags
274    fn validate_connection_flags(&self, flags: u32) -> NirvResult<()> {
275        // Check for conflicting flags
276        if (flags & SQLITE_OPEN_READONLY) != 0 && (flags & SQLITE_OPEN_READWRITE) != 0 {
277            return Err(ProtocolError::InvalidMessageFormat("Cannot specify both READONLY and READWRITE flags".to_string()).into());
278        }
279        
280        // Ensure at least one access mode is specified
281        if (flags & (SQLITE_OPEN_READONLY | SQLITE_OPEN_READWRITE)) == 0 {
282            return Err(ProtocolError::InvalidMessageFormat("Must specify either READONLY or READWRITE flag".to_string()).into());
283        }
284        
285        Ok(())
286    }
287}
288
289impl Default for SQLiteProtocolAdapter {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295#[async_trait]
296impl ProtocolAdapter for SQLiteProtocolAdapter {
297    async fn accept_connection(&self, stream: TcpStream) -> NirvResult<Connection> {
298        let connection = Connection::new(stream, ProtocolType::SQLite);
299        Ok(connection)
300    }
301    
302    async fn authenticate(&self, conn: &mut Connection, credentials: Credentials) -> NirvResult<()> {
303        // SQLite doesn't have traditional authentication, but we can simulate it
304        // for compatibility with the NIRV protocol interface
305        
306        // Read connection request if present
307        let mut buffer = vec![0u8; 1024];
308        let bytes_read = match conn.stream.read(&mut buffer).await {
309            Ok(n) => n,
310            Err(_) => {
311                // No connection request, use default settings
312                conn.authenticated = true;
313                conn.database = credentials.database.clone();
314                return Ok(());
315            }
316        };
317        
318        if bytes_read > 0 {
319            // Parse connection request
320            let (database_path, flags) = self.parse_connection_request(&buffer[..bytes_read])?;
321            
322            // Validate flags
323            self.validate_connection_flags(flags)?;
324            
325            // Set connection parameters
326            conn.database = if database_path.is_empty() { 
327                credentials.database 
328            } else { 
329                database_path 
330            };
331            
332            conn.parameters.insert("flags".to_string(), flags.to_string());
333            
334            // Send OK response
335            let ok_response = self.create_ok_response(0, 0);
336            conn.stream.write_all(&ok_response).await
337                .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send OK response: {}", e)))?;
338        }
339        
340        conn.authenticated = true;
341        Ok(())
342    }
343    
344    async fn handle_query(&self, _conn: &Connection, _query: ProtocolQuery) -> NirvResult<ProtocolResponse> {
345        // Create a mock response for now
346        // In the full implementation, this would execute the query through the engine
347        let columns = vec![
348            ColumnMetadata {
349                name: "id".to_string(),
350                data_type: DataType::Integer,
351                nullable: false,
352            },
353            ColumnMetadata {
354                name: "name".to_string(),
355                data_type: DataType::Text,
356                nullable: true,
357            },
358        ];
359        
360        let rows = vec![
361            Row::new(vec![Value::Integer(1), Value::Text("SQLite Test User".to_string())]),
362            Row::new(vec![Value::Integer(2), Value::Text("Another SQLite User".to_string())]),
363        ];
364        
365        let result = QueryResult {
366            columns,
367            rows,
368            affected_rows: Some(2),
369            execution_time: std::time::Duration::from_millis(5),
370        };
371        
372        Ok(ProtocolResponse::new(result, ProtocolType::SQLite))
373    }
374    
375    fn get_protocol_type(&self) -> ProtocolType {
376        ProtocolType::SQLite
377    }
378    
379    async fn parse_message(&self, _conn: &Connection, data: &[u8]) -> NirvResult<ProtocolQuery> {
380        let (command, command_data) = self.parse_command(data)?;
381        
382        match command {
383            SQLiteCommand::Connect => {
384                Ok(ProtocolQuery::new("CONNECT".to_string(), ProtocolType::SQLite))
385            }
386            SQLiteCommand::Query => {
387                let sql = String::from_utf8_lossy(&command_data).to_string();
388                let processed_sql = self.process_sqlite_sql(&sql);
389                Ok(ProtocolQuery::new(processed_sql, ProtocolType::SQLite))
390            }
391            SQLiteCommand::Prepare => {
392                let sql = String::from_utf8_lossy(&command_data).to_string();
393                let processed_sql = self.process_sqlite_sql(&sql);
394                Ok(ProtocolQuery::new(format!("PREPARE {}", processed_sql), ProtocolType::SQLite))
395            }
396            SQLiteCommand::Execute => {
397                // Parse statement ID and parameters
398                if command_data.len() < 4 {
399                    return Err(ProtocolError::InvalidMessageFormat("Execute command missing statement ID".to_string()).into());
400                }
401                
402                let statement_id = u32::from_le_bytes([command_data[0], command_data[1], command_data[2], command_data[3]]);
403                Ok(ProtocolQuery::new(format!("EXECUTE {}", statement_id), ProtocolType::SQLite))
404            }
405            SQLiteCommand::Close => {
406                Ok(ProtocolQuery::new("CLOSE".to_string(), ProtocolType::SQLite))
407            }
408        }
409    }
410    
411    async fn format_response(&self, _conn: &Connection, result: QueryResult) -> NirvResult<Vec<u8>> {
412        if result.columns.is_empty() {
413            // Non-SELECT query - return OK response
414            let ok_response = self.create_ok_response(result.affected_rows.unwrap_or(0) as u32, 0);
415            Ok(ok_response)
416        } else {
417            // SELECT query - return row data
418            let row_response = self.create_row_response(&result.columns, &result.rows);
419            Ok(row_response)
420        }
421    }
422    
423    async fn terminate_connection(&self, conn: &mut Connection) -> NirvResult<()> {
424        // Send close acknowledgment if possible
425        let close_response = self.create_ok_response(0, 0);
426        let _ = conn.stream.write_all(&close_response).await;
427        
428        conn.stream.shutdown().await
429            .map_err(|_e| ProtocolError::ConnectionClosed)?;
430        Ok(())
431    }
432}