1use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
21 if data.len() < 5 {
22 return Err(io::Error::new(
23 io::ErrorKind::UnexpectedEof,
24 "incomplete message header",
25 ));
26 }
27
28 let tag = data[0];
29 let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
30
31 if data.len() < len + 1 {
32 return Err(io::Error::new(
33 io::ErrorKind::UnexpectedEof,
34 "incomplete message body",
35 ));
36 }
37
38 let msg_start = 5;
40 let msg_end = len + 1;
41 let msg_data = &data[msg_start..msg_end];
42
43 let msg = match tag {
44 tags::AUTHENTICATION => decode_authentication(msg_data)?,
45 tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
46 tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
47 tags::DATA_ROW => decode_data_row(msg_data)?,
48 tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
49 tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
50 tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
51 tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
52 tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
53 _ => {
54 return Err(io::Error::new(
55 io::ErrorKind::InvalidData,
56 format!("unknown message tag: {}", tag),
57 ))
58 }
59 };
60
61 Ok((msg, len + 1))
62}
63
64fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
65 if data.len() < 4 {
66 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"));
67 }
68 let auth_type = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
69
70 let auth_msg = match auth_type {
71 auth::OK => AuthenticationMessage::Ok,
72 auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
73 auth::MD5_PASSWORD => {
74 if data.len() < 8 {
75 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"));
76 }
77 let mut salt = [0u8; 4];
78 salt.copy_from_slice(&data[4..8]);
79 AuthenticationMessage::Md5Password { salt }
80 }
81 auth::SASL => {
82 let mut mechanisms = Vec::new();
84 let remaining = &data[4..];
85 let mut offset = 0;
86 loop {
87 if offset >= remaining.len() {
88 break;
89 }
90 match remaining[offset..].iter().position(|&b| b == 0) {
91 Some(end) => {
92 let mechanism =
93 String::from_utf8_lossy(&remaining[offset..offset + end]).to_string();
94 if mechanism.is_empty() {
95 break;
96 }
97 mechanisms.push(mechanism);
98 offset += end + 1;
99 }
100 None => break,
101 }
102 }
103 AuthenticationMessage::Sasl { mechanisms }
104 }
105 auth::SASL_CONTINUE => {
106 let data_vec = data[4..].to_vec();
108 AuthenticationMessage::SaslContinue { data: data_vec }
109 }
110 auth::SASL_FINAL => {
111 let data_vec = data[4..].to_vec();
113 AuthenticationMessage::SaslFinal { data: data_vec }
114 }
115 _ => {
116 return Err(io::Error::new(
117 io::ErrorKind::Unsupported,
118 format!("unsupported auth type: {}", auth_type),
119 ))
120 }
121 };
122
123 Ok(BackendMessage::Authentication(auth_msg))
124}
125
126fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
127 if data.len() < 8 {
128 return Err(io::Error::new(
129 io::ErrorKind::UnexpectedEof,
130 "backend key data",
131 ));
132 }
133 let process_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
134 let secret_key = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
135 Ok(BackendMessage::BackendKeyData {
136 process_id,
137 secret_key,
138 })
139}
140
141fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
142 let end = data.iter().position(|&b| b == 0).ok_or_else(|| {
143 io::Error::new(
144 io::ErrorKind::InvalidData,
145 "missing null terminator in string",
146 )
147 })?;
148 let tag = String::from_utf8_lossy(&data[..end]).to_string();
149 Ok(BackendMessage::CommandComplete(tag))
150}
151
152fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
153 if data.len() < 2 {
154 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
155 }
156 let field_count = i16::from_be_bytes([data[0], data[1]]) as usize;
157 let mut fields = Vec::with_capacity(field_count);
158 let mut offset = 2;
159
160 for _ in 0..field_count {
161 if offset + 4 > data.len() {
162 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field length"));
163 }
164 let field_len = i32::from_be_bytes([
165 data[offset],
166 data[offset + 1],
167 data[offset + 2],
168 data[offset + 3],
169 ]);
170 offset += 4;
171
172 let field = if field_len == -1 {
173 None
174 } else {
175 let len = field_len as usize;
176 if offset + len > data.len() {
177 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field data"));
178 }
179 let field_bytes = Bytes::copy_from_slice(&data[offset..offset + len]);
180 offset += len;
181 Some(field_bytes)
182 };
183 fields.push(field);
184 }
185
186 Ok(BackendMessage::DataRow(fields))
187}
188
189fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
190 let fields = decode_error_fields(data)?;
191 Ok(BackendMessage::ErrorResponse(fields))
192}
193
194fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
195 let fields = decode_error_fields(data)?;
196 Ok(BackendMessage::NoticeResponse(fields))
197}
198
199fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
200 let mut fields = ErrorFields::default();
201 let mut offset = 0;
202
203 loop {
204 if offset >= data.len() {
205 break;
206 }
207 let field_type = data[offset];
208 offset += 1;
209 if field_type == 0 {
210 break;
211 }
212
213 let end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
214 io::Error::new(
215 io::ErrorKind::InvalidData,
216 "missing null terminator in error field",
217 )
218 })?;
219 let value = String::from_utf8_lossy(&data[offset..offset + end]).to_string();
220 offset += end + 1;
221
222 match field_type {
223 b'S' => fields.severity = Some(value),
224 b'C' => fields.code = Some(value),
225 b'M' => fields.message = Some(value),
226 b'D' => fields.detail = Some(value),
227 b'H' => fields.hint = Some(value),
228 b'P' => fields.position = Some(value),
229 _ => {} }
231 }
232
233 Ok(fields)
234}
235
236fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
237 let mut offset = 0;
238
239 let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
240 io::Error::new(
241 io::ErrorKind::InvalidData,
242 "missing null terminator in parameter name",
243 )
244 })?;
245 let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
246 offset += name_end + 1;
247
248 if offset >= data.len() {
249 return Err(io::Error::new(
250 io::ErrorKind::UnexpectedEof,
251 "parameter value",
252 ));
253 }
254 let value_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
255 io::Error::new(
256 io::ErrorKind::InvalidData,
257 "missing null terminator in parameter value",
258 )
259 })?;
260 let value = String::from_utf8_lossy(&data[offset..offset + value_end]).to_string();
261
262 Ok(BackendMessage::ParameterStatus { name, value })
263}
264
265fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
266 if data.is_empty() {
267 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"));
268 }
269 let status = data[0];
270 Ok(BackendMessage::ReadyForQuery { status })
271}
272
273fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
274 if data.len() < 2 {
275 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
276 }
277 let field_count = i16::from_be_bytes([data[0], data[1]]) as usize;
278 let mut fields = Vec::with_capacity(field_count);
279 let mut offset = 2;
280
281 for _ in 0..field_count {
282 let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
284 io::Error::new(
285 io::ErrorKind::InvalidData,
286 "missing null terminator in field name",
287 )
288 })?;
289 let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
290 offset += name_end + 1;
291
292 if offset + 18 > data.len() {
294 return Err(io::Error::new(
295 io::ErrorKind::UnexpectedEof,
296 "field descriptor",
297 ));
298 }
299 let table_oid = i32::from_be_bytes([
300 data[offset],
301 data[offset + 1],
302 data[offset + 2],
303 data[offset + 3],
304 ]);
305 offset += 4;
306 let column_attr = i16::from_be_bytes([data[offset], data[offset + 1]]);
307 offset += 2;
308 let type_oid = i32::from_be_bytes([
309 data[offset],
310 data[offset + 1],
311 data[offset + 2],
312 data[offset + 3],
313 ]) as u32;
314 offset += 4;
315 let type_size = i16::from_be_bytes([data[offset], data[offset + 1]]);
316 offset += 2;
317 let type_modifier = i32::from_be_bytes([
318 data[offset],
319 data[offset + 1],
320 data[offset + 2],
321 data[offset + 3],
322 ]);
323 offset += 4;
324 let format_code = i16::from_be_bytes([data[offset], data[offset + 1]]);
325 offset += 2;
326
327 fields.push(FieldDescription {
328 name,
329 table_oid,
330 column_attr,
331 type_oid,
332 type_size,
333 type_modifier,
334 format_code,
335 });
336 }
337
338 Ok(BackendMessage::RowDescription(fields))
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_decode_authentication_ok() {
347 let mut data = BytesMut::from(
348 &[
349 b'R', 0, 0, 0, 8, 0, 0, 0, 0, ][..],
353 );
354
355 let (msg, consumed) = decode_message(&mut data).unwrap();
356 match msg {
357 BackendMessage::Authentication(AuthenticationMessage::Ok) => {}
358 _ => panic!("expected Authentication::Ok"),
359 }
360 assert_eq!(consumed, 9); }
362
363 #[test]
364 fn test_decode_ready_for_query() {
365 let mut data = BytesMut::from(
366 &[
367 b'Z', 0, 0, 0, 5, b'I', ][..],
371 );
372
373 let (msg, consumed) = decode_message(&mut data).unwrap();
374 match msg {
375 BackendMessage::ReadyForQuery { status } => assert_eq!(status, b'I'),
376 _ => panic!("expected ReadyForQuery"),
377 }
378 assert_eq!(consumed, 6); }
380}