Skip to main content

we_trust_postgres/
codec.rs

1use crate::Result;
2use crate::message::Message;
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5use yykv_types::DsError;
6
7#[derive(Debug, Clone)]
8pub enum BackendMessage {
9    AuthenticationOk,
10    AuthenticationCleartextPassword,
11    AuthenticationMD5Password { salt: [u8; 4] },
12    ParameterStatus { name: String, value: String },
13    BackendKeyData { process_id: i32, secret_key: i32 },
14    ReadyForQuery { status: u8 },
15    RowDescription { fields: Vec<FieldDescription> },
16    DataRow { values: Vec<Option<Bytes>> },
17    CommandComplete { tag: String },
18    ErrorResponse { fields: Vec<(u8, String)> },
19    NoticeResponse { fields: Vec<(u8, String)> },
20    ParseComplete,
21    BindComplete,
22    NoData,
23    ParameterDescription { ids: Vec<u32> },
24    CloseComplete,
25}
26
27#[derive(Debug, Clone)]
28pub struct FieldDescription {
29    pub name: String,
30    pub table_oid: i32,
31    pub column_id: i16,
32    pub type_oid: i32,
33    pub type_size: i16,
34    pub type_modifier: i32,
35    pub format_code: i16,
36}
37
38#[derive(Debug, Clone, Default)]
39pub struct PgCodec;
40
41impl PgCodec {
42    pub fn new() -> Self {
43        Self
44    }
45}
46
47pub struct PgServerCodec;
48
49impl Default for PgServerCodec {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl PgServerCodec {
56    pub fn new() -> Self {
57        Self
58    }
59}
60
61impl Decoder for PgServerCodec {
62    type Item = Message;
63    type Error = DsError;
64
65    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
66        if src.is_empty() {
67            return Ok(None);
68        }
69
70        // Handle Startup/SSLRequest (no tag byte)
71        if src.len() >= 4 {
72            let len = i32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
73            if len == 8 && src.len() >= 8 {
74                let code = i32::from_be_bytes([src[4], src[5], src[6], src[7]]);
75                if code == 80877103 {
76                    src.advance(8);
77                    return Ok(Some(Message::SslRequest));
78                }
79            }
80
81            // Startup Message: length (i32) + protocol (i32) + parameters (strings)
82            if len > 8 && src.len() >= len {
83                let protocol = i32::from_be_bytes([src[4], src[5], src[6], src[7]]);
84                if protocol == 196608 {
85                    // 3.0
86                    let mut body = src.split_to(len);
87                    body.advance(8);
88                    let mut params = Vec::new();
89                    while body[0] != 0 {
90                        let name = PgCodec::read_string(&mut body)?;
91                        let value = PgCodec::read_string(&mut body)?;
92                        params.push((name, value));
93                    }
94                    return Ok(Some(Message::Startup { params }));
95                }
96            }
97        }
98
99        if src.len() < 5 {
100            return Ok(None);
101        }
102
103        let tag = src[0];
104        let len = i32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
105
106        if src.len() < len + 1 {
107            return Ok(None);
108        }
109
110        let mut body = src.split_to(len + 1);
111        body.advance(5);
112
113        match tag {
114            b'p' => {
115                let pass = PgCodec::read_string(&mut body)?;
116                Ok(Some(Message::Password(pass)))
117            }
118            b'Q' => {
119                let query = PgCodec::read_string(&mut body)?;
120                Ok(Some(Message::Query(query)))
121            }
122            b'P' => {
123                let name = PgCodec::read_string(&mut body)?;
124                let query = PgCodec::read_string(&mut body)?;
125                let num_params = body.get_i16();
126                let mut param_types = Vec::with_capacity(num_params as usize);
127                for _ in 0..num_params {
128                    param_types.push(body.get_u32());
129                }
130                Ok(Some(Message::Parse {
131                    name,
132                    query,
133                    param_types,
134                }))
135            }
136            b'B' => {
137                let portal = PgCodec::read_string(&mut body)?;
138                let statement = PgCodec::read_string(&mut body)?;
139
140                let num_format_codes = body.get_i16();
141                let mut format_codes = Vec::with_capacity(num_format_codes as usize);
142                for _ in 0..num_format_codes {
143                    format_codes.push(body.get_i16());
144                }
145
146                let num_params = body.get_i16();
147                let mut params = Vec::with_capacity(num_params as usize);
148                for i in 0..num_params {
149                    let len = body.get_i32();
150                    if len == -1 {
151                        params.push(yykv_types::DsValue::Null);
152                    } else {
153                        let mut val_bytes = vec![0u8; len as usize];
154                        body.copy_to_slice(&mut val_bytes);
155
156                        let format = if format_codes.len() == 1 {
157                            format_codes[0]
158                        } else if format_codes.len() > i as usize {
159                            format_codes[i as usize]
160                        } else {
161                            0 // default to text
162                        };
163
164                        if format == 0 {
165                            // Text
166                            let s = String::from_utf8_lossy(&val_bytes).to_string();
167                            params.push(yykv_types::DsValue::Text(s));
168                        } else {
169                            // Binary
170                            params.push(yykv_types::DsValue::Bytes(val_bytes.into()));
171                        }
172                    }
173                }
174                Ok(Some(Message::Bind {
175                    portal,
176                    statement,
177                    params,
178                }))
179            }
180            b'D' => {
181                let target_type = body.get_u8();
182                let name = PgCodec::read_string(&mut body)?;
183                Ok(Some(Message::Describe { target_type, name }))
184            }
185            b'C' => {
186                let target_type = body.get_u8();
187                let name = PgCodec::read_string(&mut body)?;
188                Ok(Some(Message::Close { target_type, name }))
189            }
190            b'E' => {
191                let portal = PgCodec::read_string(&mut body)?;
192                let max_rows = body.get_i32();
193                Ok(Some(Message::Execute { portal, max_rows }))
194            }
195            b'S' => Ok(Some(Message::Sync)),
196            b'H' => Ok(Some(Message::Flush)),
197            b'X' => Ok(Some(Message::Terminate)),
198            _ => {
199                // Skip unknown messages for now to avoid hanging
200                Ok(None)
201            }
202        }
203    }
204}
205
206impl Encoder<BackendMessage> for PgServerCodec {
207    type Error = DsError;
208
209    fn encode(&mut self, item: BackendMessage, dst: &mut BytesMut) -> Result<()> {
210        item.encode(dst);
211        Ok(())
212    }
213}
214
215impl BackendMessage {
216    pub fn encode(&self, dst: &mut BytesMut) {
217        match self {
218            BackendMessage::AuthenticationOk => {
219                dst.put_u8(b'R');
220                dst.put_i32(8);
221                dst.put_i32(0);
222            }
223            BackendMessage::ReadyForQuery { status } => {
224                dst.put_u8(b'Z');
225                dst.put_i32(5);
226                dst.put_u8(*status);
227            }
228            BackendMessage::CommandComplete { tag } => {
229                dst.put_u8(b'C');
230                let len = tag.len() + 1 + 4;
231                dst.put_i32(len as i32);
232                dst.put_slice(tag.as_bytes());
233                dst.put_u8(0);
234            }
235            BackendMessage::ParameterStatus { name, value } => {
236                dst.put_u8(b'S');
237                let len = name.len() + value.len() + 2 + 4;
238                dst.put_i32(len as i32);
239                dst.put_slice(name.as_bytes());
240                dst.put_u8(0);
241                dst.put_slice(value.as_bytes());
242                dst.put_u8(0);
243            }
244            BackendMessage::RowDescription { fields } => {
245                dst.put_u8(b'T');
246                let mut payload = BytesMut::new();
247                payload.put_i16(fields.len() as i16);
248                for field in fields {
249                    payload.put_slice(field.name.as_bytes());
250                    payload.put_u8(0);
251                    payload.put_i32(field.table_oid);
252                    payload.put_i16(field.column_id);
253                    payload.put_i32(field.type_oid);
254                    payload.put_i16(field.type_size);
255                    payload.put_i32(field.type_modifier);
256                    payload.put_i16(field.format_code);
257                }
258                dst.put_i32(payload.len() as i32 + 4);
259                dst.extend_from_slice(&payload);
260            }
261            BackendMessage::DataRow { values } => {
262                dst.put_u8(b'D');
263                let mut payload = BytesMut::new();
264                payload.put_i16(values.len() as i16);
265                for val in values {
266                    match val {
267                        Some(v) => {
268                            payload.put_i32(v.len() as i32);
269                            payload.put_slice(v);
270                        }
271                        None => {
272                            payload.put_i32(-1);
273                        }
274                    }
275                }
276                dst.put_i32(payload.len() as i32 + 4);
277                dst.extend_from_slice(&payload);
278            }
279            BackendMessage::ErrorResponse { fields } => {
280                dst.put_u8(b'E');
281                let mut payload = BytesMut::new();
282                for (tag, msg) in fields {
283                    payload.put_u8(*tag);
284                    payload.put_slice(msg.as_bytes());
285                    payload.put_u8(0);
286                }
287                payload.put_u8(0);
288                dst.put_i32(payload.len() as i32 + 4);
289                dst.extend_from_slice(&payload);
290            }
291            BackendMessage::ParameterDescription { ids } => {
292                dst.put_u8(b't');
293                let mut payload = BytesMut::new();
294                payload.put_i16(ids.len() as i16);
295                for &id in ids {
296                    payload.put_u32(id);
297                }
298                dst.put_i32(payload.len() as i32 + 4);
299                dst.extend_from_slice(&payload);
300            }
301            BackendMessage::CloseComplete => {
302                dst.put_u8(b'3');
303                dst.put_i32(4);
304            }
305            BackendMessage::ParseComplete => {
306                dst.put_u8(b'1');
307                dst.put_i32(4);
308            }
309            BackendMessage::BindComplete => {
310                dst.put_u8(b'2');
311                dst.put_i32(4);
312            }
313            BackendMessage::NoData => {
314                dst.put_u8(b'n');
315                dst.put_i32(4);
316            }
317            _ => {
318                // TODO: Implement other backend messages
319            }
320        }
321    }
322}
323
324impl Decoder for PgCodec {
325    type Item = BackendMessage;
326    type Error = DsError;
327
328    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
329        if src.len() < 5 {
330            return Ok(None);
331        }
332
333        let tag = src[0];
334        let len = i32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
335
336        if src.len() < len + 1 {
337            return Ok(None);
338        }
339
340        let mut body = src.split_to(len + 1);
341        body.advance(5); // Skip tag and length
342
343        let msg = match tag {
344            b'R' => Self::decode_authentication(&mut body)?,
345            b'S' => {
346                let name = Self::read_string(&mut body)?;
347                let value = Self::read_string(&mut body)?;
348                BackendMessage::ParameterStatus { name, value }
349            }
350            b'K' => {
351                let process_id = body.get_i32();
352                let secret_key = body.get_i32();
353                BackendMessage::BackendKeyData {
354                    process_id,
355                    secret_key,
356                }
357            }
358            b'Z' => {
359                let status = body.get_u8();
360                BackendMessage::ReadyForQuery { status }
361            }
362            b'T' => Self::decode_row_description(&mut body)?,
363            b'D' => Self::decode_data_row(&mut body)?,
364            b'C' => {
365                let tag = Self::read_string(&mut body)?;
366                BackendMessage::CommandComplete { tag }
367            }
368            b'E' => BackendMessage::ErrorResponse {
369                fields: Self::decode_error_notice(&mut body)?,
370            },
371            b'N' => BackendMessage::NoticeResponse {
372                fields: Self::decode_error_notice(&mut body)?,
373            },
374            b'1' => BackendMessage::ParseComplete,
375            b'2' => BackendMessage::BindComplete,
376            b'n' => BackendMessage::NoData,
377            b't' => {
378                let count = body.get_i16();
379                let mut ids = Vec::with_capacity(count as usize);
380                for _ in 0..count {
381                    ids.push(body.get_u32());
382                }
383                BackendMessage::ParameterDescription { ids }
384            }
385            b'3' => BackendMessage::CloseComplete,
386            _ => {
387                return Err(DsError::protocol(format!(
388                    "Unknown backend tag: {}",
389                    tag as char
390                )));
391            }
392        };
393
394        Ok(Some(msg))
395    }
396}
397
398impl Encoder<Message> for PgCodec {
399    type Error = DsError;
400
401    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<()> {
402        item.encode(dst);
403        Ok(())
404    }
405}
406
407impl PgCodec {
408    fn decode_authentication(body: &mut BytesMut) -> Result<BackendMessage> {
409        let auth_type = body.get_i32();
410        match auth_type {
411            0 => Ok(BackendMessage::AuthenticationOk),
412            3 => Ok(BackendMessage::AuthenticationCleartextPassword),
413            5 => {
414                let mut salt = [0u8; 4];
415                body.copy_to_slice(&mut salt);
416                Ok(BackendMessage::AuthenticationMD5Password { salt })
417            }
418            _ => Err(DsError::protocol(format!(
419                "Unsupported authentication type: {}",
420                auth_type
421            ))),
422        }
423    }
424
425    fn decode_row_description(body: &mut BytesMut) -> Result<BackendMessage> {
426        let count = body.get_i16();
427        let mut fields = Vec::with_capacity(count as usize);
428        for _ in 0..count {
429            fields.push(FieldDescription {
430                name: Self::read_string(body)?,
431                table_oid: body.get_i32(),
432                column_id: body.get_i16(),
433                type_oid: body.get_i32(),
434                type_size: body.get_i16(),
435                type_modifier: body.get_i32(),
436                format_code: body.get_i16(),
437            });
438        }
439        Ok(BackendMessage::RowDescription { fields })
440    }
441
442    fn decode_data_row(body: &mut BytesMut) -> Result<BackendMessage> {
443        let count = body.get_i16();
444        let mut values = Vec::with_capacity(count as usize);
445        for _ in 0..count {
446            let len = body.get_i32();
447            if len == -1 {
448                values.push(None);
449            } else {
450                values.push(Some(body.split_to(len as usize).freeze()));
451            }
452        }
453        Ok(BackendMessage::DataRow { values })
454    }
455
456    fn decode_error_notice(body: &mut BytesMut) -> Result<Vec<(u8, String)>> {
457        let mut fields = Vec::new();
458        loop {
459            let tag = body.get_u8();
460            if tag == 0 {
461                break;
462            }
463            fields.push((tag, Self::read_string(body)?));
464        }
465        Ok(fields)
466    }
467
468    fn read_string(body: &mut BytesMut) -> Result<String> {
469        let pos = body
470            .iter()
471            .position(|&b| b == 0)
472            .ok_or_else(|| DsError::protocol("String not null-terminated"))?;
473        let s = String::from_utf8_lossy(&body[..pos]).into_owned();
474        body.advance(pos + 1);
475        Ok(s)
476    }
477}