qail_pg/protocol/wire/
frontend.rs1use super::types::*;
4
5impl FrontendMessage {
6 #[inline]
7 fn has_nul(s: &str) -> bool {
8 s.as_bytes().contains(&0)
9 }
10
11 #[inline]
12 fn content_len_to_wire_len(content_len: usize) -> Result<i32, FrontendEncodeError> {
13 let total = content_len
14 .checked_add(4)
15 .ok_or(FrontendEncodeError::MessageTooLarge(usize::MAX))?;
16 i32::try_from(total).map_err(|_| FrontendEncodeError::MessageTooLarge(total))
17 }
18
19 pub fn encode_checked(&self) -> Result<Vec<u8>, FrontendEncodeError> {
21 match self {
22 FrontendMessage::Startup {
23 user,
24 database,
25 protocol_version,
26 startup_params,
27 } => {
28 if Self::has_nul(user) {
29 return Err(FrontendEncodeError::InteriorNul("user"));
30 }
31 if Self::has_nul(database) {
32 return Err(FrontendEncodeError::InteriorNul("database"));
33 }
34 let mut seen_startup_keys = std::collections::HashSet::new();
35 let mut buf = Vec::new();
36 buf.extend_from_slice(&protocol_version.to_be_bytes());
37 buf.extend_from_slice(b"user\0");
38 buf.extend_from_slice(user.as_bytes());
39 buf.push(0);
40 buf.extend_from_slice(b"database\0");
41 buf.extend_from_slice(database.as_bytes());
42 buf.push(0);
43 for (key, value) in startup_params {
44 let key_trimmed = key.trim();
45 if key_trimmed.is_empty() {
46 return Err(FrontendEncodeError::InvalidStartupParam(
47 "key must not be empty".to_string(),
48 ));
49 }
50 let key_lc = key_trimmed.to_ascii_lowercase();
51 if key_lc == "user" || key_lc == "database" {
52 return Err(FrontendEncodeError::InvalidStartupParam(format!(
53 "reserved key '{}'",
54 key_trimmed
55 )));
56 }
57 if !seen_startup_keys.insert(key_lc) {
58 return Err(FrontendEncodeError::InvalidStartupParam(format!(
59 "duplicate key '{}'",
60 key_trimmed
61 )));
62 }
63 if Self::has_nul(key) {
64 return Err(FrontendEncodeError::InteriorNul("startup_param_key"));
65 }
66 if Self::has_nul(value) {
67 return Err(FrontendEncodeError::InteriorNul("startup_param_value"));
68 }
69 buf.extend_from_slice(key.as_bytes());
70 buf.push(0);
71 buf.extend_from_slice(value.as_bytes());
72 buf.push(0);
73 }
74 buf.push(0);
75
76 let len = Self::content_len_to_wire_len(buf.len())?;
77 let mut result = len.to_be_bytes().to_vec();
78 result.extend(buf);
79 Ok(result)
80 }
81 FrontendMessage::Query(sql) => {
82 if Self::has_nul(sql) {
83 return Err(FrontendEncodeError::InteriorNul("sql"));
84 }
85 let mut buf = Vec::new();
86 buf.push(b'Q');
87 let mut content = Vec::with_capacity(sql.len() + 1);
88 content.extend_from_slice(sql.as_bytes());
89 content.push(0);
90 let len = Self::content_len_to_wire_len(content.len())?;
91 buf.extend_from_slice(&len.to_be_bytes());
92 buf.extend_from_slice(&content);
93 Ok(buf)
94 }
95 FrontendMessage::Terminate => Ok(vec![b'X', 0, 0, 0, 4]),
96 FrontendMessage::SASLInitialResponse { mechanism, data } => {
97 if Self::has_nul(mechanism) {
98 return Err(FrontendEncodeError::InteriorNul("mechanism"));
99 }
100 if data.len() > i32::MAX as usize {
101 return Err(FrontendEncodeError::MessageTooLarge(data.len()));
102 }
103 let mut buf = Vec::new();
104 buf.push(b'p');
105
106 let mut content = Vec::new();
107 content.extend_from_slice(mechanism.as_bytes());
108 content.push(0);
109 let data_len = i32::try_from(data.len())
110 .map_err(|_| FrontendEncodeError::MessageTooLarge(data.len()))?;
111 content.extend_from_slice(&data_len.to_be_bytes());
112 content.extend_from_slice(data);
113
114 let len = Self::content_len_to_wire_len(content.len())?;
115 buf.extend_from_slice(&len.to_be_bytes());
116 buf.extend_from_slice(&content);
117 Ok(buf)
118 }
119 FrontendMessage::SASLResponse(data) | FrontendMessage::GSSResponse(data) => {
120 if data.len() > i32::MAX as usize {
121 return Err(FrontendEncodeError::MessageTooLarge(data.len()));
122 }
123 let mut buf = Vec::new();
124 buf.push(b'p');
125 let len = Self::content_len_to_wire_len(data.len())?;
126 buf.extend_from_slice(&len.to_be_bytes());
127 buf.extend_from_slice(data);
128 Ok(buf)
129 }
130 FrontendMessage::PasswordMessage(password) => {
131 if Self::has_nul(password) {
132 return Err(FrontendEncodeError::InteriorNul("password"));
133 }
134 let mut buf = Vec::new();
135 buf.push(b'p');
136 let mut content = Vec::with_capacity(password.len() + 1);
137 content.extend_from_slice(password.as_bytes());
138 content.push(0);
139 let len = Self::content_len_to_wire_len(content.len())?;
140 buf.extend_from_slice(&len.to_be_bytes());
141 buf.extend_from_slice(&content);
142 Ok(buf)
143 }
144 FrontendMessage::Parse {
145 name,
146 query,
147 param_types,
148 } => {
149 if Self::has_nul(name) {
150 return Err(FrontendEncodeError::InteriorNul("name"));
151 }
152 if Self::has_nul(query) {
153 return Err(FrontendEncodeError::InteriorNul("query"));
154 }
155 if param_types.len() > i16::MAX as usize {
156 return Err(FrontendEncodeError::TooManyParams(param_types.len()));
157 }
158 let mut buf = Vec::new();
159 buf.push(b'P');
160
161 let mut content = Vec::new();
162 content.extend_from_slice(name.as_bytes());
163 content.push(0);
164 content.extend_from_slice(query.as_bytes());
165 content.push(0);
166 let param_count = i16::try_from(param_types.len())
167 .map_err(|_| FrontendEncodeError::TooManyParams(param_types.len()))?;
168 content.extend_from_slice(¶m_count.to_be_bytes());
169 for oid in param_types {
170 content.extend_from_slice(&oid.to_be_bytes());
171 }
172
173 let len = Self::content_len_to_wire_len(content.len())?;
174 buf.extend_from_slice(&len.to_be_bytes());
175 buf.extend_from_slice(&content);
176 Ok(buf)
177 }
178 FrontendMessage::Bind {
179 portal,
180 statement,
181 params,
182 } => {
183 if Self::has_nul(portal) {
184 return Err(FrontendEncodeError::InteriorNul("portal"));
185 }
186 if Self::has_nul(statement) {
187 return Err(FrontendEncodeError::InteriorNul("statement"));
188 }
189 if params.len() > i16::MAX as usize {
190 return Err(FrontendEncodeError::TooManyParams(params.len()));
191 }
192 if let Some(too_large) = params
193 .iter()
194 .flatten()
195 .find(|p| p.len() > i32::MAX as usize)
196 {
197 return Err(FrontendEncodeError::MessageTooLarge(too_large.len()));
198 }
199
200 let mut buf = Vec::new();
201 buf.push(b'B');
202
203 let mut content = Vec::new();
204 content.extend_from_slice(portal.as_bytes());
205 content.push(0);
206 content.extend_from_slice(statement.as_bytes());
207 content.push(0);
208 content.extend_from_slice(&0i16.to_be_bytes());
209 let param_count = i16::try_from(params.len())
210 .map_err(|_| FrontendEncodeError::TooManyParams(params.len()))?;
211 content.extend_from_slice(¶m_count.to_be_bytes());
212 for param in params {
213 match param {
214 Some(data) => {
215 let data_len = i32::try_from(data.len())
216 .map_err(|_| FrontendEncodeError::MessageTooLarge(data.len()))?;
217 content.extend_from_slice(&data_len.to_be_bytes());
218 content.extend_from_slice(data);
219 }
220 None => content.extend_from_slice(&(-1i32).to_be_bytes()),
221 }
222 }
223 content.extend_from_slice(&0i16.to_be_bytes());
224
225 let len = Self::content_len_to_wire_len(content.len())?;
226 buf.extend_from_slice(&len.to_be_bytes());
227 buf.extend_from_slice(&content);
228 Ok(buf)
229 }
230 FrontendMessage::Execute { portal, max_rows } => {
231 if Self::has_nul(portal) {
232 return Err(FrontendEncodeError::InteriorNul("portal"));
233 }
234 if *max_rows < 0 {
235 return Err(FrontendEncodeError::InvalidMaxRows(*max_rows));
236 }
237 let mut buf = Vec::new();
238 buf.push(b'E');
239 let mut content = Vec::new();
240 content.extend_from_slice(portal.as_bytes());
241 content.push(0);
242 content.extend_from_slice(&max_rows.to_be_bytes());
243 let len = Self::content_len_to_wire_len(content.len())?;
244 buf.extend_from_slice(&len.to_be_bytes());
245 buf.extend_from_slice(&content);
246 Ok(buf)
247 }
248 FrontendMessage::Sync => Ok(vec![b'S', 0, 0, 0, 4]),
249 FrontendMessage::CopyFail(msg) => {
250 if Self::has_nul(msg) {
251 return Err(FrontendEncodeError::InteriorNul("copy_fail"));
252 }
253 let mut buf = Vec::new();
254 buf.push(b'f');
255 let mut content = Vec::with_capacity(msg.len() + 1);
256 content.extend_from_slice(msg.as_bytes());
257 content.push(0);
258 let len = Self::content_len_to_wire_len(content.len())?;
259 buf.extend_from_slice(&len.to_be_bytes());
260 buf.extend_from_slice(&content);
261 Ok(buf)
262 }
263 FrontendMessage::Close { is_portal, name } => {
264 if Self::has_nul(name) {
265 return Err(FrontendEncodeError::InteriorNul("name"));
266 }
267 let mut buf = Vec::new();
268 buf.push(b'C');
269 let type_byte = if *is_portal { b'P' } else { b'S' };
270 let mut content = vec![type_byte];
271 content.extend_from_slice(name.as_bytes());
272 content.push(0);
273 let len = Self::content_len_to_wire_len(content.len())?;
274 buf.extend_from_slice(&len.to_be_bytes());
275 buf.extend_from_slice(&content);
276 Ok(buf)
277 }
278 }
279 }
280}