Skip to main content

gobby_code/graph/
typed_query.rs

1use std::collections::{BTreeMap, HashMap};
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7pub struct TypedQuery {
8    pub cypher: String,
9    pub params: HashMap<String, String>,
10}
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub enum TypedValue {
14    Null,
15    String(String),
16    Integer(i64),
17    Float(f64),
18    Bool(bool),
19    List(Vec<TypedValue>),
20    Map(BTreeMap<String, TypedValue>),
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum IdentifierKind {
25    ParameterName,
26    MapKey,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum TypedQueryError {
31    InvalidIdentifier {
32        kind: IdentifierKind,
33        identifier: String,
34    },
35    NonFiniteFloat {
36        value: String,
37    },
38}
39
40impl TypedQuery {
41    pub fn new(cypher: impl Into<String>) -> Self {
42        Self {
43            cypher: cypher.into(),
44            params: HashMap::new(),
45        }
46    }
47
48    pub fn with_params<I, K>(cypher: impl Into<String>, params: I) -> Result<Self, TypedQueryError>
49    where
50        I: IntoIterator<Item = (K, TypedValue)>,
51        K: Into<String>,
52    {
53        let mut query = Self::new(cypher);
54        for (name, value) in params {
55            query.insert_param(name, value)?;
56        }
57        Ok(query)
58    }
59
60    pub fn insert_param(
61        &mut self,
62        name: impl Into<String>,
63        value: TypedValue,
64    ) -> Result<(), TypedQueryError> {
65        let name = name.into();
66        validate_identifier(&name, IdentifierKind::ParameterName)?;
67        let rendered = render_cypher_value(&value)?;
68        self.params.insert(name, rendered);
69        Ok(())
70    }
71}
72
73pub fn cypher_string_literal(s: &str) -> String {
74    format!("'{}'", escape_string_contents(s))
75}
76
77pub fn render_cypher_value(value: &TypedValue) -> Result<String, TypedQueryError> {
78    match value {
79        TypedValue::Null => Ok("null".to_string()),
80        TypedValue::String(value) => render_string_literal(value),
81        TypedValue::Integer(value) => Ok(value.to_string()),
82        TypedValue::Float(value) => render_float(*value),
83        TypedValue::Bool(value) => Ok(value.to_string()),
84        TypedValue::List(values) => values
85            .iter()
86            .map(render_cypher_value)
87            .collect::<Result<Vec<_>, _>>()
88            .map(|values| format!("[{}]", values.join(", "))),
89        TypedValue::Map(values) => values
90            .iter()
91            .map(|(key, value)| {
92                validate_identifier(key, IdentifierKind::MapKey)?;
93                Ok(format!("{key}: {}", render_cypher_value(value)?))
94            })
95            .collect::<Result<Vec<_>, _>>()
96            .map(|values| format!("{{{}}}", values.join(", "))),
97    }
98}
99
100pub fn string_params(values: &[(&str, &str)]) -> HashMap<String, String> {
101    values
102        .iter()
103        .map(|(key, value)| ((*key).to_string(), cypher_string_literal(value)))
104        .collect()
105}
106
107pub fn clamp_limit(limit: usize, max: usize) -> usize {
108    limit.clamp(1, max)
109}
110
111pub fn clamp_offset(offset: usize, max: usize) -> usize {
112    offset.min(max)
113}
114
115pub fn id_list_literal(ids: &[String]) -> String {
116    ids.iter()
117        .map(|id| cypher_string_literal(id))
118        .collect::<Vec<_>>()
119        .join(", ")
120}
121
122pub fn validate_identifier(identifier: &str, kind: IdentifierKind) -> Result<(), TypedQueryError> {
123    let mut chars = identifier.chars();
124    let Some(first) = chars.next() else {
125        return Err(TypedQueryError::InvalidIdentifier {
126            kind,
127            identifier: identifier.to_string(),
128        });
129    };
130
131    if !(first == '_' || first.is_ascii_alphabetic())
132        || !chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
133    {
134        return Err(TypedQueryError::InvalidIdentifier {
135            kind,
136            identifier: identifier.to_string(),
137        });
138    }
139
140    Ok(())
141}
142
143fn render_string_literal(value: &str) -> Result<String, TypedQueryError> {
144    Ok(cypher_string_literal(value))
145}
146
147fn escape_string_contents(value: &str) -> String {
148    let mut escaped = String::with_capacity(value.len());
149    for ch in value.chars() {
150        match ch {
151            '\\' => escaped.push_str("\\\\"),
152            '\'' => escaped.push_str("\\'"),
153            '"' => escaped.push_str("\\\""),
154            '\n' => escaped.push_str("\\n"),
155            '\r' => escaped.push_str("\\r"),
156            '\t' => escaped.push_str("\\t"),
157            '\u{0008}' => escaped.push_str("\\b"),
158            '\u{000C}' => escaped.push_str("\\f"),
159            ch if ch.is_control() => escaped.push_str(&format!("\\u{:04X}", ch as u32)),
160            ch => escaped.push(ch),
161        }
162    }
163    escaped
164}
165
166fn render_float(value: f64) -> Result<String, TypedQueryError> {
167    if !value.is_finite() {
168        return Err(TypedQueryError::NonFiniteFloat {
169            value: value.to_string(),
170        });
171    }
172
173    let mut rendered = value.to_string();
174    if !rendered.contains('.') && !rendered.contains('e') && !rendered.contains('E') {
175        rendered.push_str(".0");
176    }
177    Ok(rendered)
178}
179
180impl fmt::Display for IdentifierKind {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        match self {
183            Self::ParameterName => f.write_str("parameter name"),
184            Self::MapKey => f.write_str("map key"),
185        }
186    }
187}
188
189impl fmt::Display for TypedQueryError {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        match self {
192            Self::InvalidIdentifier { kind, identifier } => write!(
193                f,
194                "invalid {kind} `{identifier}`; expected ^[A-Za-z_][A-Za-z0-9_]*$"
195            ),
196            Self::NonFiniteFloat { value } => {
197                write!(f, "non-finite float `{value}` is not allowed")
198            }
199        }
200    }
201}
202
203impl std::error::Error for TypedQueryError {}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use std::collections::BTreeMap;
209
210    #[test]
211    fn typed_params_render_nested_safe_cypher_literals() {
212        let mut props = BTreeMap::new();
213        props.insert("enabled".to_string(), TypedValue::Bool(true));
214        props.insert(
215            "label".to_string(),
216            TypedValue::String("caf\u{00e9} \"quote\" and 'single' \\ slash".to_string()),
217        );
218        props.insert(
219            "nested".to_string(),
220            TypedValue::List(vec![
221                TypedValue::Integer(1),
222                TypedValue::Float(2.25),
223                TypedValue::Bool(false),
224            ]),
225        );
226
227        let query = TypedQuery::with_params(
228            "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props",
229            [
230                (
231                    "name",
232                    TypedValue::String("O'Reilly \\ path \u{2603}".to_string()),
233                ),
234                ("count", TypedValue::Integer(42)),
235                ("ratio", TypedValue::Float(1.5)),
236                ("whole", TypedValue::Float(1.0)),
237                ("enabled", TypedValue::Bool(true)),
238                (
239                    "items",
240                    TypedValue::List(vec![
241                        TypedValue::String("a".to_string()),
242                        TypedValue::Integer(-7),
243                        TypedValue::Bool(false),
244                    ]),
245                ),
246                ("props", TypedValue::Map(props)),
247            ],
248        )
249        .expect("valid typed params should render");
250
251        assert_eq!(
252            query.cypher,
253            "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props"
254        );
255        assert_eq!(
256            query.params.get("name").map(String::as_str),
257            Some("'O\\'Reilly \\\\ path \u{2603}'")
258        );
259        assert_eq!(query.params.get("count").map(String::as_str), Some("42"));
260        assert_eq!(query.params.get("ratio").map(String::as_str), Some("1.5"));
261        assert_eq!(query.params.get("whole").map(String::as_str), Some("1.0"));
262        assert_eq!(
263            query.params.get("enabled").map(String::as_str),
264            Some("true")
265        );
266        assert_eq!(
267            query.params.get("items").map(String::as_str),
268            Some("['a', -7, false]")
269        );
270        assert_eq!(
271            query.params.get("props").map(String::as_str),
272            Some(
273                "{enabled: true, label: 'caf\u{00e9} \\\"quote\\\" and \\'single\\' \\\\ slash', nested: [1, 2.25, false]}"
274            )
275        );
276    }
277
278    #[test]
279    fn string_literals_escape_both_quote_delimiters() {
280        let rendered = render_cypher_value(&TypedValue::String("a 'single' and \"double\"".into()))
281            .expect("valid string should render");
282
283        assert_eq!(rendered, "'a \\'single\\' and \\\"double\\\"'");
284    }
285
286    #[test]
287    fn string_literals_escape_control_characters() {
288        let rendered = render_cypher_value(&TypedValue::String(
289            "line\ncarriage\rtab\tbackspace\u{0008}form\u{000C}escape\u{001B}".into(),
290        ))
291        .expect("control characters should render as escaped literals");
292
293        assert_eq!(
294            rendered,
295            "'line\\ncarriage\\rtab\\tbackspace\\bform\\fescape\\u001B'"
296        );
297    }
298
299    #[test]
300    fn nested_string_values_escape_control_characters() {
301        let mut props = BTreeMap::new();
302        props.insert(
303            "items".to_string(),
304            TypedValue::List(vec![TypedValue::String("line\nitem".to_string())]),
305        );
306        props.insert(
307            "label".to_string(),
308            TypedValue::String("tab\tvalue".to_string()),
309        );
310
311        let rendered =
312            render_cypher_value(&TypedValue::Map(props)).expect("nested strings should render");
313
314        assert_eq!(rendered, "{items: ['line\\nitem'], label: 'tab\\tvalue'}");
315    }
316
317    #[test]
318    fn invalid_identifiers_return_typed_errors() {
319        let param_error =
320            TypedQuery::with_params("RETURN $bad", [("bad-name", TypedValue::Bool(true))])
321                .expect_err("invalid parameter name should fail");
322        assert_eq!(
323            param_error,
324            TypedQueryError::InvalidIdentifier {
325                kind: IdentifierKind::ParameterName,
326                identifier: "bad-name".to_string(),
327            }
328        );
329
330        let mut props = BTreeMap::new();
331        props.insert("bad.key".to_string(), TypedValue::Integer(1));
332        let map_error =
333            render_cypher_value(&TypedValue::Map(props)).expect_err("invalid map key should fail");
334        assert_eq!(
335            map_error,
336            TypedQueryError::InvalidIdentifier {
337                kind: IdentifierKind::MapKey,
338                identifier: "bad.key".to_string(),
339            }
340        );
341    }
342
343    #[test]
344    fn unsafe_values_return_typed_errors() {
345        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
346            let error = render_cypher_value(&TypedValue::Float(value))
347                .expect_err("non-finite float should fail");
348            assert!(matches!(error, TypedQueryError::NonFiniteFloat { .. }));
349        }
350    }
351}