nirv_engine/protocol/
postgres_protocol.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4use tokio::net::TcpStream;
5
6use crate::protocol::{ProtocolAdapter, ProtocolType, Connection, Credentials, ProtocolQuery, ProtocolResponse, ResponseFormat};
7use crate::utils::{NirvResult, ProtocolError, QueryResult, ColumnMetadata, Row, Value, DataType};
8
9/// PostgreSQL protocol version 3.0
10const POSTGRES_PROTOCOL_VERSION: u32 = 196608; // (3 << 16) | 0
11
12/// PostgreSQL message types
13#[derive(Debug, Clone, PartialEq)]
14pub enum PostgresMessageType {
15    StartupMessage = 0,
16    Query = b'Q' as isize,
17    Terminate = b'X' as isize,
18    PasswordMessage = b'p' as isize,
19}
20
21/// PostgreSQL response message types
22#[derive(Debug, Clone, PartialEq)]
23pub enum PostgresResponseType {
24    AuthenticationOk = b'R' as isize,
25    ParameterStatus = b'S' as isize,
26    ReadyForQuery = b'Z' as isize,
27    RowDescription = b'T' as isize,
28    DataRow = b'D' as isize,
29    CommandComplete = b'C' as isize,
30    ErrorResponse = b'E' as isize,
31}
32
33/// PostgreSQL protocol adapter implementation
34#[derive(Debug)]
35pub struct PostgresProtocol {
36    // Configuration and state can be added here
37}
38
39impl PostgresProtocol {
40    /// Create a new PostgreSQL protocol adapter
41    pub fn new() -> Self {
42        Self {}
43    }
44    
45    /// Parse a startup message from the client
46    async fn parse_startup_message(&self, data: &[u8]) -> NirvResult<(u32, HashMap<String, String>)> {
47        if data.len() < 8 {
48            return Err(ProtocolError::InvalidMessageFormat("Startup message too short".to_string()).into());
49        }
50        
51        // Read protocol version (4 bytes, big-endian)
52        let protocol_version = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
53        
54        if protocol_version != POSTGRES_PROTOCOL_VERSION {
55            return Err(ProtocolError::UnsupportedVersion(format!("Protocol version {} not supported", protocol_version)).into());
56        }
57        
58        // Parse parameters (null-terminated strings)
59        let mut parameters = HashMap::new();
60        let mut pos = 8;
61        
62        while pos < data.len() - 1 {
63            // Find null terminator for key
64            let key_end = data[pos..].iter().position(|&b| b == 0)
65                .ok_or_else(|| ProtocolError::InvalidMessageFormat("Unterminated parameter key".to_string()))?;
66            
67            let key = String::from_utf8_lossy(&data[pos..pos + key_end]).to_string();
68            pos += key_end + 1;
69            
70            if pos >= data.len() {
71                break;
72            }
73            
74            // Find null terminator for value
75            let value_end = data[pos..].iter().position(|&b| b == 0)
76                .ok_or_else(|| ProtocolError::InvalidMessageFormat("Unterminated parameter value".to_string()))?;
77            
78            let value = String::from_utf8_lossy(&data[pos..pos + value_end]).to_string();
79            pos += value_end + 1;
80            
81            parameters.insert(key, value);
82        }
83        
84        Ok((protocol_version, parameters))
85    }
86    
87    /// Create an authentication OK response
88    fn create_auth_ok_response(&self) -> Vec<u8> {
89        let mut response = Vec::new();
90        response.push(b'R'); // Authentication response
91        response.extend_from_slice(&8u32.to_be_bytes()); // Message length
92        response.extend_from_slice(&0u32.to_be_bytes()); // Authentication OK
93        response
94    }
95    
96    /// Create a parameter status message
97    fn create_parameter_status(&self, name: &str, value: &str) -> Vec<u8> {
98        let mut response = Vec::new();
99        response.push(b'S'); // Parameter status
100        
101        let content_len = name.len() + value.len() + 2; // +2 for null terminators
102        response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes()); // Message length
103        
104        response.extend_from_slice(name.as_bytes());
105        response.push(0); // Null terminator
106        response.extend_from_slice(value.as_bytes());
107        response.push(0); // Null terminator
108        
109        response
110    }
111    
112    /// Create a ready for query message
113    fn create_ready_for_query(&self) -> Vec<u8> {
114        let mut response = Vec::new();
115        response.push(b'Z'); // Ready for query
116        response.extend_from_slice(&5u32.to_be_bytes()); // Message length
117        response.push(b'I'); // Transaction status: Idle
118        response
119    }
120    
121    /// Create a row description message
122    fn create_row_description(&self, columns: &[ColumnMetadata]) -> Vec<u8> {
123        let mut response = Vec::new();
124        response.push(b'T'); // Row description
125        
126        // Calculate message length
127        let mut content_len = 2; // Field count (2 bytes)
128        for col in columns {
129            content_len += col.name.len() + 1; // Name + null terminator
130            content_len += 18; // Table OID (4) + Column attr (2) + Type OID (4) + Type size (2) + Type modifier (4) + Format code (2)
131        }
132        
133        response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes());
134        response.extend_from_slice(&(columns.len() as u16).to_be_bytes()); // Field count
135        
136        for col in columns {
137            response.extend_from_slice(col.name.as_bytes());
138            response.push(0); // Null terminator
139            response.extend_from_slice(&0u32.to_be_bytes()); // Table OID
140            response.extend_from_slice(&0u16.to_be_bytes()); // Column attribute number
141            
142            // Map NIRV data types to PostgreSQL OIDs
143            let type_oid = match col.data_type {
144                DataType::Text => 25u32,      // TEXT
145                DataType::Integer => 23u32,   // INT4
146                DataType::Float => 701u32,    // FLOAT8
147                DataType::Boolean => 16u32,   // BOOL
148                DataType::Date => 1082u32,    // DATE
149                DataType::DateTime => 1114u32, // TIMESTAMP
150                DataType::Json => 114u32,     // JSON
151                DataType::Binary => 17u32,    // BYTEA
152            };
153            
154            response.extend_from_slice(&type_oid.to_be_bytes()); // Type OID
155            response.extend_from_slice(&(-1i16).to_be_bytes()); // Type size (-1 = variable)
156            response.extend_from_slice(&(-1i32).to_be_bytes()); // Type modifier
157            response.extend_from_slice(&0u16.to_be_bytes()); // Format code (0 = text)
158        }
159        
160        response
161    }
162    
163    /// Create a data row message
164    fn create_data_row(&self, row: &Row) -> Vec<u8> {
165        let mut response = Vec::new();
166        response.push(b'D'); // Data row
167        
168        // Calculate message length
169        let mut content_len = 2; // Field count (2 bytes)
170        for value in &row.values {
171            match value {
172                Value::Null => content_len += 4, // Length field only
173                _ => {
174                    let value_str = self.value_to_string(value);
175                    content_len += 4 + value_str.len(); // Length field + data
176                }
177            }
178        }
179        
180        response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes());
181        response.extend_from_slice(&(row.values.len() as u16).to_be_bytes()); // Field count
182        
183        for value in &row.values {
184            match value {
185                Value::Null => {
186                    response.extend_from_slice(&(-1i32).to_be_bytes()); // NULL value
187                }
188                _ => {
189                    let value_str = self.value_to_string(value);
190                    response.extend_from_slice(&(value_str.len() as u32).to_be_bytes());
191                    response.extend_from_slice(value_str.as_bytes());
192                }
193            }
194        }
195        
196        response
197    }
198    
199    /// Create a command complete message
200    fn create_command_complete(&self, tag: &str) -> Vec<u8> {
201        let mut response = Vec::new();
202        response.push(b'C'); // Command complete
203        
204        let content_len = tag.len() + 1; // +1 for null terminator
205        response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes());
206        response.extend_from_slice(tag.as_bytes());
207        response.push(0); // Null terminator
208        
209        response
210    }
211    
212    /// Create an error response message
213    fn create_error_response(&self, message: &str) -> Vec<u8> {
214        let mut response = Vec::new();
215        response.push(b'E'); // Error response
216        
217        let content_len = 1 + message.len() + 1 + 1; // Severity + message + null + terminator
218        response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes());
219        
220        response.push(b'S'); // Severity field
221        response.extend_from_slice(b"ERROR");
222        response.push(0); // Null terminator
223        
224        response.push(b'M'); // Message field
225        response.extend_from_slice(message.as_bytes());
226        response.push(0); // Null terminator
227        
228        response.push(0); // End of error message
229        
230        response
231    }
232    
233    /// Convert a NIRV Value to PostgreSQL string representation
234    fn value_to_string(&self, value: &Value) -> String {
235        match value {
236            Value::Text(s) => s.clone(),
237            Value::Integer(i) => i.to_string(),
238            Value::Float(f) => f.to_string(),
239            Value::Boolean(b) => if *b { "t".to_string() } else { "f".to_string() },
240            Value::Date(d) => d.clone(),
241            Value::DateTime(dt) => dt.clone(),
242            Value::Json(j) => j.clone(),
243            Value::Binary(b) => {
244                // Simple hex encoding without external dependency
245                let mut hex_string = String::with_capacity(b.len() * 2 + 2);
246                hex_string.push_str("\\x");
247                for byte in b {
248                    hex_string.push_str(&format!("{:02x}", byte));
249                }
250                hex_string
251            },
252            Value::Null => String::new(), // Should not be called for NULL values
253        }
254    }
255}
256
257impl Default for PostgresProtocol {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263#[async_trait]
264impl ProtocolAdapter for PostgresProtocol {
265    async fn accept_connection(&self, stream: TcpStream) -> NirvResult<Connection> {
266        let connection = Connection::new(stream, ProtocolType::PostgreSQL);
267        Ok(connection)
268    }
269    
270    async fn authenticate(&self, conn: &mut Connection, credentials: Credentials) -> NirvResult<()> {
271        // Read startup message
272        let mut buffer = vec![0u8; 8192];
273        let bytes_read = conn.stream.read(&mut buffer).await
274            .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to read startup message: {}", e)))?;
275        
276        if bytes_read < 8 {
277            return Err(ProtocolError::InvalidMessageFormat("Startup message too short".to_string()).into());
278        }
279        
280        // Parse startup message
281        let (_protocol_version, parameters) = self.parse_startup_message(&buffer[..bytes_read]).await?;
282        
283        // Validate credentials match startup parameters
284        if let Some(user) = parameters.get("user") {
285            if user != &credentials.username {
286                return Err(ProtocolError::AuthenticationFailed("Username mismatch".to_string()).into());
287            }
288        }
289        
290        if let Some(database) = parameters.get("database") {
291            if database != &credentials.database {
292                return Err(ProtocolError::AuthenticationFailed("Database mismatch".to_string()).into());
293            }
294        }
295        
296        // Send authentication OK
297        let auth_response = self.create_auth_ok_response();
298        conn.stream.write_all(&auth_response).await
299            .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send auth response: {}", e)))?;
300        
301        // Send parameter status messages
302        let param_status = self.create_parameter_status("server_version", "13.0 (NIRV Engine)");
303        conn.stream.write_all(&param_status).await
304            .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send parameter status: {}", e)))?;
305        
306        let encoding_status = self.create_parameter_status("client_encoding", "UTF8");
307        conn.stream.write_all(&encoding_status).await
308            .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send encoding status: {}", e)))?;
309        
310        // Send ready for query
311        let ready_response = self.create_ready_for_query();
312        conn.stream.write_all(&ready_response).await
313            .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send ready response: {}", e)))?;
314        
315        // Update connection state
316        conn.authenticated = true;
317        conn.database = credentials.database;
318        conn.parameters = parameters;
319        
320        Ok(())
321    }
322    
323    async fn handle_query(&self, _conn: &Connection, _query: ProtocolQuery) -> NirvResult<ProtocolResponse> {
324        // For now, create a mock response
325        // In the full implementation, this would parse the query and execute it
326        let columns = vec![
327            ColumnMetadata {
328                name: "id".to_string(),
329                data_type: DataType::Integer,
330                nullable: false,
331            },
332            ColumnMetadata {
333                name: "name".to_string(),
334                data_type: DataType::Text,
335                nullable: true,
336            },
337        ];
338        
339        let rows = vec![
340            Row::new(vec![Value::Integer(1), Value::Text("Test User".to_string())]),
341            Row::new(vec![Value::Integer(2), Value::Text("Another User".to_string())]),
342        ];
343        
344        let result = QueryResult {
345            columns,
346            rows,
347            affected_rows: Some(2),
348            execution_time: std::time::Duration::from_millis(10),
349        };
350        
351        Ok(ProtocolResponse::new(result, ProtocolType::PostgreSQL))
352    }
353    
354    fn get_protocol_type(&self) -> ProtocolType {
355        ProtocolType::PostgreSQL
356    }
357    
358    async fn parse_message(&self, _conn: &Connection, data: &[u8]) -> NirvResult<ProtocolQuery> {
359        if data.is_empty() {
360            return Err(ProtocolError::InvalidMessageFormat("Empty message".to_string()).into());
361        }
362        
363        let message_type = data[0];
364        
365        match message_type {
366            b'Q' => {
367                // Query message
368                if data.len() < 5 {
369                    return Err(ProtocolError::InvalidMessageFormat("Query message too short".to_string()).into());
370                }
371                
372                // Skip message type (1 byte) and length (4 bytes)
373                let query_data = &data[5..];
374                
375                // Find null terminator
376                let query_end = query_data.iter().position(|&b| b == 0)
377                    .unwrap_or(query_data.len());
378                
379                let query_string = String::from_utf8_lossy(&query_data[..query_end]).to_string();
380                
381                Ok(ProtocolQuery::new(query_string, ProtocolType::PostgreSQL))
382            }
383            b'X' => {
384                // Terminate message
385                Ok(ProtocolQuery::new("TERMINATE".to_string(), ProtocolType::PostgreSQL))
386            }
387            _ => {
388                Err(ProtocolError::InvalidMessageFormat(format!("Unknown message type: {}", message_type)).into())
389            }
390        }
391    }
392    
393    async fn format_response(&self, _conn: &Connection, result: QueryResult) -> NirvResult<Vec<u8>> {
394        let mut response = Vec::new();
395        
396        // Send row description
397        let row_desc = self.create_row_description(&result.columns);
398        response.extend_from_slice(&row_desc);
399        
400        // Send data rows
401        for row in &result.rows {
402            let data_row = self.create_data_row(row);
403            response.extend_from_slice(&data_row);
404        }
405        
406        // Send command complete
407        let tag = format!("SELECT {}", result.rows.len());
408        let cmd_complete = self.create_command_complete(&tag);
409        response.extend_from_slice(&cmd_complete);
410        
411        // Send ready for query
412        let ready = self.create_ready_for_query();
413        response.extend_from_slice(&ready);
414        
415        Ok(response)
416    }
417    
418    async fn terminate_connection(&self, conn: &mut Connection) -> NirvResult<()> {
419        conn.stream.shutdown().await
420            .map_err(|_e| ProtocolError::ConnectionClosed)?;
421        Ok(())
422    }
423}