Skip to main content

gobby_code/graph/
typed_query.rs

1use std::collections::{BTreeMap, HashMap};
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7pub struct TypedQuery {
8    pub cypher: String,
9    pub params: HashMap<String, String>,
10}
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub enum TypedValue {
14    Null,
15    String(String),
16    Integer(i64),
17    Float(f64),
18    Bool(bool),
19    List(Vec<TypedValue>),
20    Map(BTreeMap<String, TypedValue>),
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum IdentifierKind {
25    ParameterName,
26    MapKey,
27}
28
29/// String rendering context retained for compatibility with older
30/// `TypedQueryError::ControlCharacter` consumers.
31#[deprecated(
32    since = "0.9.10",
33    note = "string values are escaped now; this context is retained only for API compatibility"
34)]
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ValueContext {
37    String,
38}
39
40#[allow(deprecated)]
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum TypedQueryError {
43    InvalidIdentifier {
44        kind: IdentifierKind,
45        identifier: String,
46    },
47    /// Legacy error variant retained for API compatibility.
48    ///
49    /// String control characters are escaped by `render_cypher_value`, so new
50    /// code should not construct or match this variant for current rendering.
51    #[deprecated(
52        since = "0.9.10",
53        note = "string control characters are escaped instead of rejected"
54    )]
55    ControlCharacter {
56        #[allow(deprecated)]
57        context: ValueContext,
58        codepoint: u32,
59    },
60    NonFiniteFloat {
61        value: String,
62    },
63}
64
65impl TypedQuery {
66    pub fn new(cypher: impl Into<String>) -> Self {
67        Self {
68            cypher: cypher.into(),
69            params: HashMap::new(),
70        }
71    }
72
73    pub fn with_params<I, K>(cypher: impl Into<String>, params: I) -> Result<Self, TypedQueryError>
74    where
75        I: IntoIterator<Item = (K, TypedValue)>,
76        K: Into<String>,
77    {
78        let mut query = Self::new(cypher);
79        for (name, value) in params {
80            query.insert_param(name, value)?;
81        }
82        Ok(query)
83    }
84
85    pub fn insert_param(
86        &mut self,
87        name: impl Into<String>,
88        value: TypedValue,
89    ) -> Result<(), TypedQueryError> {
90        let name = name.into();
91        validate_identifier(&name, IdentifierKind::ParameterName)?;
92        let rendered = render_cypher_value(&value)?;
93        self.params.insert(name, rendered);
94        Ok(())
95    }
96}
97
98pub fn cypher_string_literal(s: &str) -> String {
99    format!("'{}'", escape_string_contents(s))
100}
101
102pub fn render_cypher_value(value: &TypedValue) -> Result<String, TypedQueryError> {
103    match value {
104        TypedValue::Null => Ok("null".to_string()),
105        TypedValue::String(value) => render_string_literal(value),
106        TypedValue::Integer(value) => Ok(value.to_string()),
107        TypedValue::Float(value) => render_float(*value),
108        TypedValue::Bool(value) => Ok(value.to_string()),
109        TypedValue::List(values) => values
110            .iter()
111            .map(render_cypher_value)
112            .collect::<Result<Vec<_>, _>>()
113            .map(|values| format!("[{}]", values.join(", "))),
114        TypedValue::Map(values) => values
115            .iter()
116            .map(|(key, value)| {
117                validate_identifier(key, IdentifierKind::MapKey)?;
118                Ok(format!("{key}: {}", render_cypher_value(value)?))
119            })
120            .collect::<Result<Vec<_>, _>>()
121            .map(|values| format!("{{{}}}", values.join(", "))),
122    }
123}
124
125pub fn string_params(values: &[(&str, &str)]) -> HashMap<String, String> {
126    values
127        .iter()
128        .map(|(key, value)| ((*key).to_string(), cypher_string_literal(value)))
129        .collect()
130}
131
132pub fn clamp_limit(limit: usize, max: usize) -> usize {
133    limit.clamp(1, max)
134}
135
136pub fn clamp_offset(offset: usize, max: usize) -> usize {
137    offset.min(max)
138}
139
140pub fn id_list_literal(ids: &[String]) -> String {
141    ids.iter()
142        .map(|id| cypher_string_literal(id))
143        .collect::<Vec<_>>()
144        .join(", ")
145}
146
147pub fn validate_identifier(identifier: &str, kind: IdentifierKind) -> Result<(), TypedQueryError> {
148    let mut chars = identifier.chars();
149    let Some(first) = chars.next() else {
150        return Err(TypedQueryError::InvalidIdentifier {
151            kind,
152            identifier: identifier.to_string(),
153        });
154    };
155
156    if !(first == '_' || first.is_ascii_alphabetic())
157        || !chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
158    {
159        return Err(TypedQueryError::InvalidIdentifier {
160            kind,
161            identifier: identifier.to_string(),
162        });
163    }
164
165    Ok(())
166}
167
168fn render_string_literal(value: &str) -> Result<String, TypedQueryError> {
169    Ok(cypher_string_literal(value))
170}
171
172fn escape_string_contents(value: &str) -> String {
173    let mut escaped = String::with_capacity(value.len());
174    for ch in value.chars() {
175        match ch {
176            '\\' => escaped.push_str("\\\\"),
177            '\'' => escaped.push_str("\\'"),
178            '"' => escaped.push_str("\\\""),
179            '\n' => escaped.push_str("\\n"),
180            '\r' => escaped.push_str("\\r"),
181            '\t' => escaped.push_str("\\t"),
182            '\u{0008}' => escaped.push_str("\\b"),
183            '\u{000C}' => escaped.push_str("\\f"),
184            ch if ch.is_control() => escaped.push_str(&format!("\\u{:04X}", ch as u32)),
185            ch => escaped.push(ch),
186        }
187    }
188    escaped
189}
190
191fn render_float(value: f64) -> Result<String, TypedQueryError> {
192    if !value.is_finite() {
193        return Err(TypedQueryError::NonFiniteFloat {
194            value: value.to_string(),
195        });
196    }
197
198    let mut rendered = value.to_string();
199    if !rendered.contains('.') && !rendered.contains('e') && !rendered.contains('E') {
200        rendered.push_str(".0");
201    }
202    Ok(rendered)
203}
204
205impl fmt::Display for IdentifierKind {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        match self {
208            Self::ParameterName => f.write_str("parameter name"),
209            Self::MapKey => f.write_str("map key"),
210        }
211    }
212}
213
214#[allow(deprecated)]
215impl fmt::Display for ValueContext {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        match self {
218            Self::String => f.write_str("string"),
219        }
220    }
221}
222
223#[allow(deprecated)]
224impl fmt::Display for TypedQueryError {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        match self {
227            Self::InvalidIdentifier { kind, identifier } => write!(
228                f,
229                "invalid {kind} `{identifier}`; expected ^[A-Za-z_][A-Za-z0-9_]*$"
230            ),
231            Self::ControlCharacter { context, codepoint } => write!(
232                f,
233                "control character U+{codepoint:04X} is not allowed in {context} value"
234            ),
235            Self::NonFiniteFloat { value } => {
236                write!(f, "non-finite float `{value}` is not allowed")
237            }
238        }
239    }
240}
241
242impl std::error::Error for TypedQueryError {}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::collections::BTreeMap;
248
249    #[test]
250    fn typed_params_render_nested_safe_cypher_literals() {
251        let mut props = BTreeMap::new();
252        props.insert("enabled".to_string(), TypedValue::Bool(true));
253        props.insert(
254            "label".to_string(),
255            TypedValue::String("caf\u{00e9} \"quote\" and 'single' \\ slash".to_string()),
256        );
257        props.insert(
258            "nested".to_string(),
259            TypedValue::List(vec![
260                TypedValue::Integer(1),
261                TypedValue::Float(2.25),
262                TypedValue::Bool(false),
263            ]),
264        );
265
266        let query = TypedQuery::with_params(
267            "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props",
268            [
269                (
270                    "name",
271                    TypedValue::String("O'Reilly \\ path \u{2603}".to_string()),
272                ),
273                ("count", TypedValue::Integer(42)),
274                ("ratio", TypedValue::Float(1.5)),
275                ("whole", TypedValue::Float(1.0)),
276                ("enabled", TypedValue::Bool(true)),
277                (
278                    "items",
279                    TypedValue::List(vec![
280                        TypedValue::String("a".to_string()),
281                        TypedValue::Integer(-7),
282                        TypedValue::Bool(false),
283                    ]),
284                ),
285                ("props", TypedValue::Map(props)),
286            ],
287        )
288        .expect("valid typed params should render");
289
290        assert_eq!(
291            query.cypher,
292            "RETURN $name, $count, $ratio, $whole, $enabled, $items, $props"
293        );
294        assert_eq!(
295            query.params.get("name").map(String::as_str),
296            Some("'O\\'Reilly \\\\ path \u{2603}'")
297        );
298        assert_eq!(query.params.get("count").map(String::as_str), Some("42"));
299        assert_eq!(query.params.get("ratio").map(String::as_str), Some("1.5"));
300        assert_eq!(query.params.get("whole").map(String::as_str), Some("1.0"));
301        assert_eq!(
302            query.params.get("enabled").map(String::as_str),
303            Some("true")
304        );
305        assert_eq!(
306            query.params.get("items").map(String::as_str),
307            Some("['a', -7, false]")
308        );
309        assert_eq!(
310            query.params.get("props").map(String::as_str),
311            Some(
312                "{enabled: true, label: 'caf\u{00e9} \\\"quote\\\" and \\'single\\' \\\\ slash', nested: [1, 2.25, false]}"
313            )
314        );
315    }
316
317    #[test]
318    fn string_literals_escape_both_quote_delimiters() {
319        let rendered = render_cypher_value(&TypedValue::String("a 'single' and \"double\"".into()))
320            .expect("valid string should render");
321
322        assert_eq!(rendered, "'a \\'single\\' and \\\"double\\\"'");
323    }
324
325    #[test]
326    fn string_literals_escape_control_characters() {
327        let rendered = render_cypher_value(&TypedValue::String(
328            "line\ncarriage\rtab\tbackspace\u{0008}form\u{000C}escape\u{001B}".into(),
329        ))
330        .expect("control characters should render as escaped literals");
331
332        assert_eq!(
333            rendered,
334            "'line\\ncarriage\\rtab\\tbackspace\\bform\\fescape\\u001B'"
335        );
336    }
337
338    #[test]
339    fn nested_string_values_escape_control_characters() {
340        let mut props = BTreeMap::new();
341        props.insert(
342            "items".to_string(),
343            TypedValue::List(vec![TypedValue::String("line\nitem".to_string())]),
344        );
345        props.insert(
346            "label".to_string(),
347            TypedValue::String("tab\tvalue".to_string()),
348        );
349
350        let rendered =
351            render_cypher_value(&TypedValue::Map(props)).expect("nested strings should render");
352
353        assert_eq!(rendered, "{items: ['line\\nitem'], label: 'tab\\tvalue'}");
354    }
355
356    #[test]
357    fn invalid_identifiers_return_typed_errors() {
358        let param_error =
359            TypedQuery::with_params("RETURN $bad", [("bad-name", TypedValue::Bool(true))])
360                .expect_err("invalid parameter name should fail");
361        assert_eq!(
362            param_error,
363            TypedQueryError::InvalidIdentifier {
364                kind: IdentifierKind::ParameterName,
365                identifier: "bad-name".to_string(),
366            }
367        );
368
369        let mut props = BTreeMap::new();
370        props.insert("bad.key".to_string(), TypedValue::Integer(1));
371        let map_error =
372            render_cypher_value(&TypedValue::Map(props)).expect_err("invalid map key should fail");
373        assert_eq!(
374            map_error,
375            TypedQueryError::InvalidIdentifier {
376                kind: IdentifierKind::MapKey,
377                identifier: "bad.key".to_string(),
378            }
379        );
380    }
381
382    #[test]
383    fn unsafe_values_return_typed_errors() {
384        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
385            let error = render_cypher_value(&TypedValue::Float(value))
386                .expect_err("non-finite float should fail");
387            assert!(matches!(error, TypedQueryError::NonFiniteFloat { .. }));
388        }
389    }
390}