1use std::{collections::HashMap, fmt, str::FromStr};
5
6use serde::{
7 Deserialize, Deserializer, Serialize, Serializer,
8 de::{self, Visitor},
9};
10
11use crate::{
12 fragment::Fragment,
13 value::{
14 Value,
15 blob::Blob,
16 boolean::parse::parse_bool,
17 decimal::parse::parse_decimal,
18 identity::IdentityId,
19 number::parse::{parse_float, parse_primitive_int, parse_primitive_uint},
20 ordered_f32::OrderedF32,
21 ordered_f64::OrderedF64,
22 temporal::parse::{
23 date::parse_date, datetime::parse_datetime, duration::parse_duration, time::parse_time,
24 },
25 r#type::Type,
26 uuid::parse::{parse_uuid4, parse_uuid7},
27 },
28};
29
30#[derive(Debug, Clone, Default, PartialEq, Eq)]
31pub enum Params {
32 #[default]
33 None,
34 Positional(Vec<Value>),
35 Named(HashMap<String, Value>),
36}
37
38impl Params {
39 pub fn get_positional(&self, index: usize) -> Option<&Value> {
40 match self {
41 Params::Positional(values) => values.get(index),
42 _ => None,
43 }
44 }
45
46 pub fn get_named(&self, name: &str) -> Option<&Value> {
47 match self {
48 Params::Named(map) => map.get(name),
49 _ => None,
50 }
51 }
52
53 pub fn empty() -> Params {
54 Params::None
55 }
56}
57
58impl From<()> for Params {
59 fn from(_: ()) -> Self {
60 Params::None
61 }
62}
63
64impl From<Vec<Value>> for Params {
65 fn from(values: Vec<Value>) -> Self {
66 Params::Positional(values)
67 }
68}
69
70impl From<HashMap<String, Value>> for Params {
71 fn from(map: HashMap<String, Value>) -> Self {
72 Params::Named(map)
73 }
74}
75
76impl<const N: usize> From<[Value; N]> for Params {
77 fn from(values: [Value; N]) -> Self {
78 Params::Positional(values.to_vec())
79 }
80}
81
82impl Serialize for Params {
83 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84 where
85 S: Serializer,
86 {
87 match self {
88 Params::None => serializer.serialize_none(),
89 Params::Positional(values) => values.serialize(serializer),
90 Params::Named(map) => map.serialize(serializer),
91 }
92 }
93}
94
95fn parse_typed_value(type_str: &str, value_val: &serde_json::Value) -> Result<Value, String> {
97 let str_val = value_val.as_str().ok_or_else(|| format!("expected string value for type {}", type_str))?;
99
100 let value_type = match Type::from_str(type_str) {
102 Ok(Type::Option(_)) => return Ok(Value::none()),
103 Ok(t) => t,
104 Err(_) => return Ok(Value::none()),
105 };
106
107 let fragment = Fragment::internal(str_val);
110
111 let parsed_value = match value_type {
112 Type::Boolean => parse_bool(fragment).map(Value::Boolean).unwrap_or(Value::none()),
113 Type::Float4 => parse_float::<f32>(fragment)
114 .ok()
115 .and_then(|f| OrderedF32::try_from(f).ok())
116 .map(Value::Float4)
117 .unwrap_or(Value::none()),
118 Type::Float8 => parse_float::<f64>(fragment)
119 .ok()
120 .and_then(|f| OrderedF64::try_from(f).ok())
121 .map(Value::Float8)
122 .unwrap_or(Value::none()),
123 Type::Int1 => parse_primitive_int::<i8>(fragment).map(Value::Int1).unwrap_or(Value::none()),
124 Type::Int2 => parse_primitive_int::<i16>(fragment).map(Value::Int2).unwrap_or(Value::none()),
125 Type::Int4 => parse_primitive_int::<i32>(fragment).map(Value::Int4).unwrap_or(Value::none()),
126 Type::Int8 => parse_primitive_int::<i64>(fragment).map(Value::Int8).unwrap_or(Value::none()),
127 Type::Int16 => parse_primitive_int::<i128>(fragment).map(Value::Int16).unwrap_or(Value::none()),
128 Type::Utf8 => Value::Utf8(str_val.to_string()),
129 Type::Uint1 => parse_primitive_uint::<u8>(fragment).map(Value::Uint1).unwrap_or(Value::none()),
130 Type::Uint2 => parse_primitive_uint::<u16>(fragment).map(Value::Uint2).unwrap_or(Value::none()),
131 Type::Uint4 => parse_primitive_uint::<u32>(fragment).map(Value::Uint4).unwrap_or(Value::none()),
132 Type::Uint8 => parse_primitive_uint::<u64>(fragment).map(Value::Uint8).unwrap_or(Value::none()),
133 Type::Uint16 => parse_primitive_uint::<u128>(fragment).map(Value::Uint16).unwrap_or(Value::none()),
134 Type::Date => parse_date(fragment).map(Value::Date).unwrap_or(Value::none()),
135 Type::DateTime => parse_datetime(fragment).map(Value::DateTime).unwrap_or(Value::none()),
136 Type::Time => parse_time(fragment).map(Value::Time).unwrap_or(Value::none()),
137 Type::Duration => parse_duration(fragment).map(Value::Duration).unwrap_or(Value::none()),
138 Type::Uuid4 => parse_uuid4(fragment).map(Value::Uuid4).unwrap_or(Value::none()),
139 Type::Uuid7 => parse_uuid7(fragment).map(Value::Uuid7).unwrap_or(Value::none()),
140 Type::IdentityId => parse_uuid7(fragment)
141 .map(|uuid7| Value::IdentityId(IdentityId::from(uuid7)))
142 .unwrap_or(Value::none()),
143 Type::Blob => Blob::from_hex(fragment).map(Value::Blob).unwrap_or(Value::none()),
144 Type::Option(_) => Value::none(),
145 Type::Decimal => parse_decimal(fragment).map(Value::Decimal).unwrap_or(Value::none()),
146 Type::Int | Type::Uint => {
147 unimplemented!()
148 }
149 Type::Any => unreachable!("Any type cannot be used as parameter"),
150 Type::DictionaryId => Value::none(), };
152
153 Ok(parsed_value)
154}
155
156impl<'de> Deserialize<'de> for Params {
157 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158 where
159 D: Deserializer<'de>,
160 {
161 struct ParamsVisitor;
162
163 impl<'de> Visitor<'de> for ParamsVisitor {
164 type Value = Params;
165
166 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
167 formatter.write_str(
168 "null, an array for positional parameters, or an object for named parameters",
169 )
170 }
171
172 fn visit_none<E>(self) -> Result<Self::Value, E>
173 where
174 E: de::Error,
175 {
176 Ok(Params::None)
177 }
178
179 fn visit_unit<E>(self) -> Result<Self::Value, E>
180 where
181 E: de::Error,
182 {
183 Ok(Params::None)
184 }
185
186 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
187 where
188 A: de::SeqAccess<'de>,
189 {
190 let mut values = Vec::new();
191
192 while let Some(value) = seq.next_element::<serde_json::Value>()? {
193 if let Some(obj) = value.as_object() {
196 if obj.contains_key("type") && obj.contains_key("value") {
197 let type_str = obj["type"].as_str().ok_or_else(|| {
198 de::Error::custom("type must be a string")
199 })?;
200 let value_val = &obj["value"];
201
202 let parsed_value = parse_typed_value(type_str, value_val)
203 .map_err(de::Error::custom)?;
204 values.push(parsed_value);
205 continue;
206 }
207 }
208
209 let val = Value::deserialize(value).map_err(de::Error::custom)?;
212 values.push(val);
213 }
214
215 Ok(Params::Positional(values))
216 }
217
218 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
219 where
220 A: de::MapAccess<'de>,
221 {
222 let mut result_map = HashMap::new();
223
224 while let Some(key) = map.next_key::<String>()? {
225 let value: serde_json::Value = map.next_value()?;
226
227 if let Some(obj) = value.as_object() {
230 if obj.contains_key("type") && obj.contains_key("value") {
231 let type_str = obj["type"].as_str().ok_or_else(|| {
232 de::Error::custom("type must be a string")
233 })?;
234 let value_val = &obj["value"];
235
236 let parsed_value = parse_typed_value(type_str, value_val)
237 .map_err(de::Error::custom)?;
238 result_map.insert(key, parsed_value);
239 continue;
240 }
241 }
242
243 let val = Value::deserialize(value).map_err(de::Error::custom)?;
246 result_map.insert(key, val);
247 }
248
249 Ok(Params::Named(result_map))
250 }
251 }
252
253 deserializer.deserialize_any(ParamsVisitor)
254 }
255}
256
257#[macro_export]
258macro_rules! params {
259 { $($key:tt : $value:expr),+ $(,)? } => {
261 {
262 let mut map = ::std::collections::HashMap::new();
263 $(
264 map.insert($crate::params_key!($key), $crate::value::into::IntoValue::into_value($value));
265 )*
266 $crate::params::Params::Named(map)
267 }
268 };
269
270 [] => {
272 $crate::params::Params::None
273 };
274
275 [ $($value:expr),+ $(,)? ] => {
277 {
278 let values = vec![
279 $($crate::value::into::IntoValue::into_value($value)),*
280 ];
281 $crate::params::Params::Positional(values)
282 }
283 };
284}
285
286#[macro_export]
287#[doc(hidden)]
288macro_rules! params_key {
289 ($key:ident) => {
290 stringify!($key).to_string()
291 };
292 ($key:literal) => {
293 $key.to_string()
294 };
295}
296
297#[cfg(test)]
298pub mod tests {
299 use super::*;
300 use crate::{params::Params, value::into::IntoValue};
301
302 #[test]
303 fn test_params_macro_positional() {
304 let params = params![42, true, "hello"];
305 match params {
306 Params::Positional(values) => {
307 assert_eq!(values.len(), 3);
308 assert_eq!(values[0], Value::Int4(42));
309 assert_eq!(values[1], Value::Boolean(true));
310 assert_eq!(values[2], Value::Utf8("hello".to_string()));
311 }
312 _ => panic!("Expected positional params"),
313 }
314 }
315
316 #[test]
317 fn test_params_macro_named() {
318 let params = params! {
319 name: true,
320 other: 42,
321 message: "test"
322 };
323 match params {
324 Params::Named(map) => {
325 assert_eq!(map.len(), 3);
326 assert_eq!(map.get("name"), Some(&Value::Boolean(true)));
327 assert_eq!(map.get("other"), Some(&Value::Int4(42)));
328 assert_eq!(map.get("message"), Some(&Value::Utf8("test".to_string())));
329 }
330 _ => panic!("Expected named params"),
331 }
332 }
333
334 #[test]
335 fn test_params_macro_named_with_strings() {
336 let params = params! {
337 "string_key": 100,
338 ident_key: 200,
339 "another-key": "value"
340 };
341 match params {
342 Params::Named(map) => {
343 assert_eq!(map.len(), 3);
344 assert_eq!(map.get("string_key"), Some(&Value::Int4(100)));
345 assert_eq!(map.get("ident_key"), Some(&Value::Int4(200)));
346 assert_eq!(map.get("another-key"), Some(&Value::Utf8("value".to_string())));
347 }
348 _ => panic!("Expected named params"),
349 }
350 }
351
352 #[test]
353 fn test_params_macro_empty() {
354 let params = params!();
355 assert_eq!(params, Params::None);
356
357 let params2 = params! {};
358 assert_eq!(params2, Params::None);
359
360 let params3 = params![];
361 assert_eq!(params3, Params::None);
362 }
363
364 #[test]
365 fn test_params_macro_with_values() {
366 let v1 = Value::Int8(100);
367 let v2 = 200i64.into_value();
368
369 let params = params![v1, v2, 300];
370 match params {
371 Params::Positional(values) => {
372 assert_eq!(values.len(), 3);
373 assert_eq!(values[0], Value::Int8(100));
374 assert_eq!(values[1], Value::Int8(200));
375 assert_eq!(values[2], Value::Int4(300));
376 }
377 _ => panic!("Expected positional params"),
378 }
379 }
380
381 #[test]
382 fn test_params_macro_trailing_comma() {
383 let params1 = params![1, 2, 3,];
384 let params2 = params! { a: 1, b: 2};
385
386 match params1 {
387 Params::Positional(values) => {
388 assert_eq!(values.len(), 3)
389 }
390 _ => panic!("Expected positional params"),
391 }
392
393 match params2 {
394 Params::Named(map) => assert_eq!(map.len(), 2),
395 _ => panic!("Expected named params"),
396 }
397 }
398}