qail_pg/protocol/wire/
frontend.rs1use super::types::*;
4
5impl FrontendMessage {
6 #[inline]
7 fn has_nul(s: &str) -> bool {
8 s.as_bytes().contains(&0)
9 }
10
11 #[inline]
12 fn content_len_to_wire_len(content_len: usize) -> Result<i32, FrontendEncodeError> {
13 let total = content_len
14 .checked_add(4)
15 .ok_or(FrontendEncodeError::MessageTooLarge(usize::MAX))?;
16 i32::try_from(total).map_err(|_| FrontendEncodeError::MessageTooLarge(total))
17 }
18
19 pub fn encode_checked(&self) -> Result<Vec<u8>, FrontendEncodeError> {
21 match self {
22 FrontendMessage::Startup {
23 user,
24 database,
25 startup_params,
26 } => {
27 if Self::has_nul(user) {
28 return Err(FrontendEncodeError::InteriorNul("user"));
29 }
30 if Self::has_nul(database) {
31 return Err(FrontendEncodeError::InteriorNul("database"));
32 }
33 let mut seen_startup_keys = std::collections::HashSet::new();
34 let mut buf = Vec::new();
35 buf.extend_from_slice(&196608i32.to_be_bytes());
36 buf.extend_from_slice(b"user\0");
37 buf.extend_from_slice(user.as_bytes());
38 buf.push(0);
39 buf.extend_from_slice(b"database\0");
40 buf.extend_from_slice(database.as_bytes());
41 buf.push(0);
42 for (key, value) in startup_params {
43 let key_trimmed = key.trim();
44 if key_trimmed.is_empty() {
45 return Err(FrontendEncodeError::InvalidStartupParam(
46 "key must not be empty".to_string(),
47 ));
48 }
49 let key_lc = key_trimmed.to_ascii_lowercase();
50 if key_lc == "user" || key_lc == "database" {
51 return Err(FrontendEncodeError::InvalidStartupParam(format!(
52 "reserved key '{}'",
53 key_trimmed
54 )));
55 }
56 if !seen_startup_keys.insert(key_lc) {
57 return Err(FrontendEncodeError::InvalidStartupParam(format!(
58 "duplicate key '{}'",
59 key_trimmed
60 )));
61 }
62 if Self::has_nul(key) {
63 return Err(FrontendEncodeError::InteriorNul("startup_param_key"));
64 }
65 if Self::has_nul(value) {
66 return Err(FrontendEncodeError::InteriorNul("startup_param_value"));
67 }
68 buf.extend_from_slice(key.as_bytes());
69 buf.push(0);
70 buf.extend_from_slice(value.as_bytes());
71 buf.push(0);
72 }
73 buf.push(0);
74
75 let len = Self::content_len_to_wire_len(buf.len())?;
76 let mut result = len.to_be_bytes().to_vec();
77 result.extend(buf);
78 Ok(result)
79 }
80 FrontendMessage::Query(sql) => {
81 if Self::has_nul(sql) {
82 return Err(FrontendEncodeError::InteriorNul("sql"));
83 }
84 let mut buf = Vec::new();
85 buf.push(b'Q');
86 let mut content = Vec::with_capacity(sql.len() + 1);
87 content.extend_from_slice(sql.as_bytes());
88 content.push(0);
89 let len = Self::content_len_to_wire_len(content.len())?;
90 buf.extend_from_slice(&len.to_be_bytes());
91 buf.extend_from_slice(&content);
92 Ok(buf)
93 }
94 FrontendMessage::Terminate => Ok(vec![b'X', 0, 0, 0, 4]),
95 FrontendMessage::SASLInitialResponse { mechanism, data } => {
96 if Self::has_nul(mechanism) {
97 return Err(FrontendEncodeError::InteriorNul("mechanism"));
98 }
99 if data.len() > i32::MAX as usize {
100 return Err(FrontendEncodeError::MessageTooLarge(data.len()));
101 }
102 let mut buf = Vec::new();
103 buf.push(b'p');
104
105 let mut content = Vec::new();
106 content.extend_from_slice(mechanism.as_bytes());
107 content.push(0);
108 let data_len = i32::try_from(data.len())
109 .map_err(|_| FrontendEncodeError::MessageTooLarge(data.len()))?;
110 content.extend_from_slice(&data_len.to_be_bytes());
111 content.extend_from_slice(data);
112
113 let len = Self::content_len_to_wire_len(content.len())?;
114 buf.extend_from_slice(&len.to_be_bytes());
115 buf.extend_from_slice(&content);
116 Ok(buf)
117 }
118 FrontendMessage::SASLResponse(data) | FrontendMessage::GSSResponse(data) => {
119 if data.len() > i32::MAX as usize {
120 return Err(FrontendEncodeError::MessageTooLarge(data.len()));
121 }
122 let mut buf = Vec::new();
123 buf.push(b'p');
124 let len = Self::content_len_to_wire_len(data.len())?;
125 buf.extend_from_slice(&len.to_be_bytes());
126 buf.extend_from_slice(data);
127 Ok(buf)
128 }
129 FrontendMessage::PasswordMessage(password) => {
130 if Self::has_nul(password) {
131 return Err(FrontendEncodeError::InteriorNul("password"));
132 }
133 let mut buf = Vec::new();
134 buf.push(b'p');
135 let mut content = Vec::with_capacity(password.len() + 1);
136 content.extend_from_slice(password.as_bytes());
137 content.push(0);
138 let len = Self::content_len_to_wire_len(content.len())?;
139 buf.extend_from_slice(&len.to_be_bytes());
140 buf.extend_from_slice(&content);
141 Ok(buf)
142 }
143 FrontendMessage::Parse {
144 name,
145 query,
146 param_types,
147 } => {
148 if Self::has_nul(name) {
149 return Err(FrontendEncodeError::InteriorNul("name"));
150 }
151 if Self::has_nul(query) {
152 return Err(FrontendEncodeError::InteriorNul("query"));
153 }
154 if param_types.len() > i16::MAX as usize {
155 return Err(FrontendEncodeError::TooManyParams(param_types.len()));
156 }
157 let mut buf = Vec::new();
158 buf.push(b'P');
159
160 let mut content = Vec::new();
161 content.extend_from_slice(name.as_bytes());
162 content.push(0);
163 content.extend_from_slice(query.as_bytes());
164 content.push(0);
165 let param_count = i16::try_from(param_types.len())
166 .map_err(|_| FrontendEncodeError::TooManyParams(param_types.len()))?;
167 content.extend_from_slice(¶m_count.to_be_bytes());
168 for oid in param_types {
169 content.extend_from_slice(&oid.to_be_bytes());
170 }
171
172 let len = Self::content_len_to_wire_len(content.len())?;
173 buf.extend_from_slice(&len.to_be_bytes());
174 buf.extend_from_slice(&content);
175 Ok(buf)
176 }
177 FrontendMessage::Bind {
178 portal,
179 statement,
180 params,
181 } => {
182 if Self::has_nul(portal) {
183 return Err(FrontendEncodeError::InteriorNul("portal"));
184 }
185 if Self::has_nul(statement) {
186 return Err(FrontendEncodeError::InteriorNul("statement"));
187 }
188 if params.len() > i16::MAX as usize {
189 return Err(FrontendEncodeError::TooManyParams(params.len()));
190 }
191 if let Some(too_large) = params
192 .iter()
193 .flatten()
194 .find(|p| p.len() > i32::MAX as usize)
195 {
196 return Err(FrontendEncodeError::MessageTooLarge(too_large.len()));
197 }
198
199 let mut buf = Vec::new();
200 buf.push(b'B');
201
202 let mut content = Vec::new();
203 content.extend_from_slice(portal.as_bytes());
204 content.push(0);
205 content.extend_from_slice(statement.as_bytes());
206 content.push(0);
207 content.extend_from_slice(&0i16.to_be_bytes());
208 let param_count = i16::try_from(params.len())
209 .map_err(|_| FrontendEncodeError::TooManyParams(params.len()))?;
210 content.extend_from_slice(¶m_count.to_be_bytes());
211 for param in params {
212 match param {
213 Some(data) => {
214 let data_len = i32::try_from(data.len())
215 .map_err(|_| FrontendEncodeError::MessageTooLarge(data.len()))?;
216 content.extend_from_slice(&data_len.to_be_bytes());
217 content.extend_from_slice(data);
218 }
219 None => content.extend_from_slice(&(-1i32).to_be_bytes()),
220 }
221 }
222 content.extend_from_slice(&0i16.to_be_bytes());
223
224 let len = Self::content_len_to_wire_len(content.len())?;
225 buf.extend_from_slice(&len.to_be_bytes());
226 buf.extend_from_slice(&content);
227 Ok(buf)
228 }
229 FrontendMessage::Execute { portal, max_rows } => {
230 if Self::has_nul(portal) {
231 return Err(FrontendEncodeError::InteriorNul("portal"));
232 }
233 if *max_rows < 0 {
234 return Err(FrontendEncodeError::InvalidMaxRows(*max_rows));
235 }
236 let mut buf = Vec::new();
237 buf.push(b'E');
238 let mut content = Vec::new();
239 content.extend_from_slice(portal.as_bytes());
240 content.push(0);
241 content.extend_from_slice(&max_rows.to_be_bytes());
242 let len = Self::content_len_to_wire_len(content.len())?;
243 buf.extend_from_slice(&len.to_be_bytes());
244 buf.extend_from_slice(&content);
245 Ok(buf)
246 }
247 FrontendMessage::Sync => Ok(vec![b'S', 0, 0, 0, 4]),
248 FrontendMessage::CopyFail(msg) => {
249 if Self::has_nul(msg) {
250 return Err(FrontendEncodeError::InteriorNul("copy_fail"));
251 }
252 let mut buf = Vec::new();
253 buf.push(b'f');
254 let mut content = Vec::with_capacity(msg.len() + 1);
255 content.extend_from_slice(msg.as_bytes());
256 content.push(0);
257 let len = Self::content_len_to_wire_len(content.len())?;
258 buf.extend_from_slice(&len.to_be_bytes());
259 buf.extend_from_slice(&content);
260 Ok(buf)
261 }
262 FrontendMessage::Close { is_portal, name } => {
263 if Self::has_nul(name) {
264 return Err(FrontendEncodeError::InteriorNul("name"));
265 }
266 let mut buf = Vec::new();
267 buf.push(b'C');
268 let type_byte = if *is_portal { b'P' } else { b'S' };
269 let mut content = vec![type_byte];
270 content.extend_from_slice(name.as_bytes());
271 content.push(0);
272 let len = Self::content_len_to_wire_len(content.len())?;
273 buf.extend_from_slice(&len.to_be_bytes());
274 buf.extend_from_slice(&content);
275 Ok(buf)
276 }
277 }
278 }
279}