Skip to main content

fraiseql_wire/protocol/
encode.rs

1//! Protocol message encoding
2
3use super::message::FrontendMessage;
4use bytes::{BufMut, BytesMut};
5use std::io;
6
7/// Encode a frontend message into bytes
8pub fn encode_message(msg: &FrontendMessage) -> io::Result<BytesMut> {
9    let mut buf = BytesMut::new();
10
11    match msg {
12        FrontendMessage::Startup { version, params } => {
13            encode_startup(&mut buf, *version, params)?;
14        }
15        FrontendMessage::Password(password) => {
16            encode_password(&mut buf, password)?;
17        }
18        FrontendMessage::Query(query) => {
19            encode_query(&mut buf, query)?;
20        }
21        FrontendMessage::Terminate => {
22            encode_terminate(&mut buf)?;
23        }
24        FrontendMessage::SaslInitialResponse { mechanism, data } => {
25            encode_sasl_initial_response(&mut buf, mechanism, data)?;
26        }
27        FrontendMessage::SaslResponse { data } => {
28            encode_sasl_response(&mut buf, data)?;
29        }
30    }
31
32    Ok(buf)
33}
34
35fn encode_startup(buf: &mut BytesMut, version: i32, params: &[(String, String)]) -> io::Result<()> {
36    // Startup messages don't have a type byte
37    // Reserve space for length (will be filled at end)
38    let len_pos = buf.len();
39    buf.put_i32(0);
40
41    // Protocol version
42    buf.put_i32(version);
43
44    // Parameters (key-value pairs, null-terminated)
45    for (key, value) in params {
46        buf.put(key.as_bytes());
47        buf.put_u8(0);
48        buf.put(value.as_bytes());
49        buf.put_u8(0);
50    }
51
52    // Final null terminator
53    buf.put_u8(0);
54
55    // Fill in length
56    let len = buf.len() - len_pos;
57    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
58
59    Ok(())
60}
61
62fn encode_password(buf: &mut BytesMut, password: &str) -> io::Result<()> {
63    buf.put_u8(b'p');
64    let len_pos = buf.len();
65    buf.put_i32(0);
66
67    buf.put(password.as_bytes());
68    buf.put_u8(0);
69
70    let len = buf.len() - len_pos;
71    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
72
73    Ok(())
74}
75
76fn encode_query(buf: &mut BytesMut, query: &str) -> io::Result<()> {
77    buf.put_u8(b'Q');
78    let len_pos = buf.len();
79    buf.put_i32(0);
80
81    buf.put(query.as_bytes());
82    buf.put_u8(0);
83
84    let len = buf.len() - len_pos;
85    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
86
87    Ok(())
88}
89
90fn encode_terminate(buf: &mut BytesMut) -> io::Result<()> {
91    buf.put_u8(b'X');
92    buf.put_i32(4); // Length includes itself
93    Ok(())
94}
95
96fn encode_sasl_initial_response(
97    buf: &mut BytesMut,
98    mechanism: &str,
99    data: &[u8],
100) -> io::Result<()> {
101    buf.put_u8(b'p');
102    let len_pos = buf.len();
103    buf.put_i32(0);
104
105    // Mechanism name (null-terminated)
106    buf.put(mechanism.as_bytes());
107    buf.put_u8(0);
108
109    // SASL data (as length-prefixed bytes)
110    buf.put_i32(data.len() as i32);
111    buf.put_slice(data);
112
113    let len = buf.len() - len_pos;
114    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
115
116    Ok(())
117}
118
119fn encode_sasl_response(buf: &mut BytesMut, data: &[u8]) -> io::Result<()> {
120    buf.put_u8(b'p');
121    let len_pos = buf.len();
122    buf.put_i32(0);
123
124    // SASL data (length-prefixed)
125    buf.put_slice(data);
126
127    let len = buf.len() - len_pos;
128    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
129
130    Ok(())
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_encode_query() {
139        let msg = FrontendMessage::Query("SELECT 1".to_string());
140        let buf = encode_message(&msg).unwrap();
141
142        assert_eq!(buf[0], b'Q');
143        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
144        assert_eq!(len, (buf.len() - 1) as i32);
145    }
146
147    #[test]
148    fn test_encode_terminate() {
149        let msg = FrontendMessage::Terminate;
150        let buf = encode_message(&msg).unwrap();
151
152        assert_eq!(buf[0], b'X');
153        assert_eq!(buf.len(), 5);
154    }
155}