1use std::fmt;
7
8pub const FEATURE_PARAMS: u32 = 0x0000_0001;
9pub const MAX_PARAM_COUNT: usize = 65_536;
10
11const TAG_NULL: u8 = 0x00;
12const TAG_BOOL: u8 = 0x01;
13const TAG_INT: u8 = 0x02;
14const TAG_FLOAT: u8 = 0x03;
15const TAG_TEXT: u8 = 0x04;
16const TAG_BYTES: u8 = 0x05;
17const TAG_VECTOR: u8 = 0x06;
18const TAG_JSON: u8 = 0x07;
19const TAG_TIMESTAMP: u8 = 0x08;
20const TAG_UUID: u8 = 0x09;
21
22#[derive(Debug, Clone, PartialEq)]
23pub enum ParamValue {
24 Null,
25 Bool(bool),
26 Int(i64),
27 Float(f64),
28 Text(String),
29 Bytes(Vec<u8>),
30 Vector(Vec<f32>),
31 Json(Vec<u8>),
32 Timestamp(i64),
33 Uuid([u8; 16]),
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum ParamCodecError {
38 LengthOverflow(&'static str),
39 ParamCountOverLimit(u32),
40 Truncated(&'static str),
41 InvalidUtf8(&'static str),
42 InvalidBool(u8),
43 UnknownTag(u8),
44 TrailingBytes(usize),
45}
46
47impl fmt::Display for ParamCodecError {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 match self {
50 Self::LengthOverflow(field) => write!(f, "{field} is too large for RedWire v1"),
51 Self::ParamCountOverLimit(count) => {
52 write!(f, "param_count {count} exceeds RedWire v1 limit")
53 }
54 Self::Truncated(field) => write!(f, "truncated {field}"),
55 Self::InvalidUtf8(field) => write!(f, "{field} must be valid UTF-8"),
56 Self::InvalidBool(byte) => write!(f, "invalid bool payload byte {byte}"),
57 Self::UnknownTag(tag) => write!(f, "unknown parameter value tag 0x{tag:02x}"),
58 Self::TrailingBytes(count) => write!(f, "{count} trailing bytes after payload"),
59 }
60 }
61}
62
63impl std::error::Error for ParamCodecError {}
64
65pub fn encode_query_with_params(
66 sql: &str,
67 params: &[ParamValue],
68) -> Result<Vec<u8>, ParamCodecError> {
69 if sql.len() > u32::MAX as usize {
70 return Err(ParamCodecError::LengthOverflow("sql"));
71 }
72 if params.len() > u32::MAX as usize {
73 return Err(ParamCodecError::LengthOverflow("params"));
74 }
75 if params.len() > MAX_PARAM_COUNT {
76 return Err(ParamCodecError::ParamCountOverLimit(params.len() as u32));
77 }
78
79 let mut out = Vec::new();
80 out.extend_from_slice(&(sql.len() as u32).to_le_bytes());
81 out.extend_from_slice(sql.as_bytes());
82 out.extend_from_slice(&(params.len() as u32).to_le_bytes());
83 for value in params {
84 encode_value(value, &mut out)?;
85 }
86 Ok(out)
87}
88
89pub fn decode_query_with_params(
90 payload: &[u8],
91) -> Result<(String, Vec<ParamValue>), ParamCodecError> {
92 let mut pos = 0;
93 let sql_len = read_u32(payload, &mut pos, "sql_len")? as usize;
94 let sql_bytes = read_bytes(payload, &mut pos, sql_len, "sql")?;
95 let sql = std::str::from_utf8(sql_bytes)
96 .map_err(|_| ParamCodecError::InvalidUtf8("sql"))?
97 .to_string();
98 let param_count = read_u32(payload, &mut pos, "param_count")?;
99 if param_count as usize > MAX_PARAM_COUNT {
100 return Err(ParamCodecError::ParamCountOverLimit(param_count));
101 }
102 let mut params = Vec::with_capacity(param_count as usize);
103 for _ in 0..param_count {
104 params.push(decode_value(payload, &mut pos)?);
105 }
106 if pos != payload.len() {
107 return Err(ParamCodecError::TrailingBytes(payload.len() - pos));
108 }
109 Ok((sql, params))
110}
111
112pub fn encode_value(value: &ParamValue, out: &mut Vec<u8>) -> Result<(), ParamCodecError> {
113 match value {
114 ParamValue::Null => out.push(TAG_NULL),
115 ParamValue::Bool(value) => {
116 out.push(TAG_BOOL);
117 out.push(u8::from(*value));
118 }
119 ParamValue::Int(value) => {
120 out.push(TAG_INT);
121 out.extend_from_slice(&value.to_le_bytes());
122 }
123 ParamValue::Float(value) => {
124 out.push(TAG_FLOAT);
125 out.extend_from_slice(&value.to_le_bytes());
126 }
127 ParamValue::Text(value) => {
128 out.push(TAG_TEXT);
129 write_len_prefixed(value.as_bytes(), out, "text")?;
130 }
131 ParamValue::Bytes(value) => {
132 out.push(TAG_BYTES);
133 write_len_prefixed(value, out, "bytes")?;
134 }
135 ParamValue::Vector(values) => {
136 out.push(TAG_VECTOR);
137 if values.len() > u32::MAX as usize {
138 return Err(ParamCodecError::LengthOverflow("vector"));
139 }
140 out.extend_from_slice(&(values.len() as u32).to_le_bytes());
141 for value in values {
142 out.extend_from_slice(&value.to_le_bytes());
143 }
144 }
145 ParamValue::Json(value) => {
146 out.push(TAG_JSON);
147 write_len_prefixed(value, out, "json")?;
148 }
149 ParamValue::Timestamp(value) => {
150 out.push(TAG_TIMESTAMP);
151 out.extend_from_slice(&value.to_le_bytes());
152 }
153 ParamValue::Uuid(value) => {
154 out.push(TAG_UUID);
155 out.extend_from_slice(value);
156 }
157 }
158 Ok(())
159}
160
161pub fn decode_value(payload: &[u8], pos: &mut usize) -> Result<ParamValue, ParamCodecError> {
162 let tag = *read_bytes(payload, pos, 1, "value tag")?
163 .first()
164 .expect("read one byte");
165 match tag {
166 TAG_NULL => Ok(ParamValue::Null),
167 TAG_BOOL => {
168 let value = read_bytes(payload, pos, 1, "bool")?[0];
169 match value {
170 0 => Ok(ParamValue::Bool(false)),
171 1 => Ok(ParamValue::Bool(true)),
172 other => Err(ParamCodecError::InvalidBool(other)),
173 }
174 }
175 TAG_INT => Ok(ParamValue::Int(read_i64(payload, pos, "int")?)),
176 TAG_FLOAT => Ok(ParamValue::Float(f64::from_le_bytes(read_array(
177 payload, pos, "float",
178 )?))),
179 TAG_TEXT => {
180 let len = read_u32(payload, pos, "text_len")? as usize;
181 let bytes = read_bytes(payload, pos, len, "text")?;
182 let text = std::str::from_utf8(bytes)
183 .map_err(|_| ParamCodecError::InvalidUtf8("text"))?
184 .to_string();
185 Ok(ParamValue::Text(text))
186 }
187 TAG_BYTES => {
188 let len = read_u32(payload, pos, "bytes_len")? as usize;
189 Ok(ParamValue::Bytes(
190 read_bytes(payload, pos, len, "bytes")?.to_vec(),
191 ))
192 }
193 TAG_VECTOR => {
194 let len = read_u32(payload, pos, "vector_len")? as usize;
195 let byte_len = len
196 .checked_mul(4)
197 .ok_or(ParamCodecError::LengthOverflow("vector"))?;
198 ensure_remaining(payload, *pos, byte_len, "vector")?;
199 let mut values = Vec::with_capacity(len);
200 for _ in 0..len {
201 values.push(f32::from_le_bytes(read_array(payload, pos, "vector")?));
202 }
203 Ok(ParamValue::Vector(values))
204 }
205 TAG_JSON => {
206 let len = read_u32(payload, pos, "json_len")? as usize;
207 Ok(ParamValue::Json(
208 read_bytes(payload, pos, len, "json")?.to_vec(),
209 ))
210 }
211 TAG_TIMESTAMP => Ok(ParamValue::Timestamp(read_i64(payload, pos, "timestamp")?)),
212 TAG_UUID => Ok(ParamValue::Uuid(read_array(payload, pos, "uuid")?)),
213 other => Err(ParamCodecError::UnknownTag(other)),
214 }
215}
216
217fn write_len_prefixed(
218 value: &[u8],
219 out: &mut Vec<u8>,
220 field: &'static str,
221) -> Result<(), ParamCodecError> {
222 if value.len() > u32::MAX as usize {
223 return Err(ParamCodecError::LengthOverflow(field));
224 }
225 out.extend_from_slice(&(value.len() as u32).to_le_bytes());
226 out.extend_from_slice(value);
227 Ok(())
228}
229
230fn read_u32(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<u32, ParamCodecError> {
231 Ok(u32::from_le_bytes(read_array(payload, pos, field)?))
232}
233
234fn read_i64(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<i64, ParamCodecError> {
235 Ok(i64::from_le_bytes(read_array(payload, pos, field)?))
236}
237
238fn read_array<const N: usize>(
239 payload: &[u8],
240 pos: &mut usize,
241 field: &'static str,
242) -> Result<[u8; N], ParamCodecError> {
243 let bytes = read_bytes(payload, pos, N, field)?;
244 let mut out = [0u8; N];
245 out.copy_from_slice(bytes);
246 Ok(out)
247}
248
249fn read_bytes<'a>(
250 payload: &'a [u8],
251 pos: &mut usize,
252 len: usize,
253 field: &'static str,
254) -> Result<&'a [u8], ParamCodecError> {
255 let end = pos
256 .checked_add(len)
257 .ok_or(ParamCodecError::Truncated(field))?;
258 if end > payload.len() {
259 return Err(ParamCodecError::Truncated(field));
260 }
261 let bytes = &payload[*pos..end];
262 *pos = end;
263 Ok(bytes)
264}
265
266fn ensure_remaining(
267 payload: &[u8],
268 pos: usize,
269 len: usize,
270 field: &'static str,
271) -> Result<(), ParamCodecError> {
272 let end = pos
273 .checked_add(len)
274 .ok_or(ParamCodecError::Truncated(field))?;
275 if end > payload.len() {
276 return Err(ParamCodecError::Truncated(field));
277 }
278 Ok(())
279}