fraiseql_wire/protocol/decode/
mod.rs1use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8struct Cursor<'a> {
14 data: &'a [u8],
15 offset: usize,
16}
17
18impl<'a> Cursor<'a> {
19 const fn new(data: &'a [u8]) -> Self {
20 Self { data, offset: 0 }
21 }
22
23 fn remaining(&self) -> &'a [u8] {
24 self.data.get(self.offset..).unwrap_or(&[])
27 }
28
29 const fn is_empty(&self) -> bool {
30 self.offset >= self.data.len()
31 }
32
33 fn read_u8(&mut self) -> io::Result<u8> {
34 let byte = *self
35 .data
36 .get(self.offset)
37 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "byte"))?;
38 self.offset += 1;
39 Ok(byte)
40 }
41
42 fn read_i16_be(&mut self) -> io::Result<i16> {
43 let bytes: [u8; 2] = self
44 .data
45 .get(self.offset..self.offset + 2)
46 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "i16"))?
47 .try_into()
48 .expect("slice of length 2 always converts to [u8; 2]");
52 self.offset += 2;
53 Ok(i16::from_be_bytes(bytes))
54 }
55
56 fn read_i32_be(&mut self) -> io::Result<i32> {
57 let bytes: [u8; 4] = self
58 .data
59 .get(self.offset..self.offset + 4)
60 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "i32"))?
61 .try_into()
62 .expect("slice of length 4 always converts to [u8; 4]");
64 self.offset += 4;
65 Ok(i32::from_be_bytes(bytes))
66 }
67
68 fn read_slice(&mut self, n: usize) -> io::Result<&'a [u8]> {
69 let slice = self
70 .data
71 .get(self.offset..self.offset + n)
72 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "slice"))?;
73 self.offset += n;
74 Ok(slice)
75 }
76
77 fn read_until_null(&mut self) -> io::Result<&'a [u8]> {
80 let tail = self.remaining();
81 let end = tail.iter().position(|&b| b == 0).ok_or_else(|| {
82 io::Error::new(
83 io::ErrorKind::InvalidData,
84 "missing null terminator in string",
85 )
86 })?;
87 let bytes = tail.get(..end).unwrap_or(&[]);
88 self.offset += end + 1;
90 Ok(bytes)
91 }
92
93 fn position_of_null(&self) -> Option<usize> {
95 self.remaining().iter().position(|&b| b == 0)
96 }
97}
98
99pub(crate) const MAX_FIELD_COUNT: usize = 2048;
105
106pub(crate) const MAX_ERROR_FIELD_BYTES: usize = 64 * 1024; pub(crate) const MAX_SASL_MECHANISMS: usize = 32;
118
119pub(crate) const MAX_PARAMETER_NAME_BYTES: usize = 256;
123
124pub(crate) const MAX_PARAMETER_VALUE_BYTES: usize = 64 * 1024; pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
148 if data.len() < 5 {
149 return Err(io::Error::new(
150 io::ErrorKind::UnexpectedEof,
151 "incomplete message header",
152 ));
153 }
154
155 let mut header = Cursor::new(data);
156 let tag = header.read_u8()?;
157 let len_i32 = header.read_i32_be()?;
158
159 if len_i32 < 4 {
162 return Err(io::Error::new(
163 io::ErrorKind::InvalidData,
164 "message length too small",
165 ));
166 }
167
168 let len = len_i32 as usize;
169
170 if data.len() < len + 1 {
171 return Err(io::Error::new(
172 io::ErrorKind::UnexpectedEof,
173 "incomplete message body",
174 ));
175 }
176
177 let msg_start = 5;
179 let msg_end = len + 1;
180 let msg_data = data
181 .get(msg_start..msg_end)
182 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "message body slice"))?;
183
184 let msg = match tag {
185 tags::AUTHENTICATION => decode_authentication(msg_data)?,
186 tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
187 tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
188 tags::DATA_ROW => decode_data_row(msg_data)?,
189 tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
190 tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
191 tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
192 tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
193 tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
194 _ => {
195 return Err(io::Error::new(
196 io::ErrorKind::InvalidData,
197 format!("unknown message tag: {}", tag),
198 ))
199 }
200 };
201
202 Ok((msg, len + 1))
203}
204
205fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
206 let mut cur = Cursor::new(data);
207 let auth_type = cur
208 .read_i32_be()
209 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"))?;
210
211 let auth_msg = match auth_type {
212 auth::OK => AuthenticationMessage::Ok,
213 auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
214 auth::MD5_PASSWORD => {
215 let salt_slice = cur
216 .read_slice(4)
217 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"))?;
218 let salt: [u8; 4] = salt_slice
219 .try_into()
220 .expect("slice of length 4 always converts to [u8; 4]");
222 AuthenticationMessage::Md5Password { salt }
223 }
224 auth::SASL => {
225 let mut mechanisms = Vec::new();
227 loop {
228 if cur.is_empty() {
229 break;
230 }
231 let Some(end) = cur.position_of_null() else {
232 break;
233 };
234 let mech_bytes = cur.read_slice(end).unwrap_or(&[]);
235 let mechanism = String::from_utf8_lossy(mech_bytes).to_string();
236 let _ = cur.read_u8();
238 if mechanism.is_empty() {
239 break;
240 }
241 if mechanisms.len() >= MAX_SASL_MECHANISMS {
242 break;
243 }
244 mechanisms.push(mechanism);
245 }
246 AuthenticationMessage::Sasl { mechanisms }
247 }
248 auth::SASL_CONTINUE => {
249 let data_vec = cur.remaining().to_vec();
251 AuthenticationMessage::SaslContinue { data: data_vec }
252 }
253 auth::SASL_FINAL => {
254 let data_vec = cur.remaining().to_vec();
256 AuthenticationMessage::SaslFinal { data: data_vec }
257 }
258 _ => {
259 return Err(io::Error::new(
260 io::ErrorKind::Unsupported,
261 format!("unsupported auth type: {}", auth_type),
262 ))
263 }
264 };
265
266 Ok(BackendMessage::Authentication(auth_msg))
267}
268
269fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
270 let mut cur = Cursor::new(data);
271 let process_id = cur
272 .read_i32_be()
273 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "backend key data"))?;
274 let secret_key = cur
275 .read_i32_be()
276 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "backend key data"))?;
277 Ok(BackendMessage::BackendKeyData {
278 process_id,
279 secret_key,
280 })
281}
282
283fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
284 let mut cur = Cursor::new(data);
285 let tag_bytes = cur.read_until_null()?;
286 let tag = String::from_utf8_lossy(tag_bytes).to_string();
287 Ok(BackendMessage::CommandComplete(tag))
288}
289
290fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
291 let mut cur = Cursor::new(data);
292 let field_count_i16 = cur
293 .read_i16_be()
294 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field count"))?;
295 if field_count_i16 < 0 {
296 return Err(io::Error::new(
297 io::ErrorKind::InvalidData,
298 "negative field count",
299 ));
300 }
301 let field_count = field_count_i16 as usize;
302 if field_count > MAX_FIELD_COUNT {
303 return Err(io::Error::new(
304 io::ErrorKind::InvalidData,
305 format!("DataRow field count {field_count} exceeds maximum {MAX_FIELD_COUNT}"),
306 ));
307 }
308 let mut fields = Vec::with_capacity(field_count);
309
310 for _ in 0..field_count {
311 let field_len = cur
312 .read_i32_be()
313 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field length"))?;
314
315 let field = if field_len == -1 {
316 None
317 } else if field_len < 0 {
318 return Err(io::Error::new(
319 io::ErrorKind::InvalidData,
320 "negative field length",
321 ));
322 } else {
323 let len = field_len as usize;
324 let field_slice = cur
325 .read_slice(len)
326 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field data"))?;
327 Some(Bytes::copy_from_slice(field_slice))
328 };
329 fields.push(field);
330 }
331
332 Ok(BackendMessage::DataRow(fields))
333}
334
335fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
336 let fields = decode_error_fields(data)?;
337 Ok(BackendMessage::ErrorResponse(fields))
338}
339
340fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
341 let fields = decode_error_fields(data)?;
342 Ok(BackendMessage::NoticeResponse(fields))
343}
344
345fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
346 let mut fields = ErrorFields::default();
347 let mut cur = Cursor::new(data);
348
349 loop {
350 if cur.is_empty() {
351 break;
352 }
353 let field_type = cur.read_u8()?;
354 if field_type == 0 {
355 break;
356 }
357
358 let end = cur.position_of_null().ok_or_else(|| {
359 io::Error::new(
360 io::ErrorKind::InvalidData,
361 "missing null terminator in error field",
362 )
363 })?;
364 if end > MAX_ERROR_FIELD_BYTES {
365 return Err(io::Error::new(
366 io::ErrorKind::InvalidData,
367 format!("Error field too large ({end} bytes, max {MAX_ERROR_FIELD_BYTES})"),
368 ));
369 }
370 let value_bytes = cur.read_slice(end).unwrap_or(&[]);
371 let value = String::from_utf8_lossy(value_bytes).to_string();
372 let _ = cur.read_u8();
374
375 match field_type {
376 b'S' => fields.severity = Some(value),
377 b'C' => fields.code = Some(value),
378 b'M' => fields.message = Some(value),
379 b'D' => fields.detail = Some(value),
380 b'H' => fields.hint = Some(value),
381 b'P' => fields.position = Some(value),
382 _ => {} }
384 }
385
386 Ok(fields)
387}
388
389fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
390 let mut cur = Cursor::new(data);
391
392 let name_end = cur.position_of_null().ok_or_else(|| {
393 io::Error::new(
394 io::ErrorKind::InvalidData,
395 "missing null terminator in parameter name",
396 )
397 })?;
398 if name_end > MAX_PARAMETER_NAME_BYTES {
399 return Err(io::Error::new(
400 io::ErrorKind::InvalidData,
401 format!("Parameter name too long ({name_end} bytes, max {MAX_PARAMETER_NAME_BYTES})"),
402 ));
403 }
404 let name_bytes = cur.read_slice(name_end).unwrap_or(&[]);
405 let name = String::from_utf8_lossy(name_bytes).to_string();
406 let _ = cur.read_u8();
408
409 if cur.is_empty() {
410 return Err(io::Error::new(
411 io::ErrorKind::UnexpectedEof,
412 "parameter value",
413 ));
414 }
415 let value_end = cur.position_of_null().ok_or_else(|| {
416 io::Error::new(
417 io::ErrorKind::InvalidData,
418 "missing null terminator in parameter value",
419 )
420 })?;
421 if value_end > MAX_PARAMETER_VALUE_BYTES {
422 return Err(io::Error::new(
423 io::ErrorKind::InvalidData,
424 format!(
425 "Parameter value too long ({value_end} bytes, max {MAX_PARAMETER_VALUE_BYTES})"
426 ),
427 ));
428 }
429 let value_bytes = cur.read_slice(value_end).unwrap_or(&[]);
430 let value = String::from_utf8_lossy(value_bytes).to_string();
431
432 Ok(BackendMessage::ParameterStatus { name, value })
433}
434
435fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
436 let status = *data
437 .first()
438 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"))?;
439 Ok(BackendMessage::ReadyForQuery { status })
440}
441
442fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
443 let mut cur = Cursor::new(data);
444 let field_count_i16 = cur
445 .read_i16_be()
446 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field count"))?;
447 if field_count_i16 < 0 {
448 return Err(io::Error::new(
449 io::ErrorKind::InvalidData,
450 "negative field count",
451 ));
452 }
453 let field_count = field_count_i16 as usize;
454 if field_count > MAX_FIELD_COUNT {
455 return Err(io::Error::new(
456 io::ErrorKind::InvalidData,
457 format!("RowDescription field count {field_count} exceeds maximum {MAX_FIELD_COUNT}"),
458 ));
459 }
460 let mut fields = Vec::with_capacity(field_count);
461
462 for _ in 0..field_count {
463 let name_end = cur.position_of_null().ok_or_else(|| {
465 io::Error::new(
466 io::ErrorKind::InvalidData,
467 "missing null terminator in field name",
468 )
469 })?;
470 let name_bytes = cur.read_slice(name_end).unwrap_or(&[]);
471 let name = String::from_utf8_lossy(name_bytes).to_string();
472 let _ = cur.read_u8();
474
475 let table_oid = cur
477 .read_i32_be()
478 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
479 let column_attr = cur
480 .read_i16_be()
481 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
482 let type_oid = cur
483 .read_i32_be()
484 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?
485 as u32;
486 let type_size = cur
487 .read_i16_be()
488 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
489 let type_modifier = cur
490 .read_i32_be()
491 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
492 let format_code = cur
493 .read_i16_be()
494 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
495
496 fields.push(FieldDescription {
497 name,
498 table_oid,
499 column_attr,
500 type_oid,
501 type_size,
502 type_modifier,
503 format_code,
504 });
505 }
506
507 Ok(BackendMessage::RowDescription(fields))
508}
509
510#[cfg(test)]
511mod tests;