1use std::{collections::HashMap, fmt, str::FromStr};
5
6use serde::{
7 Deserialize, Deserializer, Serialize, Serializer,
8 de::{self, Visitor},
9};
10
11use crate::{
12 Blob, BorrowedFragment, OrderedF32, OrderedF64, RowNumber, Type, Value, parse_bool, parse_date, parse_datetime,
13 parse_decimal, parse_duration, parse_float, parse_time, parse_uuid4, parse_uuid7,
14 value::{
15 IdentityId,
16 number::{parse_primitive_int, parse_primitive_uint},
17 },
18};
19
20#[derive(Debug, Clone, Default, PartialEq, Eq)]
21pub enum Params {
22 #[default]
23 None,
24 Positional(Vec<Value>),
25 Named(HashMap<String, Value>),
26}
27
28impl Params {
29 pub fn get_positional(&self, index: usize) -> Option<&Value> {
30 match self {
31 Params::Positional(values) => values.get(index),
32 _ => None,
33 }
34 }
35
36 pub fn get_named(&self, name: &str) -> Option<&Value> {
37 match self {
38 Params::Named(map) => map.get(name),
39 _ => None,
40 }
41 }
42
43 pub fn empty() -> Params {
44 Params::None
45 }
46}
47
48impl From<()> for Params {
49 fn from(_: ()) -> Self {
50 Params::None
51 }
52}
53
54impl From<Vec<Value>> for Params {
55 fn from(values: Vec<Value>) -> Self {
56 Params::Positional(values)
57 }
58}
59
60impl From<HashMap<String, Value>> for Params {
61 fn from(map: HashMap<String, Value>) -> Self {
62 Params::Named(map)
63 }
64}
65
66impl<const N: usize> From<[Value; N]> for Params {
67 fn from(values: [Value; N]) -> Self {
68 Params::Positional(values.to_vec())
69 }
70}
71
72impl Serialize for Params {
73 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74 where
75 S: Serializer,
76 {
77 match self {
78 Params::None => serializer.serialize_none(),
79 Params::Positional(values) => values.serialize(serializer),
80 Params::Named(map) => map.serialize(serializer),
81 }
82 }
83}
84
85fn parse_typed_value(type_str: &str, value_val: &serde_json::Value) -> Result<Value, String> {
87 let str_val = value_val.as_str().ok_or_else(|| format!("expected string value for type {}", type_str))?;
89
90 let value_type = match Type::from_str(type_str) {
92 Ok(Type::Undefined) => return Ok(Value::Undefined),
93 Ok(t) => t,
94 Err(_) => return Ok(Value::Undefined),
95 };
96
97 let fragment = BorrowedFragment::new_internal(str_val);
100
101 let parsed_value = match value_type {
102 Type::Boolean => parse_bool(fragment).map(Value::Boolean).unwrap_or(Value::Undefined),
103 Type::Float4 => parse_float::<f32>(fragment)
104 .ok()
105 .and_then(|f| OrderedF32::try_from(f).ok())
106 .map(Value::Float4)
107 .unwrap_or(Value::Undefined),
108 Type::Float8 => parse_float::<f64>(fragment)
109 .ok()
110 .and_then(|f| OrderedF64::try_from(f).ok())
111 .map(Value::Float8)
112 .unwrap_or(Value::Undefined),
113 Type::Int1 => parse_primitive_int::<i8>(fragment).map(Value::Int1).unwrap_or(Value::Undefined),
114 Type::Int2 => parse_primitive_int::<i16>(fragment).map(Value::Int2).unwrap_or(Value::Undefined),
115 Type::Int4 => parse_primitive_int::<i32>(fragment).map(Value::Int4).unwrap_or(Value::Undefined),
116 Type::Int8 => parse_primitive_int::<i64>(fragment).map(Value::Int8).unwrap_or(Value::Undefined),
117 Type::Int16 => parse_primitive_int::<i128>(fragment).map(Value::Int16).unwrap_or(Value::Undefined),
118 Type::Utf8 => Value::Utf8(str_val.to_string()),
119 Type::Uint1 => parse_primitive_uint::<u8>(fragment).map(Value::Uint1).unwrap_or(Value::Undefined),
120 Type::Uint2 => parse_primitive_uint::<u16>(fragment).map(Value::Uint2).unwrap_or(Value::Undefined),
121 Type::Uint4 => parse_primitive_uint::<u32>(fragment).map(Value::Uint4).unwrap_or(Value::Undefined),
122 Type::Uint8 => parse_primitive_uint::<u64>(fragment).map(Value::Uint8).unwrap_or(Value::Undefined),
123 Type::Uint16 => parse_primitive_uint::<u128>(fragment).map(Value::Uint16).unwrap_or(Value::Undefined),
124 Type::Date => parse_date(fragment).map(Value::Date).unwrap_or(Value::Undefined),
125 Type::DateTime => parse_datetime(fragment).map(Value::DateTime).unwrap_or(Value::Undefined),
126 Type::Time => parse_time(fragment).map(Value::Time).unwrap_or(Value::Undefined),
127 Type::Duration => parse_duration(fragment).map(Value::Duration).unwrap_or(Value::Undefined),
128 Type::RowNumber => parse_primitive_uint::<u64>(fragment)
129 .map(|id| Value::RowNumber(RowNumber::from(id)))
130 .unwrap_or(Value::Undefined),
131 Type::Uuid4 => parse_uuid4(fragment).map(Value::Uuid4).unwrap_or(Value::Undefined),
132 Type::Uuid7 => parse_uuid7(fragment).map(Value::Uuid7).unwrap_or(Value::Undefined),
133 Type::IdentityId => parse_uuid7(fragment)
134 .map(|uuid7| Value::IdentityId(IdentityId::from(uuid7)))
135 .unwrap_or(Value::Undefined),
136 Type::Blob => Blob::from_hex(fragment).map(Value::Blob).unwrap_or(Value::Undefined),
137 Type::Undefined => Value::Undefined,
138 Type::Decimal => parse_decimal(fragment).map(Value::Decimal).unwrap_or(Value::Undefined),
139 Type::Int | Type::Uint => {
140 unimplemented!()
141 }
142 Type::Any => unreachable!("Any type cannot be used as parameter"),
143 };
144
145 Ok(parsed_value)
146}
147
148impl<'de> Deserialize<'de> for Params {
149 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
150 where
151 D: Deserializer<'de>,
152 {
153 struct ParamsVisitor;
154
155 impl<'de> Visitor<'de> for ParamsVisitor {
156 type Value = Params;
157
158 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
159 formatter.write_str(
160 "null, an array for positional parameters, or an object for named parameters",
161 )
162 }
163
164 fn visit_none<E>(self) -> Result<Self::Value, E>
165 where
166 E: de::Error,
167 {
168 Ok(Params::None)
169 }
170
171 fn visit_unit<E>(self) -> Result<Self::Value, E>
172 where
173 E: de::Error,
174 {
175 Ok(Params::None)
176 }
177
178 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
179 where
180 A: de::SeqAccess<'de>,
181 {
182 let mut values = Vec::new();
183
184 while let Some(value) = seq.next_element::<serde_json::Value>()? {
185 if let Some(obj) = value.as_object() {
188 if obj.contains_key("type") && obj.contains_key("value") {
189 let type_str = obj["type"].as_str().ok_or_else(|| {
190 de::Error::custom("type must be a string")
191 })?;
192 let value_val = &obj["value"];
193
194 let parsed_value = parse_typed_value(type_str, value_val)
195 .map_err(de::Error::custom)?;
196 values.push(parsed_value);
197 continue;
198 }
199 }
200
201 let val = Value::deserialize(value).map_err(de::Error::custom)?;
204 values.push(val);
205 }
206
207 Ok(Params::Positional(values))
208 }
209
210 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
211 where
212 A: de::MapAccess<'de>,
213 {
214 let mut result_map = HashMap::new();
215
216 while let Some(key) = map.next_key::<String>()? {
217 let value: serde_json::Value = map.next_value()?;
218
219 if let Some(obj) = value.as_object() {
222 if obj.contains_key("type") && obj.contains_key("value") {
223 let type_str = obj["type"].as_str().ok_or_else(|| {
224 de::Error::custom("type must be a string")
225 })?;
226 let value_val = &obj["value"];
227
228 let parsed_value = parse_typed_value(type_str, value_val)
229 .map_err(de::Error::custom)?;
230 result_map.insert(key, parsed_value);
231 continue;
232 }
233 }
234
235 let val = Value::deserialize(value).map_err(de::Error::custom)?;
238 result_map.insert(key, val);
239 }
240
241 Ok(Params::Named(result_map))
242 }
243 }
244
245 deserializer.deserialize_any(ParamsVisitor)
246 }
247}
248
249#[macro_export]
250macro_rules! params {
251 () => {
253 $crate::Params::None
254 };
255
256 {} => {
258 $crate::Params::None
259 };
260
261 { $($key:tt : $value:expr),+ $(,)? } => {
263 {
264 let mut map = ::std::collections::HashMap::new();
265 $(
266 map.insert($crate::params_key!($key), $crate::IntoValue::into_value($value));
267 )*
268 $crate::Params::Named(map)
269 }
270 };
271
272 [] => {
274 $crate::Params::None
275 };
276
277 [ $($value:expr),+ $(,)? ] => {
279 {
280 let values = vec![
281 $($crate::IntoValue::into_value($value)),*
282 ];
283 $crate::Params::Positional(values)
284 }
285 };
286}
287
288#[macro_export]
289#[doc(hidden)]
290macro_rules! params_key {
291 ($key:ident) => {
292 stringify!($key).to_string()
293 };
294 ($key:literal) => {
295 $key.to_string()
296 };
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::IntoValue;
303
304 #[test]
305 fn test_params_macro_positional() {
306 let params = params![42, true, "hello"];
307 match params {
308 Params::Positional(values) => {
309 assert_eq!(values.len(), 3);
310 assert_eq!(values[0], Value::Int4(42));
311 assert_eq!(values[1], Value::Boolean(true));
312 assert_eq!(values[2], Value::Utf8("hello".to_string()));
313 }
314 _ => panic!("Expected positional params"),
315 }
316 }
317
318 #[test]
319 fn test_params_macro_named() {
320 let params = params! {
321 name: true,
322 other: 42,
323 message: "test"
324 };
325 match params {
326 Params::Named(map) => {
327 assert_eq!(map.len(), 3);
328 assert_eq!(map.get("name"), Some(&Value::Boolean(true)));
329 assert_eq!(map.get("other"), Some(&Value::Int4(42)));
330 assert_eq!(map.get("message"), Some(&Value::Utf8("test".to_string())));
331 }
332 _ => panic!("Expected named params"),
333 }
334 }
335
336 #[test]
337 fn test_params_macro_named_with_strings() {
338 let params = params! {
339 "string_key": 100,
340 ident_key: 200,
341 "another-key": "value"
342 };
343 match params {
344 Params::Named(map) => {
345 assert_eq!(map.len(), 3);
346 assert_eq!(map.get("string_key"), Some(&Value::Int4(100)));
347 assert_eq!(map.get("ident_key"), Some(&Value::Int4(200)));
348 assert_eq!(map.get("another-key"), Some(&Value::Utf8("value".to_string())));
349 }
350 _ => panic!("Expected named params"),
351 }
352 }
353
354 #[test]
355 fn test_params_macro_empty() {
356 let params = params!();
357 assert_eq!(params, Params::None);
358
359 let params2 = params! {};
360 assert_eq!(params2, Params::None);
361
362 let params3 = params![];
363 assert_eq!(params3, Params::None);
364 }
365
366 #[test]
367 fn test_params_macro_with_values() {
368 let v1 = Value::Int8(100);
369 let v2 = 200i64.into_value();
370
371 let params = params![v1, v2, 300];
372 match params {
373 Params::Positional(values) => {
374 assert_eq!(values.len(), 3);
375 assert_eq!(values[0], Value::Int8(100));
376 assert_eq!(values[1], Value::Int8(200));
377 assert_eq!(values[2], Value::Int4(300));
378 }
379 _ => panic!("Expected positional params"),
380 }
381 }
382
383 #[test]
384 fn test_params_macro_trailing_comma() {
385 let params1 = params![1, 2, 3,];
386 let params2 = params! { a: 1, b: 2};
387
388 match params1 {
389 Params::Positional(values) => {
390 assert_eq!(values.len(), 3)
391 }
392 _ => panic!("Expected positional params"),
393 }
394
395 match params2 {
396 Params::Named(map) => assert_eq!(map.len(), 2),
397 _ => panic!("Expected named params"),
398 }
399 }
400}