Skip to main content

vibesql_server/protocol/
messages.rs

1use bytes::{Buf, BufMut, BytesMut};
2use std::collections::HashMap;
3use std::io;
4use thiserror::Error;
5
6/// Wire protocol configuration for selective column updates
7///
8/// Sent by clients to override server-level selective update thresholds
9/// on a per-subscription basis.
10#[derive(Debug, Clone, PartialEq)]
11pub struct SelectiveUpdatesConfig {
12    /// Enable/disable selective updates for this subscription
13    pub enabled: Option<bool>,
14    /// Minimum columns that must change to use selective update
15    /// If fewer columns change, send full row instead
16    pub min_changed_columns: Option<usize>,
17    /// Maximum ratio of changed columns before falling back to full row
18    /// E.g., 0.5 means if >50% of columns changed, send full row instead
19    pub max_changed_columns_ratio: Option<f64>,
20}
21
22/// PostgreSQL protocol errors
23#[derive(Debug, Error)]
24pub enum ProtocolError {
25    #[error("I/O error: {0}")]
26    Io(#[from] io::Error),
27
28    #[error("Invalid message type: {0}")]
29    InvalidMessageType(u8),
30
31    #[error("Message too short")]
32    MessageTooShort,
33
34    #[error("Invalid message length: {0}")]
35    InvalidMessageLength(i32),
36
37    #[error("Invalid string encoding")]
38    InvalidString,
39
40    #[error("Unexpected message: {0}")]
41    #[allow(dead_code)]
42    UnexpectedMessage(String),
43}
44
45/// Subscription update type for SubscriptionData message
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47#[repr(u8)]
48pub enum SubscriptionUpdateType {
49    Full = 0,
50    DeltaInsert = 1,
51    DeltaUpdate = 2,
52    DeltaDelete = 3,
53    /// Selective column update - only changed columns are sent
54    /// Used with SubscriptionPartialData message
55    SelectiveUpdate = 4,
56}
57
58/// A partial row update containing only changed columns
59///
60/// Used for selective column updates to reduce bandwidth when only
61/// a few columns change in a wide table.
62#[derive(Debug, Clone, PartialEq)]
63pub struct PartialRowUpdate {
64    /// Total number of columns in the full row (for bitmap sizing)
65    pub total_columns: u16,
66    /// Bitmap indicating which columns are present (1 bit per column)
67    /// Bit 0 = column 0, Bit 1 = column 1, etc.
68    /// A set bit means the column value is included in `values`
69    pub column_mask: Vec<u8>,
70    /// Values for columns with set bits in column_mask, in column order
71    /// None = NULL, Some(bytes) = value data
72    pub values: Vec<Option<Vec<u8>>>,
73}
74
75impl PartialRowUpdate {
76    /// Create a new partial row update
77    ///
78    /// # Arguments
79    /// * `total_columns` - Total number of columns in the full row
80    /// * `present_columns` - Indices of columns that are present in this update
81    /// * `values` - Values for the present columns, in same order as present_columns
82    pub fn new(total_columns: u16, present_columns: &[u16], values: Vec<Option<Vec<u8>>>) -> Self {
83        debug_assert_eq!(present_columns.len(), values.len());
84
85        // Create bitmap
86        let bitmap_bytes = (total_columns as usize).div_ceil(8);
87        let mut column_mask = vec![0u8; bitmap_bytes];
88
89        for &col_idx in present_columns {
90            if (col_idx as usize) < total_columns as usize {
91                let byte_idx = col_idx as usize / 8;
92                let bit_idx = col_idx as usize % 8;
93                column_mask[byte_idx] |= 1 << bit_idx;
94            }
95        }
96
97        Self { total_columns, column_mask, values }
98    }
99
100    /// Check if a column is present in this update
101    pub fn is_column_present(&self, col_idx: u16) -> bool {
102        if col_idx >= self.total_columns {
103            return false;
104        }
105        let byte_idx = col_idx as usize / 8;
106        let bit_idx = col_idx as usize % 8;
107        if byte_idx < self.column_mask.len() {
108            (self.column_mask[byte_idx] & (1 << bit_idx)) != 0
109        } else {
110            false
111        }
112    }
113
114    /// Get the number of present columns
115    pub fn present_column_count(&self) -> usize {
116        self.column_mask.iter().map(|b| b.count_ones() as usize).sum()
117    }
118}
119
120/// Backend message types (server -> client)
121#[derive(Debug, Clone, PartialEq)]
122pub enum BackendMessage {
123    /// Authentication request
124    AuthenticationOk,
125    #[allow(dead_code)]
126    AuthenticationCleartextPassword,
127    #[allow(dead_code)]
128    AuthenticationMD5Password { salt: [u8; 4] },
129
130    /// Parameter status
131    ParameterStatus { name: String, value: String },
132
133    /// Backend key data (for cancellation)
134    BackendKeyData { process_id: i32, secret_key: i32 },
135
136    /// Ready for query
137    ReadyForQuery { status: TransactionStatus },
138
139    /// Row description (result set schema)
140    RowDescription { fields: Vec<FieldDescription> },
141
142    /// Data row
143    DataRow { values: Vec<Option<Vec<u8>>> },
144
145    /// Command complete
146    CommandComplete { tag: String },
147
148    /// Error response
149    ErrorResponse { fields: HashMap<u8, String> },
150
151    /// Notice response
152    #[allow(dead_code)]
153    NoticeResponse { fields: HashMap<u8, String> },
154
155    /// Empty query response
156    EmptyQueryResponse,
157
158    /// Subscription data (0xF2) - query result update
159    SubscriptionData {
160        subscription_id: [u8; 16],
161        update_type: SubscriptionUpdateType,
162        rows: Vec<Vec<Option<Vec<u8>>>>,
163    },
164
165    /// Subscription error (0xF3) - subscription error notification
166    SubscriptionError { subscription_id: [u8; 16], message: String },
167
168/// Subscription acknowledgment (0xF4) - confirms subscription registration
169    /// Sent immediately after a subscription is registered, before initial data
170    SubscriptionAck {
171        subscription_id: [u8; 16],
172        /// Number of table dependencies the subscription monitors
173        table_count: u16,
174    },
175
176    /// Subscription partial data (0xF7) - selective column update
177    ///
178    /// Used for sending only changed columns in row updates, reducing bandwidth
179    /// for wide tables where only a few columns change frequently.
180    ///
181    /// Wire format:
182    /// - 1 byte: Message type (0xF7)
183    /// - 4 bytes: Length (big-endian)
184    /// - 16 bytes: Subscription ID
185    /// - 1 byte: Update type (always SelectiveUpdate = 4)
186    /// - 4 bytes: Row count (big-endian)
187    /// - For each row:
188    ///   - 2 bytes: Total column count (big-endian)
189    ///   - N bytes: Column presence bitmap (ceil(total_columns / 8) bytes)
190    ///   - For each present column (bit=1):
191    ///     - 4 bytes: Value length (-1 for NULL)
192    ///     - M bytes: Value data (if length >= 0)
193    SubscriptionPartialData {
194        subscription_id: [u8; 16],
195        /// Partial row updates with column bitmaps
196        rows: Vec<PartialRowUpdate>,
197    },
198}
199
200/// Transaction status
201#[derive(Debug, Clone, Copy, PartialEq, Eq)]
202pub enum TransactionStatus {
203    /// Idle (not in a transaction)
204    Idle,
205    /// In a transaction block
206    #[allow(dead_code)]
207    InTransaction,
208    /// In a failed transaction block
209    #[allow(dead_code)]
210    FailedTransaction,
211}
212
213impl TransactionStatus {
214    pub fn as_byte(&self) -> u8 {
215        match self {
216            TransactionStatus::Idle => b'I',
217            TransactionStatus::InTransaction => b'T',
218            TransactionStatus::FailedTransaction => b'E',
219        }
220    }
221}
222
223/// Field description for row data
224#[derive(Debug, Clone, PartialEq)]
225pub struct FieldDescription {
226    pub name: String,
227    pub table_oid: i32,
228    pub column_attr_number: i16,
229    pub data_type_oid: i32,
230    pub data_type_size: i16,
231    pub type_modifier: i32,
232    pub format_code: i16, // 0 = text, 1 = binary
233}
234
235/// Frontend message types (client -> server)
236#[derive(Debug, Clone, PartialEq)]
237pub enum FrontendMessage {
238    /// Startup message
239    Startup { protocol_version: i32, params: HashMap<String, String> },
240
241    /// Password message
242    Password { password: String },
243
244    /// Query message
245    Query { query: String },
246
247    /// Terminate message
248    Terminate,
249
250    /// SSL request
251    SSLRequest,
252
253    /// Subscribe message (0xF0) - subscribe to query
254    /// The optional filter is a SQL WHERE clause expression applied to subscription updates.
255    /// The optional selective_updates_config allows clients to override server-level selective update settings.
256    Subscribe {
257        query: String,
258        params: Vec<Option<Vec<u8>>>,
259        filter: Option<String>,
260        selective_updates_config: Option<SelectiveUpdatesConfig>,
261    },
262
263    /// Unsubscribe message (0xF1) - cancel subscription
264    Unsubscribe { subscription_id: [u8; 16] },
265
266    /// Pause subscription message (0xF5) - temporarily pause updates
267    SubscriptionPause { subscription_id: [u8; 16] },
268
269    /// Resume subscription message (0xF6) - resume paused subscription
270    SubscriptionResume { subscription_id: [u8; 16] },
271}
272
273impl BackendMessage {
274    /// Encode a backend message to bytes
275    pub fn encode(&self, buf: &mut BytesMut) {
276        match self {
277            BackendMessage::AuthenticationOk => {
278                buf.put_u8(b'R'); // Authentication
279                buf.put_i32(8); // Length including self
280                buf.put_i32(0); // AuthenticationOk
281            }
282
283            BackendMessage::AuthenticationCleartextPassword => {
284                buf.put_u8(b'R');
285                buf.put_i32(8);
286                buf.put_i32(3); // AuthenticationCleartextPassword
287            }
288
289            BackendMessage::AuthenticationMD5Password { salt } => {
290                buf.put_u8(b'R');
291                buf.put_i32(12);
292                buf.put_i32(5); // AuthenticationMD5Password
293                buf.put_slice(salt);
294            }
295
296            BackendMessage::ParameterStatus { name, value } => {
297                buf.put_u8(b'S'); // ParameterStatus
298                let len = 4 + name.len() + 1 + value.len() + 1;
299                buf.put_i32(len as i32);
300                put_cstring(buf, name);
301                put_cstring(buf, value);
302            }
303
304            BackendMessage::BackendKeyData { process_id, secret_key } => {
305                buf.put_u8(b'K'); // BackendKeyData
306                buf.put_i32(12);
307                buf.put_i32(*process_id);
308                buf.put_i32(*secret_key);
309            }
310
311            BackendMessage::ReadyForQuery { status } => {
312                buf.put_u8(b'Z'); // ReadyForQuery
313                buf.put_i32(5);
314                buf.put_u8(status.as_byte());
315            }
316
317            BackendMessage::RowDescription { fields } => {
318                buf.put_u8(b'T'); // RowDescription
319
320                // Calculate total length
321                let mut len = 4 + 2; // length + field count
322                for field in fields {
323                    len += field.name.len() + 1 + 18; // name + null + 6 i32/i16 fields
324                }
325
326                buf.put_i32(len as i32);
327                buf.put_i16(fields.len() as i16);
328
329                for field in fields {
330                    put_cstring(buf, &field.name);
331                    buf.put_i32(field.table_oid);
332                    buf.put_i16(field.column_attr_number);
333                    buf.put_i32(field.data_type_oid);
334                    buf.put_i16(field.data_type_size);
335                    buf.put_i32(field.type_modifier);
336                    buf.put_i16(field.format_code);
337                }
338            }
339
340            BackendMessage::DataRow { values } => {
341                buf.put_u8(b'D'); // DataRow
342
343                // Calculate total length
344                let mut len = 4 + 2; // length + field count
345                for value in values {
346                    len += 4; // length field
347                    if let Some(v) = value {
348                        len += v.len();
349                    }
350                }
351
352                buf.put_i32(len as i32);
353                buf.put_i16(values.len() as i16);
354
355                for value in values {
356                    match value {
357                        Some(v) => {
358                            buf.put_i32(v.len() as i32);
359                            buf.put_slice(v);
360                        }
361                        None => {
362                            buf.put_i32(-1); // NULL value
363                        }
364                    }
365                }
366            }
367
368            BackendMessage::CommandComplete { tag } => {
369                buf.put_u8(b'C'); // CommandComplete
370                let len = 4 + tag.len() + 1;
371                buf.put_i32(len as i32);
372                put_cstring(buf, tag);
373            }
374
375            BackendMessage::ErrorResponse { fields } => {
376                buf.put_u8(b'E'); // ErrorResponse
377                encode_notice_or_error(buf, fields);
378            }
379
380            BackendMessage::NoticeResponse { fields } => {
381                buf.put_u8(b'N'); // NoticeResponse
382                encode_notice_or_error(buf, fields);
383            }
384
385            BackendMessage::EmptyQueryResponse => {
386                buf.put_u8(b'I'); // EmptyQueryResponse
387                buf.put_i32(4);
388            }
389
390            BackendMessage::SubscriptionData { subscription_id, update_type, rows } => {
391                buf.put_u8(0xF2); // SubscriptionData
392
393                // Calculate total length
394                let mut len = 4 + 16 + 1 + 4; // length + subscription_id + update_type + row count
395                for row in rows {
396                    len += 2; // column count
397                    for value in row {
398                        len += 4; // value length
399                        if let Some(v) = value {
400                            len += v.len();
401                        }
402                    }
403                }
404
405                buf.put_i32(len as i32);
406                buf.put_slice(subscription_id);
407                buf.put_u8(*update_type as u8);
408                buf.put_i32(rows.len() as i32);
409
410                for row in rows {
411                    buf.put_i16(row.len() as i16);
412                    for value in row {
413                        match value {
414                            Some(v) => {
415                                buf.put_i32(v.len() as i32);
416                                buf.put_slice(v);
417                            }
418                            None => {
419                                buf.put_i32(-1); // NULL value
420                            }
421                        }
422                    }
423                }
424            }
425
426            BackendMessage::SubscriptionError { subscription_id, message } => {
427                buf.put_u8(0xF3); // SubscriptionError
428
429                let msg_bytes = message.as_bytes();
430                let len = 4 + 16 + msg_bytes.len() + 1; // length + subscription_id + message + null terminator
431
432                buf.put_i32(len as i32);
433                buf.put_slice(subscription_id);
434                put_cstring(buf, message);
435            }
436
437BackendMessage::SubscriptionAck { subscription_id, table_count } => {
438                buf.put_u8(0xF4); // SubscriptionAck
439
440                let len: i32 = 4 + 16 + 2; // length + subscription_id + table_count
441
442                buf.put_i32(len);
443                buf.put_slice(subscription_id);
444                buf.put_u16(*table_count);
445            }
446
447            BackendMessage::SubscriptionPartialData { subscription_id, rows } => {
448                buf.put_u8(0xF7); // SubscriptionPartialData
449
450                // Calculate total length
451                // 4 (length field) + 16 (subscription_id) + 1 (update_type) + 4 (row count)
452                let mut len = 4 + 16 + 1 + 4;
453                for row in rows {
454                    // 2 (total_columns) + bitmap bytes + values
455                    len += 2;
456                    len += row.column_mask.len();
457                    for value in &row.values {
458                        len += 4; // value length field
459                        if let Some(v) = value {
460                            len += v.len();
461                        }
462                    }
463                }
464
465                buf.put_i32(len as i32);
466                buf.put_slice(subscription_id);
467                buf.put_u8(SubscriptionUpdateType::SelectiveUpdate as u8);
468                buf.put_i32(rows.len() as i32);
469
470                for row in rows {
471                    buf.put_i16(row.total_columns as i16);
472                    buf.put_slice(&row.column_mask);
473                    for value in &row.values {
474                        match value {
475                            Some(v) => {
476                                buf.put_i32(v.len() as i32);
477                                buf.put_slice(v);
478                            }
479                            None => {
480                                buf.put_i32(-1); // NULL value
481                            }
482                        }
483                    }
484                }
485            }
486        }
487    }
488}
489
490impl FrontendMessage {
491    /// Decode a frontend message from bytes
492    pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, ProtocolError> {
493        // Check if we have enough bytes for the header (1 byte type + 4 bytes length)
494        if buf.len() < 5 {
495            return Ok(None);
496        }
497
498        // Peek at message type
499        let msg_type = buf[0];
500
501        // Get message length (excluding type byte, including length field itself)
502        let len_i32 = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
503
504        // Validate length - must be at least 4 (includes the length field itself)
505        // and must be positive to avoid overflow when casting to usize
506        if len_i32 < 4 {
507            return Err(ProtocolError::InvalidMessageLength(len_i32));
508        }
509
510        let len = len_i32 as usize;
511
512        // Check if we have the full message (use saturating_add to avoid overflow)
513        let total_len = 1usize.saturating_add(len);
514        if buf.len() < total_len {
515            return Ok(None);
516        }
517
518        // Consume the message type
519        buf.advance(1);
520
521        // Decode based on message type
522        match msg_type {
523            b'Q' => {
524                // Query message
525                buf.advance(4); // length
526                let query = read_cstring(buf)?;
527                Ok(Some(FrontendMessage::Query { query }))
528            }
529
530            b'p' => {
531                // Password message
532                buf.advance(4); // length
533                let password = read_cstring(buf)?;
534                Ok(Some(FrontendMessage::Password { password }))
535            }
536
537            b'X' => {
538                // Terminate message
539                buf.advance(4); // length
540                Ok(Some(FrontendMessage::Terminate))
541            }
542
543            0xF0 => {
544                // Subscribe message
545                buf.advance(4); // length
546                let query = read_cstring(buf)?;
547                let param_count = buf.get_i16() as usize;
548                let mut params = Vec::with_capacity(param_count);
549
550                for _ in 0..param_count {
551                    let param_len = buf.get_i32();
552                    if param_len < 0 {
553                        params.push(None);
554                    } else {
555                        let mut param = vec![0u8; param_len as usize];
556                        buf.copy_to_slice(&mut param);
557                        params.push(Some(param));
558                    }
559                }
560
561                // Read optional filter expression (protocol extension)
562                // If there's data remaining, read the filter length
563                let filter = if buf.remaining() >= 2 {
564                    let filter_len = buf.get_i16();
565                    if filter_len > 0 {
566                        let filter_len = filter_len as usize;
567                        if buf.remaining() >= filter_len {
568                            let mut filter_bytes = vec![0u8; filter_len];
569                            buf.copy_to_slice(&mut filter_bytes);
570                            Some(
571                                String::from_utf8(filter_bytes)
572                                    .map_err(|_| ProtocolError::InvalidString)?,
573                            )
574                        } else {
575                            None // Not enough data for filter
576                        }
577                    } else {
578                        None // No filter (length = 0 or negative)
579                    }
580                } else {
581                    None // No filter field present (backward compatibility)
582                };
583
584                // Read optional selective updates configuration (protocol extension)
585                // Format: 1 byte flags + optional values
586                // Bit 0: enabled flag present
587                // Bit 1: min_changed_columns present
588                // Bit 2: max_changed_columns_ratio present
589                let selective_updates_config = if buf.remaining() >= 1 {
590                    let config_flags = buf.get_u8();
591                    if config_flags != 0 {
592                        let mut config = SelectiveUpdatesConfig {
593                            enabled: None,
594                            min_changed_columns: None,
595                            max_changed_columns_ratio: None,
596                        };
597
598                        // Read enabled flag if present
599                        if (config_flags & 0x01) != 0 && buf.remaining() >= 1 {
600                            config.enabled = Some(buf.get_u8() != 0);
601                        }
602
603                        // Read min_changed_columns if present
604                        if (config_flags & 0x02) != 0 && buf.remaining() >= 2 {
605                            config.min_changed_columns = Some(buf.get_u16() as usize);
606                        }
607
608                        // Read max_changed_columns_ratio if present
609                        if (config_flags & 0x04) != 0 && buf.remaining() >= 8 {
610                            config.max_changed_columns_ratio = Some(buf.get_f64());
611                        }
612
613                        Some(config)
614                    } else {
615                        None // config_flags = 0 means no config
616                    }
617                } else {
618                    None // No config field present (backward compatibility)
619                };
620
621                Ok(Some(FrontendMessage::Subscribe { query, params, filter, selective_updates_config }))
622            }
623
624            0xF1 => {
625                // Unsubscribe message
626                buf.advance(4); // length
627                let mut subscription_id = [0u8; 16];
628                buf.copy_to_slice(&mut subscription_id);
629                Ok(Some(FrontendMessage::Unsubscribe { subscription_id }))
630            }
631
632            0xF5 => {
633                // SubscriptionPause message
634                buf.advance(4); // length
635                let mut subscription_id = [0u8; 16];
636                buf.copy_to_slice(&mut subscription_id);
637                Ok(Some(FrontendMessage::SubscriptionPause { subscription_id }))
638            }
639
640            0xF6 => {
641                // SubscriptionResume message
642                buf.advance(4); // length
643                let mut subscription_id = [0u8; 16];
644                buf.copy_to_slice(&mut subscription_id);
645                Ok(Some(FrontendMessage::SubscriptionResume { subscription_id }))
646            }
647
648            _ => Err(ProtocolError::InvalidMessageType(msg_type)),
649        }
650    }
651
652    /// Decode startup message (special case - no message type byte)
653    pub fn decode_startup(buf: &mut BytesMut) -> Result<Option<Self>, ProtocolError> {
654        if buf.len() < 4 {
655            return Ok(None);
656        }
657
658        let len_i32 = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
659
660        // Validate length - startup message must be at least 8 bytes
661        // (4 bytes length + 4 bytes protocol version)
662        if len_i32 < 8 {
663            return Err(ProtocolError::InvalidMessageLength(len_i32));
664        }
665
666        let len = len_i32 as usize;
667
668        if buf.len() < len {
669            return Ok(None);
670        }
671
672        buf.advance(4); // length
673
674        let protocol_version = buf.get_i32();
675
676        // Special case: SSL request (exactly 8 bytes total)
677        if protocol_version == 80877103 {
678            return Ok(Some(FrontendMessage::SSLRequest));
679        }
680
681        // Read parameters - limit iterations to prevent infinite loops
682        let mut params = HashMap::new();
683        let max_params = 100; // Reasonable limit for startup parameters
684        for _ in 0..max_params {
685            // Check if we have data remaining for another string
686            if buf.is_empty() {
687                break;
688            }
689            let key = read_cstring(buf)?;
690            if key.is_empty() {
691                break;
692            }
693            let value = read_cstring(buf)?;
694            params.insert(key, value);
695        }
696
697        Ok(Some(FrontendMessage::Startup { protocol_version, params }))
698    }
699}
700
701/// Write a null-terminated C string
702fn put_cstring(buf: &mut BytesMut, s: &str) {
703    buf.put_slice(s.as_bytes());
704    buf.put_u8(0);
705}
706
707/// Read a null-terminated C string
708fn read_cstring(buf: &mut BytesMut) -> Result<String, ProtocolError> {
709    let null_pos = buf.iter().position(|&b| b == 0).ok_or(ProtocolError::InvalidString)?;
710
711    let bytes = buf.split_to(null_pos);
712    buf.advance(1); // skip null byte
713
714    String::from_utf8(bytes.to_vec()).map_err(|_| ProtocolError::InvalidString)
715}
716
717/// Encode error or notice response fields
718fn encode_notice_or_error(buf: &mut BytesMut, fields: &HashMap<u8, String>) {
719    // Calculate length
720    let mut len = 4 + 1; // length field + terminator
721    for value in fields.values() {
722        len += 1 + value.len() + 1; // field type + value + null
723    }
724
725    buf.put_i32(len as i32);
726
727    // Write fields
728    for (&field_type, value) in fields {
729        buf.put_u8(field_type);
730        put_cstring(buf, value);
731    }
732
733    // Terminator
734    buf.put_u8(0);
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740
741    #[test]
742    fn test_authentication_ok_encoding() {
743        let mut buf = BytesMut::new();
744        BackendMessage::AuthenticationOk.encode(&mut buf);
745
746        assert_eq!(buf[0], b'R');
747        assert_eq!(&buf[1..5], &[0, 0, 0, 8]);
748        assert_eq!(&buf[5..9], &[0, 0, 0, 0]);
749    }
750
751    #[test]
752    fn test_ready_for_query_encoding() {
753        let mut buf = BytesMut::new();
754        BackendMessage::ReadyForQuery { status: TransactionStatus::Idle }.encode(&mut buf);
755
756        assert_eq!(buf[0], b'Z');
757        assert_eq!(&buf[1..5], &[0, 0, 0, 5]);
758        assert_eq!(buf[5], b'I');
759    }
760
761    #[test]
762    fn test_query_decoding() {
763        let mut buf = BytesMut::new();
764        buf.put_u8(b'Q'); // Query message type
765        buf.put_i32(13); // Length (4 bytes length field + 9 bytes "SELECT 1\0")
766        buf.put_slice(b"SELECT 1\0");
767
768        let msg = FrontendMessage::decode(&mut buf).unwrap();
769        assert!(matches!(
770            msg,
771            Some(FrontendMessage::Query { query }) if query == "SELECT 1"
772        ));
773    }
774
775    #[test]
776    fn test_subscribe_message_parsing() {
777        let mut buf = BytesMut::new();
778        buf.put_u8(0xF0); // Subscribe
779        let mut content = BytesMut::new();
780        content.put_slice(b"SELECT * FROM users\0");
781        content.put_i16(0); // No params
782
783        buf.put_i32((4 + content.len()) as i32);
784        buf.extend(content);
785
786        let msg = FrontendMessage::decode(&mut buf).unwrap();
787        assert!(matches!(
788            msg,
789            Some(FrontendMessage::Subscribe { query, params, filter, .. })
790            if query == "SELECT * FROM users" && params.is_empty() && filter.is_none()
791        ));
792    }
793
794    #[test]
795    fn test_subscribe_with_parameters() {
796        let mut buf = BytesMut::new();
797        buf.put_u8(0xF0); // Subscribe
798        let mut content = BytesMut::new();
799        content.put_slice(b"SELECT * FROM users WHERE id = $1\0");
800        content.put_i16(1); // 1 param
801        content.put_i32(5); // param length
802        content.put_slice(b"12345");
803
804        buf.put_i32((4 + content.len()) as i32);
805        buf.extend(content);
806
807        let msg = FrontendMessage::decode(&mut buf).unwrap();
808        assert!(matches!(
809            msg,
810            Some(FrontendMessage::Subscribe { query, params, filter, .. })
811            if query == "SELECT * FROM users WHERE id = $1" && params.len() == 1 && filter.is_none()
812        ));
813    }
814
815    #[test]
816    fn test_subscribe_with_filter() {
817        let mut buf = BytesMut::new();
818        buf.put_u8(0xF0); // Subscribe
819        let mut content = BytesMut::new();
820        content.put_slice(b"SELECT * FROM users\0");
821        content.put_i16(0); // No params
822        let filter_str = "status = 'active'";
823        content.put_i16(filter_str.len() as i16); // Filter length
824        content.put_slice(filter_str.as_bytes()); // Filter expression
825
826        buf.put_i32((4 + content.len()) as i32);
827        buf.extend(content);
828
829        let msg = FrontendMessage::decode(&mut buf).unwrap();
830        match msg {
831            Some(FrontendMessage::Subscribe { query, params, filter, .. }) => {
832                assert_eq!(query, "SELECT * FROM users");
833                assert!(params.is_empty());
834                assert_eq!(filter, Some("status = 'active'".to_string()));
835            }
836            _ => panic!("Expected Subscribe message"),
837        }
838    }
839
840    #[test]
841    fn test_subscribe_with_empty_filter() {
842        let mut buf = BytesMut::new();
843        buf.put_u8(0xF0); // Subscribe
844        let mut content = BytesMut::new();
845        content.put_slice(b"SELECT * FROM users\0");
846        content.put_i16(0); // No params
847        content.put_i16(0); // Filter length = 0 (no filter)
848
849        buf.put_i32((4 + content.len()) as i32);
850        buf.extend(content);
851
852        let msg = FrontendMessage::decode(&mut buf).unwrap();
853        assert!(matches!(
854            msg,
855            Some(FrontendMessage::Subscribe { query, params, filter, .. })
856            if query == "SELECT * FROM users" && params.is_empty() && filter.is_none()
857        ));
858    }
859
860    #[test]
861    fn test_unsubscribe_message_parsing() {
862        let mut buf = BytesMut::new();
863        buf.put_u8(0xF1); // Unsubscribe
864        buf.put_i32(20); // Length: 4 (length) + 16 (UUID)
865        buf.put_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
866
867        let msg = FrontendMessage::decode(&mut buf).unwrap();
868        assert!(matches!(
869            msg,
870            Some(FrontendMessage::Unsubscribe { subscription_id })
871            if subscription_id == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
872        ));
873    }
874
875    #[test]
876    fn test_subscription_data_encoding() {
877        let mut buf = BytesMut::new();
878        let subscription_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
879        let rows = vec![vec![Some(b"value1".to_vec()), Some(b"value2".to_vec())]];
880
881        let msg = BackendMessage::SubscriptionData {
882            subscription_id,
883            update_type: SubscriptionUpdateType::Full,
884            rows,
885        };
886        msg.encode(&mut buf);
887
888        assert_eq!(buf[0], 0xF2);
889        // Verify subscription_id is at bytes 5-20
890        assert_eq!(&buf[5..21], subscription_id.as_ref());
891    }
892
893    #[test]
894    fn test_subscription_error_encoding() {
895        let mut buf = BytesMut::new();
896        let subscription_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
897
898        let msg = BackendMessage::SubscriptionError {
899            subscription_id,
900            message: "Query error".to_string(),
901        };
902        msg.encode(&mut buf);
903
904        assert_eq!(buf[0], 0xF3);
905        // Verify subscription_id is at bytes 5-20
906        assert_eq!(&buf[5..21], subscription_id.as_ref());
907    }
908
909    #[test]
910    fn test_subscription_ack_encoding() {
911        let mut buf = BytesMut::new();
912        let subscription_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
913
914        let msg = BackendMessage::SubscriptionAck { subscription_id, table_count: 3 };
915        msg.encode(&mut buf);
916
917        assert_eq!(buf[0], 0xF4);
918        // Verify length field (4 + 16 + 2 = 22)
919        assert_eq!(&buf[1..5], &[0, 0, 0, 22]);
920        // Verify subscription_id is at bytes 5-20
921        assert_eq!(&buf[5..21], subscription_id.as_ref());
922        // Verify table_count (big-endian u16)
923        assert_eq!(&buf[21..23], &[0, 3]);
924    }
925
926    #[test]
927    fn test_subscription_pause_parsing() {
928        let mut buf = BytesMut::new();
929        buf.put_u8(0xF5); // SubscriptionPause
930        buf.put_i32(20); // Length: 4 (length) + 16 (UUID)
931        buf.put_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
932
933        let msg = FrontendMessage::decode(&mut buf).unwrap();
934        assert!(matches!(
935            msg,
936            Some(FrontendMessage::SubscriptionPause { subscription_id })
937            if subscription_id == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
938        ));
939    }
940
941    #[test]
942    fn test_subscription_resume_parsing() {
943        let mut buf = BytesMut::new();
944        buf.put_u8(0xF6); // SubscriptionResume
945        buf.put_i32(20); // Length: 4 (length) + 16 (UUID)
946        buf.put_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
947
948        let msg = FrontendMessage::decode(&mut buf).unwrap();
949        assert!(matches!(
950            msg,
951            Some(FrontendMessage::SubscriptionResume { subscription_id })
952            if subscription_id == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
953        ));
954    }
955
956    #[test]
957    fn test_subscription_partial_data_encoding() {
958        let mut buf = BytesMut::new();
959        let subscription_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
960
961        // Create a partial row update with 4 columns, columns 0 and 2 present
962        let partial_row = PartialRowUpdate::new(
963            4,
964            &[0, 2],
965            vec![Some(b"id1".to_vec()), Some(b"value".to_vec())],
966        );
967
968        let msg = BackendMessage::SubscriptionPartialData {
969            subscription_id,
970            rows: vec![partial_row],
971        };
972        msg.encode(&mut buf);
973
974        // Verify message type (0xF7)
975        assert_eq!(buf[0], 0xF7);
976
977        // Verify subscription_id is at bytes 5-20
978        assert_eq!(&buf[5..21], subscription_id.as_ref());
979
980        // Verify update type is SelectiveUpdate (4)
981        assert_eq!(buf[21], 4);
982
983        // Verify row count is 1
984        let row_count = i32::from_be_bytes([buf[22], buf[23], buf[24], buf[25]]);
985        assert_eq!(row_count, 1);
986
987        // Verify total columns is 4
988        let total_cols = i16::from_be_bytes([buf[26], buf[27]]);
989        assert_eq!(total_cols, 4);
990
991        // Verify column bitmap (1 byte for 4 columns)
992        // Columns 0 and 2: binary 0101 = 5
993        assert_eq!(buf[28], 0b00000101);
994    }
995
996    #[test]
997    fn test_subscription_partial_data_encoding_with_null() {
998        let mut buf = BytesMut::new();
999        let subscription_id = [0u8; 16];
1000
1001        // Create a partial row update with NULL value
1002        let partial_row = PartialRowUpdate::new(
1003            3,
1004            &[0, 1],
1005            vec![Some(b"1".to_vec()), None], // Column 1 is NULL
1006        );
1007
1008        let msg = BackendMessage::SubscriptionPartialData {
1009            subscription_id,
1010            rows: vec![partial_row],
1011        };
1012        msg.encode(&mut buf);
1013
1014        assert_eq!(buf[0], 0xF7);
1015
1016        // After subscription_id (16 bytes), update_type (1 byte), row_count (4 bytes)
1017        // total_columns (2 bytes), column_mask (1 byte for 3 columns)
1018        // First value: length (4) + data (1)
1019        // Second value: length (-1) for NULL
1020
1021        // Find the position of the NULL value length (-1)
1022        // Position: 1 (type) + 4 (len) + 16 (id) + 1 (update_type) + 4 (row_count)
1023        //         + 2 (total_cols) + 1 (bitmap) + 4 (val1_len) + 1 (val1_data) = 34
1024        let null_pos = 34;
1025        let null_len = i32::from_be_bytes([buf[null_pos], buf[null_pos + 1], buf[null_pos + 2], buf[null_pos + 3]]);
1026        assert_eq!(null_len, -1);
1027    }
1028
1029    #[test]
1030    fn test_partial_row_update_new() {
1031        // Test with 16 columns to verify multi-byte bitmap
1032        let partial = PartialRowUpdate::new(
1033            16,
1034            &[0, 8, 15],
1035            vec![Some(b"a".to_vec()), Some(b"b".to_vec()), Some(b"c".to_vec())],
1036        );
1037
1038        assert_eq!(partial.total_columns, 16);
1039        assert_eq!(partial.column_mask.len(), 2); // ceil(16/8) = 2 bytes
1040
1041        // Byte 0: bit 0 set (column 0) = 0x01
1042        // Byte 1: bit 0 set (column 8), bit 7 set (column 15) = 0x81
1043        assert_eq!(partial.column_mask[0], 0b00000001);
1044        assert_eq!(partial.column_mask[1], 0b10000001);
1045
1046        assert!(partial.is_column_present(0));
1047        assert!(!partial.is_column_present(1));
1048        assert!(partial.is_column_present(8));
1049        assert!(partial.is_column_present(15));
1050        assert!(!partial.is_column_present(16)); // Out of range
1051    }
1052
1053    // =====================================================================
1054    // Malformed Message Handling Tests
1055    // Tests for security-relevant handling of invalid wire protocol messages
1056    // =====================================================================
1057
1058    mod malformed_message_tests {
1059        use super::*;
1060
1061        // -----------------------------------------------------------------
1062        // Truncated Message Tests
1063        // -----------------------------------------------------------------
1064
1065        #[test]
1066        fn test_truncated_message_empty_buffer() {
1067            let mut buf = BytesMut::new();
1068            // Empty buffer should return None (need more data)
1069            let result = FrontendMessage::decode(&mut buf);
1070            assert!(result.is_ok());
1071            assert!(result.unwrap().is_none());
1072        }
1073
1074        #[test]
1075        fn test_truncated_message_only_type_byte() {
1076            let mut buf = BytesMut::new();
1077            buf.put_u8(b'Q'); // Only message type, no length
1078            let result = FrontendMessage::decode(&mut buf);
1079            assert!(result.is_ok());
1080            assert!(result.unwrap().is_none());
1081        }
1082
1083        #[test]
1084        fn test_truncated_message_partial_length() {
1085            let mut buf = BytesMut::new();
1086            buf.put_u8(b'Q');
1087            buf.put_u8(0); // Only 1 byte of length (need 4)
1088            buf.put_u8(0);
1089            let result = FrontendMessage::decode(&mut buf);
1090            assert!(result.is_ok());
1091            assert!(result.unwrap().is_none());
1092        }
1093
1094        #[test]
1095        fn test_truncated_message_incomplete_body() {
1096            let mut buf = BytesMut::new();
1097            buf.put_u8(b'Q');
1098            buf.put_i32(100); // Claims 100 bytes
1099            buf.put_slice(b"SELECT"); // Only 6 bytes
1100            let result = FrontendMessage::decode(&mut buf);
1101            assert!(result.is_ok());
1102            assert!(result.unwrap().is_none());
1103        }
1104
1105        #[test]
1106        fn test_truncated_startup_empty_buffer() {
1107            let mut buf = BytesMut::new();
1108            let result = FrontendMessage::decode_startup(&mut buf);
1109            assert!(result.is_ok());
1110            assert!(result.unwrap().is_none());
1111        }
1112
1113        #[test]
1114        fn test_truncated_startup_partial_length() {
1115            let mut buf = BytesMut::new();
1116            buf.put_u8(0);
1117            buf.put_u8(0); // Only 2 bytes of length
1118            let result = FrontendMessage::decode_startup(&mut buf);
1119            assert!(result.is_ok());
1120            assert!(result.unwrap().is_none());
1121        }
1122
1123        #[test]
1124        fn test_truncated_startup_incomplete_body() {
1125            let mut buf = BytesMut::new();
1126            buf.put_i32(50); // Claims 50 bytes total
1127            buf.put_i32(196608); // Protocol version 3.0
1128            buf.put_slice(b"user\0"); // Only partial params
1129            let result = FrontendMessage::decode_startup(&mut buf);
1130            assert!(result.is_ok());
1131            assert!(result.unwrap().is_none());
1132        }
1133
1134        // -----------------------------------------------------------------
1135        // Invalid Message Type Tests
1136        // -----------------------------------------------------------------
1137
1138        #[test]
1139        fn test_invalid_message_type_byte() {
1140            let mut buf = BytesMut::new();
1141            buf.put_u8(0xFF); // Invalid message type
1142            buf.put_i32(4); // Minimal length
1143            let result = FrontendMessage::decode(&mut buf);
1144            assert!(matches!(result, Err(ProtocolError::InvalidMessageType(0xFF))));
1145        }
1146
1147        #[test]
1148        fn test_invalid_message_type_zero() {
1149            let mut buf = BytesMut::new();
1150            buf.put_u8(0x00); // Null byte as message type
1151            buf.put_i32(4);
1152            let result = FrontendMessage::decode(&mut buf);
1153            assert!(matches!(result, Err(ProtocolError::InvalidMessageType(0x00))));
1154        }
1155
1156        #[test]
1157        fn test_invalid_message_type_lowercase_q() {
1158            // 'q' is not a valid message type (Query is uppercase 'Q')
1159            let mut buf = BytesMut::new();
1160            buf.put_u8(b'q');
1161            buf.put_i32(13);
1162            buf.put_slice(b"SELECT 1\0");
1163            let result = FrontendMessage::decode(&mut buf);
1164            assert!(matches!(result, Err(ProtocolError::InvalidMessageType(b'q'))));
1165        }
1166
1167        #[test]
1168        fn test_invalid_message_type_numeric() {
1169            let mut buf = BytesMut::new();
1170            buf.put_u8(b'1'); // Numeric character
1171            buf.put_i32(4);
1172            let result = FrontendMessage::decode(&mut buf);
1173            assert!(matches!(result, Err(ProtocolError::InvalidMessageType(b'1'))));
1174        }
1175
1176        // -----------------------------------------------------------------
1177        // Length Field Mismatch Tests
1178        // -----------------------------------------------------------------
1179
1180        #[test]
1181        fn test_length_zero() {
1182            let mut buf = BytesMut::new();
1183            buf.put_u8(b'X'); // Terminate
1184            buf.put_i32(0); // Invalid zero length (should be at least 4)
1185            let result = FrontendMessage::decode(&mut buf);
1186            // Length 0 is invalid - minimum length is 4 (includes the length field itself)
1187            assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(0))));
1188        }
1189
1190        #[test]
1191        fn test_length_negative() {
1192            let mut buf = BytesMut::new();
1193            buf.put_u8(b'X');
1194            buf.put_i32(-1); // Negative length
1195            let result = FrontendMessage::decode(&mut buf);
1196            // Negative lengths are invalid - returns error instead of panic
1197            assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(-1))));
1198        }
1199
1200        #[test]
1201        fn test_length_too_small() {
1202            let mut buf = BytesMut::new();
1203            buf.put_u8(b'X');
1204            buf.put_i32(3); // Less than minimum valid length of 4
1205            let result = FrontendMessage::decode(&mut buf);
1206            assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(3))));
1207        }
1208
1209        #[test]
1210        fn test_startup_length_too_small() {
1211            let mut buf = BytesMut::new();
1212            buf.put_i32(4); // Only length field, no protocol version
1213            let result = FrontendMessage::decode_startup(&mut buf);
1214            // Startup message must be at least 8 bytes (length + protocol version)
1215            assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(4))));
1216        }
1217
1218        #[test]
1219        fn test_startup_length_negative() {
1220            let mut buf = BytesMut::new();
1221            buf.put_i32(-1); // Negative length
1222            let result = FrontendMessage::decode_startup(&mut buf);
1223            assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(-1))));
1224        }
1225
1226        // -----------------------------------------------------------------
1227        // Invalid UTF-8 Tests
1228        // -----------------------------------------------------------------
1229
1230        #[test]
1231        fn test_invalid_utf8_in_query() {
1232            let mut buf = BytesMut::new();
1233            buf.put_u8(b'Q');
1234            buf.put_i32(8); // 4 + 4 bytes of invalid data
1235            buf.put_slice(&[0xFF, 0xFE, 0x80]); // Invalid UTF-8 sequence
1236            buf.put_u8(0); // Null terminator
1237            let result = FrontendMessage::decode(&mut buf);
1238            assert!(matches!(result, Err(ProtocolError::InvalidString)));
1239        }
1240
1241        #[test]
1242        fn test_invalid_utf8_continuation_byte() {
1243            let mut buf = BytesMut::new();
1244            buf.put_u8(b'Q');
1245            buf.put_i32(6); // 4 + 2 bytes
1246            buf.put_u8(0x80); // Continuation byte without start byte
1247            buf.put_u8(0); // Null terminator
1248            let result = FrontendMessage::decode(&mut buf);
1249            assert!(matches!(result, Err(ProtocolError::InvalidString)));
1250        }
1251
1252        #[test]
1253        fn test_invalid_utf8_overlong_encoding() {
1254            let mut buf = BytesMut::new();
1255            buf.put_u8(b'Q');
1256            buf.put_i32(7);
1257            buf.put_slice(&[0xC0, 0x80]); // Overlong encoding of NUL
1258            buf.put_u8(0); // Null terminator
1259            let result = FrontendMessage::decode(&mut buf);
1260            assert!(matches!(result, Err(ProtocolError::InvalidString)));
1261        }
1262
1263        #[test]
1264        fn test_invalid_utf8_in_password() {
1265            let mut buf = BytesMut::new();
1266            buf.put_u8(b'p'); // Password message
1267            buf.put_i32(8);
1268            buf.put_slice(&[0xFE, 0xFF, 0x00]); // Invalid UTF-8 with embedded null
1269            buf.put_u8(0);
1270            let result = FrontendMessage::decode(&mut buf);
1271            // The embedded null will cause issues - string will be empty
1272            assert!(result.is_ok() || matches!(result, Err(ProtocolError::InvalidString)));
1273        }
1274
1275        #[test]
1276        fn test_invalid_utf8_in_startup_user() {
1277            let mut buf = BytesMut::new();
1278            // Build a proper startup message with invalid UTF-8 in the username value
1279            // Length: 4 (len) + 4 (version) + 5 (user\0) + 3 (invalid UTF-8 + \0) + 1 (final \0) = 17
1280            buf.put_i32(17);
1281            buf.put_i32(196608); // Protocol version 3.0
1282            buf.put_slice(b"user\0");
1283            buf.put_slice(&[0xFF, 0xFE]); // Invalid UTF-8 for username value
1284            buf.put_u8(0); // Null terminator for value
1285            buf.put_u8(0); // Final empty key to end params
1286            let result = FrontendMessage::decode_startup(&mut buf);
1287            // The invalid UTF-8 should cause an error when parsing the value
1288            assert!(matches!(result, Err(ProtocolError::InvalidString)));
1289        }
1290
1291        // -----------------------------------------------------------------
1292        // Missing Null Terminator Tests
1293        // -----------------------------------------------------------------
1294
1295        #[test]
1296        fn test_query_missing_null_terminator() {
1297            let mut buf = BytesMut::new();
1298            buf.put_u8(b'Q');
1299            buf.put_i32(12); // Length
1300            buf.put_slice(b"SELECT 1"); // No null terminator
1301            let result = FrontendMessage::decode(&mut buf);
1302            assert!(matches!(result, Err(ProtocolError::InvalidString)));
1303        }
1304
1305        #[test]
1306        fn test_startup_missing_final_null() {
1307            let mut buf = BytesMut::new();
1308            // Length: 4 (len) + 4 (version) + 5 (user\0) + 5 (test\0) = 18
1309            // Note: normally there should be a final empty key (\0) to terminate params
1310            buf.put_i32(18);
1311            buf.put_i32(196608); // Protocol version 3.0
1312            buf.put_slice(b"user\0test\0"); // No final empty string terminator
1313            let result = FrontendMessage::decode_startup(&mut buf);
1314            // With our fix, this now succeeds because the empty buffer check breaks the loop
1315            // The message is parsed but may be incomplete - this is acceptable behavior
1316            assert!(result.is_ok());
1317            let msg = result.unwrap();
1318            assert!(matches!(msg, Some(FrontendMessage::Startup { .. })));
1319        }
1320
1321        // -----------------------------------------------------------------
1322        // Zero-Length Message Tests
1323        // -----------------------------------------------------------------
1324
1325        #[test]
1326        fn test_terminate_minimal() {
1327            // Terminate message is valid with just type + length
1328            let mut buf = BytesMut::new();
1329            buf.put_u8(b'X');
1330            buf.put_i32(4); // Minimum valid length
1331            let result = FrontendMessage::decode(&mut buf);
1332            assert!(result.is_ok());
1333            assert!(matches!(result.unwrap(), Some(FrontendMessage::Terminate)));
1334        }
1335
1336        #[test]
1337        fn test_query_empty_string() {
1338            let mut buf = BytesMut::new();
1339            buf.put_u8(b'Q');
1340            buf.put_i32(5); // 4 + 1 for just null terminator
1341            buf.put_u8(0); // Empty query
1342            let result = FrontendMessage::decode(&mut buf);
1343            assert!(result.is_ok());
1344            assert!(matches!(
1345                result.unwrap(),
1346                Some(FrontendMessage::Query { query }) if query.is_empty()
1347            ));
1348        }
1349
1350        // -----------------------------------------------------------------
1351        // SSL Request Tests
1352        // -----------------------------------------------------------------
1353
1354        #[test]
1355        fn test_ssl_request_detection() {
1356            let mut buf = BytesMut::new();
1357            buf.put_i32(8); // Length
1358            buf.put_i32(80877103); // SSL request code
1359            let result = FrontendMessage::decode_startup(&mut buf);
1360            assert!(result.is_ok());
1361            assert!(matches!(result.unwrap(), Some(FrontendMessage::SSLRequest)));
1362        }
1363
1364        // -----------------------------------------------------------------
1365        // Valid Protocol Version Tests
1366        // -----------------------------------------------------------------
1367
1368        #[test]
1369        fn test_startup_protocol_version_3_0() {
1370            let mut buf = BytesMut::new();
1371            buf.put_i32(17); // Total length
1372            buf.put_i32(196608); // Protocol version 3.0 (0x00030000)
1373            buf.put_slice(b"user\0pg\0"); // user=pg
1374            buf.put_u8(0); // Empty key terminates params
1375            let result = FrontendMessage::decode_startup(&mut buf);
1376            assert!(result.is_ok());
1377            let msg = result.unwrap();
1378            assert!(matches!(
1379                msg,
1380                Some(FrontendMessage::Startup { protocol_version, params })
1381                    if protocol_version == 196608 && params.get("user") == Some(&"pg".to_string())
1382            ));
1383        }
1384
1385        // -----------------------------------------------------------------
1386        // Buffer Consumption Tests
1387        // -----------------------------------------------------------------
1388
1389        #[test]
1390        fn test_buffer_properly_consumed_after_query() {
1391            let mut buf = BytesMut::new();
1392            // First message
1393            buf.put_u8(b'Q');
1394            buf.put_i32(10);
1395            buf.put_slice(b"test1\0");
1396            // Second message should remain
1397            buf.put_u8(b'Q');
1398            buf.put_i32(10);
1399            buf.put_slice(b"test2\0");
1400
1401            let result1 = FrontendMessage::decode(&mut buf);
1402            assert!(matches!(
1403                result1.unwrap(),
1404                Some(FrontendMessage::Query { query }) if query == "test1"
1405            ));
1406
1407            let result2 = FrontendMessage::decode(&mut buf);
1408            assert!(matches!(
1409                result2.unwrap(),
1410                Some(FrontendMessage::Query { query }) if query == "test2"
1411            ));
1412        }
1413
1414        #[test]
1415        fn test_buffer_not_consumed_on_incomplete() {
1416            let mut buf = BytesMut::new();
1417            buf.put_u8(b'Q');
1418            buf.put_i32(100); // Claims 100 bytes but we don't have that many
1419
1420            let original_len = buf.len();
1421            let result = FrontendMessage::decode(&mut buf);
1422            assert!(result.is_ok());
1423            assert!(result.unwrap().is_none());
1424            assert_eq!(buf.len(), original_len); // Buffer unchanged
1425        }
1426
1427        // -----------------------------------------------------------------
1428        // Edge Cases for Large Messages
1429        // -----------------------------------------------------------------
1430
1431        #[test]
1432        fn test_very_large_declared_length() {
1433            let mut buf = BytesMut::new();
1434            buf.put_u8(b'Q');
1435            buf.put_i32(i32::MAX); // Extremely large length
1436            buf.put_slice(b"small\0");
1437            let result = FrontendMessage::decode(&mut buf);
1438            // Should return None since we don't have enough data
1439            assert!(result.is_ok());
1440            assert!(result.unwrap().is_none());
1441        }
1442
1443        // -----------------------------------------------------------------
1444        // Password Message Tests
1445        // -----------------------------------------------------------------
1446
1447        #[test]
1448        fn test_password_message_valid() {
1449            let mut buf = BytesMut::new();
1450            buf.put_u8(b'p');
1451            buf.put_i32(13); // 4 + 9 bytes
1452            buf.put_slice(b"secret\0");
1453            // Add padding to meet the declared length
1454            buf.put_slice(&[0, 0]);
1455            let result = FrontendMessage::decode(&mut buf);
1456            assert!(result.is_ok());
1457            assert!(matches!(
1458                result.unwrap(),
1459                Some(FrontendMessage::Password { password }) if password == "secret"
1460            ));
1461        }
1462
1463        #[test]
1464        fn test_password_message_empty() {
1465            let mut buf = BytesMut::new();
1466            buf.put_u8(b'p');
1467            buf.put_i32(5); // 4 + 1 for null terminator
1468            buf.put_u8(0);
1469            let result = FrontendMessage::decode(&mut buf);
1470            assert!(result.is_ok());
1471            assert!(matches!(
1472                result.unwrap(),
1473                Some(FrontendMessage::Password { password }) if password.is_empty()
1474            ));
1475        }
1476
1477        // -----------------------------------------------------------------
1478        // SelectiveUpdatesConfig Tests
1479        // -----------------------------------------------------------------
1480
1481        #[test]
1482        fn test_subscribe_with_selective_updates_config_full() {
1483            // Test parsing Subscribe with all config fields set
1484            let mut buf = BytesMut::new();
1485            buf.put_u8(0xF0); // Subscribe message type
1486            
1487            // Build the message body first to calculate length
1488            let mut body = BytesMut::new();
1489            
1490            // Query
1491            body.put_slice(b"SELECT * FROM test\0");
1492            
1493            // Parameters (no params)
1494            body.put_i16(0);
1495            
1496            // Filter (none)
1497            body.put_i16(0);
1498            
1499            // Selective updates config
1500            body.put_u8(0x07); // All three flags set (0b111)
1501            body.put_u8(1); // enabled = true
1502            body.put_u16(5); // min_changed_columns = 5
1503            body.put_f64(0.75); // max_changed_columns_ratio = 0.75
1504            
1505            // Write length (4 bytes for length field itself + body)
1506            buf.put_i32((4 + body.len()) as i32);
1507            buf.put_slice(&body);
1508            
1509            let result = FrontendMessage::decode(&mut buf);
1510            assert!(result.is_ok());
1511            
1512            let msg = result.unwrap();
1513            assert!(matches!(msg, Some(FrontendMessage::Subscribe { .. })));
1514            
1515            if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = msg {
1516                assert!(selective_updates_config.is_some());
1517                let config = selective_updates_config.unwrap();
1518                assert_eq!(config.enabled, Some(true));
1519                assert_eq!(config.min_changed_columns, Some(5));
1520                assert_eq!(config.max_changed_columns_ratio, Some(0.75));
1521            } else {
1522                panic!("Expected Subscribe message");
1523            }
1524        }
1525
1526        #[test]
1527        fn test_subscribe_with_partial_selective_config_enabled_only() {
1528            // Test parsing Subscribe with only enabled flag set
1529            let mut buf = BytesMut::new();
1530            buf.put_u8(0xF0); // Subscribe message type
1531            
1532            let mut body = BytesMut::new();
1533            body.put_slice(b"SELECT * FROM test\0");
1534            body.put_i16(0); // no params
1535            body.put_i16(0); // no filter
1536            
1537            body.put_u8(0x01); // Only enabled flag set (0b001)
1538            body.put_u8(1); // enabled = true
1539            
1540            buf.put_i32((4 + body.len()) as i32);
1541            buf.put_slice(&body);
1542            
1543            let result = FrontendMessage::decode(&mut buf);
1544            assert!(result.is_ok());
1545            
1546            if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1547                assert!(selective_updates_config.is_some());
1548                let config = selective_updates_config.unwrap();
1549                assert_eq!(config.enabled, Some(true));
1550                assert_eq!(config.min_changed_columns, None);
1551                assert_eq!(config.max_changed_columns_ratio, None);
1552            } else {
1553                panic!("Expected Subscribe message with config");
1554            }
1555        }
1556
1557        #[test]
1558        fn test_subscribe_with_partial_selective_config_min_columns_only() {
1559            // Test parsing Subscribe with only min_changed_columns flag set
1560            let mut buf = BytesMut::new();
1561            buf.put_u8(0xF0); // Subscribe message type
1562            
1563            let mut body = BytesMut::new();
1564            body.put_slice(b"SELECT * FROM test\0");
1565            body.put_i16(0); // no params
1566            body.put_i16(0); // no filter
1567            
1568            body.put_u8(0x02); // Only min_changed_columns flag set (0b010)
1569            body.put_u16(10); // min_changed_columns = 10
1570            
1571            buf.put_i32((4 + body.len()) as i32);
1572            buf.put_slice(&body);
1573            
1574            let result = FrontendMessage::decode(&mut buf);
1575            assert!(result.is_ok());
1576            
1577            if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1578                assert!(selective_updates_config.is_some());
1579                let config = selective_updates_config.unwrap();
1580                assert_eq!(config.enabled, None);
1581                assert_eq!(config.min_changed_columns, Some(10));
1582                assert_eq!(config.max_changed_columns_ratio, None);
1583            } else {
1584                panic!("Expected Subscribe message with config");
1585            }
1586        }
1587
1588        #[test]
1589        fn test_subscribe_with_partial_selective_config_max_ratio_only() {
1590            // Test parsing Subscribe with only max_changed_columns_ratio flag set
1591            let mut buf = BytesMut::new();
1592            buf.put_u8(0xF0); // Subscribe message type
1593            
1594            let mut body = BytesMut::new();
1595            body.put_slice(b"SELECT * FROM test\0");
1596            body.put_i16(0); // no params
1597            body.put_i16(0); // no filter
1598            
1599            body.put_u8(0x04); // Only max_changed_columns_ratio flag set (0b100)
1600            body.put_f64(0.5); // max_changed_columns_ratio = 0.5
1601            
1602            buf.put_i32((4 + body.len()) as i32);
1603            buf.put_slice(&body);
1604            
1605            let result = FrontendMessage::decode(&mut buf);
1606            assert!(result.is_ok());
1607            
1608            if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1609                assert!(selective_updates_config.is_some());
1610                let config = selective_updates_config.unwrap();
1611                assert_eq!(config.enabled, None);
1612                assert_eq!(config.min_changed_columns, None);
1613                assert_eq!(config.max_changed_columns_ratio, Some(0.5));
1614            } else {
1615                panic!("Expected Subscribe message with config");
1616            }
1617        }
1618
1619        #[test]
1620        fn test_subscribe_with_selective_config_zero_flags() {
1621            // Test that config_flags = 0 results in None config
1622            let mut buf = BytesMut::new();
1623            buf.put_u8(0xF0); // Subscribe message type
1624            
1625            let mut body = BytesMut::new();
1626            body.put_slice(b"SELECT * FROM test\0");
1627            body.put_i16(0); // no params
1628            body.put_i16(0); // no filter
1629            body.put_u8(0x00); // config_flags = 0 (no config)
1630            
1631            buf.put_i32((4 + body.len()) as i32);
1632            buf.put_slice(&body);
1633            
1634            let result = FrontendMessage::decode(&mut buf);
1635            assert!(result.is_ok());
1636            
1637            if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1638                assert!(selective_updates_config.is_none());
1639            } else {
1640                panic!("Expected Subscribe message");
1641            }
1642        }
1643
1644        #[test]
1645        fn test_subscribe_without_selective_config_field() {
1646            // Test backward compatibility: Subscribe without config field present
1647            let mut buf = BytesMut::new();
1648            buf.put_u8(0xF0); // Subscribe message type
1649            
1650            let mut body = BytesMut::new();
1651            body.put_slice(b"SELECT * FROM test\0");
1652            body.put_i16(0); // no params
1653            body.put_i16(0); // no filter
1654            // No config field at all
1655            
1656            buf.put_i32((4 + body.len()) as i32);
1657            buf.put_slice(&body);
1658            
1659            let result = FrontendMessage::decode(&mut buf);
1660            assert!(result.is_ok());
1661            
1662            if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1663                assert!(selective_updates_config.is_none());
1664            } else {
1665                panic!("Expected Subscribe message");
1666            }
1667        }
1668
1669        #[test]
1670        fn test_subscribe_with_config_disabled() {
1671            // Test parsing with enabled = false
1672            let mut buf = BytesMut::new();
1673            buf.put_u8(0xF0); // Subscribe message type
1674            
1675            let mut body = BytesMut::new();
1676            body.put_slice(b"SELECT * FROM test\0");
1677            body.put_i16(0); // no params
1678            body.put_i16(0); // no filter
1679            
1680            body.put_u8(0x01); // enabled flag set
1681            body.put_u8(0); // enabled = false
1682            
1683            buf.put_i32((4 + body.len()) as i32);
1684            buf.put_slice(&body);
1685            
1686            let result = FrontendMessage::decode(&mut buf);
1687            assert!(result.is_ok());
1688            
1689            if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1690                assert!(selective_updates_config.is_some());
1691                let config = selective_updates_config.unwrap();
1692                assert_eq!(config.enabled, Some(false));
1693            } else {
1694                panic!("Expected Subscribe message with config");
1695            }
1696        }
1697
1698        #[test]
1699        fn test_subscribe_with_combined_flags() {
1700            // Test parsing with enabled and min_changed_columns flags
1701            let mut buf = BytesMut::new();
1702            buf.put_u8(0xF0); // Subscribe message type
1703            
1704            let mut body = BytesMut::new();
1705            body.put_slice(b"SELECT * FROM test\0");
1706            body.put_i16(0); // no params
1707            body.put_i16(0); // no filter
1708            
1709            body.put_u8(0x03); // enabled and min_changed_columns flags (0b011)
1710            body.put_u8(1); // enabled = true
1711            body.put_u16(3); // min_changed_columns = 3
1712            
1713            buf.put_i32((4 + body.len()) as i32);
1714            buf.put_slice(&body);
1715            
1716            let result = FrontendMessage::decode(&mut buf);
1717            assert!(result.is_ok());
1718            
1719            if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1720                assert!(selective_updates_config.is_some());
1721                let config = selective_updates_config.unwrap();
1722                assert_eq!(config.enabled, Some(true));
1723                assert_eq!(config.min_changed_columns, Some(3));
1724                assert_eq!(config.max_changed_columns_ratio, None);
1725            } else {
1726                panic!("Expected Subscribe message with config");
1727            }
1728        }
1729    }
1730}