reifydb_type/
params.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the MIT, see license.md file
3
4use std::{collections::HashMap, fmt, str::FromStr};
5
6use serde::{
7	Deserialize, Deserializer, Serialize, Serializer,
8	de::{self, Visitor},
9};
10
11use crate::{
12	Blob, Fragment, OrderedF32, OrderedF64, Type, Value, parse_bool, parse_date, parse_datetime, parse_decimal,
13	parse_duration, parse_float, parse_time, parse_uuid4, parse_uuid7,
14	value::{
15		IdentityId,
16		number::{parse_primitive_int, parse_primitive_uint},
17	},
18};
19
20#[derive(Debug, Clone, Default, PartialEq, Eq)]
21pub enum Params {
22	#[default]
23	None,
24	Positional(Vec<Value>),
25	Named(HashMap<String, Value>),
26}
27
28impl Params {
29	pub fn get_positional(&self, index: usize) -> Option<&Value> {
30		match self {
31			Params::Positional(values) => values.get(index),
32			_ => None,
33		}
34	}
35
36	pub fn get_named(&self, name: &str) -> Option<&Value> {
37		match self {
38			Params::Named(map) => map.get(name),
39			_ => None,
40		}
41	}
42
43	pub fn empty() -> Params {
44		Params::None
45	}
46}
47
48impl From<()> for Params {
49	fn from(_: ()) -> Self {
50		Params::None
51	}
52}
53
54impl From<Vec<Value>> for Params {
55	fn from(values: Vec<Value>) -> Self {
56		Params::Positional(values)
57	}
58}
59
60impl From<HashMap<String, Value>> for Params {
61	fn from(map: HashMap<String, Value>) -> Self {
62		Params::Named(map)
63	}
64}
65
66impl<const N: usize> From<[Value; N]> for Params {
67	fn from(values: [Value; N]) -> Self {
68		Params::Positional(values.to_vec())
69	}
70}
71
72impl Serialize for Params {
73	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74	where
75		S: Serializer,
76	{
77		match self {
78			Params::None => serializer.serialize_none(),
79			Params::Positional(values) => values.serialize(serializer),
80			Params::Named(map) => map.serialize(serializer),
81		}
82	}
83}
84
85// Helper function to parse value from type/value format
86fn parse_typed_value(type_str: &str, value_val: &serde_json::Value) -> Result<Value, String> {
87	// Always expect string values for consistency
88	let str_val = value_val.as_str().ok_or_else(|| format!("expected string value for type {}", type_str))?;
89
90	// Parse the type string to Type enum
91	let value_type = match Type::from_str(type_str) {
92		Ok(Type::Undefined) => return Ok(Value::Undefined),
93		Ok(t) => t,
94		Err(_) => return Ok(Value::Undefined),
95	};
96
97	// Use the appropriate parse function based on type
98	// If parsing fails, return Value::Undefined
99	let fragment = Fragment::internal(str_val);
100
101	let parsed_value = match value_type {
102		Type::Boolean => parse_bool(fragment).map(Value::Boolean).unwrap_or(Value::Undefined),
103		Type::Float4 => parse_float::<f32>(fragment)
104			.ok()
105			.and_then(|f| OrderedF32::try_from(f).ok())
106			.map(Value::Float4)
107			.unwrap_or(Value::Undefined),
108		Type::Float8 => parse_float::<f64>(fragment)
109			.ok()
110			.and_then(|f| OrderedF64::try_from(f).ok())
111			.map(Value::Float8)
112			.unwrap_or(Value::Undefined),
113		Type::Int1 => parse_primitive_int::<i8>(fragment).map(Value::Int1).unwrap_or(Value::Undefined),
114		Type::Int2 => parse_primitive_int::<i16>(fragment).map(Value::Int2).unwrap_or(Value::Undefined),
115		Type::Int4 => parse_primitive_int::<i32>(fragment).map(Value::Int4).unwrap_or(Value::Undefined),
116		Type::Int8 => parse_primitive_int::<i64>(fragment).map(Value::Int8).unwrap_or(Value::Undefined),
117		Type::Int16 => parse_primitive_int::<i128>(fragment).map(Value::Int16).unwrap_or(Value::Undefined),
118		Type::Utf8 => Value::Utf8(str_val.to_string()),
119		Type::Uint1 => parse_primitive_uint::<u8>(fragment).map(Value::Uint1).unwrap_or(Value::Undefined),
120		Type::Uint2 => parse_primitive_uint::<u16>(fragment).map(Value::Uint2).unwrap_or(Value::Undefined),
121		Type::Uint4 => parse_primitive_uint::<u32>(fragment).map(Value::Uint4).unwrap_or(Value::Undefined),
122		Type::Uint8 => parse_primitive_uint::<u64>(fragment).map(Value::Uint8).unwrap_or(Value::Undefined),
123		Type::Uint16 => parse_primitive_uint::<u128>(fragment).map(Value::Uint16).unwrap_or(Value::Undefined),
124		Type::Date => parse_date(fragment).map(Value::Date).unwrap_or(Value::Undefined),
125		Type::DateTime => parse_datetime(fragment).map(Value::DateTime).unwrap_or(Value::Undefined),
126		Type::Time => parse_time(fragment).map(Value::Time).unwrap_or(Value::Undefined),
127		Type::Duration => parse_duration(fragment).map(Value::Duration).unwrap_or(Value::Undefined),
128		Type::Uuid4 => parse_uuid4(fragment).map(Value::Uuid4).unwrap_or(Value::Undefined),
129		Type::Uuid7 => parse_uuid7(fragment).map(Value::Uuid7).unwrap_or(Value::Undefined),
130		Type::IdentityId => parse_uuid7(fragment)
131			.map(|uuid7| Value::IdentityId(IdentityId::from(uuid7)))
132			.unwrap_or(Value::Undefined),
133		Type::Blob => Blob::from_hex(fragment).map(Value::Blob).unwrap_or(Value::Undefined),
134		Type::Undefined => Value::Undefined,
135		Type::Decimal => parse_decimal(fragment).map(Value::Decimal).unwrap_or(Value::Undefined),
136		Type::Int | Type::Uint => {
137			unimplemented!()
138		}
139		Type::Any => unreachable!("Any type cannot be used as parameter"),
140	};
141
142	Ok(parsed_value)
143}
144
145impl<'de> Deserialize<'de> for Params {
146	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
147	where
148		D: Deserializer<'de>,
149	{
150		struct ParamsVisitor;
151
152		impl<'de> Visitor<'de> for ParamsVisitor {
153			type Value = Params;
154
155			fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
156				formatter.write_str(
157					"null, an array for positional parameters, or an object for named parameters",
158				)
159			}
160
161			fn visit_none<E>(self) -> Result<Self::Value, E>
162			where
163				E: de::Error,
164			{
165				Ok(Params::None)
166			}
167
168			fn visit_unit<E>(self) -> Result<Self::Value, E>
169			where
170				E: de::Error,
171			{
172				Ok(Params::None)
173			}
174
175			fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
176			where
177				A: de::SeqAccess<'de>,
178			{
179				let mut values = Vec::new();
180
181				while let Some(value) = seq.next_element::<serde_json::Value>()? {
182					// Check if it's in the {"type": "Bool",
183					// "value": "true"} format
184					if let Some(obj) = value.as_object() {
185						if obj.contains_key("type") && obj.contains_key("value") {
186							let type_str = obj["type"].as_str().ok_or_else(|| {
187								de::Error::custom("type must be a string")
188							})?;
189							let value_val = &obj["value"];
190
191							let parsed_value = parse_typed_value(type_str, value_val)
192								.map_err(de::Error::custom)?;
193							values.push(parsed_value);
194							continue;
195						}
196					}
197
198					// Otherwise try to deserialize as a
199					// normal Value
200					let val = Value::deserialize(value).map_err(de::Error::custom)?;
201					values.push(val);
202				}
203
204				Ok(Params::Positional(values))
205			}
206
207			fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
208			where
209				A: de::MapAccess<'de>,
210			{
211				let mut result_map = HashMap::new();
212
213				while let Some(key) = map.next_key::<String>()? {
214					let value: serde_json::Value = map.next_value()?;
215
216					// Check if it's in the {"type": "Bool",
217					// "value": "true"} format
218					if let Some(obj) = value.as_object() {
219						if obj.contains_key("type") && obj.contains_key("value") {
220							let type_str = obj["type"].as_str().ok_or_else(|| {
221								de::Error::custom("type must be a string")
222							})?;
223							let value_val = &obj["value"];
224
225							let parsed_value = parse_typed_value(type_str, value_val)
226								.map_err(de::Error::custom)?;
227							result_map.insert(key, parsed_value);
228							continue;
229						}
230					}
231
232					// Otherwise try to deserialize as a
233					// normal Value
234					let val = Value::deserialize(value).map_err(de::Error::custom)?;
235					result_map.insert(key, val);
236				}
237
238				Ok(Params::Named(result_map))
239			}
240		}
241
242		deserializer.deserialize_any(ParamsVisitor)
243	}
244}
245
246#[macro_export]
247macro_rules! params {
248    // Empty params
249    () => {
250        $crate::Params::None
251    };
252
253    // Empty named parameters
254    {} => {
255        $crate::Params::None
256    };
257
258    // Named parameters with mixed keys: params!{ name: value, "key": value }
259    { $($key:tt : $value:expr),+ $(,)? } => {
260        {
261            let mut map = ::std::collections::HashMap::new();
262            $(
263                map.insert($crate::params_key!($key), $crate::IntoValue::into_value($value));
264            )*
265            $crate::Params::Named(map)
266        }
267    };
268
269    // Empty positional parameters
270    [] => {
271        $crate::Params::None
272    };
273
274    // Positional parameters: params![value1, value2, ...]
275    [ $($value:expr),+ $(,)? ] => {
276        {
277            let values = vec![
278                $($crate::IntoValue::into_value($value)),*
279            ];
280            $crate::Params::Positional(values)
281        }
282    };
283}
284
285#[macro_export]
286#[doc(hidden)]
287macro_rules! params_key {
288	($key:ident) => {
289		stringify!($key).to_string()
290	};
291	($key:literal) => {
292		$key.to_string()
293	};
294}
295
296#[cfg(test)]
297mod tests {
298	use super::*;
299	use crate::IntoValue;
300
301	#[test]
302	fn test_params_macro_positional() {
303		let params = params![42, true, "hello"];
304		match params {
305			Params::Positional(values) => {
306				assert_eq!(values.len(), 3);
307				assert_eq!(values[0], Value::Int4(42));
308				assert_eq!(values[1], Value::Boolean(true));
309				assert_eq!(values[2], Value::Utf8("hello".to_string()));
310			}
311			_ => panic!("Expected positional params"),
312		}
313	}
314
315	#[test]
316	fn test_params_macro_named() {
317		let params = params! {
318		    name: true,
319		    other: 42,
320		    message: "test"
321		};
322		match params {
323			Params::Named(map) => {
324				assert_eq!(map.len(), 3);
325				assert_eq!(map.get("name"), Some(&Value::Boolean(true)));
326				assert_eq!(map.get("other"), Some(&Value::Int4(42)));
327				assert_eq!(map.get("message"), Some(&Value::Utf8("test".to_string())));
328			}
329			_ => panic!("Expected named params"),
330		}
331	}
332
333	#[test]
334	fn test_params_macro_named_with_strings() {
335		let params = params! {
336		    "string_key": 100,
337		    ident_key: 200,
338		    "another-key": "value"
339		};
340		match params {
341			Params::Named(map) => {
342				assert_eq!(map.len(), 3);
343				assert_eq!(map.get("string_key"), Some(&Value::Int4(100)));
344				assert_eq!(map.get("ident_key"), Some(&Value::Int4(200)));
345				assert_eq!(map.get("another-key"), Some(&Value::Utf8("value".to_string())));
346			}
347			_ => panic!("Expected named params"),
348		}
349	}
350
351	#[test]
352	fn test_params_macro_empty() {
353		let params = params!();
354		assert_eq!(params, Params::None);
355
356		let params2 = params! {};
357		assert_eq!(params2, Params::None);
358
359		let params3 = params![];
360		assert_eq!(params3, Params::None);
361	}
362
363	#[test]
364	fn test_params_macro_with_values() {
365		let v1 = Value::Int8(100);
366		let v2 = 200i64.into_value();
367
368		let params = params![v1, v2, 300];
369		match params {
370			Params::Positional(values) => {
371				assert_eq!(values.len(), 3);
372				assert_eq!(values[0], Value::Int8(100));
373				assert_eq!(values[1], Value::Int8(200));
374				assert_eq!(values[2], Value::Int4(300));
375			}
376			_ => panic!("Expected positional params"),
377		}
378	}
379
380	#[test]
381	fn test_params_macro_trailing_comma() {
382		let params1 = params![1, 2, 3,];
383		let params2 = params! { a: 1, b: 2};
384
385		match params1 {
386			Params::Positional(values) => {
387				assert_eq!(values.len(), 3)
388			}
389			_ => panic!("Expected positional params"),
390		}
391
392		match params2 {
393			Params::Named(map) => assert_eq!(map.len(), 2),
394			_ => panic!("Expected named params"),
395		}
396	}
397}