Skip to main content

heliosdb_proxy/
protocol.rs

1//! Protocol Handling
2//!
3//! Wire protocol parsing and serialization for HeliosDB proxy.
4
5use crate::{ProxyError, Result};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use std::collections::HashMap;
8
9/// Protocol message types
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum MessageType {
12    /// Startup message from client
13    Startup,
14    /// SSL request
15    SSLRequest,
16    /// Authentication request
17    AuthRequest,
18    /// Password message
19    Password,
20    /// Query message
21    Query,
22    /// Parse message (prepared statement)
23    Parse,
24    /// Bind message
25    Bind,
26    /// Describe message
27    Describe,
28    /// Execute message
29    Execute,
30    /// Sync message
31    Sync,
32    /// Flush message
33    Flush,
34    /// Close message
35    Close,
36    /// Terminate message
37    Terminate,
38    /// Copy data
39    CopyData,
40    /// Copy done
41    CopyDone,
42    /// Copy fail
43    CopyFail,
44    /// Function call (deprecated)
45    FunctionCall,
46    /// Backend key data
47    BackendKeyData,
48    /// Parameter status
49    ParameterStatus,
50    /// Ready for query
51    ReadyForQuery,
52    /// Row description
53    RowDescription,
54    /// Data row
55    DataRow,
56    /// Command complete
57    CommandComplete,
58    /// Empty query response
59    EmptyQueryResponse,
60    /// Error response
61    ErrorResponse,
62    /// Notice response
63    NoticeResponse,
64    /// Notification response
65    NotificationResponse,
66    /// Parse complete
67    ParseComplete,
68    /// Bind complete
69    BindComplete,
70    /// Close complete
71    CloseComplete,
72    /// Portal suspended
73    PortalSuspended,
74    /// No data
75    NoData,
76    /// Parameter description
77    ParameterDescription,
78    /// Unknown message
79    Unknown(u8),
80}
81
82impl MessageType {
83    /// Get message type from tag byte
84    pub fn from_tag(tag: u8) -> Self {
85        match tag {
86            b'Q' => MessageType::Query,
87            b'P' => MessageType::Parse,
88            b'B' => MessageType::Bind,
89            b'D' => MessageType::Describe,
90            b'E' => MessageType::Execute,
91            b'S' => MessageType::Sync,
92            b'H' => MessageType::Flush,
93            b'C' => MessageType::Close,
94            b'X' => MessageType::Terminate,
95            b'd' => MessageType::CopyData,
96            b'c' => MessageType::CopyDone,
97            b'f' => MessageType::CopyFail,
98            b'F' => MessageType::FunctionCall,
99            b'p' => MessageType::Password,
100            b'K' => MessageType::BackendKeyData,
101            // Note: server-side D/E/C/S tags (DataRow, ErrorResponse,
102            // CommandComplete, ParameterStatus) collide with client-side
103            // Describe/Execute/Close/Sync above; from_tag() is direction-
104            // agnostic and resolves them to the client-side variants.
105            // Disambiguation, when needed, lives at the call site.
106            b'Z' => MessageType::ReadyForQuery,
107            b'T' => MessageType::RowDescription,
108            b'I' => MessageType::EmptyQueryResponse,
109            b'N' => MessageType::NoticeResponse,
110            b'A' => MessageType::NotificationResponse,
111            b'1' => MessageType::ParseComplete,
112            b'2' => MessageType::BindComplete,
113            b'3' => MessageType::CloseComplete,
114            b's' => MessageType::PortalSuspended,
115            b'n' => MessageType::NoData,
116            b't' => MessageType::ParameterDescription,
117            _ => MessageType::Unknown(tag),
118        }
119    }
120
121    /// Get tag byte for message type
122    pub fn to_tag(&self) -> Option<u8> {
123        match self {
124            MessageType::Query => Some(b'Q'),
125            MessageType::Parse => Some(b'P'),
126            MessageType::Bind => Some(b'B'),
127            MessageType::Describe => Some(b'D'),
128            MessageType::Execute => Some(b'E'),
129            MessageType::Sync => Some(b'S'),
130            MessageType::Flush => Some(b'H'),
131            MessageType::Close => Some(b'C'),
132            MessageType::Terminate => Some(b'X'),
133            MessageType::CopyData => Some(b'd'),
134            MessageType::CopyDone => Some(b'c'),
135            MessageType::CopyFail => Some(b'f'),
136            MessageType::FunctionCall => Some(b'F'),
137            MessageType::Password => Some(b'p'),
138            MessageType::BackendKeyData => Some(b'K'),
139            MessageType::ParameterStatus => Some(b'S'),
140            MessageType::ReadyForQuery => Some(b'Z'),
141            MessageType::RowDescription => Some(b'T'),
142            MessageType::DataRow => Some(b'D'),
143            MessageType::CommandComplete => Some(b'C'),
144            MessageType::EmptyQueryResponse => Some(b'I'),
145            MessageType::ErrorResponse => Some(b'E'),
146            MessageType::NoticeResponse => Some(b'N'),
147            MessageType::NotificationResponse => Some(b'A'),
148            MessageType::ParseComplete => Some(b'1'),
149            MessageType::BindComplete => Some(b'2'),
150            MessageType::CloseComplete => Some(b'3'),
151            MessageType::PortalSuspended => Some(b's'),
152            MessageType::NoData => Some(b'n'),
153            MessageType::ParameterDescription => Some(b't'),
154            _ => None,
155        }
156    }
157}
158
159/// A protocol message
160#[derive(Debug, Clone)]
161pub struct Message {
162    /// Message type
163    pub msg_type: MessageType,
164    /// Message payload
165    pub payload: BytesMut,
166}
167
168impl Message {
169    /// Create a new message
170    pub fn new(msg_type: MessageType, payload: BytesMut) -> Self {
171        Self { msg_type, payload }
172    }
173
174    /// Create an empty message
175    pub fn empty(msg_type: MessageType) -> Self {
176        Self {
177            msg_type,
178            payload: BytesMut::new(),
179        }
180    }
181
182    /// Encode message to bytes
183    pub fn encode(&self) -> BytesMut {
184        let mut buf = BytesMut::new();
185
186        if let Some(tag) = self.msg_type.to_tag() {
187            buf.put_u8(tag);
188        }
189
190        // Length includes itself (4 bytes)
191        let len = self.payload.len() as u32 + 4;
192        buf.put_u32(len);
193        buf.extend_from_slice(&self.payload);
194
195        buf
196    }
197}
198
199/// Protocol codec for framing messages
200pub struct ProtocolCodec {
201    /// Maximum message size
202    max_message_size: usize,
203}
204
205impl Default for ProtocolCodec {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211impl ProtocolCodec {
212    /// Create a new codec
213    pub fn new() -> Self {
214        Self {
215            max_message_size: 100 * 1024 * 1024, // 100MB max
216        }
217    }
218
219    /// Create codec with custom max message size
220    pub fn with_max_size(max_message_size: usize) -> Self {
221        Self { max_message_size }
222    }
223
224    /// Decode a startup message (no tag byte)
225    pub fn decode_startup(&self, src: &mut BytesMut) -> Result<Option<StartupMessage>> {
226        if src.len() < 4 {
227            return Ok(None);
228        }
229
230        let len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
231
232        if len > self.max_message_size {
233            return Err(ProxyError::Protocol(format!(
234                "Message too large: {} bytes",
235                len
236            )));
237        }
238
239        if src.len() < len {
240            return Ok(None);
241        }
242
243        src.advance(4);
244        let protocol_version = src.get_u32();
245
246        // Check for SSL request
247        if protocol_version == 80877103 {
248            return Ok(Some(StartupMessage::SSLRequest));
249        }
250
251        // Check for cancel request
252        if protocol_version == 80877102 {
253            let pid = src.get_u32();
254            let key = src.get_u32();
255            return Ok(Some(StartupMessage::CancelRequest { pid, key }));
256        }
257
258        // Parse parameters
259        let mut params = HashMap::new();
260        let remaining = len - 8; // Already read length and version
261        let mut param_bytes = src.split_to(remaining);
262
263        while param_bytes.has_remaining() {
264            let key = read_cstring(&mut param_bytes)?;
265            if key.is_empty() {
266                break;
267            }
268            let value = read_cstring(&mut param_bytes)?;
269            params.insert(key, value);
270        }
271
272        Ok(Some(StartupMessage::Startup {
273            protocol_version,
274            params,
275        }))
276    }
277
278    /// Decode a regular message (with tag byte)
279    pub fn decode_message(&self, src: &mut BytesMut) -> Result<Option<Message>> {
280        if src.len() < 5 {
281            return Ok(None);
282        }
283
284        let tag = src[0];
285        let len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
286
287        if len > self.max_message_size {
288            return Err(ProxyError::Protocol(format!(
289                "Message too large: {} bytes",
290                len
291            )));
292        }
293
294        // Length includes itself, so total message is 1 (tag) + len
295        let total_len = 1 + len;
296        if src.len() < total_len {
297            return Ok(None);
298        }
299
300        src.advance(5); // Skip tag and length
301        let payload = src.split_to(len - 4); // Length includes the 4-byte length field
302
303        let msg_type = MessageType::from_tag(tag);
304        Ok(Some(Message::new(msg_type, payload)))
305    }
306
307    /// Encode a message
308    pub fn encode_message(&self, msg: &Message) -> BytesMut {
309        msg.encode()
310    }
311}
312
313/// Startup message variants
314#[derive(Debug, Clone)]
315pub enum StartupMessage {
316    /// Regular startup
317    Startup {
318        protocol_version: u32,
319        params: HashMap<String, String>,
320    },
321    /// SSL request
322    SSLRequest,
323    /// Cancel request
324    CancelRequest { pid: u32, key: u32 },
325}
326
327/// Read a null-terminated string from the buffer.
328///
329/// Scans for the null terminator in a single pass (no per-byte `get_u8`
330/// loop, no Vec growth), then hands the exact-size byte slice to `String`.
331/// On `BytesMut`, `split_to` is O(1), and `BytesMut -> Vec<u8>` is
332/// zero-copy when (as here) the split-off buffer has a single owner.
333fn read_cstring(buf: &mut BytesMut) -> Result<String> {
334    let end = buf
335        .iter()
336        .position(|&b| b == 0)
337        .ok_or_else(|| ProxyError::Protocol(
338            "unterminated cstring in protocol buffer".to_string(),
339        ))?;
340
341    let bytes = buf.split_to(end);
342    buf.advance(1); // consume the null terminator
343
344    String::from_utf8(bytes.into())
345        .map_err(|e| ProxyError::Protocol(format!("Invalid UTF-8 in cstring: {}", e)))
346}
347
348/// Write a null-terminated string to buffer
349fn write_cstring(buf: &mut BytesMut, s: &str) {
350    buf.extend_from_slice(s.as_bytes());
351    buf.put_u8(0);
352}
353
354/// Query message payload
355#[derive(Debug, Clone)]
356pub struct QueryMessage {
357    pub query: String,
358}
359
360impl QueryMessage {
361    /// Parse from message payload
362    pub fn parse(mut payload: BytesMut) -> Result<Self> {
363        let query = read_cstring(&mut payload)?;
364        Ok(Self { query })
365    }
366
367    /// Encode to message
368    pub fn encode(&self) -> Message {
369        let mut payload = BytesMut::new();
370        write_cstring(&mut payload, &self.query);
371        Message::new(MessageType::Query, payload)
372    }
373}
374
375/// Parse message payload (prepared statement)
376#[derive(Debug, Clone)]
377pub struct ParseMessage {
378    pub name: String,
379    pub query: String,
380    pub param_types: Vec<u32>,
381}
382
383impl ParseMessage {
384    /// Parse from message payload
385    pub fn parse(mut payload: BytesMut) -> Result<Self> {
386        let name = read_cstring(&mut payload)?;
387        let query = read_cstring(&mut payload)?;
388
389        let num_params = payload.get_u16() as usize;
390        let mut param_types = Vec::with_capacity(num_params);
391
392        for _ in 0..num_params {
393            param_types.push(payload.get_u32());
394        }
395
396        Ok(Self {
397            name,
398            query,
399            param_types,
400        })
401    }
402
403    /// Encode to message
404    pub fn encode(&self) -> Message {
405        let mut payload = BytesMut::new();
406        write_cstring(&mut payload, &self.name);
407        write_cstring(&mut payload, &self.query);
408        payload.put_u16(self.param_types.len() as u16);
409        for &t in &self.param_types {
410            payload.put_u32(t);
411        }
412        Message::new(MessageType::Parse, payload)
413    }
414}
415
416/// Bind message payload
417///
418/// `param_values` uses [`bytes::Bytes`] so parameter values are held by
419/// reference into the original protocol buffer — no per-parameter `Vec`
420/// allocation during parse.
421#[derive(Debug, Clone)]
422pub struct BindMessage {
423    pub portal: String,
424    pub statement: String,
425    pub param_formats: Vec<i16>,
426    pub param_values: Vec<Option<Bytes>>,
427    pub result_formats: Vec<i16>,
428}
429
430impl BindMessage {
431    /// Parse from message payload
432    pub fn parse(mut payload: BytesMut) -> Result<Self> {
433        let portal = read_cstring(&mut payload)?;
434        let statement = read_cstring(&mut payload)?;
435
436        // Parameter formats
437        let num_formats = payload.get_u16() as usize;
438        let mut param_formats = Vec::with_capacity(num_formats);
439        for _ in 0..num_formats {
440            param_formats.push(payload.get_i16());
441        }
442
443        // Parameter values — zero-copy: `split_to` slices the Arc'd buffer
444        // and `freeze()` turns the split-off `BytesMut` into a shared
445        // `Bytes` without allocating.
446        let num_values = payload.get_u16() as usize;
447        let mut param_values = Vec::with_capacity(num_values);
448        for _ in 0..num_values {
449            let len = payload.get_i32();
450            if len == -1 {
451                param_values.push(None);
452            } else {
453                let value = payload.split_to(len as usize).freeze();
454                param_values.push(Some(value));
455            }
456        }
457
458        // Result formats
459        let num_result_formats = payload.get_u16() as usize;
460        let mut result_formats = Vec::with_capacity(num_result_formats);
461        for _ in 0..num_result_formats {
462            result_formats.push(payload.get_i16());
463        }
464
465        Ok(Self {
466            portal,
467            statement,
468            param_formats,
469            param_values,
470            result_formats,
471        })
472    }
473}
474
475/// Execute message payload
476#[derive(Debug, Clone)]
477pub struct ExecuteMessage {
478    pub portal: String,
479    pub max_rows: i32,
480}
481
482impl ExecuteMessage {
483    /// Parse from message payload
484    pub fn parse(mut payload: BytesMut) -> Result<Self> {
485        let portal = read_cstring(&mut payload)?;
486        let max_rows = payload.get_i32();
487        Ok(Self { portal, max_rows })
488    }
489
490    /// Encode to message
491    pub fn encode(&self) -> Message {
492        let mut payload = BytesMut::new();
493        write_cstring(&mut payload, &self.portal);
494        payload.put_i32(self.max_rows);
495        Message::new(MessageType::Execute, payload)
496    }
497}
498
499/// Error response message
500#[derive(Debug, Clone)]
501pub struct ErrorResponse {
502    pub fields: HashMap<char, String>,
503}
504
505impl ErrorResponse {
506    /// Parse from message payload
507    pub fn parse(mut payload: BytesMut) -> Result<Self> {
508        let mut fields = HashMap::new();
509
510        while payload.has_remaining() {
511            let code = payload.get_u8();
512            if code == 0 {
513                break;
514            }
515            let value = read_cstring(&mut payload)?;
516            fields.insert(code as char, value);
517        }
518
519        Ok(Self { fields })
520    }
521
522    /// Get severity
523    pub fn severity(&self) -> Option<&str> {
524        self.fields.get(&'S').map(|s| s.as_str())
525    }
526
527    /// Get error code
528    pub fn code(&self) -> Option<&str> {
529        self.fields.get(&'C').map(|s| s.as_str())
530    }
531
532    /// Get message
533    pub fn message(&self) -> Option<&str> {
534        self.fields.get(&'M').map(|s| s.as_str())
535    }
536
537    /// Encode to message
538    pub fn encode(&self) -> Message {
539        let mut payload = BytesMut::new();
540        for (&code, value) in &self.fields {
541            payload.put_u8(code as u8);
542            write_cstring(&mut payload, value);
543        }
544        payload.put_u8(0);
545        Message::new(MessageType::ErrorResponse, payload)
546    }
547}
548
549/// Ready for query message
550#[derive(Debug, Clone, Copy, PartialEq, Eq)]
551pub enum TransactionStatus {
552    /// Idle (not in transaction)
553    Idle,
554    /// In transaction block
555    InTransaction,
556    /// In failed transaction block
557    Failed,
558}
559
560impl TransactionStatus {
561    /// Parse from byte
562    pub fn from_byte(b: u8) -> Self {
563        match b {
564            b'I' => TransactionStatus::Idle,
565            b'T' => TransactionStatus::InTransaction,
566            b'E' => TransactionStatus::Failed,
567            _ => TransactionStatus::Idle,
568        }
569    }
570
571    /// Convert to byte
572    pub fn to_byte(&self) -> u8 {
573        match self {
574            TransactionStatus::Idle => b'I',
575            TransactionStatus::InTransaction => b'T',
576            TransactionStatus::Failed => b'E',
577        }
578    }
579}
580
581/// Command complete message
582#[derive(Debug, Clone)]
583pub struct CommandComplete {
584    pub tag: String,
585}
586
587impl CommandComplete {
588    /// Parse from message payload
589    pub fn parse(mut payload: BytesMut) -> Result<Self> {
590        let tag = read_cstring(&mut payload)?;
591        Ok(Self { tag })
592    }
593
594    /// Encode to message
595    pub fn encode(&self) -> Message {
596        let mut payload = BytesMut::new();
597        write_cstring(&mut payload, &self.tag);
598        Message::new(MessageType::CommandComplete, payload)
599    }
600
601    /// Get rows affected for INSERT/UPDATE/DELETE
602    pub fn rows_affected(&self) -> Option<u64> {
603        let parts: Vec<&str> = self.tag.split_whitespace().collect();
604        if parts.len() >= 2 {
605            parts.last()?.parse().ok()
606        } else {
607            None
608        }
609    }
610}
611
612/// Authentication request types
613#[derive(Debug, Clone)]
614pub enum AuthRequest {
615    /// Authentication OK
616    Ok,
617    /// Cleartext password
618    CleartextPassword,
619    /// MD5 password
620    Md5Password { salt: [u8; 4] },
621    /// SASL
622    SASL { mechanisms: Vec<String> },
623    /// SASL continue
624    SASLContinue { data: Vec<u8> },
625    /// SASL final
626    SASLFinal { data: Vec<u8> },
627    /// Unknown
628    Unknown(i32),
629}
630
631impl AuthRequest {
632    /// Parse from message payload
633    pub fn parse(mut payload: BytesMut) -> Result<Self> {
634        let auth_type = payload.get_i32();
635
636        Ok(match auth_type {
637            0 => AuthRequest::Ok,
638            3 => AuthRequest::CleartextPassword,
639            5 => {
640                let mut salt = [0u8; 4];
641                payload.copy_to_slice(&mut salt);
642                AuthRequest::Md5Password { salt }
643            }
644            10 => {
645                let mut mechanisms = Vec::new();
646                loop {
647                    let mech = read_cstring(&mut payload)?;
648                    if mech.is_empty() {
649                        break;
650                    }
651                    mechanisms.push(mech);
652                }
653                AuthRequest::SASL { mechanisms }
654            }
655            11 => {
656                let data = payload.to_vec();
657                AuthRequest::SASLContinue { data }
658            }
659            12 => {
660                let data = payload.to_vec();
661                AuthRequest::SASLFinal { data }
662            }
663            _ => AuthRequest::Unknown(auth_type),
664        })
665    }
666
667    /// Encode to message
668    pub fn encode(&self) -> Message {
669        let mut payload = BytesMut::new();
670
671        match self {
672            AuthRequest::Ok => {
673                payload.put_i32(0);
674            }
675            AuthRequest::CleartextPassword => {
676                payload.put_i32(3);
677            }
678            AuthRequest::Md5Password { salt } => {
679                payload.put_i32(5);
680                payload.extend_from_slice(salt);
681            }
682            AuthRequest::SASL { mechanisms } => {
683                payload.put_i32(10);
684                for mech in mechanisms {
685                    write_cstring(&mut payload, mech);
686                }
687                payload.put_u8(0);
688            }
689            AuthRequest::SASLContinue { data } => {
690                payload.put_i32(11);
691                payload.extend_from_slice(data);
692            }
693            AuthRequest::SASLFinal { data } => {
694                payload.put_i32(12);
695                payload.extend_from_slice(data);
696            }
697            AuthRequest::Unknown(t) => {
698                payload.put_i32(*t);
699            }
700        }
701
702        Message::new(MessageType::AuthRequest, payload)
703    }
704}
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709
710    #[test]
711    fn test_message_type_round_trip() {
712        let types = vec![
713            MessageType::Query,
714            MessageType::Parse,
715            MessageType::Bind,
716            MessageType::Execute,
717            MessageType::Sync,
718        ];
719
720        for msg_type in types {
721            if let Some(tag) = msg_type.to_tag() {
722                let decoded = MessageType::from_tag(tag);
723                assert_eq!(decoded, msg_type);
724            }
725        }
726    }
727
728    #[test]
729    fn test_query_message() {
730        let query = QueryMessage {
731            query: "SELECT 1".to_string(),
732        };
733        let msg = query.encode();
734        assert_eq!(msg.msg_type, MessageType::Query);
735
736        let decoded = QueryMessage::parse(msg.payload).unwrap();
737        assert_eq!(decoded.query, "SELECT 1");
738    }
739
740    #[test]
741    fn test_error_response() {
742        let mut fields = HashMap::new();
743        fields.insert('S', "ERROR".to_string());
744        fields.insert('C', "42P01".to_string());
745        fields.insert('M', "relation does not exist".to_string());
746
747        let err = ErrorResponse { fields };
748        assert_eq!(err.severity(), Some("ERROR"));
749        assert_eq!(err.code(), Some("42P01"));
750        assert_eq!(err.message(), Some("relation does not exist"));
751    }
752
753    #[test]
754    fn test_command_complete() {
755        let cmd = CommandComplete {
756            tag: "INSERT 0 5".to_string(),
757        };
758        assert_eq!(cmd.rows_affected(), Some(5));
759
760        let cmd2 = CommandComplete {
761            tag: "SELECT 100".to_string(),
762        };
763        assert_eq!(cmd2.rows_affected(), Some(100));
764    }
765
766    #[test]
767    fn test_transaction_status() {
768        assert_eq!(TransactionStatus::from_byte(b'I'), TransactionStatus::Idle);
769        assert_eq!(
770            TransactionStatus::from_byte(b'T'),
771            TransactionStatus::InTransaction
772        );
773        assert_eq!(TransactionStatus::from_byte(b'E'), TransactionStatus::Failed);
774
775        assert_eq!(TransactionStatus::Idle.to_byte(), b'I');
776        assert_eq!(TransactionStatus::InTransaction.to_byte(), b'T');
777        assert_eq!(TransactionStatus::Failed.to_byte(), b'E');
778    }
779
780    #[test]
781    fn test_protocol_codec() {
782        let codec = ProtocolCodec::new();
783        let query = QueryMessage {
784            query: "SELECT 1".to_string(),
785        };
786        let msg = query.encode();
787        let encoded = codec.encode_message(&msg);
788
789        assert!(encoded.len() > 5);
790        assert_eq!(encoded[0], b'Q');
791    }
792
793    /// An unterminated cstring must surface a protocol error, not be
794    /// silently treated as the full remaining buffer (as the old
795    /// incremental-push loop did).
796    #[test]
797    fn test_read_cstring_unterminated() {
798        let mut buf = BytesMut::from("not-null-terminated");
799        let err = read_cstring(&mut buf).expect_err("should reject unterminated cstring");
800        assert!(
801            matches!(err, ProxyError::Protocol(_)),
802            "expected Protocol error, got {err:?}"
803        );
804    }
805
806    /// Multiple cstrings back-to-back in the same buffer must parse
807    /// independently and leave the tail intact for subsequent fields.
808    #[test]
809    fn test_read_cstring_sequence() {
810        let mut buf = BytesMut::new();
811        buf.put_slice(b"first\0second\0tail");
812        let a = read_cstring(&mut buf).unwrap();
813        let b = read_cstring(&mut buf).unwrap();
814        assert_eq!(a, "first");
815        assert_eq!(b, "second");
816        assert_eq!(&buf[..], b"tail");
817    }
818
819    /// BindMessage parameter values are now `Bytes` (zero-copy), not
820    /// `Vec<u8>`. Round-trip a synthetic payload and confirm the
821    /// parsed values match.
822    #[test]
823    fn test_bind_message_param_values_are_bytes() {
824        let mut payload = BytesMut::new();
825        // portal, statement (both empty)
826        payload.put_u8(0);
827        payload.put_u8(0);
828        // one param format: 0 (text)
829        payload.put_u16(1);
830        payload.put_i16(0);
831        // two params: "hi" (2 bytes) and NULL (-1)
832        payload.put_u16(2);
833        payload.put_i32(2);
834        payload.put_slice(b"hi");
835        payload.put_i32(-1);
836        // zero result formats
837        payload.put_u16(0);
838
839        let bind = BindMessage::parse(payload).expect("parse failed");
840        assert_eq!(bind.param_values.len(), 2);
841        match &bind.param_values[0] {
842            Some(b) => assert_eq!(b.as_ref(), b"hi"),
843            None => panic!("first param must be Some"),
844        }
845        assert!(bind.param_values[1].is_none());
846    }
847}