Skip to main content

pg_srv/
decoding.rs

1//! Decoding values from the Protocol representation
2
3use crate::{
4    protocol::{ErrorCode, ErrorResponse, Format},
5    ProtocolError,
6};
7use byteorder::{BigEndian, ByteOrder};
8use std::backtrace::Backtrace;
9
10/// This trait explains how to decode values from the protocol
11/// It's used in the Bind message
12pub trait FromProtocolValue {
13    // Converts native type to raw value in specific format
14    fn from_protocol(raw: &[u8], format: Format) -> Result<Self, ProtocolError>
15    where
16        Self: Sized,
17    {
18        match format {
19            Format::Text => Self::from_text(raw),
20            Format::Binary => Self::from_binary(raw),
21        }
22    }
23
24    /// Decodes raw value to native type in text format
25    fn from_text(raw: &[u8]) -> Result<Self, ProtocolError>
26    where
27        Self: Sized;
28
29    /// Decodes raw value to native type in binary format
30    fn from_binary(raw: &[u8]) -> Result<Self, ProtocolError>
31    where
32        Self: Sized;
33}
34
35impl FromProtocolValue for String {
36    fn from_text(raw: &[u8]) -> Result<Self, ProtocolError> {
37        std::str::from_utf8(raw)
38            .map(|s| s.to_string())
39            .map_err(|err| ProtocolError::ErrorResponse {
40                source: ErrorResponse::error(ErrorCode::ProtocolViolation, err.to_string()),
41                backtrace: Backtrace::capture(),
42            })
43    }
44
45    fn from_binary(raw: &[u8]) -> Result<Self, ProtocolError> {
46        std::str::from_utf8(raw)
47            .map(|s| s.to_string())
48            .map_err(|err| ProtocolError::ErrorResponse {
49                source: ErrorResponse::error(ErrorCode::ProtocolViolation, err.to_string()),
50                backtrace: Backtrace::capture(),
51            })
52    }
53}
54
55impl FromProtocolValue for i64 {
56    fn from_text(raw: &[u8]) -> Result<Self, ProtocolError> {
57        let as_str = std::str::from_utf8(raw).map_err(|err| ProtocolError::ErrorResponse {
58            source: ErrorResponse::error(ErrorCode::ProtocolViolation, err.to_string()),
59            backtrace: Backtrace::capture(),
60        })?;
61
62        as_str
63            .parse::<i64>()
64            .map_err(|err| ProtocolError::ErrorResponse {
65                source: ErrorResponse::error(ErrorCode::ProtocolViolation, err.to_string()),
66                backtrace: Backtrace::capture(),
67            })
68    }
69
70    fn from_binary(raw: &[u8]) -> Result<Self, ProtocolError> {
71        Ok(BigEndian::read_i64(raw))
72    }
73}
74
75impl FromProtocolValue for bool {
76    fn from_text(raw: &[u8]) -> Result<Self, ProtocolError> {
77        match raw[0] {
78            b't' => Ok(true),
79            b'f' => Ok(false),
80            other => Err(ProtocolError::ErrorResponse {
81                source: ErrorResponse::error(
82                    ErrorCode::ProtocolViolation,
83                    format!("Unable to decode bool from text, actual: {}", other),
84                ),
85                backtrace: Backtrace::capture(),
86            }),
87        }
88    }
89
90    fn from_binary(raw: &[u8]) -> Result<Self, ProtocolError> {
91        match raw[0] {
92            1 => Ok(true),
93            0 => Ok(false),
94            other => Err(ProtocolError::ErrorResponse {
95                source: ErrorResponse::error(
96                    ErrorCode::ProtocolViolation,
97                    format!("Unable to decode bool from binary, actual: {}", other),
98                ),
99                backtrace: Backtrace::capture(),
100            }),
101        }
102    }
103}
104
105impl FromProtocolValue for f64 {
106    fn from_text(raw: &[u8]) -> Result<Self, ProtocolError> {
107        let as_str = std::str::from_utf8(raw).map_err(|err| ProtocolError::ErrorResponse {
108            source: ErrorResponse::error(ErrorCode::ProtocolViolation, err.to_string()),
109            backtrace: Backtrace::capture(),
110        })?;
111
112        as_str
113            .parse::<f64>()
114            .map_err(|err| ProtocolError::ErrorResponse {
115                source: ErrorResponse::error(ErrorCode::ProtocolViolation, err.to_string()),
116                backtrace: Backtrace::capture(),
117            })
118    }
119
120    fn from_binary(raw: &[u8]) -> Result<Self, ProtocolError> {
121        Ok(BigEndian::read_f64(raw))
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use crate::*;
128
129    use crate::protocol::Format;
130    use crate::values::timestamp::TimestampValue;
131    use bytes::BytesMut;
132    #[cfg(feature = "with-chrono")]
133    use chrono::NaiveDate;
134
135    fn assert_test_decode<T: ToProtocolValue + FromProtocolValue + std::cmp::PartialEq>(
136        value: T,
137        format: Format,
138    ) -> Result<(), ProtocolError> {
139        let mut buf = BytesMut::new();
140        value.to_protocol(&mut buf, format)?;
141
142        // skip length
143        let mut encoded = Vec::with_capacity(buf.len() - 4);
144        encoded.extend_from_slice(&buf.as_ref()[4..]);
145
146        assert_eq!(value, T::from_protocol(&encoded, format)?);
147
148        Ok(())
149    }
150
151    #[test]
152    fn test_text_decoders() -> Result<(), ProtocolError> {
153        assert_test_decode("test".to_string(), Format::Text)?;
154        assert_test_decode(true, Format::Text)?;
155        assert_test_decode(false, Format::Text)?;
156        assert_test_decode(1_i64, Format::Text)?;
157        assert_test_decode(100_i64, Format::Text)?;
158        assert_test_decode(std::f64::consts::PI, Format::Text)?;
159        assert_test_decode(-std::f64::consts::E, Format::Text)?;
160        assert_test_decode(0.0_f64, Format::Text)?;
161        assert_test_decode(TimestampValue::new(1650890322000000000, None), Format::Text)?;
162        assert_test_decode(TimestampValue::new(0, None), Format::Text)?;
163        assert_test_decode(TimestampValue::new(1234567890123456000, None), Format::Text)?;
164
165        #[cfg(feature = "with-chrono")]
166        {
167            assert_test_decode(NaiveDate::from_ymd_opt(2025, 8, 8).unwrap(), Format::Text)?;
168            assert_test_decode(NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), Format::Text)?;
169            assert_test_decode(NaiveDate::from_ymd_opt(1999, 12, 31).unwrap(), Format::Text)?;
170        }
171
172        Ok(())
173    }
174
175    #[test]
176    fn test_binary_decoders() -> Result<(), ProtocolError> {
177        assert_test_decode("test".to_string(), Format::Binary)?;
178        assert_test_decode(true, Format::Binary)?;
179        assert_test_decode(false, Format::Binary)?;
180        assert_test_decode(1_i64, Format::Binary)?;
181        assert_test_decode(100_i64, Format::Binary)?;
182        assert_test_decode(std::f64::consts::PI, Format::Binary)?;
183        assert_test_decode(-std::f64::consts::E, Format::Binary)?;
184        assert_test_decode(0.0_f64, Format::Binary)?;
185        assert_test_decode(
186            TimestampValue::new(1650890322000000000, None),
187            Format::Binary,
188        )?;
189        assert_test_decode(TimestampValue::new(0, None), Format::Binary)?;
190        assert_test_decode(
191            TimestampValue::new(1234567890123456000, None),
192            Format::Binary,
193        )?;
194
195        #[cfg(feature = "with-chrono")]
196        {
197            assert_test_decode(NaiveDate::from_ymd_opt(2025, 8, 8).unwrap(), Format::Binary)?;
198            assert_test_decode(NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), Format::Binary)?;
199            assert_test_decode(
200                NaiveDate::from_ymd_opt(1999, 12, 31).unwrap(),
201                Format::Binary,
202            )?;
203        }
204
205        Ok(())
206    }
207}