1use std::collections::{BTreeMap, HashMap};
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7pub struct TypedQuery {
8 pub cypher: String,
9 pub params: HashMap<String, String>,
10}
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub enum TypedValue {
14 Null,
15 String(String),
16 Integer(i64),
17 Float(f64),
18 Bool(bool),
19 List(Vec<TypedValue>),
20 Map(BTreeMap<String, TypedValue>),
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum IdentifierKind {
25 ParameterName,
26 MapKey,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ValueContext {
31 String,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub enum TypedQueryError {
36 InvalidIdentifier {
37 kind: IdentifierKind,
38 identifier: String,
39 },
40 ControlCharacter {
41 context: ValueContext,
42 codepoint: u32,
43 },
44 NonFiniteFloat {
45 value: String,
46 },
47}
48
49impl TypedQuery {
50 pub fn new(cypher: impl Into<String>) -> Self {
51 Self {
52 cypher: cypher.into(),
53 params: HashMap::new(),
54 }
55 }
56
57 pub fn with_params<I, K>(cypher: impl Into<String>, params: I) -> Result<Self, TypedQueryError>
58 where
59 I: IntoIterator<Item = (K, TypedValue)>,
60 K: Into<String>,
61 {
62 let mut query = Self::new(cypher);
63 for (name, value) in params {
64 query.insert_param(name, value)?;
65 }
66 Ok(query)
67 }
68
69 pub fn insert_param(
70 &mut self,
71 name: impl Into<String>,
72 value: TypedValue,
73 ) -> Result<(), TypedQueryError> {
74 let name = name.into();
75 validate_identifier(&name, IdentifierKind::ParameterName)?;
76 let rendered = render_cypher_value(&value)?;
77 self.params.insert(name, rendered);
78 Ok(())
79 }
80}
81
82pub fn cypher_string_literal(s: &str) -> String {
83 format!("'{}'", escape_string_contents(s))
84}
85
86pub fn render_cypher_value(value: &TypedValue) -> Result<String, TypedQueryError> {
87 match value {
88 TypedValue::Null => Ok("null".to_string()),
89 TypedValue::String(value) => render_string_literal(value),
90 TypedValue::Integer(value) => Ok(value.to_string()),
91 TypedValue::Float(value) => render_float(*value),
92 TypedValue::Bool(value) => Ok(value.to_string()),
93 TypedValue::List(values) => values
94 .iter()
95 .map(render_cypher_value)
96 .collect::<Result<Vec<_>, _>>()
97 .map(|values| format!("[{}]", values.join(", "))),
98 TypedValue::Map(values) => values
99 .iter()
100 .map(|(key, value)| {
101 validate_identifier(key, IdentifierKind::MapKey)?;
102 Ok(format!("{key}: {}", render_cypher_value(value)?))
103 })
104 .collect::<Result<Vec<_>, _>>()
105 .map(|values| format!("{{{}}}", values.join(", "))),
106 }
107}
108
109pub fn string_params(values: &[(&str, &str)]) -> HashMap<String, String> {
110 values
111 .iter()
112 .map(|(key, value)| ((*key).to_string(), cypher_string_literal(value)))
113 .collect()
114}
115
116pub fn clamp_limit(limit: usize, max: usize) -> usize {
117 limit.clamp(1, max)
118}
119
120pub fn clamp_offset(offset: usize, max: usize) -> usize {
121 offset.min(max)
122}
123
124pub fn id_list_literal(ids: &[String]) -> String {
125 ids.iter()
126 .map(|id| cypher_string_literal(id))
127 .collect::<Vec<_>>()
128 .join(", ")
129}
130
131pub fn validate_identifier(identifier: &str, kind: IdentifierKind) -> Result<(), TypedQueryError> {
132 let mut chars = identifier.chars();
133 let Some(first) = chars.next() else {
134 return Err(TypedQueryError::InvalidIdentifier {
135 kind,
136 identifier: identifier.to_string(),
137 });
138 };
139
140 if !(first == '_' || first.is_ascii_alphabetic())
141 || !chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
142 {
143 return Err(TypedQueryError::InvalidIdentifier {
144 kind,
145 identifier: identifier.to_string(),
146 });
147 }
148
149 Ok(())
150}
151
152fn render_string_literal(value: &str) -> Result<String, TypedQueryError> {
153 Ok(cypher_string_literal(value))
154}
155
156fn escape_string_contents(value: &str) -> String {
157 let mut escaped = String::with_capacity(value.len());
158 for ch in value.chars() {
159 match ch {
160 '\\' => escaped.push_str("\\\\"),
161 '\'' => escaped.push_str("\\'"),
162 '"' => escaped.push_str("\\\""),
163 '\n' => escaped.push_str("\\n"),
164 '\r' => escaped.push_str("\\r"),
165 '\t' => escaped.push_str("\\t"),
166 '\u{0008}' => escaped.push_str("\\b"),
167 '\u{000C}' => escaped.push_str("\\f"),
168 ch if ch.is_control() => escaped.push_str(&format!("\\u{:04X}", ch as u32)),
169 ch => escaped.push(ch),
170 }
171 }
172 escaped
173}
174
175fn render_float(value: f64) -> Result<String, TypedQueryError> {
176 if !value.is_finite() {
177 return Err(TypedQueryError::NonFiniteFloat {
178 value: value.to_string(),
179 });
180 }
181
182 let mut rendered = value.to_string();
183 if !rendered.contains('.') && !rendered.contains('e') && !rendered.contains('E') {
184 rendered.push_str(".0");
185 }
186 Ok(rendered)
187}
188
189impl fmt::Display for IdentifierKind {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 match self {
192 Self::ParameterName => f.write_str("parameter name"),
193 Self::MapKey => f.write_str("map key"),
194 }
195 }
196}
197
198impl fmt::Display for ValueContext {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 match self {
201 Self::String => f.write_str("string"),
202 }
203 }
204}
205
206impl fmt::Display for TypedQueryError {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 match self {
209 Self::InvalidIdentifier { kind, identifier } => write!(
210 f,
211 "invalid {kind} `{identifier}`; expected ^[A-Za-z_][A-Za-z0-9_]*$"
212 ),
213 Self::ControlCharacter { context, codepoint } => write!(
214 f,
215 "control character U+{codepoint:04X} is not allowed in {context} value"
216 ),
217 Self::NonFiniteFloat { value } => {
218 write!(f, "non-finite float `{value}` is not allowed")
219 }
220 }
221 }
222}
223
224impl std::error::Error for TypedQueryError {}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use std::collections::BTreeMap;
230
231 #[test]
232 fn typed_params_render_nested_safe_cypher_literals() {
233 let mut props = BTreeMap::new();
234 props.insert("enabled".to_string(), TypedValue::Bool(true));
235 props.insert(
236 "label".to_string(),
237 TypedValue::String("caf\u{00e9} \"quote\" and 'single' \\ slash".to_string()),
238 );
239 props.insert(
240 "nested".to_string(),
241 TypedValue::List(vec![
242 TypedValue::Integer(1),
243 TypedValue::Float(2.25),
244 TypedValue::Bool(false),
245 ]),
246 );
247
248 let query = TypedQuery::with_params(
249 "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props",
250 [
251 (
252 "name",
253 TypedValue::String("O'Reilly \\ path \u{2603}".to_string()),
254 ),
255 ("count", TypedValue::Integer(42)),
256 ("ratio", TypedValue::Float(1.5)),
257 ("whole", TypedValue::Float(1.0)),
258 ("enabled", TypedValue::Bool(true)),
259 (
260 "items",
261 TypedValue::List(vec![
262 TypedValue::String("a".to_string()),
263 TypedValue::Integer(-7),
264 TypedValue::Bool(false),
265 ]),
266 ),
267 ("props", TypedValue::Map(props)),
268 ],
269 )
270 .expect("valid typed params should render");
271
272 assert_eq!(
273 query.cypher,
274 "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props"
275 );
276 assert_eq!(
277 query.params.get("name").map(String::as_str),
278 Some("'O\\'Reilly \\\\ path \u{2603}'")
279 );
280 assert_eq!(query.params.get("count").map(String::as_str), Some("42"));
281 assert_eq!(query.params.get("ratio").map(String::as_str), Some("1.5"));
282 assert_eq!(query.params.get("whole").map(String::as_str), Some("1.0"));
283 assert_eq!(
284 query.params.get("enabled").map(String::as_str),
285 Some("true")
286 );
287 assert_eq!(
288 query.params.get("items").map(String::as_str),
289 Some("['a', -7, false]")
290 );
291 assert_eq!(
292 query.params.get("props").map(String::as_str),
293 Some(
294 "{enabled: true, label: 'caf\u{00e9} \\\"quote\\\" and \\'single\\' \\\\ slash', nested: [1, 2.25, false]}"
295 )
296 );
297 }
298
299 #[test]
300 fn string_literals_escape_both_quote_delimiters() {
301 let rendered = render_cypher_value(&TypedValue::String("a 'single' and \"double\"".into()))
302 .expect("valid string should render");
303
304 assert_eq!(rendered, "'a \\'single\\' and \\\"double\\\"'");
305 }
306
307 #[test]
308 fn string_literals_escape_control_characters() {
309 let rendered = render_cypher_value(&TypedValue::String(
310 "line\ncarriage\rtab\tbackspace\u{0008}form\u{000C}escape\u{001B}".into(),
311 ))
312 .expect("control characters should render as escaped literals");
313
314 assert_eq!(
315 rendered,
316 "'line\\ncarriage\\rtab\\tbackspace\\bform\\fescape\\u001B'"
317 );
318 }
319
320 #[test]
321 fn nested_string_values_escape_control_characters() {
322 let mut props = BTreeMap::new();
323 props.insert(
324 "items".to_string(),
325 TypedValue::List(vec![TypedValue::String("line\nitem".to_string())]),
326 );
327 props.insert(
328 "label".to_string(),
329 TypedValue::String("tab\tvalue".to_string()),
330 );
331
332 let rendered =
333 render_cypher_value(&TypedValue::Map(props)).expect("nested strings should render");
334
335 assert_eq!(rendered, "{items: ['line\\nitem'], label: 'tab\\tvalue'}");
336 }
337
338 #[test]
339 fn invalid_identifiers_return_typed_errors() {
340 let param_error =
341 TypedQuery::with_params("RETURN $bad", [("bad-name", TypedValue::Bool(true))])
342 .expect_err("invalid parameter name should fail");
343 assert_eq!(
344 param_error,
345 TypedQueryError::InvalidIdentifier {
346 kind: IdentifierKind::ParameterName,
347 identifier: "bad-name".to_string(),
348 }
349 );
350
351 let mut props = BTreeMap::new();
352 props.insert("bad.key".to_string(), TypedValue::Integer(1));
353 let map_error =
354 render_cypher_value(&TypedValue::Map(props)).expect_err("invalid map key should fail");
355 assert_eq!(
356 map_error,
357 TypedQueryError::InvalidIdentifier {
358 kind: IdentifierKind::MapKey,
359 identifier: "bad.key".to_string(),
360 }
361 );
362 }
363
364 #[test]
365 fn unsafe_values_return_typed_errors() {
366 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
367 let error = render_cypher_value(&TypedValue::Float(value))
368 .expect_err("non-finite float should fail");
369 assert!(matches!(error, TypedQueryError::NonFiniteFloat { .. }));
370 }
371 }
372}