1use std::{collections::HashMap, fmt, str::FromStr};
5
6use serde::{
7 Deserialize, Deserializer, Serialize, Serializer,
8 de::{self, Visitor},
9};
10
11use crate::{
12 Blob, Fragment, OrderedF32, OrderedF64, Type, Value, parse_bool, parse_date, parse_datetime, parse_decimal,
13 parse_duration, parse_float, parse_time, parse_uuid4, parse_uuid7,
14 value::{
15 IdentityId,
16 number::{parse_primitive_int, parse_primitive_uint},
17 },
18};
19
20#[derive(Debug, Clone, Default, PartialEq, Eq)]
21pub enum Params {
22 #[default]
23 None,
24 Positional(Vec<Value>),
25 Named(HashMap<String, Value>),
26}
27
28impl Params {
29 pub fn get_positional(&self, index: usize) -> Option<&Value> {
30 match self {
31 Params::Positional(values) => values.get(index),
32 _ => None,
33 }
34 }
35
36 pub fn get_named(&self, name: &str) -> Option<&Value> {
37 match self {
38 Params::Named(map) => map.get(name),
39 _ => None,
40 }
41 }
42
43 pub fn empty() -> Params {
44 Params::None
45 }
46}
47
48impl From<()> for Params {
49 fn from(_: ()) -> Self {
50 Params::None
51 }
52}
53
54impl From<Vec<Value>> for Params {
55 fn from(values: Vec<Value>) -> Self {
56 Params::Positional(values)
57 }
58}
59
60impl From<HashMap<String, Value>> for Params {
61 fn from(map: HashMap<String, Value>) -> Self {
62 Params::Named(map)
63 }
64}
65
66impl<const N: usize> From<[Value; N]> for Params {
67 fn from(values: [Value; N]) -> Self {
68 Params::Positional(values.to_vec())
69 }
70}
71
72impl Serialize for Params {
73 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74 where
75 S: Serializer,
76 {
77 match self {
78 Params::None => serializer.serialize_none(),
79 Params::Positional(values) => values.serialize(serializer),
80 Params::Named(map) => map.serialize(serializer),
81 }
82 }
83}
84
85fn parse_typed_value(type_str: &str, value_val: &serde_json::Value) -> Result<Value, String> {
87 let str_val = value_val.as_str().ok_or_else(|| format!("expected string value for type {}", type_str))?;
89
90 let value_type = match Type::from_str(type_str) {
92 Ok(Type::Undefined) => return Ok(Value::Undefined),
93 Ok(t) => t,
94 Err(_) => return Ok(Value::Undefined),
95 };
96
97 let fragment = Fragment::internal(str_val);
100
101 let parsed_value = match value_type {
102 Type::Boolean => parse_bool(fragment).map(Value::Boolean).unwrap_or(Value::Undefined),
103 Type::Float4 => parse_float::<f32>(fragment)
104 .ok()
105 .and_then(|f| OrderedF32::try_from(f).ok())
106 .map(Value::Float4)
107 .unwrap_or(Value::Undefined),
108 Type::Float8 => parse_float::<f64>(fragment)
109 .ok()
110 .and_then(|f| OrderedF64::try_from(f).ok())
111 .map(Value::Float8)
112 .unwrap_or(Value::Undefined),
113 Type::Int1 => parse_primitive_int::<i8>(fragment).map(Value::Int1).unwrap_or(Value::Undefined),
114 Type::Int2 => parse_primitive_int::<i16>(fragment).map(Value::Int2).unwrap_or(Value::Undefined),
115 Type::Int4 => parse_primitive_int::<i32>(fragment).map(Value::Int4).unwrap_or(Value::Undefined),
116 Type::Int8 => parse_primitive_int::<i64>(fragment).map(Value::Int8).unwrap_or(Value::Undefined),
117 Type::Int16 => parse_primitive_int::<i128>(fragment).map(Value::Int16).unwrap_or(Value::Undefined),
118 Type::Utf8 => Value::Utf8(str_val.to_string()),
119 Type::Uint1 => parse_primitive_uint::<u8>(fragment).map(Value::Uint1).unwrap_or(Value::Undefined),
120 Type::Uint2 => parse_primitive_uint::<u16>(fragment).map(Value::Uint2).unwrap_or(Value::Undefined),
121 Type::Uint4 => parse_primitive_uint::<u32>(fragment).map(Value::Uint4).unwrap_or(Value::Undefined),
122 Type::Uint8 => parse_primitive_uint::<u64>(fragment).map(Value::Uint8).unwrap_or(Value::Undefined),
123 Type::Uint16 => parse_primitive_uint::<u128>(fragment).map(Value::Uint16).unwrap_or(Value::Undefined),
124 Type::Date => parse_date(fragment).map(Value::Date).unwrap_or(Value::Undefined),
125 Type::DateTime => parse_datetime(fragment).map(Value::DateTime).unwrap_or(Value::Undefined),
126 Type::Time => parse_time(fragment).map(Value::Time).unwrap_or(Value::Undefined),
127 Type::Duration => parse_duration(fragment).map(Value::Duration).unwrap_or(Value::Undefined),
128 Type::Uuid4 => parse_uuid4(fragment).map(Value::Uuid4).unwrap_or(Value::Undefined),
129 Type::Uuid7 => parse_uuid7(fragment).map(Value::Uuid7).unwrap_or(Value::Undefined),
130 Type::IdentityId => parse_uuid7(fragment)
131 .map(|uuid7| Value::IdentityId(IdentityId::from(uuid7)))
132 .unwrap_or(Value::Undefined),
133 Type::Blob => Blob::from_hex(fragment).map(Value::Blob).unwrap_or(Value::Undefined),
134 Type::Undefined => Value::Undefined,
135 Type::Decimal => parse_decimal(fragment).map(Value::Decimal).unwrap_or(Value::Undefined),
136 Type::Int | Type::Uint => {
137 unimplemented!()
138 }
139 Type::Any => unreachable!("Any type cannot be used as parameter"),
140 };
141
142 Ok(parsed_value)
143}
144
145impl<'de> Deserialize<'de> for Params {
146 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
147 where
148 D: Deserializer<'de>,
149 {
150 struct ParamsVisitor;
151
152 impl<'de> Visitor<'de> for ParamsVisitor {
153 type Value = Params;
154
155 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
156 formatter.write_str(
157 "null, an array for positional parameters, or an object for named parameters",
158 )
159 }
160
161 fn visit_none<E>(self) -> Result<Self::Value, E>
162 where
163 E: de::Error,
164 {
165 Ok(Params::None)
166 }
167
168 fn visit_unit<E>(self) -> Result<Self::Value, E>
169 where
170 E: de::Error,
171 {
172 Ok(Params::None)
173 }
174
175 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
176 where
177 A: de::SeqAccess<'de>,
178 {
179 let mut values = Vec::new();
180
181 while let Some(value) = seq.next_element::<serde_json::Value>()? {
182 if let Some(obj) = value.as_object() {
185 if obj.contains_key("type") && obj.contains_key("value") {
186 let type_str = obj["type"].as_str().ok_or_else(|| {
187 de::Error::custom("type must be a string")
188 })?;
189 let value_val = &obj["value"];
190
191 let parsed_value = parse_typed_value(type_str, value_val)
192 .map_err(de::Error::custom)?;
193 values.push(parsed_value);
194 continue;
195 }
196 }
197
198 let val = Value::deserialize(value).map_err(de::Error::custom)?;
201 values.push(val);
202 }
203
204 Ok(Params::Positional(values))
205 }
206
207 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
208 where
209 A: de::MapAccess<'de>,
210 {
211 let mut result_map = HashMap::new();
212
213 while let Some(key) = map.next_key::<String>()? {
214 let value: serde_json::Value = map.next_value()?;
215
216 if let Some(obj) = value.as_object() {
219 if obj.contains_key("type") && obj.contains_key("value") {
220 let type_str = obj["type"].as_str().ok_or_else(|| {
221 de::Error::custom("type must be a string")
222 })?;
223 let value_val = &obj["value"];
224
225 let parsed_value = parse_typed_value(type_str, value_val)
226 .map_err(de::Error::custom)?;
227 result_map.insert(key, parsed_value);
228 continue;
229 }
230 }
231
232 let val = Value::deserialize(value).map_err(de::Error::custom)?;
235 result_map.insert(key, val);
236 }
237
238 Ok(Params::Named(result_map))
239 }
240 }
241
242 deserializer.deserialize_any(ParamsVisitor)
243 }
244}
245
246#[macro_export]
247macro_rules! params {
248 () => {
250 $crate::Params::None
251 };
252
253 {} => {
255 $crate::Params::None
256 };
257
258 { $($key:tt : $value:expr),+ $(,)? } => {
260 {
261 let mut map = ::std::collections::HashMap::new();
262 $(
263 map.insert($crate::params_key!($key), $crate::IntoValue::into_value($value));
264 )*
265 $crate::Params::Named(map)
266 }
267 };
268
269 [] => {
271 $crate::Params::None
272 };
273
274 [ $($value:expr),+ $(,)? ] => {
276 {
277 let values = vec![
278 $($crate::IntoValue::into_value($value)),*
279 ];
280 $crate::Params::Positional(values)
281 }
282 };
283}
284
285#[macro_export]
286#[doc(hidden)]
287macro_rules! params_key {
288 ($key:ident) => {
289 stringify!($key).to_string()
290 };
291 ($key:literal) => {
292 $key.to_string()
293 };
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::IntoValue;
300
301 #[test]
302 fn test_params_macro_positional() {
303 let params = params![42, true, "hello"];
304 match params {
305 Params::Positional(values) => {
306 assert_eq!(values.len(), 3);
307 assert_eq!(values[0], Value::Int4(42));
308 assert_eq!(values[1], Value::Boolean(true));
309 assert_eq!(values[2], Value::Utf8("hello".to_string()));
310 }
311 _ => panic!("Expected positional params"),
312 }
313 }
314
315 #[test]
316 fn test_params_macro_named() {
317 let params = params! {
318 name: true,
319 other: 42,
320 message: "test"
321 };
322 match params {
323 Params::Named(map) => {
324 assert_eq!(map.len(), 3);
325 assert_eq!(map.get("name"), Some(&Value::Boolean(true)));
326 assert_eq!(map.get("other"), Some(&Value::Int4(42)));
327 assert_eq!(map.get("message"), Some(&Value::Utf8("test".to_string())));
328 }
329 _ => panic!("Expected named params"),
330 }
331 }
332
333 #[test]
334 fn test_params_macro_named_with_strings() {
335 let params = params! {
336 "string_key": 100,
337 ident_key: 200,
338 "another-key": "value"
339 };
340 match params {
341 Params::Named(map) => {
342 assert_eq!(map.len(), 3);
343 assert_eq!(map.get("string_key"), Some(&Value::Int4(100)));
344 assert_eq!(map.get("ident_key"), Some(&Value::Int4(200)));
345 assert_eq!(map.get("another-key"), Some(&Value::Utf8("value".to_string())));
346 }
347 _ => panic!("Expected named params"),
348 }
349 }
350
351 #[test]
352 fn test_params_macro_empty() {
353 let params = params!();
354 assert_eq!(params, Params::None);
355
356 let params2 = params! {};
357 assert_eq!(params2, Params::None);
358
359 let params3 = params![];
360 assert_eq!(params3, Params::None);
361 }
362
363 #[test]
364 fn test_params_macro_with_values() {
365 let v1 = Value::Int8(100);
366 let v2 = 200i64.into_value();
367
368 let params = params![v1, v2, 300];
369 match params {
370 Params::Positional(values) => {
371 assert_eq!(values.len(), 3);
372 assert_eq!(values[0], Value::Int8(100));
373 assert_eq!(values[1], Value::Int8(200));
374 assert_eq!(values[2], Value::Int4(300));
375 }
376 _ => panic!("Expected positional params"),
377 }
378 }
379
380 #[test]
381 fn test_params_macro_trailing_comma() {
382 let params1 = params![1, 2, 3,];
383 let params2 = params! { a: 1, b: 2};
384
385 match params1 {
386 Params::Positional(values) => {
387 assert_eq!(values.len(), 3)
388 }
389 _ => panic!("Expected positional params"),
390 }
391
392 match params2 {
393 Params::Named(map) => assert_eq!(map.len(), 2),
394 _ => panic!("Expected named params"),
395 }
396 }
397}