1use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
21 if data.len() < 5 {
22 return Err(io::Error::new(
23 io::ErrorKind::UnexpectedEof,
24 "incomplete message header",
25 ));
26 }
27
28 let tag = data[0];
29 let len_i32 = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
30
31 if len_i32 < 4 {
34 return Err(io::Error::new(
35 io::ErrorKind::InvalidData,
36 "message length too small",
37 ));
38 }
39
40 let len = len_i32 as usize;
41
42 if data.len() < len + 1 {
43 return Err(io::Error::new(
44 io::ErrorKind::UnexpectedEof,
45 "incomplete message body",
46 ));
47 }
48
49 let msg_start = 5;
51 let msg_end = len + 1;
52 let msg_data = &data[msg_start..msg_end];
53
54 let msg = match tag {
55 tags::AUTHENTICATION => decode_authentication(msg_data)?,
56 tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
57 tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
58 tags::DATA_ROW => decode_data_row(msg_data)?,
59 tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
60 tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
61 tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
62 tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
63 tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
64 _ => {
65 return Err(io::Error::new(
66 io::ErrorKind::InvalidData,
67 format!("unknown message tag: {}", tag),
68 ))
69 }
70 };
71
72 Ok((msg, len + 1))
73}
74
75fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
76 if data.len() < 4 {
77 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"));
78 }
79 let auth_type = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
80
81 let auth_msg = match auth_type {
82 auth::OK => AuthenticationMessage::Ok,
83 auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
84 auth::MD5_PASSWORD => {
85 if data.len() < 8 {
86 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"));
87 }
88 let mut salt = [0u8; 4];
89 salt.copy_from_slice(&data[4..8]);
90 AuthenticationMessage::Md5Password { salt }
91 }
92 auth::SASL => {
93 let mut mechanisms = Vec::new();
95 let remaining = &data[4..];
96 let mut offset = 0;
97 loop {
98 if offset >= remaining.len() {
99 break;
100 }
101 match remaining[offset..].iter().position(|&b| b == 0) {
102 Some(end) => {
103 let mechanism =
104 String::from_utf8_lossy(&remaining[offset..offset + end]).to_string();
105 if mechanism.is_empty() {
106 break;
107 }
108 mechanisms.push(mechanism);
109 offset += end + 1;
110 }
111 None => break,
112 }
113 }
114 AuthenticationMessage::Sasl { mechanisms }
115 }
116 auth::SASL_CONTINUE => {
117 let data_vec = data[4..].to_vec();
119 AuthenticationMessage::SaslContinue { data: data_vec }
120 }
121 auth::SASL_FINAL => {
122 let data_vec = data[4..].to_vec();
124 AuthenticationMessage::SaslFinal { data: data_vec }
125 }
126 _ => {
127 return Err(io::Error::new(
128 io::ErrorKind::Unsupported,
129 format!("unsupported auth type: {}", auth_type),
130 ))
131 }
132 };
133
134 Ok(BackendMessage::Authentication(auth_msg))
135}
136
137fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
138 if data.len() < 8 {
139 return Err(io::Error::new(
140 io::ErrorKind::UnexpectedEof,
141 "backend key data",
142 ));
143 }
144 let process_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
145 let secret_key = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
146 Ok(BackendMessage::BackendKeyData {
147 process_id,
148 secret_key,
149 })
150}
151
152fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
153 let end = data.iter().position(|&b| b == 0).ok_or_else(|| {
154 io::Error::new(
155 io::ErrorKind::InvalidData,
156 "missing null terminator in string",
157 )
158 })?;
159 let tag = String::from_utf8_lossy(&data[..end]).to_string();
160 Ok(BackendMessage::CommandComplete(tag))
161}
162
163fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
164 if data.len() < 2 {
165 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
166 }
167 let field_count_i16 = i16::from_be_bytes([data[0], data[1]]);
168 if field_count_i16 < 0 {
169 return Err(io::Error::new(
170 io::ErrorKind::InvalidData,
171 "negative field count",
172 ));
173 }
174 let field_count = field_count_i16 as usize;
175 let mut fields = Vec::with_capacity(field_count);
176 let mut offset = 2;
177
178 for _ in 0..field_count {
179 if offset + 4 > data.len() {
180 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field length"));
181 }
182 let field_len = i32::from_be_bytes([
183 data[offset],
184 data[offset + 1],
185 data[offset + 2],
186 data[offset + 3],
187 ]);
188 offset += 4;
189
190 let field = if field_len == -1 {
191 None
192 } else if field_len < 0 {
193 return Err(io::Error::new(
194 io::ErrorKind::InvalidData,
195 "negative field length",
196 ));
197 } else {
198 let len = field_len as usize;
199 if offset + len > data.len() {
200 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field data"));
201 }
202 let field_bytes = Bytes::copy_from_slice(&data[offset..offset + len]);
203 offset += len;
204 Some(field_bytes)
205 };
206 fields.push(field);
207 }
208
209 Ok(BackendMessage::DataRow(fields))
210}
211
212fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
213 let fields = decode_error_fields(data)?;
214 Ok(BackendMessage::ErrorResponse(fields))
215}
216
217fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
218 let fields = decode_error_fields(data)?;
219 Ok(BackendMessage::NoticeResponse(fields))
220}
221
222fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
223 let mut fields = ErrorFields::default();
224 let mut offset = 0;
225
226 loop {
227 if offset >= data.len() {
228 break;
229 }
230 let field_type = data[offset];
231 offset += 1;
232 if field_type == 0 {
233 break;
234 }
235
236 let end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
237 io::Error::new(
238 io::ErrorKind::InvalidData,
239 "missing null terminator in error field",
240 )
241 })?;
242 let value = String::from_utf8_lossy(&data[offset..offset + end]).to_string();
243 offset += end + 1;
244
245 match field_type {
246 b'S' => fields.severity = Some(value),
247 b'C' => fields.code = Some(value),
248 b'M' => fields.message = Some(value),
249 b'D' => fields.detail = Some(value),
250 b'H' => fields.hint = Some(value),
251 b'P' => fields.position = Some(value),
252 _ => {} }
254 }
255
256 Ok(fields)
257}
258
259fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
260 let mut offset = 0;
261
262 let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
263 io::Error::new(
264 io::ErrorKind::InvalidData,
265 "missing null terminator in parameter name",
266 )
267 })?;
268 let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
269 offset += name_end + 1;
270
271 if offset >= data.len() {
272 return Err(io::Error::new(
273 io::ErrorKind::UnexpectedEof,
274 "parameter value",
275 ));
276 }
277 let value_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
278 io::Error::new(
279 io::ErrorKind::InvalidData,
280 "missing null terminator in parameter value",
281 )
282 })?;
283 let value = String::from_utf8_lossy(&data[offset..offset + value_end]).to_string();
284
285 Ok(BackendMessage::ParameterStatus { name, value })
286}
287
288fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
289 if data.is_empty() {
290 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"));
291 }
292 let status = data[0];
293 Ok(BackendMessage::ReadyForQuery { status })
294}
295
296fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
297 if data.len() < 2 {
298 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
299 }
300 let field_count_i16 = i16::from_be_bytes([data[0], data[1]]);
301 if field_count_i16 < 0 {
302 return Err(io::Error::new(
303 io::ErrorKind::InvalidData,
304 "negative field count",
305 ));
306 }
307 let field_count = field_count_i16 as usize;
308 let mut fields = Vec::with_capacity(field_count);
309 let mut offset = 2;
310
311 for _ in 0..field_count {
312 let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
314 io::Error::new(
315 io::ErrorKind::InvalidData,
316 "missing null terminator in field name",
317 )
318 })?;
319 let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
320 offset += name_end + 1;
321
322 if offset + 18 > data.len() {
324 return Err(io::Error::new(
325 io::ErrorKind::UnexpectedEof,
326 "field descriptor",
327 ));
328 }
329 let table_oid = i32::from_be_bytes([
330 data[offset],
331 data[offset + 1],
332 data[offset + 2],
333 data[offset + 3],
334 ]);
335 offset += 4;
336 let column_attr = i16::from_be_bytes([data[offset], data[offset + 1]]);
337 offset += 2;
338 let type_oid = i32::from_be_bytes([
339 data[offset],
340 data[offset + 1],
341 data[offset + 2],
342 data[offset + 3],
343 ]) as u32;
344 offset += 4;
345 let type_size = i16::from_be_bytes([data[offset], data[offset + 1]]);
346 offset += 2;
347 let type_modifier = i32::from_be_bytes([
348 data[offset],
349 data[offset + 1],
350 data[offset + 2],
351 data[offset + 3],
352 ]);
353 offset += 4;
354 let format_code = i16::from_be_bytes([data[offset], data[offset + 1]]);
355 offset += 2;
356
357 fields.push(FieldDescription {
358 name,
359 table_oid,
360 column_attr,
361 type_oid,
362 type_size,
363 type_modifier,
364 format_code,
365 });
366 }
367
368 Ok(BackendMessage::RowDescription(fields))
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_decode_authentication_ok() {
377 let mut data = BytesMut::from(
378 &[
379 b'R', 0, 0, 0, 8, 0, 0, 0, 0, ][..],
383 );
384
385 let (msg, consumed) = decode_message(&mut data).unwrap();
386 match msg {
387 BackendMessage::Authentication(AuthenticationMessage::Ok) => {}
388 _ => panic!("expected Authentication::Ok"),
389 }
390 assert_eq!(consumed, 9); }
392
393 #[test]
394 fn test_decode_ready_for_query() {
395 let mut data = BytesMut::from(
396 &[
397 b'Z', 0, 0, 0, 5, b'I', ][..],
401 );
402
403 let (msg, consumed) = decode_message(&mut data).unwrap();
404 match msg {
405 BackendMessage::ReadyForQuery { status } => assert_eq!(status, b'I'),
406 _ => panic!("expected ReadyForQuery"),
407 }
408 assert_eq!(consumed, 6); }
410}