Skip to main content

fraiseql_wire/protocol/
decode.rs

1//! Protocol message decoding
2
3use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8/// Maximum number of fields accepted in a single DataRow or RowDescription message.
9///
10/// PostgreSQL's protocol allows up to 1600 columns per table (hard limit enforced by
11/// the server), so 2048 is a generous cap that prevents an attacker-supplied message
12/// from triggering a huge `Vec::with_capacity` before any bounds are checked.
13const MAX_FIELD_COUNT: usize = 2048;
14
15/// Maximum byte length of a single error/notice field string (severity, message, etc.).
16///
17/// A 64 KiB cap is generous for any human-readable error message. Without this limit a
18/// malicious server can send a single oversized field and drive unbounded allocation
19/// in `String::from_utf8_lossy` before the string is ever stored.
20const MAX_ERROR_FIELD_BYTES: usize = 64 * 1024; // 64 KiB
21
22/// Maximum number of SASL mechanism names accepted in an Authentication message.
23///
24/// Real providers offer one or two mechanisms (e.g. SCRAM-SHA-256).  Capping at 32
25/// prevents a rogue server from flooding the `Vec<String>` until memory is exhausted.
26const MAX_SASL_MECHANISMS: usize = 32;
27
28/// Maximum byte length of a ParameterStatus name (e.g. `"server_version"`).
29///
30/// PostgreSQL parameter names are short identifiers; 256 bytes is more than enough.
31const MAX_PARAMETER_NAME_BYTES: usize = 256;
32
33/// Maximum byte length of a ParameterStatus value.
34///
35/// 64 KiB covers realistic values (long `TimeZone` strings, etc.) while preventing
36/// a malicious server from inflating memory with an oversized value string.
37const MAX_PARAMETER_VALUE_BYTES: usize = 64 * 1024; // 64 KiB
38
39/// Decode a backend message from `BytesMut` without cloning
40///
41/// This version decodes in-place from a mutable `BytesMut` buffer and returns
42/// the number of bytes consumed. The caller must advance the buffer after calling this.
43///
44/// # Returns
45/// `Ok((msg, consumed))` - Message and number of bytes consumed
46/// `Err(e)` - IO error if message is incomplete or invalid
47///
48/// # Performance
49/// This version avoids the expensive `buf.clone().freeze()` call by working directly
50/// with references, reducing allocations and copies in the hot path.
51///
52/// # Errors
53///
54/// Returns `io::Error` with `UnexpectedEof` if the buffer is too small for a complete message.
55/// Returns `io::Error` with `InvalidData` if the message length or content is malformed.
56pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
57    if data.len() < 5 {
58        return Err(io::Error::new(
59            io::ErrorKind::UnexpectedEof,
60            "incomplete message header",
61        ));
62    }
63
64    let tag = data[0];
65    let len_i32 = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
66
67    // PostgreSQL message length includes the 4 length bytes but not the tag byte.
68    // Minimum valid length is 4 (just the length field itself).
69    if len_i32 < 4 {
70        return Err(io::Error::new(
71            io::ErrorKind::InvalidData,
72            "message length too small",
73        ));
74    }
75
76    let len = len_i32 as usize;
77
78    if data.len() < len + 1 {
79        return Err(io::Error::new(
80            io::ErrorKind::UnexpectedEof,
81            "incomplete message body",
82        ));
83    }
84
85    // Create a temporary slice starting after the tag and length
86    let msg_start = 5;
87    let msg_end = len + 1;
88    let msg_data = &data[msg_start..msg_end];
89
90    let msg = match tag {
91        tags::AUTHENTICATION => decode_authentication(msg_data)?,
92        tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
93        tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
94        tags::DATA_ROW => decode_data_row(msg_data)?,
95        tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
96        tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
97        tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
98        tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
99        tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
100        _ => {
101            return Err(io::Error::new(
102                io::ErrorKind::InvalidData,
103                format!("unknown message tag: {}", tag),
104            ))
105        }
106    };
107
108    Ok((msg, len + 1))
109}
110
111fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
112    if data.len() < 4 {
113        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"));
114    }
115    let auth_type = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
116
117    let auth_msg = match auth_type {
118        auth::OK => AuthenticationMessage::Ok,
119        auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
120        auth::MD5_PASSWORD => {
121            if data.len() < 8 {
122                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"));
123            }
124            let mut salt = [0u8; 4];
125            salt.copy_from_slice(&data[4..8]);
126            AuthenticationMessage::Md5Password { salt }
127        }
128        auth::SASL => {
129            // SASL: read mechanism list (null-terminated strings)
130            let mut mechanisms = Vec::new();
131            let remaining = &data[4..];
132            let mut offset = 0;
133            loop {
134                if offset >= remaining.len() {
135                    break;
136                }
137                match remaining[offset..].iter().position(|&b| b == 0) {
138                    Some(end) => {
139                        let mechanism =
140                            String::from_utf8_lossy(&remaining[offset..offset + end]).to_string();
141                        if mechanism.is_empty() {
142                            break;
143                        }
144                        if mechanisms.len() >= MAX_SASL_MECHANISMS {
145                            break;
146                        }
147                        mechanisms.push(mechanism);
148                        offset += end + 1;
149                    }
150                    None => break,
151                }
152            }
153            AuthenticationMessage::Sasl { mechanisms }
154        }
155        auth::SASL_CONTINUE => {
156            // SASL continue: read remaining data as bytes
157            let data_vec = data[4..].to_vec();
158            AuthenticationMessage::SaslContinue { data: data_vec }
159        }
160        auth::SASL_FINAL => {
161            // SASL final: read remaining data as bytes
162            let data_vec = data[4..].to_vec();
163            AuthenticationMessage::SaslFinal { data: data_vec }
164        }
165        _ => {
166            return Err(io::Error::new(
167                io::ErrorKind::Unsupported,
168                format!("unsupported auth type: {}", auth_type),
169            ))
170        }
171    };
172
173    Ok(BackendMessage::Authentication(auth_msg))
174}
175
176fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
177    if data.len() < 8 {
178        return Err(io::Error::new(
179            io::ErrorKind::UnexpectedEof,
180            "backend key data",
181        ));
182    }
183    let process_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
184    let secret_key = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
185    Ok(BackendMessage::BackendKeyData {
186        process_id,
187        secret_key,
188    })
189}
190
191fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
192    let end = data.iter().position(|&b| b == 0).ok_or_else(|| {
193        io::Error::new(
194            io::ErrorKind::InvalidData,
195            "missing null terminator in string",
196        )
197    })?;
198    let tag = String::from_utf8_lossy(&data[..end]).to_string();
199    Ok(BackendMessage::CommandComplete(tag))
200}
201
202fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
203    if data.len() < 2 {
204        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
205    }
206    let field_count_i16 = i16::from_be_bytes([data[0], data[1]]);
207    if field_count_i16 < 0 {
208        return Err(io::Error::new(
209            io::ErrorKind::InvalidData,
210            "negative field count",
211        ));
212    }
213    let field_count = field_count_i16 as usize;
214    if field_count > MAX_FIELD_COUNT {
215        return Err(io::Error::new(
216            io::ErrorKind::InvalidData,
217            format!("DataRow field count {field_count} exceeds maximum {MAX_FIELD_COUNT}"),
218        ));
219    }
220    let mut fields = Vec::with_capacity(field_count);
221    let mut offset = 2;
222
223    for _ in 0..field_count {
224        if offset + 4 > data.len() {
225            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field length"));
226        }
227        let field_len = i32::from_be_bytes([
228            data[offset],
229            data[offset + 1],
230            data[offset + 2],
231            data[offset + 3],
232        ]);
233        offset += 4;
234
235        let field = if field_len == -1 {
236            None
237        } else if field_len < 0 {
238            return Err(io::Error::new(
239                io::ErrorKind::InvalidData,
240                "negative field length",
241            ));
242        } else {
243            let len = field_len as usize;
244            if offset + len > data.len() {
245                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field data"));
246            }
247            let field_bytes = Bytes::copy_from_slice(&data[offset..offset + len]);
248            offset += len;
249            Some(field_bytes)
250        };
251        fields.push(field);
252    }
253
254    Ok(BackendMessage::DataRow(fields))
255}
256
257fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
258    let fields = decode_error_fields(data)?;
259    Ok(BackendMessage::ErrorResponse(fields))
260}
261
262fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
263    let fields = decode_error_fields(data)?;
264    Ok(BackendMessage::NoticeResponse(fields))
265}
266
267fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
268    let mut fields = ErrorFields::default();
269    let mut offset = 0;
270
271    loop {
272        if offset >= data.len() {
273            break;
274        }
275        let field_type = data[offset];
276        offset += 1;
277        if field_type == 0 {
278            break;
279        }
280
281        let end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
282            io::Error::new(
283                io::ErrorKind::InvalidData,
284                "missing null terminator in error field",
285            )
286        })?;
287        if end > MAX_ERROR_FIELD_BYTES {
288            return Err(io::Error::new(
289                io::ErrorKind::InvalidData,
290                format!("Error field too large ({end} bytes, max {MAX_ERROR_FIELD_BYTES})"),
291            ));
292        }
293        let value = String::from_utf8_lossy(&data[offset..offset + end]).to_string();
294        offset += end + 1;
295
296        match field_type {
297            b'S' => fields.severity = Some(value),
298            b'C' => fields.code = Some(value),
299            b'M' => fields.message = Some(value),
300            b'D' => fields.detail = Some(value),
301            b'H' => fields.hint = Some(value),
302            b'P' => fields.position = Some(value),
303            _ => {} // Ignore unknown fields
304        }
305    }
306
307    Ok(fields)
308}
309
310fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
311    let mut offset = 0;
312
313    let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
314        io::Error::new(
315            io::ErrorKind::InvalidData,
316            "missing null terminator in parameter name",
317        )
318    })?;
319    if name_end > MAX_PARAMETER_NAME_BYTES {
320        return Err(io::Error::new(
321            io::ErrorKind::InvalidData,
322            format!("Parameter name too long ({name_end} bytes, max {MAX_PARAMETER_NAME_BYTES})"),
323        ));
324    }
325    let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
326    offset += name_end + 1;
327
328    if offset >= data.len() {
329        return Err(io::Error::new(
330            io::ErrorKind::UnexpectedEof,
331            "parameter value",
332        ));
333    }
334    let value_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
335        io::Error::new(
336            io::ErrorKind::InvalidData,
337            "missing null terminator in parameter value",
338        )
339    })?;
340    if value_end > MAX_PARAMETER_VALUE_BYTES {
341        return Err(io::Error::new(
342            io::ErrorKind::InvalidData,
343            format!(
344                "Parameter value too long ({value_end} bytes, max {MAX_PARAMETER_VALUE_BYTES})"
345            ),
346        ));
347    }
348    let value = String::from_utf8_lossy(&data[offset..offset + value_end]).to_string();
349
350    Ok(BackendMessage::ParameterStatus { name, value })
351}
352
353fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
354    if data.is_empty() {
355        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"));
356    }
357    let status = data[0];
358    Ok(BackendMessage::ReadyForQuery { status })
359}
360
361fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
362    if data.len() < 2 {
363        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
364    }
365    let field_count_i16 = i16::from_be_bytes([data[0], data[1]]);
366    if field_count_i16 < 0 {
367        return Err(io::Error::new(
368            io::ErrorKind::InvalidData,
369            "negative field count",
370        ));
371    }
372    let field_count = field_count_i16 as usize;
373    if field_count > MAX_FIELD_COUNT {
374        return Err(io::Error::new(
375            io::ErrorKind::InvalidData,
376            format!("RowDescription field count {field_count} exceeds maximum {MAX_FIELD_COUNT}"),
377        ));
378    }
379    let mut fields = Vec::with_capacity(field_count);
380    let mut offset = 2;
381
382    for _ in 0..field_count {
383        // Read name (null-terminated string)
384        let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
385            io::Error::new(
386                io::ErrorKind::InvalidData,
387                "missing null terminator in field name",
388            )
389        })?;
390        let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
391        offset += name_end + 1;
392
393        // Read field descriptor (26 bytes: 4+2+4+2+4+2)
394        if offset + 18 > data.len() {
395            return Err(io::Error::new(
396                io::ErrorKind::UnexpectedEof,
397                "field descriptor",
398            ));
399        }
400        let table_oid = i32::from_be_bytes([
401            data[offset],
402            data[offset + 1],
403            data[offset + 2],
404            data[offset + 3],
405        ]);
406        offset += 4;
407        let column_attr = i16::from_be_bytes([data[offset], data[offset + 1]]);
408        offset += 2;
409        let type_oid = i32::from_be_bytes([
410            data[offset],
411            data[offset + 1],
412            data[offset + 2],
413            data[offset + 3],
414        ]) as u32;
415        offset += 4;
416        let type_size = i16::from_be_bytes([data[offset], data[offset + 1]]);
417        offset += 2;
418        let type_modifier = i32::from_be_bytes([
419            data[offset],
420            data[offset + 1],
421            data[offset + 2],
422            data[offset + 3],
423        ]);
424        offset += 4;
425        let format_code = i16::from_be_bytes([data[offset], data[offset + 1]]);
426        offset += 2;
427
428        fields.push(FieldDescription {
429            name,
430            table_oid,
431            column_attr,
432            type_oid,
433            type_size,
434            type_modifier,
435            format_code,
436        });
437    }
438
439    Ok(BackendMessage::RowDescription(fields))
440}
441
442#[cfg(test)]
443mod tests {
444    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
445    use super::*;
446
447    #[test]
448    fn test_decode_authentication_ok() {
449        let mut data = BytesMut::from(
450            &[
451                b'R', // Authentication
452                0, 0, 0, 8, // Length = 8
453                0, 0, 0, 0, // Auth OK
454            ][..],
455        );
456
457        let (msg, consumed) = decode_message(&mut data).unwrap();
458        match msg {
459            BackendMessage::Authentication(AuthenticationMessage::Ok) => {}
460            _ => panic!("expected Authentication::Ok"),
461        }
462        assert_eq!(consumed, 9); // 1 tag + 4 len + 4 auth type
463    }
464
465    #[test]
466    fn test_decode_ready_for_query() {
467        let mut data = BytesMut::from(
468            &[
469                b'Z', // ReadyForQuery
470                0, 0, 0, 5,    // Length = 5
471                b'I', // Idle
472            ][..],
473        );
474
475        let (msg, consumed) = decode_message(&mut data).unwrap();
476        match msg {
477            BackendMessage::ReadyForQuery { status } => assert_eq!(status, b'I'),
478            _ => panic!("expected ReadyForQuery"),
479        }
480        assert_eq!(consumed, 6); // 1 tag + 4 len + 1 status
481    }
482
483    // ── Field-count guard tests ────────────────────────────────────────────────
484
485    fn make_data_row_with_count(count: i16) -> BytesMut {
486        // DataRow: tag 'D', length (4 bytes), field_count (2 bytes), then `count` null fields.
487        // Each null field is represented by length -1 (i32: 0xFF FF FF FF).
488        let body_len: u32 = 2 + 4 * u32::from(count.unsigned_abs());
489        let mut buf = BytesMut::new();
490        buf.extend_from_slice(b"D");
491        buf.extend_from_slice(&(body_len + 4).to_be_bytes()); // length includes itself
492        buf.extend_from_slice(&count.to_be_bytes());
493        for _ in 0..count {
494            buf.extend_from_slice(&(-1i32).to_be_bytes()); // NULL field
495        }
496        buf
497    }
498
499    fn make_row_description_with_count(count: i16) -> BytesMut {
500        // RowDescription: tag 'T', length, field_count, then `count` minimal field descriptors.
501        // Each descriptor: name (1 null byte) + 18 bytes of OID/size info = 19 bytes.
502        let body_len: u32 = 2 + 19 * u32::from(count.unsigned_abs());
503        let mut buf = BytesMut::new();
504        buf.extend_from_slice(b"T");
505        buf.extend_from_slice(&(body_len + 4).to_be_bytes());
506        buf.extend_from_slice(&count.to_be_bytes());
507        for _ in 0..count {
508            buf.extend_from_slice(&[0u8]); // empty name (null terminator)
509            buf.extend_from_slice(&[0u8; 18]); // table_oid(4) + col_attr(2) + type_oid(4) + type_size(2) + type_mod(4) + format(2)
510        }
511        buf
512    }
513
514    #[test]
515    fn test_data_row_zero_fields_accepted() {
516        let mut buf = make_data_row_with_count(0);
517        let result = decode_message(&mut buf);
518        assert!(result.is_ok(), "zero-field DataRow must be accepted");
519        let (msg, _) = result.unwrap();
520        assert!(matches!(msg, BackendMessage::DataRow(fields) if fields.is_empty()));
521    }
522
523    #[test]
524    fn test_data_row_field_count_exceeds_max_is_rejected() {
525        // MAX_FIELD_COUNT + 1 = 2049 fields — must trigger the guard before
526        // any field data is read.
527        let count: i16 = (MAX_FIELD_COUNT + 1) as i16; // 2049
528        let mut buf = BytesMut::new();
529        buf.extend_from_slice(b"D");
530        // body = 2 (count) + 4 (padding); length field includes itself: 2+4+4 = 10
531        buf.extend_from_slice(&10u32.to_be_bytes());
532        buf.extend_from_slice(&count.to_be_bytes());
533        buf.extend_from_slice(&[0u8; 4]);
534
535        let result = decode_message(&mut buf);
536        assert!(result.is_err(), "DataRow with 2049 fields must be rejected");
537        let err = result.unwrap_err();
538        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
539        let msg = err.to_string();
540        assert!(msg.contains("2048"), "error must mention the limit: {msg}");
541    }
542
543    #[test]
544    fn test_row_description_field_count_exceeds_max_is_rejected() {
545        let count: i16 = (MAX_FIELD_COUNT + 1) as i16; // 2049
546        let mut buf = BytesMut::new();
547        buf.extend_from_slice(b"T");
548        buf.extend_from_slice(&10u32.to_be_bytes());
549        buf.extend_from_slice(&count.to_be_bytes());
550        buf.extend_from_slice(&[0u8; 4]);
551
552        let result = decode_message(&mut buf);
553        assert!(
554            result.is_err(),
555            "RowDescription with 2049 fields must be rejected"
556        );
557        let err = result.unwrap_err();
558        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
559        let msg = err.to_string();
560        assert!(msg.contains("2048"), "error must mention the limit: {msg}");
561    }
562
563    #[test]
564    fn test_row_description_small_field_count_accepted() {
565        let mut buf = make_row_description_with_count(3);
566        let result = decode_message(&mut buf);
567        assert!(
568            result.is_ok(),
569            "3-field RowDescription must be accepted: {result:?}"
570        );
571        let (msg, _) = result.unwrap();
572        assert!(matches!(msg, BackendMessage::RowDescription(fields) if fields.len() == 3));
573    }
574
575    // ── Error-field size cap tests (S21-H1) ───────────────────────────────────
576
577    fn make_error_response(field_type: u8, field_value: &[u8]) -> BytesMut {
578        // ErrorResponse: tag 'E', length (4 bytes), then fields.
579        // Each field: 1-byte type + value bytes + null terminator, then a final null byte.
580        let body_len = 1 + field_value.len() + 1 + 1; // type + value + null + terminator
581        let mut buf = BytesMut::new();
582        buf.extend_from_slice(b"E");
583        buf.extend_from_slice(&(body_len as u32 + 4).to_be_bytes());
584        buf.extend_from_slice(&[field_type]);
585        buf.extend_from_slice(field_value);
586        buf.extend_from_slice(&[0]); // null terminator for field value
587        buf.extend_from_slice(&[0]); // terminating null byte
588        buf
589    }
590
591    #[test]
592    fn error_field_within_limit_is_accepted() {
593        let value = vec![b'x'; 1024]; // 1 KiB — well within 64 KiB limit
594        let mut buf = make_error_response(b'M', &value);
595        let result = decode_message(&mut buf);
596        assert!(
597            result.is_ok(),
598            "small error field must be accepted: {result:?}"
599        );
600    }
601
602    #[test]
603    fn error_field_exceeding_limit_is_rejected() {
604        let value = vec![b'x'; MAX_ERROR_FIELD_BYTES + 1]; // one byte over the cap
605        let mut buf = make_error_response(b'M', &value);
606        let result = decode_message(&mut buf);
607        assert!(result.is_err(), "oversized error field must be rejected");
608        let err = result.unwrap_err();
609        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
610        let msg = err.to_string();
611        assert!(
612            msg.contains("too large") || msg.contains("65536"),
613            "error must mention size limit: {msg}"
614        );
615    }
616
617    // ── SASL mechanism cap tests (S21-H2) ─────────────────────────────────────
618
619    fn make_sasl_auth(mechanisms: &[&str]) -> BytesMut {
620        // Authentication SASL: tag 'R', length, auth type (10 = SASL), mechanism list.
621        let mut mechanism_bytes: Vec<u8> = Vec::new();
622        for m in mechanisms {
623            mechanism_bytes.extend_from_slice(m.as_bytes());
624            mechanism_bytes.push(0);
625        }
626        mechanism_bytes.push(0); // final double-null terminator
627        let body_len = 4 + mechanism_bytes.len(); // auth type (4) + mechanisms
628        let mut buf = BytesMut::new();
629        buf.extend_from_slice(b"R");
630        buf.extend_from_slice(&(body_len as u32 + 4).to_be_bytes());
631        buf.extend_from_slice(&10u32.to_be_bytes()); // SASL auth type
632        buf.extend_from_slice(&mechanism_bytes);
633        buf
634    }
635
636    #[test]
637    fn sasl_mechanisms_within_limit_are_accepted() {
638        let mechanisms: Vec<&str> = (0..MAX_SASL_MECHANISMS).map(|_| "SCRAM-SHA-256").collect();
639        let mut buf = make_sasl_auth(&mechanisms);
640        let result = decode_message(&mut buf);
641        assert!(
642            result.is_ok(),
643            "SASL with {MAX_SASL_MECHANISMS} mechanisms must be accepted"
644        );
645    }
646
647    #[test]
648    fn sasl_mechanisms_exceeding_limit_are_truncated_not_rejected() {
649        // The guard breaks out of the loop rather than erroring; verify it still succeeds
650        // with at most MAX_SASL_MECHANISMS entries.
651        let mechanisms: Vec<&str> = (0..MAX_SASL_MECHANISMS + 5)
652            .map(|_| "SCRAM-SHA-256")
653            .collect();
654        let mut buf = make_sasl_auth(&mechanisms);
655        let result = decode_message(&mut buf);
656        assert!(
657            result.is_ok(),
658            "SASL with excess mechanisms must still parse successfully"
659        );
660        if let Ok((
661            BackendMessage::Authentication(AuthenticationMessage::Sasl { mechanisms: parsed }),
662            _,
663        )) = result
664        {
665            assert!(
666                parsed.len() <= MAX_SASL_MECHANISMS,
667                "parsed mechanisms must not exceed cap: {} > {MAX_SASL_MECHANISMS}",
668                parsed.len()
669            );
670        }
671    }
672
673    // ── Parameter name/value cap tests (S21-H3) ───────────────────────────────
674
675    fn make_parameter_status(name: &[u8], value: &[u8]) -> BytesMut {
676        let body_len = name.len() + 1 + value.len() + 1; // name + null + value + null
677        let mut buf = BytesMut::new();
678        buf.extend_from_slice(b"S");
679        buf.extend_from_slice(&(body_len as u32 + 4).to_be_bytes());
680        buf.extend_from_slice(name);
681        buf.extend_from_slice(&[0]);
682        buf.extend_from_slice(value);
683        buf.extend_from_slice(&[0]);
684        buf
685    }
686
687    #[test]
688    fn parameter_status_normal_is_accepted() {
689        let mut buf = make_parameter_status(b"server_version", b"16.0");
690        let result = decode_message(&mut buf);
691        assert!(
692            result.is_ok(),
693            "normal ParameterStatus must be accepted: {result:?}"
694        );
695    }
696
697    #[test]
698    fn parameter_name_exceeding_limit_is_rejected() {
699        let long_name = vec![b'a'; MAX_PARAMETER_NAME_BYTES + 1];
700        let mut buf = make_parameter_status(&long_name, b"value");
701        let result = decode_message(&mut buf);
702        assert!(result.is_err(), "oversized parameter name must be rejected");
703        let msg = result.unwrap_err().to_string();
704        assert!(
705            msg.contains("too long") || msg.contains("256"),
706            "error must mention the name limit: {msg}"
707        );
708    }
709
710    #[test]
711    fn parameter_value_exceeding_limit_is_rejected() {
712        let long_value = vec![b'v'; MAX_PARAMETER_VALUE_BYTES + 1];
713        let mut buf = make_parameter_status(b"timezone", &long_value);
714        let result = decode_message(&mut buf);
715        assert!(
716            result.is_err(),
717            "oversized parameter value must be rejected"
718        );
719        let msg = result.unwrap_err().to_string();
720        assert!(
721            msg.contains("too long") || msg.contains("65536"),
722            "error must mention the value limit: {msg}"
723        );
724    }
725}