Skip to main content

qail_pg/protocol/wire/
frontend.rs

1//! FrontendMessage encoder — client-to-server wire format.
2
3use super::types::*;
4
5impl FrontendMessage {
6    #[inline]
7    fn has_nul(s: &str) -> bool {
8        s.as_bytes().contains(&0)
9    }
10
11    #[inline]
12    fn content_len_to_wire_len(content_len: usize) -> Result<i32, FrontendEncodeError> {
13        let total = content_len
14            .checked_add(4)
15            .ok_or(FrontendEncodeError::MessageTooLarge(usize::MAX))?;
16        i32::try_from(total).map_err(|_| FrontendEncodeError::MessageTooLarge(total))
17    }
18
19    /// Fallible encoder that returns explicit reason on invalid input.
20    pub fn encode_checked(&self) -> Result<Vec<u8>, FrontendEncodeError> {
21        match self {
22            FrontendMessage::Startup {
23                user,
24                database,
25                startup_params,
26            } => {
27                if Self::has_nul(user) {
28                    return Err(FrontendEncodeError::InteriorNul("user"));
29                }
30                if Self::has_nul(database) {
31                    return Err(FrontendEncodeError::InteriorNul("database"));
32                }
33                let mut seen_startup_keys = std::collections::HashSet::new();
34                let mut buf = Vec::new();
35                buf.extend_from_slice(&196608i32.to_be_bytes());
36                buf.extend_from_slice(b"user\0");
37                buf.extend_from_slice(user.as_bytes());
38                buf.push(0);
39                buf.extend_from_slice(b"database\0");
40                buf.extend_from_slice(database.as_bytes());
41                buf.push(0);
42                for (key, value) in startup_params {
43                    let key_trimmed = key.trim();
44                    if key_trimmed.is_empty() {
45                        return Err(FrontendEncodeError::InvalidStartupParam(
46                            "key must not be empty".to_string(),
47                        ));
48                    }
49                    let key_lc = key_trimmed.to_ascii_lowercase();
50                    if key_lc == "user" || key_lc == "database" {
51                        return Err(FrontendEncodeError::InvalidStartupParam(format!(
52                            "reserved key '{}'",
53                            key_trimmed
54                        )));
55                    }
56                    if !seen_startup_keys.insert(key_lc) {
57                        return Err(FrontendEncodeError::InvalidStartupParam(format!(
58                            "duplicate key '{}'",
59                            key_trimmed
60                        )));
61                    }
62                    if Self::has_nul(key) {
63                        return Err(FrontendEncodeError::InteriorNul("startup_param_key"));
64                    }
65                    if Self::has_nul(value) {
66                        return Err(FrontendEncodeError::InteriorNul("startup_param_value"));
67                    }
68                    buf.extend_from_slice(key.as_bytes());
69                    buf.push(0);
70                    buf.extend_from_slice(value.as_bytes());
71                    buf.push(0);
72                }
73                buf.push(0);
74
75                let len = Self::content_len_to_wire_len(buf.len())?;
76                let mut result = len.to_be_bytes().to_vec();
77                result.extend(buf);
78                Ok(result)
79            }
80            FrontendMessage::Query(sql) => {
81                if Self::has_nul(sql) {
82                    return Err(FrontendEncodeError::InteriorNul("sql"));
83                }
84                let mut buf = Vec::new();
85                buf.push(b'Q');
86                let mut content = Vec::with_capacity(sql.len() + 1);
87                content.extend_from_slice(sql.as_bytes());
88                content.push(0);
89                let len = Self::content_len_to_wire_len(content.len())?;
90                buf.extend_from_slice(&len.to_be_bytes());
91                buf.extend_from_slice(&content);
92                Ok(buf)
93            }
94            FrontendMessage::Terminate => Ok(vec![b'X', 0, 0, 0, 4]),
95            FrontendMessage::SASLInitialResponse { mechanism, data } => {
96                if Self::has_nul(mechanism) {
97                    return Err(FrontendEncodeError::InteriorNul("mechanism"));
98                }
99                if data.len() > i32::MAX as usize {
100                    return Err(FrontendEncodeError::MessageTooLarge(data.len()));
101                }
102                let mut buf = Vec::new();
103                buf.push(b'p');
104
105                let mut content = Vec::new();
106                content.extend_from_slice(mechanism.as_bytes());
107                content.push(0);
108                let data_len = i32::try_from(data.len())
109                    .map_err(|_| FrontendEncodeError::MessageTooLarge(data.len()))?;
110                content.extend_from_slice(&data_len.to_be_bytes());
111                content.extend_from_slice(data);
112
113                let len = Self::content_len_to_wire_len(content.len())?;
114                buf.extend_from_slice(&len.to_be_bytes());
115                buf.extend_from_slice(&content);
116                Ok(buf)
117            }
118            FrontendMessage::SASLResponse(data) | FrontendMessage::GSSResponse(data) => {
119                if data.len() > i32::MAX as usize {
120                    return Err(FrontendEncodeError::MessageTooLarge(data.len()));
121                }
122                let mut buf = Vec::new();
123                buf.push(b'p');
124                let len = Self::content_len_to_wire_len(data.len())?;
125                buf.extend_from_slice(&len.to_be_bytes());
126                buf.extend_from_slice(data);
127                Ok(buf)
128            }
129            FrontendMessage::PasswordMessage(password) => {
130                if Self::has_nul(password) {
131                    return Err(FrontendEncodeError::InteriorNul("password"));
132                }
133                let mut buf = Vec::new();
134                buf.push(b'p');
135                let mut content = Vec::with_capacity(password.len() + 1);
136                content.extend_from_slice(password.as_bytes());
137                content.push(0);
138                let len = Self::content_len_to_wire_len(content.len())?;
139                buf.extend_from_slice(&len.to_be_bytes());
140                buf.extend_from_slice(&content);
141                Ok(buf)
142            }
143            FrontendMessage::Parse {
144                name,
145                query,
146                param_types,
147            } => {
148                if Self::has_nul(name) {
149                    return Err(FrontendEncodeError::InteriorNul("name"));
150                }
151                if Self::has_nul(query) {
152                    return Err(FrontendEncodeError::InteriorNul("query"));
153                }
154                if param_types.len() > i16::MAX as usize {
155                    return Err(FrontendEncodeError::TooManyParams(param_types.len()));
156                }
157                let mut buf = Vec::new();
158                buf.push(b'P');
159
160                let mut content = Vec::new();
161                content.extend_from_slice(name.as_bytes());
162                content.push(0);
163                content.extend_from_slice(query.as_bytes());
164                content.push(0);
165                let param_count = i16::try_from(param_types.len())
166                    .map_err(|_| FrontendEncodeError::TooManyParams(param_types.len()))?;
167                content.extend_from_slice(&param_count.to_be_bytes());
168                for oid in param_types {
169                    content.extend_from_slice(&oid.to_be_bytes());
170                }
171
172                let len = Self::content_len_to_wire_len(content.len())?;
173                buf.extend_from_slice(&len.to_be_bytes());
174                buf.extend_from_slice(&content);
175                Ok(buf)
176            }
177            FrontendMessage::Bind {
178                portal,
179                statement,
180                params,
181            } => {
182                if Self::has_nul(portal) {
183                    return Err(FrontendEncodeError::InteriorNul("portal"));
184                }
185                if Self::has_nul(statement) {
186                    return Err(FrontendEncodeError::InteriorNul("statement"));
187                }
188                if params.len() > i16::MAX as usize {
189                    return Err(FrontendEncodeError::TooManyParams(params.len()));
190                }
191                if let Some(too_large) = params
192                    .iter()
193                    .flatten()
194                    .find(|p| p.len() > i32::MAX as usize)
195                {
196                    return Err(FrontendEncodeError::MessageTooLarge(too_large.len()));
197                }
198
199                let mut buf = Vec::new();
200                buf.push(b'B');
201
202                let mut content = Vec::new();
203                content.extend_from_slice(portal.as_bytes());
204                content.push(0);
205                content.extend_from_slice(statement.as_bytes());
206                content.push(0);
207                content.extend_from_slice(&0i16.to_be_bytes());
208                let param_count = i16::try_from(params.len())
209                    .map_err(|_| FrontendEncodeError::TooManyParams(params.len()))?;
210                content.extend_from_slice(&param_count.to_be_bytes());
211                for param in params {
212                    match param {
213                        Some(data) => {
214                            let data_len = i32::try_from(data.len())
215                                .map_err(|_| FrontendEncodeError::MessageTooLarge(data.len()))?;
216                            content.extend_from_slice(&data_len.to_be_bytes());
217                            content.extend_from_slice(data);
218                        }
219                        None => content.extend_from_slice(&(-1i32).to_be_bytes()),
220                    }
221                }
222                content.extend_from_slice(&0i16.to_be_bytes());
223
224                let len = Self::content_len_to_wire_len(content.len())?;
225                buf.extend_from_slice(&len.to_be_bytes());
226                buf.extend_from_slice(&content);
227                Ok(buf)
228            }
229            FrontendMessage::Execute { portal, max_rows } => {
230                if Self::has_nul(portal) {
231                    return Err(FrontendEncodeError::InteriorNul("portal"));
232                }
233                if *max_rows < 0 {
234                    return Err(FrontendEncodeError::InvalidMaxRows(*max_rows));
235                }
236                let mut buf = Vec::new();
237                buf.push(b'E');
238                let mut content = Vec::new();
239                content.extend_from_slice(portal.as_bytes());
240                content.push(0);
241                content.extend_from_slice(&max_rows.to_be_bytes());
242                let len = Self::content_len_to_wire_len(content.len())?;
243                buf.extend_from_slice(&len.to_be_bytes());
244                buf.extend_from_slice(&content);
245                Ok(buf)
246            }
247            FrontendMessage::Sync => Ok(vec![b'S', 0, 0, 0, 4]),
248            FrontendMessage::CopyFail(msg) => {
249                if Self::has_nul(msg) {
250                    return Err(FrontendEncodeError::InteriorNul("copy_fail"));
251                }
252                let mut buf = Vec::new();
253                buf.push(b'f');
254                let mut content = Vec::with_capacity(msg.len() + 1);
255                content.extend_from_slice(msg.as_bytes());
256                content.push(0);
257                let len = Self::content_len_to_wire_len(content.len())?;
258                buf.extend_from_slice(&len.to_be_bytes());
259                buf.extend_from_slice(&content);
260                Ok(buf)
261            }
262            FrontendMessage::Close { is_portal, name } => {
263                if Self::has_nul(name) {
264                    return Err(FrontendEncodeError::InteriorNul("name"));
265                }
266                let mut buf = Vec::new();
267                buf.push(b'C');
268                let type_byte = if *is_portal { b'P' } else { b'S' };
269                let mut content = vec![type_byte];
270                content.extend_from_slice(name.as_bytes());
271                content.push(0);
272                let len = Self::content_len_to_wire_len(content.len())?;
273                buf.extend_from_slice(&len.to_be_bytes());
274                buf.extend_from_slice(&content);
275                Ok(buf)
276            }
277        }
278    }
279}