Skip to main content

fraiseql_wire/protocol/
encode.rs

1//! Protocol message encoding
2
3use super::message::FrontendMessage;
4use bytes::{BufMut, BytesMut};
5use std::io;
6
7/// Encode a frontend message into bytes
8pub fn encode_message(msg: &FrontendMessage) -> io::Result<BytesMut> {
9    let mut buf = BytesMut::new();
10
11    match msg {
12        FrontendMessage::Startup { version, params } => {
13            encode_startup(&mut buf, *version, params)?;
14        }
15        FrontendMessage::Password(password) => {
16            encode_password(&mut buf, password)?;
17        }
18        FrontendMessage::Query(query) => {
19            encode_query(&mut buf, query)?;
20        }
21        FrontendMessage::Terminate => {
22            encode_terminate(&mut buf)?;
23        }
24        FrontendMessage::SaslInitialResponse { mechanism, data } => {
25            encode_sasl_initial_response(&mut buf, mechanism, data)?;
26        }
27        FrontendMessage::SaslResponse { data } => {
28            encode_sasl_response(&mut buf, data)?;
29        }
30        FrontendMessage::SslRequest => {
31            encode_ssl_request(&mut buf)?;
32        }
33    }
34
35    Ok(buf)
36}
37
38fn encode_startup(buf: &mut BytesMut, version: i32, params: &[(String, String)]) -> io::Result<()> {
39    // Startup messages don't have a type byte
40    // Reserve space for length (will be filled at end)
41    let len_pos = buf.len();
42    buf.put_i32(0);
43
44    // Protocol version
45    buf.put_i32(version);
46
47    // Parameters (key-value pairs, null-terminated)
48    for (key, value) in params {
49        buf.put(key.as_bytes());
50        buf.put_u8(0);
51        buf.put(value.as_bytes());
52        buf.put_u8(0);
53    }
54
55    // Final null terminator
56    buf.put_u8(0);
57
58    // Fill in length
59    let len = buf.len() - len_pos;
60    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
61
62    Ok(())
63}
64
65fn encode_password(buf: &mut BytesMut, password: &str) -> io::Result<()> {
66    buf.put_u8(b'p');
67    let len_pos = buf.len();
68    buf.put_i32(0);
69
70    buf.put(password.as_bytes());
71    buf.put_u8(0);
72
73    let len = buf.len() - len_pos;
74    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
75
76    Ok(())
77}
78
79fn encode_query(buf: &mut BytesMut, query: &str) -> io::Result<()> {
80    buf.put_u8(b'Q');
81    let len_pos = buf.len();
82    buf.put_i32(0);
83
84    buf.put(query.as_bytes());
85    buf.put_u8(0);
86
87    let len = buf.len() - len_pos;
88    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
89
90    Ok(())
91}
92
93fn encode_terminate(buf: &mut BytesMut) -> io::Result<()> {
94    buf.put_u8(b'X');
95    buf.put_i32(4); // Length includes itself
96    Ok(())
97}
98
99fn encode_sasl_initial_response(
100    buf: &mut BytesMut,
101    mechanism: &str,
102    data: &[u8],
103) -> io::Result<()> {
104    buf.put_u8(b'p');
105    let len_pos = buf.len();
106    buf.put_i32(0);
107
108    // Mechanism name (null-terminated)
109    buf.put(mechanism.as_bytes());
110    buf.put_u8(0);
111
112    // SASL data (as length-prefixed bytes)
113    buf.put_i32(data.len() as i32);
114    buf.put_slice(data);
115
116    let len = buf.len() - len_pos;
117    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
118
119    Ok(())
120}
121
122fn encode_ssl_request(buf: &mut BytesMut) -> io::Result<()> {
123    buf.put_i32(8); // Length (includes itself)
124    buf.put_i32(super::constants::SSL_REQUEST_CODE);
125    Ok(())
126}
127
128fn encode_sasl_response(buf: &mut BytesMut, data: &[u8]) -> io::Result<()> {
129    buf.put_u8(b'p');
130    let len_pos = buf.len();
131    buf.put_i32(0);
132
133    // SASL data (length-prefixed)
134    buf.put_slice(data);
135
136    let len = buf.len() - len_pos;
137    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
138
139    Ok(())
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_encode_query() {
148        let msg = FrontendMessage::Query("SELECT 1".to_string());
149        let buf = encode_message(&msg).unwrap();
150
151        assert_eq!(buf[0], b'Q');
152        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
153        assert_eq!(len, (buf.len() - 1) as i32);
154    }
155
156    #[test]
157    fn test_encode_terminate() {
158        let msg = FrontendMessage::Terminate;
159        let buf = encode_message(&msg).unwrap();
160
161        assert_eq!(buf[0], b'X');
162        assert_eq!(buf.len(), 5);
163    }
164
165    #[test]
166    fn test_encode_ssl_request() {
167        let msg = FrontendMessage::SslRequest;
168        let buf = encode_message(&msg).unwrap();
169
170        // SSLRequest is exactly 8 bytes: 4-byte length (8) + 4-byte code (80877103)
171        assert_eq!(buf.len(), 8);
172        // Length = 8 (big-endian)
173        assert_eq!(&buf[0..4], &[0x00, 0x00, 0x00, 0x08]);
174        // SSL request code = 80877103 = 0x04D2162F
175        assert_eq!(&buf[4..8], &[0x04, 0xD2, 0x16, 0x2F]);
176    }
177}