Skip to main content

reddb_wire/
query_with_params.rs

1//! RedWire `QueryWithParams` payload codec.
2//!
3//! Payload layout v1:
4//! `u32 sql_len` + UTF-8 SQL + `u32 param_count` + encoded values.
5
6use std::fmt;
7
8pub const FEATURE_PARAMS: u32 = 0x0000_0001;
9pub const MAX_PARAM_COUNT: usize = 65_536;
10
11const TAG_NULL: u8 = 0x00;
12const TAG_BOOL: u8 = 0x01;
13const TAG_INT: u8 = 0x02;
14const TAG_FLOAT: u8 = 0x03;
15const TAG_TEXT: u8 = 0x04;
16const TAG_BYTES: u8 = 0x05;
17const TAG_VECTOR: u8 = 0x06;
18const TAG_JSON: u8 = 0x07;
19const TAG_TIMESTAMP: u8 = 0x08;
20const TAG_UUID: u8 = 0x09;
21
22#[derive(Debug, Clone, PartialEq)]
23pub enum ParamValue {
24    Null,
25    Bool(bool),
26    Int(i64),
27    Float(f64),
28    Text(String),
29    Bytes(Vec<u8>),
30    Vector(Vec<f32>),
31    Json(Vec<u8>),
32    Timestamp(i64),
33    Uuid([u8; 16]),
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum ParamCodecError {
38    LengthOverflow(&'static str),
39    ParamCountOverLimit(u32),
40    Truncated(&'static str),
41    InvalidUtf8(&'static str),
42    InvalidBool(u8),
43    UnknownTag(u8),
44    TrailingBytes(usize),
45}
46
47impl fmt::Display for ParamCodecError {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self {
50            Self::LengthOverflow(field) => write!(f, "{field} is too large for RedWire v1"),
51            Self::ParamCountOverLimit(count) => {
52                write!(f, "param_count {count} exceeds RedWire v1 limit")
53            }
54            Self::Truncated(field) => write!(f, "truncated {field}"),
55            Self::InvalidUtf8(field) => write!(f, "{field} must be valid UTF-8"),
56            Self::InvalidBool(byte) => write!(f, "invalid bool payload byte {byte}"),
57            Self::UnknownTag(tag) => write!(f, "unknown parameter value tag 0x{tag:02x}"),
58            Self::TrailingBytes(count) => write!(f, "{count} trailing bytes after payload"),
59        }
60    }
61}
62
63impl std::error::Error for ParamCodecError {}
64
65pub fn encode_query_with_params(
66    sql: &str,
67    params: &[ParamValue],
68) -> Result<Vec<u8>, ParamCodecError> {
69    if sql.len() > u32::MAX as usize {
70        return Err(ParamCodecError::LengthOverflow("sql"));
71    }
72    if params.len() > u32::MAX as usize {
73        return Err(ParamCodecError::LengthOverflow("params"));
74    }
75    if params.len() > MAX_PARAM_COUNT {
76        return Err(ParamCodecError::ParamCountOverLimit(params.len() as u32));
77    }
78
79    let mut out = Vec::new();
80    out.extend_from_slice(&(sql.len() as u32).to_le_bytes());
81    out.extend_from_slice(sql.as_bytes());
82    out.extend_from_slice(&(params.len() as u32).to_le_bytes());
83    for value in params {
84        encode_value(value, &mut out)?;
85    }
86    Ok(out)
87}
88
89pub fn decode_query_with_params(
90    payload: &[u8],
91) -> Result<(String, Vec<ParamValue>), ParamCodecError> {
92    let mut pos = 0;
93    let sql_len = read_u32(payload, &mut pos, "sql_len")? as usize;
94    let sql_bytes = read_bytes(payload, &mut pos, sql_len, "sql")?;
95    let sql = std::str::from_utf8(sql_bytes)
96        .map_err(|_| ParamCodecError::InvalidUtf8("sql"))?
97        .to_string();
98    let param_count = read_u32(payload, &mut pos, "param_count")?;
99    if param_count as usize > MAX_PARAM_COUNT {
100        return Err(ParamCodecError::ParamCountOverLimit(param_count));
101    }
102    let mut params = Vec::with_capacity(param_count as usize);
103    for _ in 0..param_count {
104        params.push(decode_value(payload, &mut pos)?);
105    }
106    if pos != payload.len() {
107        return Err(ParamCodecError::TrailingBytes(payload.len() - pos));
108    }
109    Ok((sql, params))
110}
111
112pub fn encode_value(value: &ParamValue, out: &mut Vec<u8>) -> Result<(), ParamCodecError> {
113    match value {
114        ParamValue::Null => out.push(TAG_NULL),
115        ParamValue::Bool(value) => {
116            out.push(TAG_BOOL);
117            out.push(u8::from(*value));
118        }
119        ParamValue::Int(value) => {
120            out.push(TAG_INT);
121            out.extend_from_slice(&value.to_le_bytes());
122        }
123        ParamValue::Float(value) => {
124            out.push(TAG_FLOAT);
125            out.extend_from_slice(&value.to_le_bytes());
126        }
127        ParamValue::Text(value) => {
128            out.push(TAG_TEXT);
129            write_len_prefixed(value.as_bytes(), out, "text")?;
130        }
131        ParamValue::Bytes(value) => {
132            out.push(TAG_BYTES);
133            write_len_prefixed(value, out, "bytes")?;
134        }
135        ParamValue::Vector(values) => {
136            out.push(TAG_VECTOR);
137            if values.len() > u32::MAX as usize {
138                return Err(ParamCodecError::LengthOverflow("vector"));
139            }
140            out.extend_from_slice(&(values.len() as u32).to_le_bytes());
141            for value in values {
142                out.extend_from_slice(&value.to_le_bytes());
143            }
144        }
145        ParamValue::Json(value) => {
146            out.push(TAG_JSON);
147            write_len_prefixed(value, out, "json")?;
148        }
149        ParamValue::Timestamp(value) => {
150            out.push(TAG_TIMESTAMP);
151            out.extend_from_slice(&value.to_le_bytes());
152        }
153        ParamValue::Uuid(value) => {
154            out.push(TAG_UUID);
155            out.extend_from_slice(value);
156        }
157    }
158    Ok(())
159}
160
161pub fn decode_value(payload: &[u8], pos: &mut usize) -> Result<ParamValue, ParamCodecError> {
162    let tag = *read_bytes(payload, pos, 1, "value tag")?
163        .first()
164        .expect("read one byte");
165    match tag {
166        TAG_NULL => Ok(ParamValue::Null),
167        TAG_BOOL => {
168            let value = read_bytes(payload, pos, 1, "bool")?[0];
169            match value {
170                0 => Ok(ParamValue::Bool(false)),
171                1 => Ok(ParamValue::Bool(true)),
172                other => Err(ParamCodecError::InvalidBool(other)),
173            }
174        }
175        TAG_INT => Ok(ParamValue::Int(read_i64(payload, pos, "int")?)),
176        TAG_FLOAT => Ok(ParamValue::Float(f64::from_le_bytes(read_array(
177            payload, pos, "float",
178        )?))),
179        TAG_TEXT => {
180            let len = read_u32(payload, pos, "text_len")? as usize;
181            let bytes = read_bytes(payload, pos, len, "text")?;
182            let text = std::str::from_utf8(bytes)
183                .map_err(|_| ParamCodecError::InvalidUtf8("text"))?
184                .to_string();
185            Ok(ParamValue::Text(text))
186        }
187        TAG_BYTES => {
188            let len = read_u32(payload, pos, "bytes_len")? as usize;
189            Ok(ParamValue::Bytes(
190                read_bytes(payload, pos, len, "bytes")?.to_vec(),
191            ))
192        }
193        TAG_VECTOR => {
194            let len = read_u32(payload, pos, "vector_len")? as usize;
195            let byte_len = len
196                .checked_mul(4)
197                .ok_or(ParamCodecError::LengthOverflow("vector"))?;
198            ensure_remaining(payload, *pos, byte_len, "vector")?;
199            let mut values = Vec::with_capacity(len);
200            for _ in 0..len {
201                values.push(f32::from_le_bytes(read_array(payload, pos, "vector")?));
202            }
203            Ok(ParamValue::Vector(values))
204        }
205        TAG_JSON => {
206            let len = read_u32(payload, pos, "json_len")? as usize;
207            Ok(ParamValue::Json(
208                read_bytes(payload, pos, len, "json")?.to_vec(),
209            ))
210        }
211        TAG_TIMESTAMP => Ok(ParamValue::Timestamp(read_i64(payload, pos, "timestamp")?)),
212        TAG_UUID => Ok(ParamValue::Uuid(read_array(payload, pos, "uuid")?)),
213        other => Err(ParamCodecError::UnknownTag(other)),
214    }
215}
216
217fn write_len_prefixed(
218    value: &[u8],
219    out: &mut Vec<u8>,
220    field: &'static str,
221) -> Result<(), ParamCodecError> {
222    if value.len() > u32::MAX as usize {
223        return Err(ParamCodecError::LengthOverflow(field));
224    }
225    out.extend_from_slice(&(value.len() as u32).to_le_bytes());
226    out.extend_from_slice(value);
227    Ok(())
228}
229
230fn read_u32(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<u32, ParamCodecError> {
231    Ok(u32::from_le_bytes(read_array(payload, pos, field)?))
232}
233
234fn read_i64(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<i64, ParamCodecError> {
235    Ok(i64::from_le_bytes(read_array(payload, pos, field)?))
236}
237
238fn read_array<const N: usize>(
239    payload: &[u8],
240    pos: &mut usize,
241    field: &'static str,
242) -> Result<[u8; N], ParamCodecError> {
243    let bytes = read_bytes(payload, pos, N, field)?;
244    let mut out = [0u8; N];
245    out.copy_from_slice(bytes);
246    Ok(out)
247}
248
249fn read_bytes<'a>(
250    payload: &'a [u8],
251    pos: &mut usize,
252    len: usize,
253    field: &'static str,
254) -> Result<&'a [u8], ParamCodecError> {
255    let end = pos
256        .checked_add(len)
257        .ok_or(ParamCodecError::Truncated(field))?;
258    if end > payload.len() {
259        return Err(ParamCodecError::Truncated(field));
260    }
261    let bytes = &payload[*pos..end];
262    *pos = end;
263    Ok(bytes)
264}
265
266fn ensure_remaining(
267    payload: &[u8],
268    pos: usize,
269    len: usize,
270    field: &'static str,
271) -> Result<(), ParamCodecError> {
272    let end = pos
273        .checked_add(len)
274        .ok_or(ParamCodecError::Truncated(field))?;
275    if end > payload.len() {
276        return Err(ParamCodecError::Truncated(field));
277    }
278    Ok(())
279}