Skip to main content

ocular_protocol/
postgres.rs

1//! PostgreSQL wire protocol parser (v3)
2//!
3//! Message format: [type:1][length:4 (includes self)][payload...]
4//! Startup message has no type byte: [length:4][protocol_version:4][params...]
5
6/// Parse a client→server message, return human-readable summary
7pub fn parse_postgres_request(buf: &[u8]) -> Option<String> {
8    if buf.is_empty() { return None; }
9
10    // Startup message or SSL request: no type byte, starts with [length:4][code:4]
11    // Detect by checking if first byte could be a valid message type
12    let first = buf[0];
13    let is_typed_msg = matches!(first, b'Q' | b'P' | b'B' | b'E' | b'D' | b'S' | b'X' | b'C' | b'p' | b'H' | b'F' | b'd' | b'c' | b'f');
14
15    if !is_typed_msg && buf.len() >= 8 {
16        let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
17        let version = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
18        if version == 196608 {
19            // Protocol 3.0 startup
20            let end = len.min(buf.len());
21            let params = parse_startup_params(&buf[8..end]);
22            return Some(format!("Startup user={}", params));
23        }
24        if version == 80877103 {
25            return Some("SSLRequest".into());
26        }
27        // Cancel request
28        if version == 80877102 {
29            return Some("CancelRequest".into());
30        }
31    }
32
33    if !is_typed_msg { return None; }
34
35    let msg_type = first;
36    if buf.len() < 5 { return None; }
37    let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
38    if buf.len() < 1 + len { return None; }
39    let payload = &buf[5..1 + len];
40
41    match msg_type {
42        b'Q' => {
43            // Simple query
44            let sql = read_cstr(payload);
45            let truncated: String = sql.chars().take(120).collect();
46            if truncated.len() < sql.len() {
47                Some(format!("{}...", truncated))
48            } else {
49                Some(truncated)
50            }
51        }
52        b'P' => {
53            // Parse (prepared statement)
54            let stmt = read_cstr(payload);
55            let rest = &payload[stmt.len() + 1..];
56            let query = read_cstr(rest);
57            let q: String = query.chars().take(100).collect();
58            if stmt.is_empty() {
59                Some(format!("PREPARE {}", q))
60            } else {
61                Some(format!("PREPARE [{}] {}", stmt, q))
62            }
63        }
64        b'B' => Some("BIND".into()),
65        b'E' => {
66            // Execute
67            let portal = read_cstr(payload);
68            if portal.is_empty() {
69                Some("EXECUTE".into())
70            } else {
71                Some(format!("EXECUTE [{}]", portal))
72            }
73        }
74        b'D' => {
75            // Describe
76            let kind = if !payload.is_empty() { payload[0] } else { 0 };
77            let name = if payload.len() > 1 { read_cstr(&payload[1..]) } else { String::new() };
78            match kind {
79                b'S' => Some(format!("DESCRIBE STMT {}", name)),
80                b'P' => Some(format!("DESCRIBE PORTAL {}", name)),
81                _ => Some("DESCRIBE".into()),
82            }
83        }
84        b'S' => Some("SYNC".into()),
85        b'X' => Some("TERMINATE".into()),
86        b'C' => {
87            // Close
88            let kind = if !payload.is_empty() { payload[0] } else { 0 };
89            let name = if payload.len() > 1 { read_cstr(&payload[1..]) } else { String::new() };
90            match kind {
91                b'S' => Some(format!("CLOSE STMT {}", name)),
92                b'P' => Some(format!("CLOSE PORTAL {}", name)),
93                _ => Some("CLOSE".into()),
94            }
95        }
96        b'p' => Some("PasswordMessage".into()),
97        b'H' => Some("FLUSH".into()),
98        _ => None,
99    }
100}
101
102/// Extract full SQL from request (no truncation)
103pub fn extract_postgres_full_command(buf: &[u8]) -> Option<String> {
104    if buf.is_empty() { return None; }
105    let first = buf[0];
106    let is_typed = matches!(first, b'Q' | b'P' | b'B' | b'E' | b'D' | b'S' | b'X' | b'C' | b'p' | b'H' | b'F' | b'd' | b'c' | b'f');
107    if !is_typed { return parse_postgres_request(buf); }
108    if buf.len() < 5 { return None; }
109    let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
110    if buf.len() < 1 + len { return None; }
111    let payload = &buf[5..1 + len];
112    match first {
113        b'Q' => Some(read_cstr(payload)),
114        b'P' => {
115            let stmt = read_cstr(payload);
116            let rest = &payload[stmt.len() + 1..];
117            Some(read_cstr(rest))
118        }
119        _ => parse_postgres_request(buf),
120    }
121}
122
123/// Parse a server→client message, return short summary
124/// Scans for the most important message in a multi-message buffer.
125pub fn parse_postgres_response(buf: &[u8]) -> Option<String> {
126    if buf.is_empty() { return None; }
127
128    // SSL response: single byte 'N' (no SSL) or 'S' (SSL)
129    if buf.len() == 1 {
130        return match buf[0] {
131            b'N' => Some("SSLResponse: No".into()),
132            b'S' => Some("SSLResponse: Yes".into()),
133            _ => None,
134        };
135    }
136
137    // Scan all messages, prefer Error/CommandComplete over Auth/ReadyForQuery
138    let mut result: Option<String> = None;
139    let mut pos = 0;
140    while pos + 5 <= buf.len() {
141        let msg_type = buf[pos];
142        let len = u32::from_be_bytes([buf[pos+1], buf[pos+2], buf[pos+3], buf[pos+4]]) as usize;
143        if pos + 1 + len > buf.len() { break; }
144        let payload = &buf[pos+5..pos+1+len];
145
146        let parsed = parse_single_response(msg_type, payload);
147        if let Some(ref _p) = parsed {
148            // Error/CommandComplete take priority
149            if msg_type == b'E' || msg_type == b'C' {
150                return parsed;
151            }
152            // Keep first meaningful result as fallback
153            if result.is_none() {
154                result = parsed;
155            }
156        }
157        pos += 1 + len;
158    }
159    result
160}
161
162fn parse_single_response(msg_type: u8, payload: &[u8]) -> Option<String> {
163
164    match msg_type {
165        b'R' => {
166            // Authentication
167            if payload.len() >= 4 {
168                let auth_type = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
169                match auth_type {
170                    0 => Some("AuthenticationOk".into()),
171                    3 => Some("AuthenticationCleartextPassword".into()),
172                    5 => Some("AuthenticationMD5Password".into()),
173                    10 => Some("AuthenticationSASL".into()),
174                    11 => Some("AuthenticationSASLContinue".into()),
175                    12 => Some("AuthenticationSASLFinal".into()),
176                    _ => Some(format!("Authentication({})", auth_type)),
177                }
178            } else {
179                Some("Authentication".into())
180            }
181        }
182        b'T' => {
183            // RowDescription
184            if payload.len() >= 2 {
185                let col_count = u16::from_be_bytes([payload[0], payload[1]]);
186                Some(format!("RowDescription ({} cols)", col_count))
187            } else {
188                Some("RowDescription".into())
189            }
190        }
191        b'D' => Some("DataRow".into()),
192        b'C' => {
193            // CommandComplete
194            let tag = read_cstr(payload);
195            Some(format!("OK: {}", tag))
196        }
197        b'Z' => {
198            // ReadyForQuery
199            let status = if !payload.is_empty() {
200                match payload[0] {
201                    b'I' => "idle",
202                    b'T' => "in transaction",
203                    b'E' => "failed transaction",
204                    _ => "?",
205                }
206            } else { "?" };
207            Some(format!("Ready ({})", status))
208        }
209        b'E' => {
210            // ErrorResponse
211            let msg = parse_error_fields(payload);
212            Some(format!("ERROR: {}", msg))
213        }
214        b'N' => {
215            // NoticeResponse
216            let msg = parse_error_fields(payload);
217            Some(format!("NOTICE: {}", msg))
218        }
219        b'S' => {
220            // ParameterStatus
221            let name = read_cstr(payload);
222            let rest = &payload[name.len() + 1..];
223            let value = read_cstr(rest);
224            Some(format!("Set {} = {}", name, value))
225        }
226        b'K' => Some("BackendKeyData".into()),
227        b'1' => Some("ParseComplete".into()),
228        b'2' => Some("BindComplete".into()),
229        b'3' => Some("CloseComplete".into()),
230        b'n' => Some("NoData".into()),
231        b't' => Some("ParameterDescription".into()),
232        b'I' => Some("EmptyQueryResponse".into()),
233        _ => None,
234    }
235}
236
237/// Format response detail for the detail panel
238pub fn format_postgres_response_detail(buf: &[u8]) -> Option<String> {
239    if buf.is_empty() { return None; }
240    // SSL response
241    if buf.len() == 1 {
242        return parse_postgres_response(buf);
243    }
244    // Try to parse multiple messages for a complete result
245    let mut detail = String::new();
246    let mut pos = 0;
247    let mut row_count = 0u64;
248
249    while pos < buf.len() {
250        if pos + 5 > buf.len() { break; }
251        let msg_type = buf[pos];
252        let len = u32::from_be_bytes([buf[pos+1], buf[pos+2], buf[pos+3], buf[pos+4]]) as usize;
253        if pos + 1 + len > buf.len() { break; }
254        let payload = &buf[pos+5..pos+1+len];
255
256        match msg_type {
257            b'T' if payload.len() >= 2 => {
258                // RowDescription - extract column names
259                let col_count = u16::from_be_bytes([payload[0], payload[1]]) as usize;
260                    let mut p = 2;
261                    let mut cols = Vec::new();
262                    for _ in 0..col_count {
263                        let name = read_cstr(&payload[p..]);
264                        p += name.len() + 1 + 18; // name + null + 18 bytes of field info
265                        cols.push(name);
266                    }
267                    detail.push_str(&format!("Columns: {}\n", cols.join(" | ")));
268            }
269            b'D' => {
270                row_count += 1;
271                if row_count <= 20 {
272                    // DataRow: [col_count:2][for each: len:4 (or -1 for NULL), data]
273                    if payload.len() >= 2 {
274                        let ncols = u16::from_be_bytes([payload[0], payload[1]]) as usize;
275                        let mut p = 2;
276                        let mut fields = Vec::new();
277                        for _ in 0..ncols {
278                            if p + 4 > payload.len() { break; }
279                            let flen = i32::from_be_bytes([payload[p], payload[p+1], payload[p+2], payload[p+3]]);
280                            p += 4;
281                            if flen < 0 {
282                                fields.push("NULL".to_string());
283                            } else {
284                                let end = p + flen as usize;
285                                if end <= payload.len() {
286                                    fields.push(String::from_utf8_lossy(&payload[p..end]).to_string());
287                                }
288                                p = end;
289                            }
290                        }
291                        detail.push_str(&fields.join(" | "));
292                        detail.push('\n');
293                    }
294                }
295            }
296            b'C' if row_count > 0 => {
297                detail.push_str(&format!("{} rows\n", row_count));
298            }
299            b'E' => {
300                let msg = parse_error_fields(payload);
301                detail.push_str(&format!("ERROR: {}\n", msg));
302            }
303            _ => {}
304        }
305        pos += 1 + len;
306    }
307
308    if detail.is_empty() {
309        parse_postgres_response(buf)
310    } else {
311        Some(detail)
312    }
313}
314
315/// Check if a PostgreSQL response is complete (ends with ReadyForQuery 'Z')
316pub fn postgres_response_complete(buf: &[u8]) -> bool {
317    if buf.is_empty() { return false; }
318    // SSL response: single byte
319    if buf.len() == 1 && (buf[0] == b'N' || buf[0] == b'S') {
320        return true;
321    }
322    if buf.len() < 6 { return false; }
323    // Check if last message is ReadyForQuery
324    let mut pos = 0;
325    let mut last_type = 0u8;
326    while pos < buf.len() {
327        if pos + 5 > buf.len() { break; }
328        let msg_type = buf[pos];
329        let len = u32::from_be_bytes([buf[pos+1], buf[pos+2], buf[pos+3], buf[pos+4]]) as usize;
330        let end = pos + 1 + len;
331        if end > buf.len() { break; }
332        last_type = msg_type;
333        pos = end;
334    }
335    last_type == b'Z' && pos == buf.len()
336}
337
338fn read_cstr(buf: &[u8]) -> String {
339    let end = buf.iter().position(|&b| b == 0).unwrap_or(buf.len());
340    String::from_utf8_lossy(&buf[..end]).to_string()
341}
342
343fn parse_startup_params(buf: &[u8]) -> String {
344    let mut user = String::new();
345    let mut pos = 0;
346    while pos < buf.len() {
347        let key = read_cstr(&buf[pos..]);
348        if key.is_empty() { break; }
349        pos += key.len() + 1;
350        let val = read_cstr(&buf[pos..]);
351        pos += val.len() + 1;
352        if key == "user" { user = val; }
353    }
354    user
355}
356
357fn parse_error_fields(buf: &[u8]) -> String {
358    let mut msg = String::new();
359    let mut pos = 0;
360    while pos < buf.len() {
361        let field_type = buf[pos];
362        if field_type == 0 { break; }
363        pos += 1;
364        let value = read_cstr(&buf[pos..]);
365        pos += value.len() + 1;
366        if field_type == b'M' {
367            msg = value;
368        }
369    }
370    msg
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_parse_simple_query() {
379        // 'Q' + length + "SELECT 1\0"
380        let sql = b"SELECT 1\0";
381        let len = (sql.len() as u32 + 4).to_be_bytes();
382        let mut buf = vec![b'Q'];
383        buf.extend_from_slice(&len);
384        buf.extend_from_slice(sql);
385        let result = parse_postgres_request(&buf).unwrap();
386        assert_eq!(result, "SELECT 1");
387    }
388
389    #[test]
390    fn test_parse_command_complete() {
391        // 'C' + length + "INSERT 0 1\0"
392        let tag = b"INSERT 0 1\0";
393        let len = (tag.len() as u32 + 4).to_be_bytes();
394        let mut buf = vec![b'C'];
395        buf.extend_from_slice(&len);
396        buf.extend_from_slice(tag);
397        let result = parse_postgres_response(&buf).unwrap();
398        assert_eq!(result, "OK: INSERT 0 1");
399    }
400}