Skip to main content

scry_protocol/
param.rs

1//! Parameter value types for prepared statement parameters.
2
3use serde::{Deserialize, Serialize};
4
5/// Represents a typed parameter value from PostgreSQL's Bind message.
6///
7/// Covers core PostgreSQL types with an `Unknown` escape hatch for
8/// extension types or unrecognized OIDs.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[serde(tag = "type", content = "value")]
11pub enum ParamValue {
12    /// SQL NULL
13    Null,
14
15    /// Boolean (OID 16)
16    Bool(bool),
17
18    /// 16-bit integer (OID 21)
19    Int16(i16),
20
21    /// 32-bit integer (OID 23)
22    Int32(i32),
23
24    /// 64-bit integer (OID 20)
25    Int64(i64),
26
27    /// 32-bit float (OID 700)
28    Float32(f32),
29
30    /// 64-bit float (OID 701)
31    Float64(f64),
32
33    /// Arbitrary precision numeric as string (OID 1700)
34    Numeric(String),
35
36    /// Text string (OID 25, 1043 varchar, etc.)
37    Text(String),
38
39    /// Binary data (OID 17 bytea)
40    #[serde(with = "base64_serde")]
41    Bytes(Vec<u8>),
42
43    /// Date as days since 2000-01-01 (OID 1082)
44    Date(i32),
45
46    /// Time as microseconds since midnight (OID 1083)
47    Time(i64),
48
49    /// Timestamp as microseconds since 2000-01-01 (OID 1114)
50    Timestamp(i64),
51
52    /// Timestamp with timezone as microseconds since 2000-01-01 UTC (OID 1184)
53    TimestampTz(i64),
54
55    /// Interval (OID 1186)
56    Interval {
57        months: i32,
58        days: i32,
59        microseconds: i64,
60    },
61
62    /// UUID as 16 bytes (OID 2950)
63    #[serde(with = "uuid_serde")]
64    Uuid([u8; 16]),
65
66    /// JSON/JSONB as string (OID 114, 3802)
67    Json(String),
68
69    /// Array of values (OID varies)
70    Array {
71        elements: Vec<ParamValue>,
72        dimensions: Vec<i32>,
73    },
74
75    /// Range type (OID varies)
76    Range {
77        lower: Option<Box<ParamValue>>,
78        upper: Option<Box<ParamValue>>,
79        lower_inc: bool,
80        upper_inc: bool,
81    },
82
83    /// Composite/record type (OID varies)
84    Composite { fields: Vec<ParamValue> },
85
86    /// Unknown or extension type - escape hatch
87    Unknown {
88        oid: u32,
89        #[serde(with = "base64_serde")]
90        data: Vec<u8>,
91    },
92}
93
94/// Base64 serialization for byte arrays
95mod base64_serde {
96    use serde::{Deserialize, Deserializer, Serializer};
97
98    pub fn serialize<S>(bytes: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
99    where
100        S: Serializer,
101    {
102        use base64::Engine;
103        let encoded = base64::engine::general_purpose::STANDARD.encode(bytes);
104        serializer.serialize_str(&encoded)
105    }
106
107    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
108    where
109        D: Deserializer<'de>,
110    {
111        use base64::Engine;
112        let s = String::deserialize(deserializer)?;
113        base64::engine::general_purpose::STANDARD
114            .decode(&s)
115            .map_err(serde::de::Error::custom)
116    }
117}
118
119/// UUID serialization as hex string
120mod uuid_serde {
121    use serde::{Deserialize, Deserializer, Serializer};
122
123    pub fn serialize<S>(bytes: &[u8; 16], serializer: S) -> Result<S::Ok, S::Error>
124    where
125        S: Serializer,
126    {
127        let hex = format!(
128            "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
129            bytes[0], bytes[1], bytes[2], bytes[3],
130            bytes[4], bytes[5],
131            bytes[6], bytes[7],
132            bytes[8], bytes[9],
133            bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]
134        );
135        serializer.serialize_str(&hex)
136    }
137
138    pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 16], D::Error>
139    where
140        D: Deserializer<'de>,
141    {
142        let s = String::deserialize(deserializer)?;
143        let hex: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
144        if hex.len() != 32 {
145            return Err(serde::de::Error::custom("UUID must be 32 hex chars"));
146        }
147        let mut bytes = [0u8; 16];
148        for i in 0..16 {
149            bytes[i] = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16)
150                .map_err(serde::de::Error::custom)?;
151        }
152        Ok(bytes)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_param_value_null_roundtrip() {
162        let val = ParamValue::Null;
163        let json = serde_json::to_string(&val).unwrap();
164        let parsed: ParamValue = serde_json::from_str(&json).unwrap();
165        assert_eq!(val, parsed);
166    }
167
168    #[test]
169    fn test_param_value_int32_roundtrip() {
170        let val = ParamValue::Int32(42);
171        let json = serde_json::to_string(&val).unwrap();
172        assert!(json.contains("\"type\":\"Int32\""));
173        let parsed: ParamValue = serde_json::from_str(&json).unwrap();
174        assert_eq!(val, parsed);
175    }
176
177    #[test]
178    fn test_param_value_text_roundtrip() {
179        let val = ParamValue::Text("hello world".to_string());
180        let json = serde_json::to_string(&val).unwrap();
181        let parsed: ParamValue = serde_json::from_str(&json).unwrap();
182        assert_eq!(val, parsed);
183    }
184
185    #[test]
186    fn test_param_value_bytes_base64() {
187        let val = ParamValue::Bytes(vec![0x01, 0x02, 0x03]);
188        let json = serde_json::to_string(&val).unwrap();
189        assert!(json.contains("AQID")); // base64 of [1,2,3]
190        let parsed: ParamValue = serde_json::from_str(&json).unwrap();
191        assert_eq!(val, parsed);
192    }
193
194    #[test]
195    fn test_param_value_uuid_hex() {
196        let val = ParamValue::Uuid([
197            0x55, 0x06, 0x7d, 0xc5, 0xb9, 0x1c, 0x40, 0x78,
198            0x90, 0x5b, 0x8a, 0x7f, 0xdd, 0x00, 0x83, 0x0c,
199        ]);
200        let json = serde_json::to_string(&val).unwrap();
201        assert!(json.contains("55067dc5-b91c-4078-905b-8a7fdd00830c"));
202        let parsed: ParamValue = serde_json::from_str(&json).unwrap();
203        assert_eq!(val, parsed);
204    }
205
206    #[test]
207    fn test_param_value_array_roundtrip() {
208        let val = ParamValue::Array {
209            elements: vec![ParamValue::Int32(1), ParamValue::Int32(2)],
210            dimensions: vec![2],
211        };
212        let json = serde_json::to_string(&val).unwrap();
213        let parsed: ParamValue = serde_json::from_str(&json).unwrap();
214        assert_eq!(val, parsed);
215    }
216
217    #[test]
218    fn test_param_value_unknown_roundtrip() {
219        let val = ParamValue::Unknown {
220            oid: 12345,
221            data: vec![0xDE, 0xAD, 0xBE, 0xEF],
222        };
223        let json = serde_json::to_string(&val).unwrap();
224        let parsed: ParamValue = serde_json::from_str(&json).unwrap();
225        assert_eq!(val, parsed);
226    }
227}