Skip to main content

heliosdb_proxy/
protocol.rs

1//! Protocol Handling
2//!
3//! Wire protocol parsing and serialization for HeliosDB proxy.
4
5use crate::{ProxyError, Result};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use std::collections::HashMap;
8
9/// Protocol message types
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum MessageType {
12    /// Startup message from client
13    Startup,
14    /// SSL request
15    SSLRequest,
16    /// Authentication request
17    AuthRequest,
18    /// Password message
19    Password,
20    /// Query message
21    Query,
22    /// Parse message (prepared statement)
23    Parse,
24    /// Bind message
25    Bind,
26    /// Describe message
27    Describe,
28    /// Execute message
29    Execute,
30    /// Sync message
31    Sync,
32    /// Flush message
33    Flush,
34    /// Close message
35    Close,
36    /// Terminate message
37    Terminate,
38    /// Copy data
39    CopyData,
40    /// Copy done
41    CopyDone,
42    /// Copy fail
43    CopyFail,
44    /// Function call (deprecated)
45    FunctionCall,
46    /// Backend key data
47    BackendKeyData,
48    /// Parameter status
49    ParameterStatus,
50    /// Ready for query
51    ReadyForQuery,
52    /// Row description
53    RowDescription,
54    /// Data row
55    DataRow,
56    /// Command complete
57    CommandComplete,
58    /// Empty query response
59    EmptyQueryResponse,
60    /// Error response
61    ErrorResponse,
62    /// Notice response
63    NoticeResponse,
64    /// Notification response
65    NotificationResponse,
66    /// Parse complete
67    ParseComplete,
68    /// Bind complete
69    BindComplete,
70    /// Close complete
71    CloseComplete,
72    /// Portal suspended
73    PortalSuspended,
74    /// No data
75    NoData,
76    /// Parameter description
77    ParameterDescription,
78    /// Unknown message
79    Unknown(u8),
80}
81
82impl MessageType {
83    /// Get message type from tag byte
84    pub fn from_tag(tag: u8) -> Self {
85        match tag {
86            b'Q' => MessageType::Query,
87            b'P' => MessageType::Parse,
88            b'B' => MessageType::Bind,
89            b'D' => MessageType::Describe,
90            b'E' => MessageType::Execute,
91            b'S' => MessageType::Sync,
92            b'H' => MessageType::Flush,
93            b'C' => MessageType::Close,
94            b'X' => MessageType::Terminate,
95            b'd' => MessageType::CopyData,
96            b'c' => MessageType::CopyDone,
97            b'f' => MessageType::CopyFail,
98            b'F' => MessageType::FunctionCall,
99            b'p' => MessageType::Password,
100            b'R' => MessageType::AuthRequest,
101            b'K' => MessageType::BackendKeyData,
102            // Note: server-side D/E/C/S tags (DataRow, ErrorResponse,
103            // CommandComplete, ParameterStatus) collide with client-side
104            // Describe/Execute/Close/Sync above; from_tag() is direction-
105            // agnostic and resolves them to the client-side variants.
106            // Disambiguation, when needed, lives at the call site.
107            b'Z' => MessageType::ReadyForQuery,
108            b'T' => MessageType::RowDescription,
109            b'I' => MessageType::EmptyQueryResponse,
110            b'N' => MessageType::NoticeResponse,
111            b'A' => MessageType::NotificationResponse,
112            b'1' => MessageType::ParseComplete,
113            b'2' => MessageType::BindComplete,
114            b'3' => MessageType::CloseComplete,
115            b's' => MessageType::PortalSuspended,
116            b'n' => MessageType::NoData,
117            b't' => MessageType::ParameterDescription,
118            _ => MessageType::Unknown(tag),
119        }
120    }
121
122    /// Get tag byte for message type
123    pub fn to_tag(&self) -> Option<u8> {
124        match self {
125            MessageType::Query => Some(b'Q'),
126            MessageType::Parse => Some(b'P'),
127            MessageType::Bind => Some(b'B'),
128            MessageType::Describe => Some(b'D'),
129            MessageType::Execute => Some(b'E'),
130            MessageType::Sync => Some(b'S'),
131            MessageType::Flush => Some(b'H'),
132            MessageType::Close => Some(b'C'),
133            MessageType::Terminate => Some(b'X'),
134            MessageType::CopyData => Some(b'd'),
135            MessageType::CopyDone => Some(b'c'),
136            MessageType::CopyFail => Some(b'f'),
137            MessageType::FunctionCall => Some(b'F'),
138            MessageType::Password => Some(b'p'),
139            MessageType::AuthRequest => Some(b'R'),
140            MessageType::BackendKeyData => Some(b'K'),
141            MessageType::ParameterStatus => Some(b'S'),
142            MessageType::ReadyForQuery => Some(b'Z'),
143            MessageType::RowDescription => Some(b'T'),
144            MessageType::DataRow => Some(b'D'),
145            MessageType::CommandComplete => Some(b'C'),
146            MessageType::EmptyQueryResponse => Some(b'I'),
147            MessageType::ErrorResponse => Some(b'E'),
148            MessageType::NoticeResponse => Some(b'N'),
149            MessageType::NotificationResponse => Some(b'A'),
150            MessageType::ParseComplete => Some(b'1'),
151            MessageType::BindComplete => Some(b'2'),
152            MessageType::CloseComplete => Some(b'3'),
153            MessageType::PortalSuspended => Some(b's'),
154            MessageType::NoData => Some(b'n'),
155            MessageType::ParameterDescription => Some(b't'),
156            _ => None,
157        }
158    }
159}
160
161/// A protocol message
162#[derive(Debug, Clone)]
163pub struct Message {
164    /// Message type
165    pub msg_type: MessageType,
166    /// Message payload
167    pub payload: BytesMut,
168}
169
170impl Message {
171    /// Create a new message
172    pub fn new(msg_type: MessageType, payload: BytesMut) -> Self {
173        Self { msg_type, payload }
174    }
175
176    /// Create an empty message
177    pub fn empty(msg_type: MessageType) -> Self {
178        Self {
179            msg_type,
180            payload: BytesMut::new(),
181        }
182    }
183
184    /// Encode message to bytes
185    pub fn encode(&self) -> BytesMut {
186        let mut buf = BytesMut::new();
187
188        if let Some(tag) = self.msg_type.to_tag() {
189            buf.put_u8(tag);
190        }
191
192        // Length includes itself (4 bytes)
193        let len = self.payload.len() as u32 + 4;
194        buf.put_u32(len);
195        buf.extend_from_slice(&self.payload);
196
197        buf
198    }
199}
200
201/// Protocol codec for framing messages
202pub struct ProtocolCodec {
203    /// Maximum message size
204    max_message_size: usize,
205}
206
207impl Default for ProtocolCodec {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213impl ProtocolCodec {
214    /// Create a new codec
215    pub fn new() -> Self {
216        Self {
217            max_message_size: 100 * 1024 * 1024, // 100MB max
218        }
219    }
220
221    /// Create codec with custom max message size
222    pub fn with_max_size(max_message_size: usize) -> Self {
223        Self { max_message_size }
224    }
225
226    /// Decode a startup message (no tag byte)
227    pub fn decode_startup(&self, src: &mut BytesMut) -> Result<Option<StartupMessage>> {
228        if src.len() < 4 {
229            return Ok(None);
230        }
231
232        let len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
233
234        if len > self.max_message_size {
235            return Err(ProxyError::Protocol(format!(
236                "Message too large: {} bytes",
237                len
238            )));
239        }
240
241        if src.len() < len {
242            return Ok(None);
243        }
244
245        src.advance(4);
246        let protocol_version = src.get_u32();
247
248        // Check for SSL request
249        if protocol_version == 80877103 {
250            return Ok(Some(StartupMessage::SSLRequest));
251        }
252
253        // Check for cancel request
254        if protocol_version == 80877102 {
255            let pid = src.get_u32();
256            let key = src.get_u32();
257            return Ok(Some(StartupMessage::CancelRequest { pid, key }));
258        }
259
260        // Parse parameters
261        let mut params = HashMap::new();
262        let remaining = len - 8; // Already read length and version
263        let mut param_bytes = src.split_to(remaining);
264
265        while param_bytes.has_remaining() {
266            let key = read_cstring(&mut param_bytes)?;
267            if key.is_empty() {
268                break;
269            }
270            let value = read_cstring(&mut param_bytes)?;
271            params.insert(key, value);
272        }
273
274        Ok(Some(StartupMessage::Startup {
275            protocol_version,
276            params,
277        }))
278    }
279
280    /// Decode a regular message (with tag byte)
281    pub fn decode_message(&self, src: &mut BytesMut) -> Result<Option<Message>> {
282        if src.len() < 5 {
283            return Ok(None);
284        }
285
286        let tag = src[0];
287        let len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
288
289        if len > self.max_message_size {
290            return Err(ProxyError::Protocol(format!(
291                "Message too large: {} bytes",
292                len
293            )));
294        }
295
296        // Length includes itself, so total message is 1 (tag) + len
297        let total_len = 1 + len;
298        if src.len() < total_len {
299            return Ok(None);
300        }
301
302        src.advance(5); // Skip tag and length
303        let payload = src.split_to(len - 4); // Length includes the 4-byte length field
304
305        let msg_type = MessageType::from_tag(tag);
306        Ok(Some(Message::new(msg_type, payload)))
307    }
308
309    /// Encode a message
310    pub fn encode_message(&self, msg: &Message) -> BytesMut {
311        msg.encode()
312    }
313}
314
315/// Startup message variants
316#[derive(Debug, Clone)]
317pub enum StartupMessage {
318    /// Regular startup
319    Startup {
320        protocol_version: u32,
321        params: HashMap<String, String>,
322    },
323    /// SSL request
324    SSLRequest,
325    /// Cancel request
326    CancelRequest { pid: u32, key: u32 },
327}
328
329/// Read a null-terminated string from the buffer.
330///
331/// Scans for the null terminator in a single pass (no per-byte `get_u8`
332/// loop, no Vec growth), then hands the exact-size byte slice to `String`.
333/// On `BytesMut`, `split_to` is O(1), and `BytesMut -> Vec<u8>` is
334/// zero-copy when (as here) the split-off buffer has a single owner.
335fn read_cstring(buf: &mut BytesMut) -> Result<String> {
336    let end = buf
337        .iter()
338        .position(|&b| b == 0)
339        .ok_or_else(|| ProxyError::Protocol(
340            "unterminated cstring in protocol buffer".to_string(),
341        ))?;
342
343    let bytes = buf.split_to(end);
344    buf.advance(1); // consume the null terminator
345
346    String::from_utf8(bytes.into())
347        .map_err(|e| ProxyError::Protocol(format!("Invalid UTF-8 in cstring: {}", e)))
348}
349
350/// Write a null-terminated string to buffer
351fn write_cstring(buf: &mut BytesMut, s: &str) {
352    buf.extend_from_slice(s.as_bytes());
353    buf.put_u8(0);
354}
355
356/// Borrow the SQL text out of a Query/Parse-style payload without
357/// copying. Mirrors `read_cstring` semantics (bytes up to the first
358/// NUL, strict UTF-8) but never allocates — for hot-path inspection
359/// where the message itself is forwarded verbatim.
360pub fn query_text(payload: &[u8]) -> Option<&str> {
361    let end = payload.iter().position(|&b| b == 0)?;
362    std::str::from_utf8(&payload[..end]).ok()
363}
364
365/// Case-insensitive ASCII prefix test without allocating an
366/// uppercased copy of the haystack.
367pub fn starts_with_ci(s: &str, prefix: &str) -> bool {
368    s.len() >= prefix.len() && s.as_bytes()[..prefix.len()].eq_ignore_ascii_case(prefix.as_bytes())
369}
370
371/// Case-insensitive ASCII substring test without allocating.
372pub fn contains_ci(haystack: &str, needle: &str) -> bool {
373    if needle.is_empty() {
374        return true;
375    }
376    if haystack.len() < needle.len() {
377        return false;
378    }
379    haystack
380        .as_bytes()
381        .windows(needle.len())
382        .any(|w| w.eq_ignore_ascii_case(needle.as_bytes()))
383}
384
385/// Query message payload
386#[derive(Debug, Clone)]
387pub struct QueryMessage {
388    pub query: String,
389}
390
391impl QueryMessage {
392    /// Parse from message payload
393    pub fn parse(mut payload: BytesMut) -> Result<Self> {
394        let query = read_cstring(&mut payload)?;
395        Ok(Self { query })
396    }
397
398    /// Encode to message
399    pub fn encode(&self) -> Message {
400        let mut payload = BytesMut::new();
401        write_cstring(&mut payload, &self.query);
402        Message::new(MessageType::Query, payload)
403    }
404}
405
406/// Parse message payload (prepared statement)
407#[derive(Debug, Clone)]
408pub struct ParseMessage {
409    pub name: String,
410    pub query: String,
411    pub param_types: Vec<u32>,
412}
413
414impl ParseMessage {
415    /// Parse from message payload
416    pub fn parse(mut payload: BytesMut) -> Result<Self> {
417        let name = read_cstring(&mut payload)?;
418        let query = read_cstring(&mut payload)?;
419
420        let num_params = payload.get_u16() as usize;
421        let mut param_types = Vec::with_capacity(num_params);
422
423        for _ in 0..num_params {
424            param_types.push(payload.get_u32());
425        }
426
427        Ok(Self {
428            name,
429            query,
430            param_types,
431        })
432    }
433
434    /// Encode to message
435    pub fn encode(&self) -> Message {
436        let mut payload = BytesMut::new();
437        write_cstring(&mut payload, &self.name);
438        write_cstring(&mut payload, &self.query);
439        payload.put_u16(self.param_types.len() as u16);
440        for &t in &self.param_types {
441            payload.put_u32(t);
442        }
443        Message::new(MessageType::Parse, payload)
444    }
445}
446
447/// Bind message payload
448///
449/// `param_values` uses [`bytes::Bytes`] so parameter values are held by
450/// reference into the original protocol buffer — no per-parameter `Vec`
451/// allocation during parse.
452#[derive(Debug, Clone)]
453pub struct BindMessage {
454    pub portal: String,
455    pub statement: String,
456    pub param_formats: Vec<i16>,
457    pub param_values: Vec<Option<Bytes>>,
458    pub result_formats: Vec<i16>,
459}
460
461impl BindMessage {
462    /// Parse from message payload
463    pub fn parse(mut payload: BytesMut) -> Result<Self> {
464        let portal = read_cstring(&mut payload)?;
465        let statement = read_cstring(&mut payload)?;
466
467        // Parameter formats
468        let num_formats = payload.get_u16() as usize;
469        let mut param_formats = Vec::with_capacity(num_formats);
470        for _ in 0..num_formats {
471            param_formats.push(payload.get_i16());
472        }
473
474        // Parameter values — zero-copy: `split_to` slices the Arc'd buffer
475        // and `freeze()` turns the split-off `BytesMut` into a shared
476        // `Bytes` without allocating.
477        let num_values = payload.get_u16() as usize;
478        let mut param_values = Vec::with_capacity(num_values);
479        for _ in 0..num_values {
480            let len = payload.get_i32();
481            if len == -1 {
482                param_values.push(None);
483            } else {
484                let value = payload.split_to(len as usize).freeze();
485                param_values.push(Some(value));
486            }
487        }
488
489        // Result formats
490        let num_result_formats = payload.get_u16() as usize;
491        let mut result_formats = Vec::with_capacity(num_result_formats);
492        for _ in 0..num_result_formats {
493            result_formats.push(payload.get_i16());
494        }
495
496        Ok(Self {
497            portal,
498            statement,
499            param_formats,
500            param_values,
501            result_formats,
502        })
503    }
504}
505
506/// Execute message payload
507#[derive(Debug, Clone)]
508pub struct ExecuteMessage {
509    pub portal: String,
510    pub max_rows: i32,
511}
512
513impl ExecuteMessage {
514    /// Parse from message payload
515    pub fn parse(mut payload: BytesMut) -> Result<Self> {
516        let portal = read_cstring(&mut payload)?;
517        let max_rows = payload.get_i32();
518        Ok(Self { portal, max_rows })
519    }
520
521    /// Encode to message
522    pub fn encode(&self) -> Message {
523        let mut payload = BytesMut::new();
524        write_cstring(&mut payload, &self.portal);
525        payload.put_i32(self.max_rows);
526        Message::new(MessageType::Execute, payload)
527    }
528}
529
530/// Error response message
531#[derive(Debug, Clone)]
532pub struct ErrorResponse {
533    pub fields: HashMap<char, String>,
534}
535
536impl ErrorResponse {
537    /// Parse from message payload
538    pub fn parse(mut payload: BytesMut) -> Result<Self> {
539        let mut fields = HashMap::new();
540
541        while payload.has_remaining() {
542            let code = payload.get_u8();
543            if code == 0 {
544                break;
545            }
546            let value = read_cstring(&mut payload)?;
547            fields.insert(code as char, value);
548        }
549
550        Ok(Self { fields })
551    }
552
553    /// Get severity
554    pub fn severity(&self) -> Option<&str> {
555        self.fields.get(&'S').map(|s| s.as_str())
556    }
557
558    /// Get error code
559    pub fn code(&self) -> Option<&str> {
560        self.fields.get(&'C').map(|s| s.as_str())
561    }
562
563    /// Get message
564    pub fn message(&self) -> Option<&str> {
565        self.fields.get(&'M').map(|s| s.as_str())
566    }
567
568    /// Encode to message
569    pub fn encode(&self) -> Message {
570        let mut payload = BytesMut::new();
571        for (&code, value) in &self.fields {
572            payload.put_u8(code as u8);
573            write_cstring(&mut payload, value);
574        }
575        payload.put_u8(0);
576        Message::new(MessageType::ErrorResponse, payload)
577    }
578}
579
580/// Ready for query message
581#[derive(Debug, Clone, Copy, PartialEq, Eq)]
582pub enum TransactionStatus {
583    /// Idle (not in transaction)
584    Idle,
585    /// In transaction block
586    InTransaction,
587    /// In failed transaction block
588    Failed,
589}
590
591impl TransactionStatus {
592    /// Parse from byte
593    pub fn from_byte(b: u8) -> Self {
594        match b {
595            b'I' => TransactionStatus::Idle,
596            b'T' => TransactionStatus::InTransaction,
597            b'E' => TransactionStatus::Failed,
598            _ => TransactionStatus::Idle,
599        }
600    }
601
602    /// Convert to byte
603    pub fn to_byte(&self) -> u8 {
604        match self {
605            TransactionStatus::Idle => b'I',
606            TransactionStatus::InTransaction => b'T',
607            TransactionStatus::Failed => b'E',
608        }
609    }
610}
611
612/// Command complete message
613#[derive(Debug, Clone)]
614pub struct CommandComplete {
615    pub tag: String,
616}
617
618impl CommandComplete {
619    /// Parse from message payload
620    pub fn parse(mut payload: BytesMut) -> Result<Self> {
621        let tag = read_cstring(&mut payload)?;
622        Ok(Self { tag })
623    }
624
625    /// Encode to message
626    pub fn encode(&self) -> Message {
627        let mut payload = BytesMut::new();
628        write_cstring(&mut payload, &self.tag);
629        Message::new(MessageType::CommandComplete, payload)
630    }
631
632    /// Get rows affected for INSERT/UPDATE/DELETE
633    pub fn rows_affected(&self) -> Option<u64> {
634        let parts: Vec<&str> = self.tag.split_whitespace().collect();
635        if parts.len() >= 2 {
636            parts.last()?.parse().ok()
637        } else {
638            None
639        }
640    }
641}
642
643/// Authentication request types
644#[derive(Debug, Clone)]
645pub enum AuthRequest {
646    /// Authentication OK
647    Ok,
648    /// Cleartext password
649    CleartextPassword,
650    /// MD5 password
651    Md5Password { salt: [u8; 4] },
652    /// SASL
653    SASL { mechanisms: Vec<String> },
654    /// SASL continue
655    SASLContinue { data: Vec<u8> },
656    /// SASL final
657    SASLFinal { data: Vec<u8> },
658    /// Unknown
659    Unknown(i32),
660}
661
662impl AuthRequest {
663    /// Parse from message payload
664    pub fn parse(mut payload: BytesMut) -> Result<Self> {
665        let auth_type = payload.get_i32();
666
667        Ok(match auth_type {
668            0 => AuthRequest::Ok,
669            3 => AuthRequest::CleartextPassword,
670            5 => {
671                let mut salt = [0u8; 4];
672                payload.copy_to_slice(&mut salt);
673                AuthRequest::Md5Password { salt }
674            }
675            10 => {
676                let mut mechanisms = Vec::new();
677                loop {
678                    let mech = read_cstring(&mut payload)?;
679                    if mech.is_empty() {
680                        break;
681                    }
682                    mechanisms.push(mech);
683                }
684                AuthRequest::SASL { mechanisms }
685            }
686            11 => {
687                let data = payload.to_vec();
688                AuthRequest::SASLContinue { data }
689            }
690            12 => {
691                let data = payload.to_vec();
692                AuthRequest::SASLFinal { data }
693            }
694            _ => AuthRequest::Unknown(auth_type),
695        })
696    }
697
698    /// Encode to message
699    pub fn encode(&self) -> Message {
700        let mut payload = BytesMut::new();
701
702        match self {
703            AuthRequest::Ok => {
704                payload.put_i32(0);
705            }
706            AuthRequest::CleartextPassword => {
707                payload.put_i32(3);
708            }
709            AuthRequest::Md5Password { salt } => {
710                payload.put_i32(5);
711                payload.extend_from_slice(salt);
712            }
713            AuthRequest::SASL { mechanisms } => {
714                payload.put_i32(10);
715                for mech in mechanisms {
716                    write_cstring(&mut payload, mech);
717                }
718                payload.put_u8(0);
719            }
720            AuthRequest::SASLContinue { data } => {
721                payload.put_i32(11);
722                payload.extend_from_slice(data);
723            }
724            AuthRequest::SASLFinal { data } => {
725                payload.put_i32(12);
726                payload.extend_from_slice(data);
727            }
728            AuthRequest::Unknown(t) => {
729                payload.put_i32(*t);
730            }
731        }
732
733        Message::new(MessageType::AuthRequest, payload)
734    }
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740
741    #[test]
742    fn test_message_type_round_trip() {
743        let types = vec![
744            MessageType::Query,
745            MessageType::Parse,
746            MessageType::Bind,
747            MessageType::Execute,
748            MessageType::Sync,
749        ];
750
751        for msg_type in types {
752            if let Some(tag) = msg_type.to_tag() {
753                let decoded = MessageType::from_tag(tag);
754                assert_eq!(decoded, msg_type);
755            }
756        }
757    }
758
759    #[test]
760    fn test_auth_request_tag_mapping() {
761        // Regression: 'R' (AuthenticationRequest) must decode to AuthRequest,
762        // not Unknown(82) — the backend client matches on this to authenticate.
763        assert_eq!(MessageType::from_tag(b'R'), MessageType::AuthRequest);
764        assert_eq!(MessageType::AuthRequest.to_tag(), Some(b'R'));
765    }
766
767    #[test]
768    fn test_query_message() {
769        let query = QueryMessage {
770            query: "SELECT 1".to_string(),
771        };
772        let msg = query.encode();
773        assert_eq!(msg.msg_type, MessageType::Query);
774
775        let decoded = QueryMessage::parse(msg.payload).unwrap();
776        assert_eq!(decoded.query, "SELECT 1");
777    }
778
779    #[test]
780    fn test_error_response() {
781        let mut fields = HashMap::new();
782        fields.insert('S', "ERROR".to_string());
783        fields.insert('C', "42P01".to_string());
784        fields.insert('M', "relation does not exist".to_string());
785
786        let err = ErrorResponse { fields };
787        assert_eq!(err.severity(), Some("ERROR"));
788        assert_eq!(err.code(), Some("42P01"));
789        assert_eq!(err.message(), Some("relation does not exist"));
790    }
791
792    #[test]
793    fn test_command_complete() {
794        let cmd = CommandComplete {
795            tag: "INSERT 0 5".to_string(),
796        };
797        assert_eq!(cmd.rows_affected(), Some(5));
798
799        let cmd2 = CommandComplete {
800            tag: "SELECT 100".to_string(),
801        };
802        assert_eq!(cmd2.rows_affected(), Some(100));
803    }
804
805    #[test]
806    fn test_transaction_status() {
807        assert_eq!(TransactionStatus::from_byte(b'I'), TransactionStatus::Idle);
808        assert_eq!(
809            TransactionStatus::from_byte(b'T'),
810            TransactionStatus::InTransaction
811        );
812        assert_eq!(TransactionStatus::from_byte(b'E'), TransactionStatus::Failed);
813
814        assert_eq!(TransactionStatus::Idle.to_byte(), b'I');
815        assert_eq!(TransactionStatus::InTransaction.to_byte(), b'T');
816        assert_eq!(TransactionStatus::Failed.to_byte(), b'E');
817    }
818
819    #[test]
820    fn test_protocol_codec() {
821        let codec = ProtocolCodec::new();
822        let query = QueryMessage {
823            query: "SELECT 1".to_string(),
824        };
825        let msg = query.encode();
826        let encoded = codec.encode_message(&msg);
827
828        assert!(encoded.len() > 5);
829        assert_eq!(encoded[0], b'Q');
830    }
831
832    /// An unterminated cstring must surface a protocol error, not be
833    /// silently treated as the full remaining buffer (as the old
834    /// incremental-push loop did).
835    #[test]
836    fn test_read_cstring_unterminated() {
837        let mut buf = BytesMut::from("not-null-terminated");
838        let err = read_cstring(&mut buf).expect_err("should reject unterminated cstring");
839        assert!(
840            matches!(err, ProxyError::Protocol(_)),
841            "expected Protocol error, got {err:?}"
842        );
843    }
844
845    /// Multiple cstrings back-to-back in the same buffer must parse
846    /// independently and leave the tail intact for subsequent fields.
847    #[test]
848    fn test_read_cstring_sequence() {
849        let mut buf = BytesMut::new();
850        buf.put_slice(b"first\0second\0tail");
851        let a = read_cstring(&mut buf).unwrap();
852        let b = read_cstring(&mut buf).unwrap();
853        assert_eq!(a, "first");
854        assert_eq!(b, "second");
855        assert_eq!(&buf[..], b"tail");
856    }
857
858    /// BindMessage parameter values are now `Bytes` (zero-copy), not
859    /// `Vec<u8>`. Round-trip a synthetic payload and confirm the
860    /// parsed values match.
861    #[test]
862    fn test_bind_message_param_values_are_bytes() {
863        let mut payload = BytesMut::new();
864        // portal, statement (both empty)
865        payload.put_u8(0);
866        payload.put_u8(0);
867        // one param format: 0 (text)
868        payload.put_u16(1);
869        payload.put_i16(0);
870        // two params: "hi" (2 bytes) and NULL (-1)
871        payload.put_u16(2);
872        payload.put_i32(2);
873        payload.put_slice(b"hi");
874        payload.put_i32(-1);
875        // zero result formats
876        payload.put_u16(0);
877
878        let bind = BindMessage::parse(payload).expect("parse failed");
879        assert_eq!(bind.param_values.len(), 2);
880        match &bind.param_values[0] {
881            Some(b) => assert_eq!(b.as_ref(), b"hi"),
882            None => panic!("first param must be Some"),
883        }
884        assert!(bind.param_values[1].is_none());
885    }
886}