reifydb_type/
params.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the MIT, see license.md file
3
4use std::{collections::HashMap, fmt, str::FromStr};
5
6use serde::{
7	Deserialize, Deserializer, Serialize, Serializer,
8	de::{self, Visitor},
9};
10
11use crate::{
12	Blob, BorrowedFragment, OrderedF32, OrderedF64, RowNumber, Type, Value, parse_bool, parse_date, parse_datetime,
13	parse_decimal, parse_duration, parse_float, parse_time, parse_uuid4, parse_uuid7,
14	value::{
15		IdentityId,
16		number::{parse_primitive_int, parse_primitive_uint},
17	},
18};
19
20#[derive(Debug, Clone, Default, PartialEq, Eq)]
21pub enum Params {
22	#[default]
23	None,
24	Positional(Vec<Value>),
25	Named(HashMap<String, Value>),
26}
27
28impl Params {
29	pub fn get_positional(&self, index: usize) -> Option<&Value> {
30		match self {
31			Params::Positional(values) => values.get(index),
32			_ => None,
33		}
34	}
35
36	pub fn get_named(&self, name: &str) -> Option<&Value> {
37		match self {
38			Params::Named(map) => map.get(name),
39			_ => None,
40		}
41	}
42
43	pub fn empty() -> Params {
44		Params::None
45	}
46}
47
48impl From<()> for Params {
49	fn from(_: ()) -> Self {
50		Params::None
51	}
52}
53
54impl From<Vec<Value>> for Params {
55	fn from(values: Vec<Value>) -> Self {
56		Params::Positional(values)
57	}
58}
59
60impl From<HashMap<String, Value>> for Params {
61	fn from(map: HashMap<String, Value>) -> Self {
62		Params::Named(map)
63	}
64}
65
66impl<const N: usize> From<[Value; N]> for Params {
67	fn from(values: [Value; N]) -> Self {
68		Params::Positional(values.to_vec())
69	}
70}
71
72impl Serialize for Params {
73	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74	where
75		S: Serializer,
76	{
77		match self {
78			Params::None => serializer.serialize_none(),
79			Params::Positional(values) => values.serialize(serializer),
80			Params::Named(map) => map.serialize(serializer),
81		}
82	}
83}
84
85// Helper function to parse value from type/value format
86fn parse_typed_value(type_str: &str, value_val: &serde_json::Value) -> Result<Value, String> {
87	// Always expect string values for consistency
88	let str_val = value_val.as_str().ok_or_else(|| format!("expected string value for type {}", type_str))?;
89
90	// Parse the type string to Type enum
91	let value_type = match Type::from_str(type_str) {
92		Ok(Type::Undefined) => return Ok(Value::Undefined),
93		Ok(t) => t,
94		Err(_) => return Ok(Value::Undefined),
95	};
96
97	// Use the appropriate parse function based on type
98	// If parsing fails, return Value::Undefined
99	let fragment = BorrowedFragment::new_internal(str_val);
100
101	let parsed_value = match value_type {
102		Type::Boolean => parse_bool(fragment).map(Value::Boolean).unwrap_or(Value::Undefined),
103		Type::Float4 => parse_float::<f32>(fragment)
104			.ok()
105			.and_then(|f| OrderedF32::try_from(f).ok())
106			.map(Value::Float4)
107			.unwrap_or(Value::Undefined),
108		Type::Float8 => parse_float::<f64>(fragment)
109			.ok()
110			.and_then(|f| OrderedF64::try_from(f).ok())
111			.map(Value::Float8)
112			.unwrap_or(Value::Undefined),
113		Type::Int1 => parse_primitive_int::<i8>(fragment).map(Value::Int1).unwrap_or(Value::Undefined),
114		Type::Int2 => parse_primitive_int::<i16>(fragment).map(Value::Int2).unwrap_or(Value::Undefined),
115		Type::Int4 => parse_primitive_int::<i32>(fragment).map(Value::Int4).unwrap_or(Value::Undefined),
116		Type::Int8 => parse_primitive_int::<i64>(fragment).map(Value::Int8).unwrap_or(Value::Undefined),
117		Type::Int16 => parse_primitive_int::<i128>(fragment).map(Value::Int16).unwrap_or(Value::Undefined),
118		Type::Utf8 => Value::Utf8(str_val.to_string()),
119		Type::Uint1 => parse_primitive_uint::<u8>(fragment).map(Value::Uint1).unwrap_or(Value::Undefined),
120		Type::Uint2 => parse_primitive_uint::<u16>(fragment).map(Value::Uint2).unwrap_or(Value::Undefined),
121		Type::Uint4 => parse_primitive_uint::<u32>(fragment).map(Value::Uint4).unwrap_or(Value::Undefined),
122		Type::Uint8 => parse_primitive_uint::<u64>(fragment).map(Value::Uint8).unwrap_or(Value::Undefined),
123		Type::Uint16 => parse_primitive_uint::<u128>(fragment).map(Value::Uint16).unwrap_or(Value::Undefined),
124		Type::Date => parse_date(fragment).map(Value::Date).unwrap_or(Value::Undefined),
125		Type::DateTime => parse_datetime(fragment).map(Value::DateTime).unwrap_or(Value::Undefined),
126		Type::Time => parse_time(fragment).map(Value::Time).unwrap_or(Value::Undefined),
127		Type::Duration => parse_duration(fragment).map(Value::Duration).unwrap_or(Value::Undefined),
128		Type::RowNumber => parse_primitive_uint::<u64>(fragment)
129			.map(|id| Value::RowNumber(RowNumber::from(id)))
130			.unwrap_or(Value::Undefined),
131		Type::Uuid4 => parse_uuid4(fragment).map(Value::Uuid4).unwrap_or(Value::Undefined),
132		Type::Uuid7 => parse_uuid7(fragment).map(Value::Uuid7).unwrap_or(Value::Undefined),
133		Type::IdentityId => parse_uuid7(fragment)
134			.map(|uuid7| Value::IdentityId(IdentityId::from(uuid7)))
135			.unwrap_or(Value::Undefined),
136		Type::Blob => Blob::from_hex(fragment).map(Value::Blob).unwrap_or(Value::Undefined),
137		Type::Undefined => Value::Undefined,
138		Type::Decimal => parse_decimal(fragment).map(Value::Decimal).unwrap_or(Value::Undefined),
139		Type::Int | Type::Uint => {
140			unimplemented!()
141		}
142		Type::Any => unreachable!("Any type cannot be used as parameter"),
143	};
144
145	Ok(parsed_value)
146}
147
148impl<'de> Deserialize<'de> for Params {
149	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
150	where
151		D: Deserializer<'de>,
152	{
153		struct ParamsVisitor;
154
155		impl<'de> Visitor<'de> for ParamsVisitor {
156			type Value = Params;
157
158			fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
159				formatter.write_str(
160					"null, an array for positional parameters, or an object for named parameters",
161				)
162			}
163
164			fn visit_none<E>(self) -> Result<Self::Value, E>
165			where
166				E: de::Error,
167			{
168				Ok(Params::None)
169			}
170
171			fn visit_unit<E>(self) -> Result<Self::Value, E>
172			where
173				E: de::Error,
174			{
175				Ok(Params::None)
176			}
177
178			fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
179			where
180				A: de::SeqAccess<'de>,
181			{
182				let mut values = Vec::new();
183
184				while let Some(value) = seq.next_element::<serde_json::Value>()? {
185					// Check if it's in the {"type": "Bool",
186					// "value": "true"} format
187					if let Some(obj) = value.as_object() {
188						if obj.contains_key("type") && obj.contains_key("value") {
189							let type_str = obj["type"].as_str().ok_or_else(|| {
190								de::Error::custom("type must be a string")
191							})?;
192							let value_val = &obj["value"];
193
194							let parsed_value = parse_typed_value(type_str, value_val)
195								.map_err(de::Error::custom)?;
196							values.push(parsed_value);
197							continue;
198						}
199					}
200
201					// Otherwise try to deserialize as a
202					// normal Value
203					let val = Value::deserialize(value).map_err(de::Error::custom)?;
204					values.push(val);
205				}
206
207				Ok(Params::Positional(values))
208			}
209
210			fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
211			where
212				A: de::MapAccess<'de>,
213			{
214				let mut result_map = HashMap::new();
215
216				while let Some(key) = map.next_key::<String>()? {
217					let value: serde_json::Value = map.next_value()?;
218
219					// Check if it's in the {"type": "Bool",
220					// "value": "true"} format
221					if let Some(obj) = value.as_object() {
222						if obj.contains_key("type") && obj.contains_key("value") {
223							let type_str = obj["type"].as_str().ok_or_else(|| {
224								de::Error::custom("type must be a string")
225							})?;
226							let value_val = &obj["value"];
227
228							let parsed_value = parse_typed_value(type_str, value_val)
229								.map_err(de::Error::custom)?;
230							result_map.insert(key, parsed_value);
231							continue;
232						}
233					}
234
235					// Otherwise try to deserialize as a
236					// normal Value
237					let val = Value::deserialize(value).map_err(de::Error::custom)?;
238					result_map.insert(key, val);
239				}
240
241				Ok(Params::Named(result_map))
242			}
243		}
244
245		deserializer.deserialize_any(ParamsVisitor)
246	}
247}
248
249#[macro_export]
250macro_rules! params {
251    // Empty params
252    () => {
253        $crate::Params::None
254    };
255
256    // Empty named parameters
257    {} => {
258        $crate::Params::None
259    };
260
261    // Named parameters with mixed keys: params!{ name: value, "key": value }
262    { $($key:tt : $value:expr),+ $(,)? } => {
263        {
264            let mut map = ::std::collections::HashMap::new();
265            $(
266                map.insert($crate::params_key!($key), $crate::IntoValue::into_value($value));
267            )*
268            $crate::Params::Named(map)
269        }
270    };
271
272    // Empty positional parameters
273    [] => {
274        $crate::Params::None
275    };
276
277    // Positional parameters: params![value1, value2, ...]
278    [ $($value:expr),+ $(,)? ] => {
279        {
280            let values = vec![
281                $($crate::IntoValue::into_value($value)),*
282            ];
283            $crate::Params::Positional(values)
284        }
285    };
286}
287
288#[macro_export]
289#[doc(hidden)]
290macro_rules! params_key {
291	($key:ident) => {
292		stringify!($key).to_string()
293	};
294	($key:literal) => {
295		$key.to_string()
296	};
297}
298
299#[cfg(test)]
300mod tests {
301	use super::*;
302	use crate::IntoValue;
303
304	#[test]
305	fn test_params_macro_positional() {
306		let params = params![42, true, "hello"];
307		match params {
308			Params::Positional(values) => {
309				assert_eq!(values.len(), 3);
310				assert_eq!(values[0], Value::Int4(42));
311				assert_eq!(values[1], Value::Boolean(true));
312				assert_eq!(values[2], Value::Utf8("hello".to_string()));
313			}
314			_ => panic!("Expected positional params"),
315		}
316	}
317
318	#[test]
319	fn test_params_macro_named() {
320		let params = params! {
321		    name: true,
322		    other: 42,
323		    message: "test"
324		};
325		match params {
326			Params::Named(map) => {
327				assert_eq!(map.len(), 3);
328				assert_eq!(map.get("name"), Some(&Value::Boolean(true)));
329				assert_eq!(map.get("other"), Some(&Value::Int4(42)));
330				assert_eq!(map.get("message"), Some(&Value::Utf8("test".to_string())));
331			}
332			_ => panic!("Expected named params"),
333		}
334	}
335
336	#[test]
337	fn test_params_macro_named_with_strings() {
338		let params = params! {
339		    "string_key": 100,
340		    ident_key: 200,
341		    "another-key": "value"
342		};
343		match params {
344			Params::Named(map) => {
345				assert_eq!(map.len(), 3);
346				assert_eq!(map.get("string_key"), Some(&Value::Int4(100)));
347				assert_eq!(map.get("ident_key"), Some(&Value::Int4(200)));
348				assert_eq!(map.get("another-key"), Some(&Value::Utf8("value".to_string())));
349			}
350			_ => panic!("Expected named params"),
351		}
352	}
353
354	#[test]
355	fn test_params_macro_empty() {
356		let params = params!();
357		assert_eq!(params, Params::None);
358
359		let params2 = params! {};
360		assert_eq!(params2, Params::None);
361
362		let params3 = params![];
363		assert_eq!(params3, Params::None);
364	}
365
366	#[test]
367	fn test_params_macro_with_values() {
368		let v1 = Value::Int8(100);
369		let v2 = 200i64.into_value();
370
371		let params = params![v1, v2, 300];
372		match params {
373			Params::Positional(values) => {
374				assert_eq!(values.len(), 3);
375				assert_eq!(values[0], Value::Int8(100));
376				assert_eq!(values[1], Value::Int8(200));
377				assert_eq!(values[2], Value::Int4(300));
378			}
379			_ => panic!("Expected positional params"),
380		}
381	}
382
383	#[test]
384	fn test_params_macro_trailing_comma() {
385		let params1 = params![1, 2, 3,];
386		let params2 = params! { a: 1, b: 2};
387
388		match params1 {
389			Params::Positional(values) => {
390				assert_eq!(values.len(), 3)
391			}
392			_ => panic!("Expected positional params"),
393		}
394
395		match params2 {
396			Params::Named(map) => assert_eq!(map.len(), 2),
397			_ => panic!("Expected named params"),
398		}
399	}
400}