1use async_trait::async_trait;
2use std::collections::HashMap;
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4use tokio::net::TcpStream;
5
6use crate::protocol::{ProtocolAdapter, ProtocolType, Connection, Credentials, ProtocolQuery, ProtocolResponse, ResponseFormat};
7use crate::utils::{NirvResult, ProtocolError, QueryResult, ColumnMetadata, Row, Value, DataType};
8
9const POSTGRES_PROTOCOL_VERSION: u32 = 196608; #[derive(Debug, Clone, PartialEq)]
14pub enum PostgresMessageType {
15 StartupMessage = 0,
16 Query = b'Q' as isize,
17 Terminate = b'X' as isize,
18 PasswordMessage = b'p' as isize,
19}
20
21#[derive(Debug, Clone, PartialEq)]
23pub enum PostgresResponseType {
24 AuthenticationOk = b'R' as isize,
25 ParameterStatus = b'S' as isize,
26 ReadyForQuery = b'Z' as isize,
27 RowDescription = b'T' as isize,
28 DataRow = b'D' as isize,
29 CommandComplete = b'C' as isize,
30 ErrorResponse = b'E' as isize,
31}
32
33#[derive(Debug)]
35pub struct PostgresProtocol {
36 }
38
39impl PostgresProtocol {
40 pub fn new() -> Self {
42 Self {}
43 }
44
45 async fn parse_startup_message(&self, data: &[u8]) -> NirvResult<(u32, HashMap<String, String>)> {
47 if data.len() < 8 {
48 return Err(ProtocolError::InvalidMessageFormat("Startup message too short".to_string()).into());
49 }
50
51 let protocol_version = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
53
54 if protocol_version != POSTGRES_PROTOCOL_VERSION {
55 return Err(ProtocolError::UnsupportedVersion(format!("Protocol version {} not supported", protocol_version)).into());
56 }
57
58 let mut parameters = HashMap::new();
60 let mut pos = 8;
61
62 while pos < data.len() - 1 {
63 let key_end = data[pos..].iter().position(|&b| b == 0)
65 .ok_or_else(|| ProtocolError::InvalidMessageFormat("Unterminated parameter key".to_string()))?;
66
67 let key = String::from_utf8_lossy(&data[pos..pos + key_end]).to_string();
68 pos += key_end + 1;
69
70 if pos >= data.len() {
71 break;
72 }
73
74 let value_end = data[pos..].iter().position(|&b| b == 0)
76 .ok_or_else(|| ProtocolError::InvalidMessageFormat("Unterminated parameter value".to_string()))?;
77
78 let value = String::from_utf8_lossy(&data[pos..pos + value_end]).to_string();
79 pos += value_end + 1;
80
81 parameters.insert(key, value);
82 }
83
84 Ok((protocol_version, parameters))
85 }
86
87 fn create_auth_ok_response(&self) -> Vec<u8> {
89 let mut response = Vec::new();
90 response.push(b'R'); response.extend_from_slice(&8u32.to_be_bytes()); response.extend_from_slice(&0u32.to_be_bytes()); response
94 }
95
96 fn create_parameter_status(&self, name: &str, value: &str) -> Vec<u8> {
98 let mut response = Vec::new();
99 response.push(b'S'); let content_len = name.len() + value.len() + 2; response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes()); response.extend_from_slice(name.as_bytes());
105 response.push(0); response.extend_from_slice(value.as_bytes());
107 response.push(0); response
110 }
111
112 fn create_ready_for_query(&self) -> Vec<u8> {
114 let mut response = Vec::new();
115 response.push(b'Z'); response.extend_from_slice(&5u32.to_be_bytes()); response.push(b'I'); response
119 }
120
121 fn create_row_description(&self, columns: &[ColumnMetadata]) -> Vec<u8> {
123 let mut response = Vec::new();
124 response.push(b'T'); let mut content_len = 2; for col in columns {
129 content_len += col.name.len() + 1; content_len += 18; }
132
133 response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes());
134 response.extend_from_slice(&(columns.len() as u16).to_be_bytes()); for col in columns {
137 response.extend_from_slice(col.name.as_bytes());
138 response.push(0); response.extend_from_slice(&0u32.to_be_bytes()); response.extend_from_slice(&0u16.to_be_bytes()); let type_oid = match col.data_type {
144 DataType::Text => 25u32, DataType::Integer => 23u32, DataType::Float => 701u32, DataType::Boolean => 16u32, DataType::Date => 1082u32, DataType::DateTime => 1114u32, DataType::Json => 114u32, DataType::Binary => 17u32, };
153
154 response.extend_from_slice(&type_oid.to_be_bytes()); response.extend_from_slice(&(-1i16).to_be_bytes()); response.extend_from_slice(&(-1i32).to_be_bytes()); response.extend_from_slice(&0u16.to_be_bytes()); }
159
160 response
161 }
162
163 fn create_data_row(&self, row: &Row) -> Vec<u8> {
165 let mut response = Vec::new();
166 response.push(b'D'); let mut content_len = 2; for value in &row.values {
171 match value {
172 Value::Null => content_len += 4, _ => {
174 let value_str = self.value_to_string(value);
175 content_len += 4 + value_str.len(); }
177 }
178 }
179
180 response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes());
181 response.extend_from_slice(&(row.values.len() as u16).to_be_bytes()); for value in &row.values {
184 match value {
185 Value::Null => {
186 response.extend_from_slice(&(-1i32).to_be_bytes()); }
188 _ => {
189 let value_str = self.value_to_string(value);
190 response.extend_from_slice(&(value_str.len() as u32).to_be_bytes());
191 response.extend_from_slice(value_str.as_bytes());
192 }
193 }
194 }
195
196 response
197 }
198
199 fn create_command_complete(&self, tag: &str) -> Vec<u8> {
201 let mut response = Vec::new();
202 response.push(b'C'); let content_len = tag.len() + 1; response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes());
206 response.extend_from_slice(tag.as_bytes());
207 response.push(0); response
210 }
211
212 fn create_error_response(&self, message: &str) -> Vec<u8> {
214 let mut response = Vec::new();
215 response.push(b'E'); let content_len = 1 + message.len() + 1 + 1; response.extend_from_slice(&(content_len as u32 + 4).to_be_bytes());
219
220 response.push(b'S'); response.extend_from_slice(b"ERROR");
222 response.push(0); response.push(b'M'); response.extend_from_slice(message.as_bytes());
226 response.push(0); response.push(0); response
231 }
232
233 fn value_to_string(&self, value: &Value) -> String {
235 match value {
236 Value::Text(s) => s.clone(),
237 Value::Integer(i) => i.to_string(),
238 Value::Float(f) => f.to_string(),
239 Value::Boolean(b) => if *b { "t".to_string() } else { "f".to_string() },
240 Value::Date(d) => d.clone(),
241 Value::DateTime(dt) => dt.clone(),
242 Value::Json(j) => j.clone(),
243 Value::Binary(b) => {
244 let mut hex_string = String::with_capacity(b.len() * 2 + 2);
246 hex_string.push_str("\\x");
247 for byte in b {
248 hex_string.push_str(&format!("{:02x}", byte));
249 }
250 hex_string
251 },
252 Value::Null => String::new(), }
254 }
255}
256
257impl Default for PostgresProtocol {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263#[async_trait]
264impl ProtocolAdapter for PostgresProtocol {
265 async fn accept_connection(&self, stream: TcpStream) -> NirvResult<Connection> {
266 let connection = Connection::new(stream, ProtocolType::PostgreSQL);
267 Ok(connection)
268 }
269
270 async fn authenticate(&self, conn: &mut Connection, credentials: Credentials) -> NirvResult<()> {
271 let mut buffer = vec![0u8; 8192];
273 let bytes_read = conn.stream.read(&mut buffer).await
274 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to read startup message: {}", e)))?;
275
276 if bytes_read < 8 {
277 return Err(ProtocolError::InvalidMessageFormat("Startup message too short".to_string()).into());
278 }
279
280 let (_protocol_version, parameters) = self.parse_startup_message(&buffer[..bytes_read]).await?;
282
283 if let Some(user) = parameters.get("user") {
285 if user != &credentials.username {
286 return Err(ProtocolError::AuthenticationFailed("Username mismatch".to_string()).into());
287 }
288 }
289
290 if let Some(database) = parameters.get("database") {
291 if database != &credentials.database {
292 return Err(ProtocolError::AuthenticationFailed("Database mismatch".to_string()).into());
293 }
294 }
295
296 let auth_response = self.create_auth_ok_response();
298 conn.stream.write_all(&auth_response).await
299 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send auth response: {}", e)))?;
300
301 let param_status = self.create_parameter_status("server_version", "13.0 (NIRV Engine)");
303 conn.stream.write_all(¶m_status).await
304 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send parameter status: {}", e)))?;
305
306 let encoding_status = self.create_parameter_status("client_encoding", "UTF8");
307 conn.stream.write_all(&encoding_status).await
308 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send encoding status: {}", e)))?;
309
310 let ready_response = self.create_ready_for_query();
312 conn.stream.write_all(&ready_response).await
313 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send ready response: {}", e)))?;
314
315 conn.authenticated = true;
317 conn.database = credentials.database;
318 conn.parameters = parameters;
319
320 Ok(())
321 }
322
323 async fn handle_query(&self, _conn: &Connection, _query: ProtocolQuery) -> NirvResult<ProtocolResponse> {
324 let columns = vec![
327 ColumnMetadata {
328 name: "id".to_string(),
329 data_type: DataType::Integer,
330 nullable: false,
331 },
332 ColumnMetadata {
333 name: "name".to_string(),
334 data_type: DataType::Text,
335 nullable: true,
336 },
337 ];
338
339 let rows = vec![
340 Row::new(vec![Value::Integer(1), Value::Text("Test User".to_string())]),
341 Row::new(vec![Value::Integer(2), Value::Text("Another User".to_string())]),
342 ];
343
344 let result = QueryResult {
345 columns,
346 rows,
347 affected_rows: Some(2),
348 execution_time: std::time::Duration::from_millis(10),
349 };
350
351 Ok(ProtocolResponse::new(result, ProtocolType::PostgreSQL))
352 }
353
354 fn get_protocol_type(&self) -> ProtocolType {
355 ProtocolType::PostgreSQL
356 }
357
358 async fn parse_message(&self, _conn: &Connection, data: &[u8]) -> NirvResult<ProtocolQuery> {
359 if data.is_empty() {
360 return Err(ProtocolError::InvalidMessageFormat("Empty message".to_string()).into());
361 }
362
363 let message_type = data[0];
364
365 match message_type {
366 b'Q' => {
367 if data.len() < 5 {
369 return Err(ProtocolError::InvalidMessageFormat("Query message too short".to_string()).into());
370 }
371
372 let query_data = &data[5..];
374
375 let query_end = query_data.iter().position(|&b| b == 0)
377 .unwrap_or(query_data.len());
378
379 let query_string = String::from_utf8_lossy(&query_data[..query_end]).to_string();
380
381 Ok(ProtocolQuery::new(query_string, ProtocolType::PostgreSQL))
382 }
383 b'X' => {
384 Ok(ProtocolQuery::new("TERMINATE".to_string(), ProtocolType::PostgreSQL))
386 }
387 _ => {
388 Err(ProtocolError::InvalidMessageFormat(format!("Unknown message type: {}", message_type)).into())
389 }
390 }
391 }
392
393 async fn format_response(&self, _conn: &Connection, result: QueryResult) -> NirvResult<Vec<u8>> {
394 let mut response = Vec::new();
395
396 let row_desc = self.create_row_description(&result.columns);
398 response.extend_from_slice(&row_desc);
399
400 for row in &result.rows {
402 let data_row = self.create_data_row(row);
403 response.extend_from_slice(&data_row);
404 }
405
406 let tag = format!("SELECT {}", result.rows.len());
408 let cmd_complete = self.create_command_complete(&tag);
409 response.extend_from_slice(&cmd_complete);
410
411 let ready = self.create_ready_for_query();
413 response.extend_from_slice(&ready);
414
415 Ok(response)
416 }
417
418 async fn terminate_connection(&self, conn: &mut Connection) -> NirvResult<()> {
419 conn.stream.shutdown().await
420 .map_err(|_e| ProtocolError::ConnectionClosed)?;
421 Ok(())
422 }
423}