Skip to main content

heliosdb_proxy/
protocol.rs

1//! Protocol Handling
2//!
3//! Wire protocol parsing and serialization for HeliosDB proxy.
4
5use crate::{ProxyError, Result};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use std::collections::HashMap;
8
9/// Protocol message types
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum MessageType {
12    /// Startup message from client
13    Startup,
14    /// SSL request
15    SSLRequest,
16    /// Authentication request
17    AuthRequest,
18    /// Password message
19    Password,
20    /// Query message
21    Query,
22    /// Parse message (prepared statement)
23    Parse,
24    /// Bind message
25    Bind,
26    /// Describe message
27    Describe,
28    /// Execute message
29    Execute,
30    /// Sync message
31    Sync,
32    /// Flush message
33    Flush,
34    /// Close message
35    Close,
36    /// Terminate message
37    Terminate,
38    /// Copy data
39    CopyData,
40    /// Copy done
41    CopyDone,
42    /// Copy fail
43    CopyFail,
44    /// Function call (deprecated)
45    FunctionCall,
46    /// Backend key data
47    BackendKeyData,
48    /// Parameter status
49    ParameterStatus,
50    /// Ready for query
51    ReadyForQuery,
52    /// Row description
53    RowDescription,
54    /// Data row
55    DataRow,
56    /// Command complete
57    CommandComplete,
58    /// Empty query response
59    EmptyQueryResponse,
60    /// Error response
61    ErrorResponse,
62    /// Notice response
63    NoticeResponse,
64    /// Notification response
65    NotificationResponse,
66    /// Parse complete
67    ParseComplete,
68    /// Bind complete
69    BindComplete,
70    /// Close complete
71    CloseComplete,
72    /// Portal suspended
73    PortalSuspended,
74    /// No data
75    NoData,
76    /// Parameter description
77    ParameterDescription,
78    /// Unknown message
79    Unknown(u8),
80}
81
82impl MessageType {
83    /// Get message type from tag byte
84    pub fn from_tag(tag: u8) -> Self {
85        match tag {
86            b'Q' => MessageType::Query,
87            b'P' => MessageType::Parse,
88            b'B' => MessageType::Bind,
89            b'D' => MessageType::Describe,
90            b'E' => MessageType::Execute,
91            b'S' => MessageType::Sync,
92            b'H' => MessageType::Flush,
93            b'C' => MessageType::Close,
94            b'X' => MessageType::Terminate,
95            b'd' => MessageType::CopyData,
96            b'c' => MessageType::CopyDone,
97            b'f' => MessageType::CopyFail,
98            b'F' => MessageType::FunctionCall,
99            b'p' => MessageType::Password,
100            b'R' => MessageType::AuthRequest,
101            b'K' => MessageType::BackendKeyData,
102            // Note: server-side D/E/C/S tags (DataRow, ErrorResponse,
103            // CommandComplete, ParameterStatus) collide with client-side
104            // Describe/Execute/Close/Sync above; from_tag() is direction-
105            // agnostic and resolves them to the client-side variants.
106            // Disambiguation, when needed, lives at the call site.
107            b'Z' => MessageType::ReadyForQuery,
108            b'T' => MessageType::RowDescription,
109            b'I' => MessageType::EmptyQueryResponse,
110            b'N' => MessageType::NoticeResponse,
111            b'A' => MessageType::NotificationResponse,
112            b'1' => MessageType::ParseComplete,
113            b'2' => MessageType::BindComplete,
114            b'3' => MessageType::CloseComplete,
115            b's' => MessageType::PortalSuspended,
116            b'n' => MessageType::NoData,
117            b't' => MessageType::ParameterDescription,
118            _ => MessageType::Unknown(tag),
119        }
120    }
121
122    /// Get tag byte for message type
123    pub fn to_tag(&self) -> Option<u8> {
124        match self {
125            MessageType::Query => Some(b'Q'),
126            MessageType::Parse => Some(b'P'),
127            MessageType::Bind => Some(b'B'),
128            MessageType::Describe => Some(b'D'),
129            MessageType::Execute => Some(b'E'),
130            MessageType::Sync => Some(b'S'),
131            MessageType::Flush => Some(b'H'),
132            MessageType::Close => Some(b'C'),
133            MessageType::Terminate => Some(b'X'),
134            MessageType::CopyData => Some(b'd'),
135            MessageType::CopyDone => Some(b'c'),
136            MessageType::CopyFail => Some(b'f'),
137            MessageType::FunctionCall => Some(b'F'),
138            MessageType::Password => Some(b'p'),
139            MessageType::AuthRequest => Some(b'R'),
140            MessageType::BackendKeyData => Some(b'K'),
141            MessageType::ParameterStatus => Some(b'S'),
142            MessageType::ReadyForQuery => Some(b'Z'),
143            MessageType::RowDescription => Some(b'T'),
144            MessageType::DataRow => Some(b'D'),
145            MessageType::CommandComplete => Some(b'C'),
146            MessageType::EmptyQueryResponse => Some(b'I'),
147            MessageType::ErrorResponse => Some(b'E'),
148            MessageType::NoticeResponse => Some(b'N'),
149            MessageType::NotificationResponse => Some(b'A'),
150            MessageType::ParseComplete => Some(b'1'),
151            MessageType::BindComplete => Some(b'2'),
152            MessageType::CloseComplete => Some(b'3'),
153            MessageType::PortalSuspended => Some(b's'),
154            MessageType::NoData => Some(b'n'),
155            MessageType::ParameterDescription => Some(b't'),
156            _ => None,
157        }
158    }
159}
160
161/// A protocol message
162#[derive(Debug, Clone)]
163pub struct Message {
164    /// Message type
165    pub msg_type: MessageType,
166    /// Message payload
167    pub payload: BytesMut,
168}
169
170impl Message {
171    /// Create a new message
172    pub fn new(msg_type: MessageType, payload: BytesMut) -> Self {
173        Self { msg_type, payload }
174    }
175
176    /// Create an empty message
177    pub fn empty(msg_type: MessageType) -> Self {
178        Self {
179            msg_type,
180            payload: BytesMut::new(),
181        }
182    }
183
184    /// Encode message to bytes
185    pub fn encode(&self) -> BytesMut {
186        let mut buf = BytesMut::new();
187
188        if let Some(tag) = self.msg_type.to_tag() {
189            buf.put_u8(tag);
190        }
191
192        // Length includes itself (4 bytes)
193        let len = self.payload.len() as u32 + 4;
194        buf.put_u32(len);
195        buf.extend_from_slice(&self.payload);
196
197        buf
198    }
199}
200
201/// Protocol codec for framing messages
202pub struct ProtocolCodec {
203    /// Maximum message size
204    max_message_size: usize,
205}
206
207impl Default for ProtocolCodec {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213impl ProtocolCodec {
214    /// Create a new codec
215    pub fn new() -> Self {
216        Self {
217            max_message_size: 100 * 1024 * 1024, // 100MB max
218        }
219    }
220
221    /// Create codec with custom max message size
222    pub fn with_max_size(max_message_size: usize) -> Self {
223        Self { max_message_size }
224    }
225
226    /// Decode a startup message (no tag byte)
227    pub fn decode_startup(&self, src: &mut BytesMut) -> Result<Option<StartupMessage>> {
228        if src.len() < 4 {
229            return Ok(None);
230        }
231
232        let len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
233
234        if len > self.max_message_size {
235            return Err(ProxyError::Protocol(format!(
236                "Message too large: {} bytes",
237                len
238            )));
239        }
240
241        if src.len() < len {
242            return Ok(None);
243        }
244
245        src.advance(4);
246        let protocol_version = src.get_u32();
247
248        // Check for SSL request
249        if protocol_version == 80877103 {
250            return Ok(Some(StartupMessage::SSLRequest));
251        }
252
253        // Check for cancel request
254        if protocol_version == 80877102 {
255            let pid = src.get_u32();
256            let key = src.get_u32();
257            return Ok(Some(StartupMessage::CancelRequest { pid, key }));
258        }
259
260        // Parse parameters
261        let mut params = HashMap::new();
262        let remaining = len - 8; // Already read length and version
263        let mut param_bytes = src.split_to(remaining);
264
265        while param_bytes.has_remaining() {
266            let key = read_cstring(&mut param_bytes)?;
267            if key.is_empty() {
268                break;
269            }
270            let value = read_cstring(&mut param_bytes)?;
271            params.insert(key, value);
272        }
273
274        Ok(Some(StartupMessage::Startup {
275            protocol_version,
276            params,
277        }))
278    }
279
280    /// Decode a regular message (with tag byte)
281    pub fn decode_message(&self, src: &mut BytesMut) -> Result<Option<Message>> {
282        if src.len() < 5 {
283            return Ok(None);
284        }
285
286        let tag = src[0];
287        let len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
288
289        if len > self.max_message_size {
290            return Err(ProxyError::Protocol(format!(
291                "Message too large: {} bytes",
292                len
293            )));
294        }
295
296        // Length includes itself, so total message is 1 (tag) + len
297        let total_len = 1 + len;
298        if src.len() < total_len {
299            return Ok(None);
300        }
301
302        src.advance(5); // Skip tag and length
303        let payload = src.split_to(len - 4); // Length includes the 4-byte length field
304
305        let msg_type = MessageType::from_tag(tag);
306        Ok(Some(Message::new(msg_type, payload)))
307    }
308
309    /// Encode a message
310    pub fn encode_message(&self, msg: &Message) -> BytesMut {
311        msg.encode()
312    }
313}
314
315/// Startup message variants
316#[derive(Debug, Clone)]
317pub enum StartupMessage {
318    /// Regular startup
319    Startup {
320        protocol_version: u32,
321        params: HashMap<String, String>,
322    },
323    /// SSL request
324    SSLRequest,
325    /// Cancel request
326    CancelRequest { pid: u32, key: u32 },
327}
328
329/// Read a null-terminated string from the buffer.
330///
331/// Scans for the null terminator in a single pass (no per-byte `get_u8`
332/// loop, no Vec growth), then hands the exact-size byte slice to `String`.
333/// On `BytesMut`, `split_to` is O(1), and `BytesMut -> Vec<u8>` is
334/// zero-copy when (as here) the split-off buffer has a single owner.
335fn read_cstring(buf: &mut BytesMut) -> Result<String> {
336    let end = buf.iter().position(|&b| b == 0).ok_or_else(|| {
337        ProxyError::Protocol("unterminated cstring in protocol buffer".to_string())
338    })?;
339
340    let bytes = buf.split_to(end);
341    buf.advance(1); // consume the null terminator
342
343    String::from_utf8(bytes.into())
344        .map_err(|e| ProxyError::Protocol(format!("Invalid UTF-8 in cstring: {}", e)))
345}
346
347/// Write a null-terminated string to buffer
348fn write_cstring(buf: &mut BytesMut, s: &str) {
349    buf.extend_from_slice(s.as_bytes());
350    buf.put_u8(0);
351}
352
353/// Borrow the SQL text out of a Query/Parse-style payload without
354/// copying. Mirrors `read_cstring` semantics (bytes up to the first
355/// NUL, strict UTF-8) but never allocates — for hot-path inspection
356/// where the message itself is forwarded verbatim.
357pub fn query_text(payload: &[u8]) -> Option<&str> {
358    let end = payload.iter().position(|&b| b == 0)?;
359    std::str::from_utf8(&payload[..end]).ok()
360}
361
362/// Case-insensitive ASCII prefix test without allocating an
363/// uppercased copy of the haystack.
364pub fn starts_with_ci(s: &str, prefix: &str) -> bool {
365    s.len() >= prefix.len() && s.as_bytes()[..prefix.len()].eq_ignore_ascii_case(prefix.as_bytes())
366}
367
368/// Case-insensitive ASCII substring test without allocating.
369pub fn contains_ci(haystack: &str, needle: &str) -> bool {
370    if needle.is_empty() {
371        return true;
372    }
373    if haystack.len() < needle.len() {
374        return false;
375    }
376    haystack
377        .as_bytes()
378        .windows(needle.len())
379        .any(|w| w.eq_ignore_ascii_case(needle.as_bytes()))
380}
381
382/// Query message payload
383#[derive(Debug, Clone)]
384pub struct QueryMessage {
385    pub query: String,
386}
387
388impl QueryMessage {
389    /// Parse from message payload
390    pub fn parse(mut payload: BytesMut) -> Result<Self> {
391        let query = read_cstring(&mut payload)?;
392        Ok(Self { query })
393    }
394
395    /// Encode to message
396    pub fn encode(&self) -> Message {
397        let mut payload = BytesMut::new();
398        write_cstring(&mut payload, &self.query);
399        Message::new(MessageType::Query, payload)
400    }
401}
402
403/// Parse message payload (prepared statement)
404#[derive(Debug, Clone)]
405pub struct ParseMessage {
406    pub name: String,
407    pub query: String,
408    pub param_types: Vec<u32>,
409}
410
411impl ParseMessage {
412    /// Parse from message payload
413    pub fn parse(mut payload: BytesMut) -> Result<Self> {
414        let name = read_cstring(&mut payload)?;
415        let query = read_cstring(&mut payload)?;
416
417        let num_params = payload.get_u16() as usize;
418        let mut param_types = Vec::with_capacity(num_params);
419
420        for _ in 0..num_params {
421            param_types.push(payload.get_u32());
422        }
423
424        Ok(Self {
425            name,
426            query,
427            param_types,
428        })
429    }
430
431    /// Encode to message
432    pub fn encode(&self) -> Message {
433        let mut payload = BytesMut::new();
434        write_cstring(&mut payload, &self.name);
435        write_cstring(&mut payload, &self.query);
436        payload.put_u16(self.param_types.len() as u16);
437        for &t in &self.param_types {
438            payload.put_u32(t);
439        }
440        Message::new(MessageType::Parse, payload)
441    }
442}
443
444/// Bind message payload
445///
446/// `param_values` uses [`bytes::Bytes`] so parameter values are held by
447/// reference into the original protocol buffer — no per-parameter `Vec`
448/// allocation during parse.
449#[derive(Debug, Clone)]
450pub struct BindMessage {
451    pub portal: String,
452    pub statement: String,
453    pub param_formats: Vec<i16>,
454    pub param_values: Vec<Option<Bytes>>,
455    pub result_formats: Vec<i16>,
456}
457
458impl BindMessage {
459    /// Parse from message payload
460    pub fn parse(mut payload: BytesMut) -> Result<Self> {
461        let portal = read_cstring(&mut payload)?;
462        let statement = read_cstring(&mut payload)?;
463
464        // Parameter formats
465        let num_formats = payload.get_u16() as usize;
466        let mut param_formats = Vec::with_capacity(num_formats);
467        for _ in 0..num_formats {
468            param_formats.push(payload.get_i16());
469        }
470
471        // Parameter values — zero-copy: `split_to` slices the Arc'd buffer
472        // and `freeze()` turns the split-off `BytesMut` into a shared
473        // `Bytes` without allocating.
474        let num_values = payload.get_u16() as usize;
475        let mut param_values = Vec::with_capacity(num_values);
476        for _ in 0..num_values {
477            let len = payload.get_i32();
478            if len == -1 {
479                param_values.push(None);
480            } else {
481                let value = payload.split_to(len as usize).freeze();
482                param_values.push(Some(value));
483            }
484        }
485
486        // Result formats
487        let num_result_formats = payload.get_u16() as usize;
488        let mut result_formats = Vec::with_capacity(num_result_formats);
489        for _ in 0..num_result_formats {
490            result_formats.push(payload.get_i16());
491        }
492
493        Ok(Self {
494            portal,
495            statement,
496            param_formats,
497            param_values,
498            result_formats,
499        })
500    }
501}
502
503/// Execute message payload
504#[derive(Debug, Clone)]
505pub struct ExecuteMessage {
506    pub portal: String,
507    pub max_rows: i32,
508}
509
510impl ExecuteMessage {
511    /// Parse from message payload
512    pub fn parse(mut payload: BytesMut) -> Result<Self> {
513        let portal = read_cstring(&mut payload)?;
514        let max_rows = payload.get_i32();
515        Ok(Self { portal, max_rows })
516    }
517
518    /// Encode to message
519    pub fn encode(&self) -> Message {
520        let mut payload = BytesMut::new();
521        write_cstring(&mut payload, &self.portal);
522        payload.put_i32(self.max_rows);
523        Message::new(MessageType::Execute, payload)
524    }
525}
526
527/// Error response message
528#[derive(Debug, Clone)]
529pub struct ErrorResponse {
530    pub fields: HashMap<char, String>,
531}
532
533impl ErrorResponse {
534    /// Parse from message payload
535    pub fn parse(mut payload: BytesMut) -> Result<Self> {
536        let mut fields = HashMap::new();
537
538        while payload.has_remaining() {
539            let code = payload.get_u8();
540            if code == 0 {
541                break;
542            }
543            let value = read_cstring(&mut payload)?;
544            fields.insert(code as char, value);
545        }
546
547        Ok(Self { fields })
548    }
549
550    /// Get severity
551    pub fn severity(&self) -> Option<&str> {
552        self.fields.get(&'S').map(|s| s.as_str())
553    }
554
555    /// Get error code
556    pub fn code(&self) -> Option<&str> {
557        self.fields.get(&'C').map(|s| s.as_str())
558    }
559
560    /// Get message
561    pub fn message(&self) -> Option<&str> {
562        self.fields.get(&'M').map(|s| s.as_str())
563    }
564
565    /// Encode to message
566    pub fn encode(&self) -> Message {
567        let mut payload = BytesMut::new();
568        for (&code, value) in &self.fields {
569            payload.put_u8(code as u8);
570            write_cstring(&mut payload, value);
571        }
572        payload.put_u8(0);
573        Message::new(MessageType::ErrorResponse, payload)
574    }
575}
576
577/// Ready for query message
578#[derive(Debug, Clone, Copy, PartialEq, Eq)]
579pub enum TransactionStatus {
580    /// Idle (not in transaction)
581    Idle,
582    /// In transaction block
583    InTransaction,
584    /// In failed transaction block
585    Failed,
586}
587
588impl TransactionStatus {
589    /// Parse from byte
590    pub fn from_byte(b: u8) -> Self {
591        match b {
592            b'I' => TransactionStatus::Idle,
593            b'T' => TransactionStatus::InTransaction,
594            b'E' => TransactionStatus::Failed,
595            _ => TransactionStatus::Idle,
596        }
597    }
598
599    /// Convert to byte
600    pub fn to_byte(&self) -> u8 {
601        match self {
602            TransactionStatus::Idle => b'I',
603            TransactionStatus::InTransaction => b'T',
604            TransactionStatus::Failed => b'E',
605        }
606    }
607}
608
609/// Command complete message
610#[derive(Debug, Clone)]
611pub struct CommandComplete {
612    pub tag: String,
613}
614
615impl CommandComplete {
616    /// Parse from message payload
617    pub fn parse(mut payload: BytesMut) -> Result<Self> {
618        let tag = read_cstring(&mut payload)?;
619        Ok(Self { tag })
620    }
621
622    /// Encode to message
623    pub fn encode(&self) -> Message {
624        let mut payload = BytesMut::new();
625        write_cstring(&mut payload, &self.tag);
626        Message::new(MessageType::CommandComplete, payload)
627    }
628
629    /// Get rows affected for INSERT/UPDATE/DELETE
630    pub fn rows_affected(&self) -> Option<u64> {
631        let parts: Vec<&str> = self.tag.split_whitespace().collect();
632        if parts.len() >= 2 {
633            parts.last()?.parse().ok()
634        } else {
635            None
636        }
637    }
638}
639
640/// Authentication request types
641#[derive(Debug, Clone)]
642pub enum AuthRequest {
643    /// Authentication OK
644    Ok,
645    /// Cleartext password
646    CleartextPassword,
647    /// MD5 password
648    Md5Password { salt: [u8; 4] },
649    /// SASL
650    SASL { mechanisms: Vec<String> },
651    /// SASL continue
652    SASLContinue { data: Vec<u8> },
653    /// SASL final
654    SASLFinal { data: Vec<u8> },
655    /// Unknown
656    Unknown(i32),
657}
658
659impl AuthRequest {
660    /// Parse from message payload
661    pub fn parse(mut payload: BytesMut) -> Result<Self> {
662        let auth_type = payload.get_i32();
663
664        Ok(match auth_type {
665            0 => AuthRequest::Ok,
666            3 => AuthRequest::CleartextPassword,
667            5 => {
668                let mut salt = [0u8; 4];
669                payload.copy_to_slice(&mut salt);
670                AuthRequest::Md5Password { salt }
671            }
672            10 => {
673                let mut mechanisms = Vec::new();
674                loop {
675                    let mech = read_cstring(&mut payload)?;
676                    if mech.is_empty() {
677                        break;
678                    }
679                    mechanisms.push(mech);
680                }
681                AuthRequest::SASL { mechanisms }
682            }
683            11 => {
684                let data = payload.to_vec();
685                AuthRequest::SASLContinue { data }
686            }
687            12 => {
688                let data = payload.to_vec();
689                AuthRequest::SASLFinal { data }
690            }
691            _ => AuthRequest::Unknown(auth_type),
692        })
693    }
694
695    /// Encode to message
696    pub fn encode(&self) -> Message {
697        let mut payload = BytesMut::new();
698
699        match self {
700            AuthRequest::Ok => {
701                payload.put_i32(0);
702            }
703            AuthRequest::CleartextPassword => {
704                payload.put_i32(3);
705            }
706            AuthRequest::Md5Password { salt } => {
707                payload.put_i32(5);
708                payload.extend_from_slice(salt);
709            }
710            AuthRequest::SASL { mechanisms } => {
711                payload.put_i32(10);
712                for mech in mechanisms {
713                    write_cstring(&mut payload, mech);
714                }
715                payload.put_u8(0);
716            }
717            AuthRequest::SASLContinue { data } => {
718                payload.put_i32(11);
719                payload.extend_from_slice(data);
720            }
721            AuthRequest::SASLFinal { data } => {
722                payload.put_i32(12);
723                payload.extend_from_slice(data);
724            }
725            AuthRequest::Unknown(t) => {
726                payload.put_i32(*t);
727            }
728        }
729
730        Message::new(MessageType::AuthRequest, payload)
731    }
732}
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737
738    #[test]
739    fn test_message_type_round_trip() {
740        let types = vec![
741            MessageType::Query,
742            MessageType::Parse,
743            MessageType::Bind,
744            MessageType::Execute,
745            MessageType::Sync,
746        ];
747
748        for msg_type in types {
749            if let Some(tag) = msg_type.to_tag() {
750                let decoded = MessageType::from_tag(tag);
751                assert_eq!(decoded, msg_type);
752            }
753        }
754    }
755
756    #[test]
757    fn test_auth_request_tag_mapping() {
758        // Regression: 'R' (AuthenticationRequest) must decode to AuthRequest,
759        // not Unknown(82) — the backend client matches on this to authenticate.
760        assert_eq!(MessageType::from_tag(b'R'), MessageType::AuthRequest);
761        assert_eq!(MessageType::AuthRequest.to_tag(), Some(b'R'));
762    }
763
764    #[test]
765    fn test_query_message() {
766        let query = QueryMessage {
767            query: "SELECT 1".to_string(),
768        };
769        let msg = query.encode();
770        assert_eq!(msg.msg_type, MessageType::Query);
771
772        let decoded = QueryMessage::parse(msg.payload).unwrap();
773        assert_eq!(decoded.query, "SELECT 1");
774    }
775
776    #[test]
777    fn test_error_response() {
778        let mut fields = HashMap::new();
779        fields.insert('S', "ERROR".to_string());
780        fields.insert('C', "42P01".to_string());
781        fields.insert('M', "relation does not exist".to_string());
782
783        let err = ErrorResponse { fields };
784        assert_eq!(err.severity(), Some("ERROR"));
785        assert_eq!(err.code(), Some("42P01"));
786        assert_eq!(err.message(), Some("relation does not exist"));
787    }
788
789    #[test]
790    fn test_command_complete() {
791        let cmd = CommandComplete {
792            tag: "INSERT 0 5".to_string(),
793        };
794        assert_eq!(cmd.rows_affected(), Some(5));
795
796        let cmd2 = CommandComplete {
797            tag: "SELECT 100".to_string(),
798        };
799        assert_eq!(cmd2.rows_affected(), Some(100));
800    }
801
802    #[test]
803    fn test_transaction_status() {
804        assert_eq!(TransactionStatus::from_byte(b'I'), TransactionStatus::Idle);
805        assert_eq!(
806            TransactionStatus::from_byte(b'T'),
807            TransactionStatus::InTransaction
808        );
809        assert_eq!(
810            TransactionStatus::from_byte(b'E'),
811            TransactionStatus::Failed
812        );
813
814        assert_eq!(TransactionStatus::Idle.to_byte(), b'I');
815        assert_eq!(TransactionStatus::InTransaction.to_byte(), b'T');
816        assert_eq!(TransactionStatus::Failed.to_byte(), b'E');
817    }
818
819    #[test]
820    fn test_protocol_codec() {
821        let codec = ProtocolCodec::new();
822        let query = QueryMessage {
823            query: "SELECT 1".to_string(),
824        };
825        let msg = query.encode();
826        let encoded = codec.encode_message(&msg);
827
828        assert!(encoded.len() > 5);
829        assert_eq!(encoded[0], b'Q');
830    }
831
832    /// An unterminated cstring must surface a protocol error, not be
833    /// silently treated as the full remaining buffer (as the old
834    /// incremental-push loop did).
835    #[test]
836    fn test_read_cstring_unterminated() {
837        let mut buf = BytesMut::from("not-null-terminated");
838        let err = read_cstring(&mut buf).expect_err("should reject unterminated cstring");
839        assert!(
840            matches!(err, ProxyError::Protocol(_)),
841            "expected Protocol error, got {err:?}"
842        );
843    }
844
845    /// Multiple cstrings back-to-back in the same buffer must parse
846    /// independently and leave the tail intact for subsequent fields.
847    #[test]
848    fn test_read_cstring_sequence() {
849        let mut buf = BytesMut::new();
850        buf.put_slice(b"first\0second\0tail");
851        let a = read_cstring(&mut buf).unwrap();
852        let b = read_cstring(&mut buf).unwrap();
853        assert_eq!(a, "first");
854        assert_eq!(b, "second");
855        assert_eq!(&buf[..], b"tail");
856    }
857
858    /// BindMessage parameter values are now `Bytes` (zero-copy), not
859    /// `Vec<u8>`. Round-trip a synthetic payload and confirm the
860    /// parsed values match.
861    #[test]
862    fn test_bind_message_param_values_are_bytes() {
863        let mut payload = BytesMut::new();
864        // portal, statement (both empty)
865        payload.put_u8(0);
866        payload.put_u8(0);
867        // one param format: 0 (text)
868        payload.put_u16(1);
869        payload.put_i16(0);
870        // two params: "hi" (2 bytes) and NULL (-1)
871        payload.put_u16(2);
872        payload.put_i32(2);
873        payload.put_slice(b"hi");
874        payload.put_i32(-1);
875        // zero result formats
876        payload.put_u16(0);
877
878        let bind = BindMessage::parse(payload).expect("parse failed");
879        assert_eq!(bind.param_values.len(), 2);
880        match &bind.param_values[0] {
881            Some(b) => assert_eq!(b.as_ref(), b"hi"),
882            None => panic!("first param must be Some"),
883        }
884        assert!(bind.param_values[1].is_none());
885    }
886}