Skip to main content

reifydb_type/
params.rs

1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025 ReifyDB
3
4use std::{collections::HashMap, fmt, str::FromStr};
5
6use serde::{
7	Deserialize, Deserializer, Serialize, Serializer,
8	de::{self, Visitor},
9};
10
11use crate::{
12	fragment::Fragment,
13	value::{
14		Value,
15		blob::Blob,
16		boolean::parse::parse_bool,
17		decimal::parse::parse_decimal,
18		identity::IdentityId,
19		number::parse::{parse_float, parse_primitive_int, parse_primitive_uint},
20		ordered_f32::OrderedF32,
21		ordered_f64::OrderedF64,
22		temporal::parse::{
23			date::parse_date, datetime::parse_datetime, duration::parse_duration, time::parse_time,
24		},
25		r#type::Type,
26		uuid::parse::{parse_uuid4, parse_uuid7},
27	},
28};
29
30#[derive(Debug, Clone, Default, PartialEq, Eq)]
31pub enum Params {
32	#[default]
33	None,
34	Positional(Vec<Value>),
35	Named(HashMap<String, Value>),
36}
37
38impl Params {
39	pub fn get_positional(&self, index: usize) -> Option<&Value> {
40		match self {
41			Params::Positional(values) => values.get(index),
42			_ => None,
43		}
44	}
45
46	pub fn get_named(&self, name: &str) -> Option<&Value> {
47		match self {
48			Params::Named(map) => map.get(name),
49			_ => None,
50		}
51	}
52
53	pub fn empty() -> Params {
54		Params::None
55	}
56}
57
58impl From<()> for Params {
59	fn from(_: ()) -> Self {
60		Params::None
61	}
62}
63
64impl From<Vec<Value>> for Params {
65	fn from(values: Vec<Value>) -> Self {
66		Params::Positional(values)
67	}
68}
69
70impl From<HashMap<String, Value>> for Params {
71	fn from(map: HashMap<String, Value>) -> Self {
72		Params::Named(map)
73	}
74}
75
76impl<const N: usize> From<[Value; N]> for Params {
77	fn from(values: [Value; N]) -> Self {
78		Params::Positional(values.to_vec())
79	}
80}
81
82impl Serialize for Params {
83	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84	where
85		S: Serializer,
86	{
87		match self {
88			Params::None => serializer.serialize_none(),
89			Params::Positional(values) => values.serialize(serializer),
90			Params::Named(map) => map.serialize(serializer),
91		}
92	}
93}
94
95// Helper function to parse value from type/value format
96fn parse_typed_value(type_str: &str, value_val: &serde_json::Value) -> Result<Value, String> {
97	// Always expect string values for consistency
98	let str_val = value_val.as_str().ok_or_else(|| format!("expected string value for type {}", type_str))?;
99
100	// Parse the type string to Type enum
101	let value_type = match Type::from_str(type_str) {
102		Ok(Type::Option(_)) => return Ok(Value::none()),
103		Ok(t) => t,
104		Err(_) => return Ok(Value::none()),
105	};
106
107	// Use the appropriate parse function based on type
108	// If parsing fails, return Value::none()
109	let fragment = Fragment::internal(str_val);
110
111	let parsed_value = match value_type {
112		Type::Boolean => parse_bool(fragment).map(Value::Boolean).unwrap_or(Value::none()),
113		Type::Float4 => parse_float::<f32>(fragment)
114			.ok()
115			.and_then(|f| OrderedF32::try_from(f).ok())
116			.map(Value::Float4)
117			.unwrap_or(Value::none()),
118		Type::Float8 => parse_float::<f64>(fragment)
119			.ok()
120			.and_then(|f| OrderedF64::try_from(f).ok())
121			.map(Value::Float8)
122			.unwrap_or(Value::none()),
123		Type::Int1 => parse_primitive_int::<i8>(fragment).map(Value::Int1).unwrap_or(Value::none()),
124		Type::Int2 => parse_primitive_int::<i16>(fragment).map(Value::Int2).unwrap_or(Value::none()),
125		Type::Int4 => parse_primitive_int::<i32>(fragment).map(Value::Int4).unwrap_or(Value::none()),
126		Type::Int8 => parse_primitive_int::<i64>(fragment).map(Value::Int8).unwrap_or(Value::none()),
127		Type::Int16 => parse_primitive_int::<i128>(fragment).map(Value::Int16).unwrap_or(Value::none()),
128		Type::Utf8 => Value::Utf8(str_val.to_string()),
129		Type::Uint1 => parse_primitive_uint::<u8>(fragment).map(Value::Uint1).unwrap_or(Value::none()),
130		Type::Uint2 => parse_primitive_uint::<u16>(fragment).map(Value::Uint2).unwrap_or(Value::none()),
131		Type::Uint4 => parse_primitive_uint::<u32>(fragment).map(Value::Uint4).unwrap_or(Value::none()),
132		Type::Uint8 => parse_primitive_uint::<u64>(fragment).map(Value::Uint8).unwrap_or(Value::none()),
133		Type::Uint16 => parse_primitive_uint::<u128>(fragment).map(Value::Uint16).unwrap_or(Value::none()),
134		Type::Date => parse_date(fragment).map(Value::Date).unwrap_or(Value::none()),
135		Type::DateTime => parse_datetime(fragment).map(Value::DateTime).unwrap_or(Value::none()),
136		Type::Time => parse_time(fragment).map(Value::Time).unwrap_or(Value::none()),
137		Type::Duration => parse_duration(fragment).map(Value::Duration).unwrap_or(Value::none()),
138		Type::Uuid4 => parse_uuid4(fragment).map(Value::Uuid4).unwrap_or(Value::none()),
139		Type::Uuid7 => parse_uuid7(fragment).map(Value::Uuid7).unwrap_or(Value::none()),
140		Type::IdentityId => parse_uuid7(fragment)
141			.map(|uuid7| Value::IdentityId(IdentityId::from(uuid7)))
142			.unwrap_or(Value::none()),
143		Type::Blob => Blob::from_hex(fragment).map(Value::Blob).unwrap_or(Value::none()),
144		Type::Option(_) => Value::none(),
145		Type::Decimal => parse_decimal(fragment).map(Value::Decimal).unwrap_or(Value::none()),
146		Type::Int | Type::Uint => {
147			unimplemented!()
148		}
149		Type::Any => unreachable!("Any type cannot be used as parameter"),
150		Type::DictionaryId => Value::none(), // DictionaryId cannot be parsed from string parameters
151	};
152
153	Ok(parsed_value)
154}
155
156impl<'de> Deserialize<'de> for Params {
157	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158	where
159		D: Deserializer<'de>,
160	{
161		struct ParamsVisitor;
162
163		impl<'de> Visitor<'de> for ParamsVisitor {
164			type Value = Params;
165
166			fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
167				formatter.write_str(
168					"null, an array for positional parameters, or an object for named parameters",
169				)
170			}
171
172			fn visit_none<E>(self) -> Result<Self::Value, E>
173			where
174				E: de::Error,
175			{
176				Ok(Params::None)
177			}
178
179			fn visit_unit<E>(self) -> Result<Self::Value, E>
180			where
181				E: de::Error,
182			{
183				Ok(Params::None)
184			}
185
186			fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
187			where
188				A: de::SeqAccess<'de>,
189			{
190				let mut values = Vec::new();
191
192				while let Some(value) = seq.next_element::<serde_json::Value>()? {
193					// Check if it's in the {"type": "Bool",
194					// "value": "true"} format
195					if let Some(obj) = value.as_object() {
196						if obj.contains_key("type") && obj.contains_key("value") {
197							let type_str = obj["type"].as_str().ok_or_else(|| {
198								de::Error::custom("type must be a string")
199							})?;
200							let value_val = &obj["value"];
201
202							let parsed_value = parse_typed_value(type_str, value_val)
203								.map_err(de::Error::custom)?;
204							values.push(parsed_value);
205							continue;
206						}
207					}
208
209					// Otherwise try to deserialize as a
210					// normal Value
211					let val = Value::deserialize(value).map_err(de::Error::custom)?;
212					values.push(val);
213				}
214
215				Ok(Params::Positional(values))
216			}
217
218			fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
219			where
220				A: de::MapAccess<'de>,
221			{
222				let mut result_map = HashMap::new();
223
224				while let Some(key) = map.next_key::<String>()? {
225					let value: serde_json::Value = map.next_value()?;
226
227					// Check if it's in the {"type": "Bool",
228					// "value": "true"} format
229					if let Some(obj) = value.as_object() {
230						if obj.contains_key("type") && obj.contains_key("value") {
231							let type_str = obj["type"].as_str().ok_or_else(|| {
232								de::Error::custom("type must be a string")
233							})?;
234							let value_val = &obj["value"];
235
236							let parsed_value = parse_typed_value(type_str, value_val)
237								.map_err(de::Error::custom)?;
238							result_map.insert(key, parsed_value);
239							continue;
240						}
241					}
242
243					// Otherwise try to deserialize as a
244					// normal Value
245					let val = Value::deserialize(value).map_err(de::Error::custom)?;
246					result_map.insert(key, val);
247				}
248
249				Ok(Params::Named(result_map))
250			}
251		}
252
253		deserializer.deserialize_any(ParamsVisitor)
254	}
255}
256
257#[macro_export]
258macro_rules! params {
259    // Named parameters with mixed keys: params!{ name: value, "key": value }
260    { $($key:tt : $value:expr),+ $(,)? } => {
261        {
262            let mut map = ::std::collections::HashMap::new();
263            $(
264                map.insert($crate::params_key!($key), $crate::value::into::IntoValue::into_value($value));
265            )*
266            $crate::params::Params::Named(map)
267        }
268    };
269
270    // Empty positional parameters
271    [] => {
272        $crate::params::Params::None
273    };
274
275    // Positional parameters: params![value1, value2, ...]
276    [ $($value:expr),+ $(,)? ] => {
277        {
278            let values = vec![
279                $($crate::value::into::IntoValue::into_value($value)),*
280            ];
281            $crate::params::Params::Positional(values)
282        }
283    };
284}
285
286#[macro_export]
287#[doc(hidden)]
288macro_rules! params_key {
289	($key:ident) => {
290		stringify!($key).to_string()
291	};
292	($key:literal) => {
293		$key.to_string()
294	};
295}
296
297#[cfg(test)]
298pub mod tests {
299	use super::*;
300	use crate::{params::Params, value::into::IntoValue};
301
302	#[test]
303	fn test_params_macro_positional() {
304		let params = params![42, true, "hello"];
305		match params {
306			Params::Positional(values) => {
307				assert_eq!(values.len(), 3);
308				assert_eq!(values[0], Value::Int4(42));
309				assert_eq!(values[1], Value::Boolean(true));
310				assert_eq!(values[2], Value::Utf8("hello".to_string()));
311			}
312			_ => panic!("Expected positional params"),
313		}
314	}
315
316	#[test]
317	fn test_params_macro_named() {
318		let params = params! {
319		    name: true,
320		    other: 42,
321		    message: "test"
322		};
323		match params {
324			Params::Named(map) => {
325				assert_eq!(map.len(), 3);
326				assert_eq!(map.get("name"), Some(&Value::Boolean(true)));
327				assert_eq!(map.get("other"), Some(&Value::Int4(42)));
328				assert_eq!(map.get("message"), Some(&Value::Utf8("test".to_string())));
329			}
330			_ => panic!("Expected named params"),
331		}
332	}
333
334	#[test]
335	fn test_params_macro_named_with_strings() {
336		let params = params! {
337		    "string_key": 100,
338		    ident_key: 200,
339		    "another-key": "value"
340		};
341		match params {
342			Params::Named(map) => {
343				assert_eq!(map.len(), 3);
344				assert_eq!(map.get("string_key"), Some(&Value::Int4(100)));
345				assert_eq!(map.get("ident_key"), Some(&Value::Int4(200)));
346				assert_eq!(map.get("another-key"), Some(&Value::Utf8("value".to_string())));
347			}
348			_ => panic!("Expected named params"),
349		}
350	}
351
352	#[test]
353	fn test_params_macro_empty() {
354		let params = params!();
355		assert_eq!(params, Params::None);
356
357		let params2 = params! {};
358		assert_eq!(params2, Params::None);
359
360		let params3 = params![];
361		assert_eq!(params3, Params::None);
362	}
363
364	#[test]
365	fn test_params_macro_with_values() {
366		let v1 = Value::Int8(100);
367		let v2 = 200i64.into_value();
368
369		let params = params![v1, v2, 300];
370		match params {
371			Params::Positional(values) => {
372				assert_eq!(values.len(), 3);
373				assert_eq!(values[0], Value::Int8(100));
374				assert_eq!(values[1], Value::Int8(200));
375				assert_eq!(values[2], Value::Int4(300));
376			}
377			_ => panic!("Expected positional params"),
378		}
379	}
380
381	#[test]
382	fn test_params_macro_trailing_comma() {
383		let params1 = params![1, 2, 3,];
384		let params2 = params! { a: 1, b: 2};
385
386		match params1 {
387			Params::Positional(values) => {
388				assert_eq!(values.len(), 3)
389			}
390			_ => panic!("Expected positional params"),
391		}
392
393		match params2 {
394			Params::Named(map) => assert_eq!(map.len(), 2),
395			_ => panic!("Expected named params"),
396		}
397	}
398}