Skip to main content

fraiseql_wire/protocol/
encode.rs

1//! Protocol message encoding
2
3use super::message::FrontendMessage;
4use bytes::{BufMut, BytesMut};
5use std::io;
6
7/// Encode a frontend message into bytes
8///
9/// # Errors
10///
11/// Returns `io::Error` if the message contains invalid data that cannot be encoded.
12pub fn encode_message(msg: &FrontendMessage) -> io::Result<BytesMut> {
13    let mut buf = BytesMut::new();
14
15    match msg {
16        FrontendMessage::Startup { version, params } => {
17            encode_startup(&mut buf, *version, params)?;
18        }
19        FrontendMessage::Password(password) => {
20            encode_password(&mut buf, password)?;
21        }
22        FrontendMessage::Query(query) => {
23            encode_query(&mut buf, query)?;
24        }
25        FrontendMessage::Terminate => {
26            encode_terminate(&mut buf)?;
27        }
28        FrontendMessage::SaslInitialResponse { mechanism, data } => {
29            encode_sasl_initial_response(&mut buf, mechanism, data)?;
30        }
31        FrontendMessage::SaslResponse { data } => {
32            encode_sasl_response(&mut buf, data)?;
33        }
34    }
35
36    Ok(buf)
37}
38
39fn encode_startup(buf: &mut BytesMut, version: i32, params: &[(String, String)]) -> io::Result<()> {
40    // Startup messages don't have a type byte
41    // Reserve space for length (will be filled at end)
42    let len_pos = buf.len();
43    buf.put_i32(0);
44
45    // Protocol version
46    buf.put_i32(version);
47
48    // Parameters (key-value pairs, null-terminated)
49    for (key, value) in params {
50        buf.put(key.as_bytes());
51        buf.put_u8(0);
52        buf.put(value.as_bytes());
53        buf.put_u8(0);
54    }
55
56    // Final null terminator
57    buf.put_u8(0);
58
59    // Fill in length
60    let len = buf.len() - len_pos;
61    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
62
63    Ok(())
64}
65
66fn encode_password(buf: &mut BytesMut, password: &str) -> io::Result<()> {
67    buf.put_u8(b'p');
68    let len_pos = buf.len();
69    buf.put_i32(0);
70
71    buf.put(password.as_bytes());
72    buf.put_u8(0);
73
74    let len = buf.len() - len_pos;
75    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
76
77    Ok(())
78}
79
80fn encode_query(buf: &mut BytesMut, query: &str) -> io::Result<()> {
81    buf.put_u8(b'Q');
82    let len_pos = buf.len();
83    buf.put_i32(0);
84
85    buf.put(query.as_bytes());
86    buf.put_u8(0);
87
88    let len = buf.len() - len_pos;
89    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
90
91    Ok(())
92}
93
94fn encode_terminate(buf: &mut BytesMut) -> io::Result<()> {
95    buf.put_u8(b'X');
96    buf.put_i32(4); // Length includes itself
97    Ok(())
98}
99
100fn encode_sasl_initial_response(
101    buf: &mut BytesMut,
102    mechanism: &str,
103    data: &[u8],
104) -> io::Result<()> {
105    buf.put_u8(b'p');
106    let len_pos = buf.len();
107    buf.put_i32(0);
108
109    // Mechanism name (null-terminated)
110    buf.put(mechanism.as_bytes());
111    buf.put_u8(0);
112
113    // SASL data (as length-prefixed bytes)
114    buf.put_i32(data.len() as i32);
115    buf.put_slice(data);
116
117    let len = buf.len() - len_pos;
118    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
119
120    Ok(())
121}
122
123fn encode_sasl_response(buf: &mut BytesMut, data: &[u8]) -> io::Result<()> {
124    buf.put_u8(b'p');
125    let len_pos = buf.len();
126    buf.put_i32(0);
127
128    // SASL data (length-prefixed)
129    buf.put_slice(data);
130
131    let len = buf.len() - len_pos;
132    buf[len_pos..len_pos + 4].copy_from_slice(&(len as i32).to_be_bytes());
133
134    Ok(())
135}
136
137#[cfg(test)]
138mod tests {
139    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
140    use super::*;
141
142    #[test]
143    fn test_encode_query() {
144        let msg = FrontendMessage::Query("SELECT 1".to_string());
145        let buf = encode_message(&msg).unwrap();
146
147        assert_eq!(buf[0], b'Q');
148        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
149        assert_eq!(len, (buf.len() - 1) as i32);
150    }
151
152    #[test]
153    fn test_encode_terminate() {
154        let msg = FrontendMessage::Terminate;
155        let buf = encode_message(&msg).unwrap();
156
157        assert_eq!(buf[0], b'X');
158        assert_eq!(buf.len(), 5);
159    }
160}