Skip to main content

gobby_code/graph/
typed_query.rs

1use std::collections::{BTreeMap, HashMap};
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7pub struct TypedQuery {
8    pub cypher: String,
9    pub params: HashMap<String, String>,
10}
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub enum TypedValue {
14    Null,
15    String(String),
16    Integer(i64),
17    Float(f64),
18    Bool(bool),
19    List(Vec<TypedValue>),
20    Map(BTreeMap<String, TypedValue>),
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum IdentifierKind {
25    ParameterName,
26    MapKey,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ValueContext {
31    String,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub enum TypedQueryError {
36    InvalidIdentifier {
37        kind: IdentifierKind,
38        identifier: String,
39    },
40    ControlCharacter {
41        context: ValueContext,
42        codepoint: u32,
43    },
44    NonFiniteFloat {
45        value: String,
46    },
47}
48
49impl TypedQuery {
50    pub fn new(cypher: impl Into<String>) -> Self {
51        Self {
52            cypher: cypher.into(),
53            params: HashMap::new(),
54        }
55    }
56
57    pub fn with_params<I, K>(cypher: impl Into<String>, params: I) -> Result<Self, TypedQueryError>
58    where
59        I: IntoIterator<Item = (K, TypedValue)>,
60        K: Into<String>,
61    {
62        let mut query = Self::new(cypher);
63        for (name, value) in params {
64            query.insert_param(name, value)?;
65        }
66        Ok(query)
67    }
68
69    pub fn insert_param(
70        &mut self,
71        name: impl Into<String>,
72        value: TypedValue,
73    ) -> Result<(), TypedQueryError> {
74        let name = name.into();
75        validate_identifier(&name, IdentifierKind::ParameterName)?;
76        let rendered = render_cypher_value(&value)?;
77        self.params.insert(name, rendered);
78        Ok(())
79    }
80}
81
82pub fn cypher_string_literal(s: &str) -> String {
83    format!("'{}'", escape_string_contents(s))
84}
85
86pub fn render_cypher_value(value: &TypedValue) -> Result<String, TypedQueryError> {
87    match value {
88        TypedValue::Null => Ok("null".to_string()),
89        TypedValue::String(value) => render_string_literal(value),
90        TypedValue::Integer(value) => Ok(value.to_string()),
91        TypedValue::Float(value) => render_float(*value),
92        TypedValue::Bool(value) => Ok(value.to_string()),
93        TypedValue::List(values) => values
94            .iter()
95            .map(render_cypher_value)
96            .collect::<Result<Vec<_>, _>>()
97            .map(|values| format!("[{}]", values.join(", "))),
98        TypedValue::Map(values) => values
99            .iter()
100            .map(|(key, value)| {
101                validate_identifier(key, IdentifierKind::MapKey)?;
102                Ok(format!("{key}: {}", render_cypher_value(value)?))
103            })
104            .collect::<Result<Vec<_>, _>>()
105            .map(|values| format!("{{{}}}", values.join(", "))),
106    }
107}
108
109pub fn string_params(values: &[(&str, &str)]) -> HashMap<String, String> {
110    values
111        .iter()
112        .map(|(key, value)| ((*key).to_string(), cypher_string_literal(value)))
113        .collect()
114}
115
116pub fn clamp_limit(limit: usize, max: usize) -> usize {
117    limit.clamp(1, max)
118}
119
120pub fn clamp_offset(offset: usize, max: usize) -> usize {
121    offset.min(max)
122}
123
124pub fn id_list_literal(ids: &[String]) -> String {
125    ids.iter()
126        .map(|id| cypher_string_literal(id))
127        .collect::<Vec<_>>()
128        .join(", ")
129}
130
131pub fn validate_identifier(identifier: &str, kind: IdentifierKind) -> Result<(), TypedQueryError> {
132    let mut chars = identifier.chars();
133    let Some(first) = chars.next() else {
134        return Err(TypedQueryError::InvalidIdentifier {
135            kind,
136            identifier: identifier.to_string(),
137        });
138    };
139
140    if !(first == '_' || first.is_ascii_alphabetic())
141        || !chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
142    {
143        return Err(TypedQueryError::InvalidIdentifier {
144            kind,
145            identifier: identifier.to_string(),
146        });
147    }
148
149    Ok(())
150}
151
152fn render_string_literal(value: &str) -> Result<String, TypedQueryError> {
153    Ok(cypher_string_literal(value))
154}
155
156fn escape_string_contents(value: &str) -> String {
157    let mut escaped = String::with_capacity(value.len());
158    for ch in value.chars() {
159        match ch {
160            '\\' => escaped.push_str("\\\\"),
161            '\'' => escaped.push_str("\\'"),
162            '"' => escaped.push_str("\\\""),
163            '\n' => escaped.push_str("\\n"),
164            '\r' => escaped.push_str("\\r"),
165            '\t' => escaped.push_str("\\t"),
166            '\u{0008}' => escaped.push_str("\\b"),
167            '\u{000C}' => escaped.push_str("\\f"),
168            ch if ch.is_control() => escaped.push_str(&format!("\\u{:04X}", ch as u32)),
169            ch => escaped.push(ch),
170        }
171    }
172    escaped
173}
174
175fn render_float(value: f64) -> Result<String, TypedQueryError> {
176    if !value.is_finite() {
177        return Err(TypedQueryError::NonFiniteFloat {
178            value: value.to_string(),
179        });
180    }
181
182    let mut rendered = value.to_string();
183    if !rendered.contains('.') && !rendered.contains('e') && !rendered.contains('E') {
184        rendered.push_str(".0");
185    }
186    Ok(rendered)
187}
188
189impl fmt::Display for IdentifierKind {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        match self {
192            Self::ParameterName => f.write_str("parameter name"),
193            Self::MapKey => f.write_str("map key"),
194        }
195    }
196}
197
198impl fmt::Display for ValueContext {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        match self {
201            Self::String => f.write_str("string"),
202        }
203    }
204}
205
206impl fmt::Display for TypedQueryError {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        match self {
209            Self::InvalidIdentifier { kind, identifier } => write!(
210                f,
211                "invalid {kind} `{identifier}`; expected ^[A-Za-z_][A-Za-z0-9_]*$"
212            ),
213            Self::ControlCharacter { context, codepoint } => write!(
214                f,
215                "control character U+{codepoint:04X} is not allowed in {context} value"
216            ),
217            Self::NonFiniteFloat { value } => {
218                write!(f, "non-finite float `{value}` is not allowed")
219            }
220        }
221    }
222}
223
224impl std::error::Error for TypedQueryError {}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use std::collections::BTreeMap;
230
231    #[test]
232    fn typed_params_render_nested_safe_cypher_literals() {
233        let mut props = BTreeMap::new();
234        props.insert("enabled".to_string(), TypedValue::Bool(true));
235        props.insert(
236            "label".to_string(),
237            TypedValue::String("caf\u{00e9} \"quote\" and 'single' \\ slash".to_string()),
238        );
239        props.insert(
240            "nested".to_string(),
241            TypedValue::List(vec![
242                TypedValue::Integer(1),
243                TypedValue::Float(2.25),
244                TypedValue::Bool(false),
245            ]),
246        );
247
248        let query = TypedQuery::with_params(
249            "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props",
250            [
251                (
252                    "name",
253                    TypedValue::String("O'Reilly \\ path \u{2603}".to_string()),
254                ),
255                ("count", TypedValue::Integer(42)),
256                ("ratio", TypedValue::Float(1.5)),
257                ("whole", TypedValue::Float(1.0)),
258                ("enabled", TypedValue::Bool(true)),
259                (
260                    "items",
261                    TypedValue::List(vec![
262                        TypedValue::String("a".to_string()),
263                        TypedValue::Integer(-7),
264                        TypedValue::Bool(false),
265                    ]),
266                ),
267                ("props", TypedValue::Map(props)),
268            ],
269        )
270        .expect("valid typed params should render");
271
272        assert_eq!(
273            query.cypher,
274            "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props"
275        );
276        assert_eq!(
277            query.params.get("name").map(String::as_str),
278            Some("'O\\'Reilly \\\\ path \u{2603}'")
279        );
280        assert_eq!(query.params.get("count").map(String::as_str), Some("42"));
281        assert_eq!(query.params.get("ratio").map(String::as_str), Some("1.5"));
282        assert_eq!(query.params.get("whole").map(String::as_str), Some("1.0"));
283        assert_eq!(
284            query.params.get("enabled").map(String::as_str),
285            Some("true")
286        );
287        assert_eq!(
288            query.params.get("items").map(String::as_str),
289            Some("['a', -7, false]")
290        );
291        assert_eq!(
292            query.params.get("props").map(String::as_str),
293            Some(
294                "{enabled: true, label: 'caf\u{00e9} \\\"quote\\\" and \\'single\\' \\\\ slash', nested: [1, 2.25, false]}"
295            )
296        );
297    }
298
299    #[test]
300    fn string_literals_escape_both_quote_delimiters() {
301        let rendered = render_cypher_value(&TypedValue::String("a 'single' and \"double\"".into()))
302            .expect("valid string should render");
303
304        assert_eq!(rendered, "'a \\'single\\' and \\\"double\\\"'");
305    }
306
307    #[test]
308    fn string_literals_escape_control_characters() {
309        let rendered = render_cypher_value(&TypedValue::String(
310            "line\ncarriage\rtab\tbackspace\u{0008}form\u{000C}escape\u{001B}".into(),
311        ))
312        .expect("control characters should render as escaped literals");
313
314        assert_eq!(
315            rendered,
316            "'line\\ncarriage\\rtab\\tbackspace\\bform\\fescape\\u001B'"
317        );
318    }
319
320    #[test]
321    fn nested_string_values_escape_control_characters() {
322        let mut props = BTreeMap::new();
323        props.insert(
324            "items".to_string(),
325            TypedValue::List(vec![TypedValue::String("line\nitem".to_string())]),
326        );
327        props.insert(
328            "label".to_string(),
329            TypedValue::String("tab\tvalue".to_string()),
330        );
331
332        let rendered =
333            render_cypher_value(&TypedValue::Map(props)).expect("nested strings should render");
334
335        assert_eq!(rendered, "{items: ['line\\nitem'], label: 'tab\\tvalue'}");
336    }
337
338    #[test]
339    fn invalid_identifiers_return_typed_errors() {
340        let param_error =
341            TypedQuery::with_params("RETURN $bad", [("bad-name", TypedValue::Bool(true))])
342                .expect_err("invalid parameter name should fail");
343        assert_eq!(
344            param_error,
345            TypedQueryError::InvalidIdentifier {
346                kind: IdentifierKind::ParameterName,
347                identifier: "bad-name".to_string(),
348            }
349        );
350
351        let mut props = BTreeMap::new();
352        props.insert("bad.key".to_string(), TypedValue::Integer(1));
353        let map_error =
354            render_cypher_value(&TypedValue::Map(props)).expect_err("invalid map key should fail");
355        assert_eq!(
356            map_error,
357            TypedQueryError::InvalidIdentifier {
358                kind: IdentifierKind::MapKey,
359                identifier: "bad.key".to_string(),
360            }
361        );
362    }
363
364    #[test]
365    fn unsafe_values_return_typed_errors() {
366        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
367            let error = render_cypher_value(&TypedValue::Float(value))
368                .expect_err("non-finite float should fail");
369            assert!(matches!(error, TypedQueryError::NonFiniteFloat { .. }));
370        }
371    }
372}