1use std::collections::{BTreeMap, HashMap};
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7pub struct TypedQuery {
8 pub cypher: String,
9 pub params: HashMap<String, String>,
10}
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub enum TypedValue {
14 Null,
15 String(String),
16 Integer(i64),
17 Float(f64),
18 Bool(bool),
19 List(Vec<TypedValue>),
20 Map(BTreeMap<String, TypedValue>),
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum IdentifierKind {
25 ParameterName,
26 MapKey,
27}
28
29#[deprecated(
32 since = "0.9.10",
33 note = "string values are escaped now; this context is retained only for API compatibility"
34)]
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ValueContext {
37 String,
38}
39
40#[allow(deprecated)]
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum TypedQueryError {
43 InvalidIdentifier {
44 kind: IdentifierKind,
45 identifier: String,
46 },
47 #[deprecated(
52 since = "0.9.10",
53 note = "string control characters are escaped instead of rejected"
54 )]
55 ControlCharacter {
56 #[allow(deprecated)]
57 context: ValueContext,
58 codepoint: u32,
59 },
60 NonFiniteFloat {
61 value: String,
62 },
63}
64
65impl TypedQuery {
66 pub fn new(cypher: impl Into<String>) -> Self {
67 Self {
68 cypher: cypher.into(),
69 params: HashMap::new(),
70 }
71 }
72
73 pub fn with_params<I, K>(cypher: impl Into<String>, params: I) -> Result<Self, TypedQueryError>
74 where
75 I: IntoIterator<Item = (K, TypedValue)>,
76 K: Into<String>,
77 {
78 let mut query = Self::new(cypher);
79 for (name, value) in params {
80 query.insert_param(name, value)?;
81 }
82 Ok(query)
83 }
84
85 pub fn insert_param(
86 &mut self,
87 name: impl Into<String>,
88 value: TypedValue,
89 ) -> Result<(), TypedQueryError> {
90 let name = name.into();
91 validate_identifier(&name, IdentifierKind::ParameterName)?;
92 let rendered = render_cypher_value(&value)?;
93 self.params.insert(name, rendered);
94 Ok(())
95 }
96}
97
98pub fn cypher_string_literal(s: &str) -> String {
99 format!("'{}'", escape_string_contents(s))
100}
101
102pub fn render_cypher_value(value: &TypedValue) -> Result<String, TypedQueryError> {
103 match value {
104 TypedValue::Null => Ok("null".to_string()),
105 TypedValue::String(value) => render_string_literal(value),
106 TypedValue::Integer(value) => Ok(value.to_string()),
107 TypedValue::Float(value) => render_float(*value),
108 TypedValue::Bool(value) => Ok(value.to_string()),
109 TypedValue::List(values) => values
110 .iter()
111 .map(render_cypher_value)
112 .collect::<Result<Vec<_>, _>>()
113 .map(|values| format!("[{}]", values.join(", "))),
114 TypedValue::Map(values) => values
115 .iter()
116 .map(|(key, value)| {
117 validate_identifier(key, IdentifierKind::MapKey)?;
118 Ok(format!("{key}: {}", render_cypher_value(value)?))
119 })
120 .collect::<Result<Vec<_>, _>>()
121 .map(|values| format!("{{{}}}", values.join(", "))),
122 }
123}
124
125pub fn string_params(values: &[(&str, &str)]) -> HashMap<String, String> {
126 values
127 .iter()
128 .map(|(key, value)| ((*key).to_string(), cypher_string_literal(value)))
129 .collect()
130}
131
132pub fn clamp_limit(limit: usize, max: usize) -> usize {
133 limit.clamp(1, max)
134}
135
136pub fn clamp_offset(offset: usize, max: usize) -> usize {
137 offset.min(max)
138}
139
140pub fn id_list_literal(ids: &[String]) -> String {
141 ids.iter()
142 .map(|id| cypher_string_literal(id))
143 .collect::<Vec<_>>()
144 .join(", ")
145}
146
147pub fn validate_identifier(identifier: &str, kind: IdentifierKind) -> Result<(), TypedQueryError> {
148 let mut chars = identifier.chars();
149 let Some(first) = chars.next() else {
150 return Err(TypedQueryError::InvalidIdentifier {
151 kind,
152 identifier: identifier.to_string(),
153 });
154 };
155
156 if !(first == '_' || first.is_ascii_alphabetic())
157 || !chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
158 {
159 return Err(TypedQueryError::InvalidIdentifier {
160 kind,
161 identifier: identifier.to_string(),
162 });
163 }
164
165 Ok(())
166}
167
168fn render_string_literal(value: &str) -> Result<String, TypedQueryError> {
169 Ok(cypher_string_literal(value))
170}
171
172fn escape_string_contents(value: &str) -> String {
173 let mut escaped = String::with_capacity(value.len());
174 for ch in value.chars() {
175 match ch {
176 '\\' => escaped.push_str("\\\\"),
177 '\'' => escaped.push_str("\\'"),
178 '"' => escaped.push_str("\\\""),
179 '\n' => escaped.push_str("\\n"),
180 '\r' => escaped.push_str("\\r"),
181 '\t' => escaped.push_str("\\t"),
182 '\u{0008}' => escaped.push_str("\\b"),
183 '\u{000C}' => escaped.push_str("\\f"),
184 ch if ch.is_control() => escaped.push_str(&format!("\\u{:04X}", ch as u32)),
185 ch => escaped.push(ch),
186 }
187 }
188 escaped
189}
190
191fn render_float(value: f64) -> Result<String, TypedQueryError> {
192 if !value.is_finite() {
193 return Err(TypedQueryError::NonFiniteFloat {
194 value: value.to_string(),
195 });
196 }
197
198 let mut rendered = value.to_string();
199 if !rendered.contains('.') && !rendered.contains('e') && !rendered.contains('E') {
200 rendered.push_str(".0");
201 }
202 Ok(rendered)
203}
204
205impl fmt::Display for IdentifierKind {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 match self {
208 Self::ParameterName => f.write_str("parameter name"),
209 Self::MapKey => f.write_str("map key"),
210 }
211 }
212}
213
214#[allow(deprecated)]
215impl fmt::Display for ValueContext {
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 match self {
218 Self::String => f.write_str("string"),
219 }
220 }
221}
222
223#[allow(deprecated)]
224impl fmt::Display for TypedQueryError {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 match self {
227 Self::InvalidIdentifier { kind, identifier } => write!(
228 f,
229 "invalid {kind} `{identifier}`; expected ^[A-Za-z_][A-Za-z0-9_]*$"
230 ),
231 Self::ControlCharacter { context, codepoint } => write!(
232 f,
233 "control character U+{codepoint:04X} is not allowed in {context} value"
234 ),
235 Self::NonFiniteFloat { value } => {
236 write!(f, "non-finite float `{value}` is not allowed")
237 }
238 }
239 }
240}
241
242impl std::error::Error for TypedQueryError {}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use std::collections::BTreeMap;
248
249 #[test]
250 fn typed_params_render_nested_safe_cypher_literals() {
251 let mut props = BTreeMap::new();
252 props.insert("enabled".to_string(), TypedValue::Bool(true));
253 props.insert(
254 "label".to_string(),
255 TypedValue::String("caf\u{00e9} \"quote\" and 'single' \\ slash".to_string()),
256 );
257 props.insert(
258 "nested".to_string(),
259 TypedValue::List(vec![
260 TypedValue::Integer(1),
261 TypedValue::Float(2.25),
262 TypedValue::Bool(false),
263 ]),
264 );
265
266 let query = TypedQuery::with_params(
267 "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props",
268 [
269 (
270 "name",
271 TypedValue::String("O'Reilly \\ path \u{2603}".to_string()),
272 ),
273 ("count", TypedValue::Integer(42)),
274 ("ratio", TypedValue::Float(1.5)),
275 ("whole", TypedValue::Float(1.0)),
276 ("enabled", TypedValue::Bool(true)),
277 (
278 "items",
279 TypedValue::List(vec![
280 TypedValue::String("a".to_string()),
281 TypedValue::Integer(-7),
282 TypedValue::Bool(false),
283 ]),
284 ),
285 ("props", TypedValue::Map(props)),
286 ],
287 )
288 .expect("valid typed params should render");
289
290 assert_eq!(
291 query.cypher,
292 "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props"
293 );
294 assert_eq!(
295 query.params.get("name").map(String::as_str),
296 Some("'O\\'Reilly \\\\ path \u{2603}'")
297 );
298 assert_eq!(query.params.get("count").map(String::as_str), Some("42"));
299 assert_eq!(query.params.get("ratio").map(String::as_str), Some("1.5"));
300 assert_eq!(query.params.get("whole").map(String::as_str), Some("1.0"));
301 assert_eq!(
302 query.params.get("enabled").map(String::as_str),
303 Some("true")
304 );
305 assert_eq!(
306 query.params.get("items").map(String::as_str),
307 Some("['a', -7, false]")
308 );
309 assert_eq!(
310 query.params.get("props").map(String::as_str),
311 Some(
312 "{enabled: true, label: 'caf\u{00e9} \\\"quote\\\" and \\'single\\' \\\\ slash', nested: [1, 2.25, false]}"
313 )
314 );
315 }
316
317 #[test]
318 fn string_literals_escape_both_quote_delimiters() {
319 let rendered = render_cypher_value(&TypedValue::String("a 'single' and \"double\"".into()))
320 .expect("valid string should render");
321
322 assert_eq!(rendered, "'a \\'single\\' and \\\"double\\\"'");
323 }
324
325 #[test]
326 fn string_literals_escape_control_characters() {
327 let rendered = render_cypher_value(&TypedValue::String(
328 "line\ncarriage\rtab\tbackspace\u{0008}form\u{000C}escape\u{001B}".into(),
329 ))
330 .expect("control characters should render as escaped literals");
331
332 assert_eq!(
333 rendered,
334 "'line\\ncarriage\\rtab\\tbackspace\\bform\\fescape\\u001B'"
335 );
336 }
337
338 #[test]
339 fn nested_string_values_escape_control_characters() {
340 let mut props = BTreeMap::new();
341 props.insert(
342 "items".to_string(),
343 TypedValue::List(vec![TypedValue::String("line\nitem".to_string())]),
344 );
345 props.insert(
346 "label".to_string(),
347 TypedValue::String("tab\tvalue".to_string()),
348 );
349
350 let rendered =
351 render_cypher_value(&TypedValue::Map(props)).expect("nested strings should render");
352
353 assert_eq!(rendered, "{items: ['line\\nitem'], label: 'tab\\tvalue'}");
354 }
355
356 #[test]
357 fn invalid_identifiers_return_typed_errors() {
358 let param_error =
359 TypedQuery::with_params("RETURN $bad", [("bad-name", TypedValue::Bool(true))])
360 .expect_err("invalid parameter name should fail");
361 assert_eq!(
362 param_error,
363 TypedQueryError::InvalidIdentifier {
364 kind: IdentifierKind::ParameterName,
365 identifier: "bad-name".to_string(),
366 }
367 );
368
369 let mut props = BTreeMap::new();
370 props.insert("bad.key".to_string(), TypedValue::Integer(1));
371 let map_error =
372 render_cypher_value(&TypedValue::Map(props)).expect_err("invalid map key should fail");
373 assert_eq!(
374 map_error,
375 TypedQueryError::InvalidIdentifier {
376 kind: IdentifierKind::MapKey,
377 identifier: "bad.key".to_string(),
378 }
379 );
380 }
381
382 #[test]
383 fn unsafe_values_return_typed_errors() {
384 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
385 let error = render_cypher_value(&TypedValue::Float(value))
386 .expect_err("non-finite float should fail");
387 assert!(matches!(error, TypedQueryError::NonFiniteFloat { .. }));
388 }
389 }
390}