Skip to main content

graphd/
values.rs

1use base64::Engine;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5/// Internal ID for nodes and relationships in the graph database.
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7pub struct InternalId {
8    pub table: u64,
9    pub offset: u64,
10}
11
12/// A value returned from a Cypher query, encoded per the Strana protocol spec.
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14#[serde(untagged)]
15pub enum GraphValue {
16    Null,
17    Bool(bool),
18    Int(i64),
19    Float(f64),
20    String(String),
21    List(Vec<GraphValue>),
22    Map(HashMap<String, GraphValue>),
23    Tagged(TaggedValue),
24}
25
26/// Graph-specific types that use a `$type` discriminator.
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
28#[serde(tag = "$type", rename_all = "lowercase")]
29pub enum TaggedValue {
30    Node {
31        id: InternalId,
32        label: String,
33        properties: HashMap<String, GraphValue>,
34    },
35    Rel {
36        id: InternalId,
37        label: String,
38        src: InternalId,
39        dst: InternalId,
40        properties: HashMap<String, GraphValue>,
41    },
42    Path {
43        nodes: Vec<GraphValue>,
44        rels: Vec<GraphValue>,
45    },
46    Union {
47        tag: String,
48        value: Box<GraphValue>,
49    },
50}
51
52/// Convert a LadybugDB Value into a Strana GraphValue.
53pub fn from_lbug_value(value: &lbug::Value) -> GraphValue {
54    match value {
55        lbug::Value::Null(_) => GraphValue::Null,
56        lbug::Value::Bool(b) => GraphValue::Bool(*b),
57        lbug::Value::Int8(n) => GraphValue::Int(*n as i64),
58        lbug::Value::Int16(n) => GraphValue::Int(*n as i64),
59        lbug::Value::Int32(n) => GraphValue::Int(*n as i64),
60        lbug::Value::Int64(n) => GraphValue::Int(*n),
61        lbug::Value::Int128(n) => GraphValue::String(n.to_string()),
62        lbug::Value::UInt8(n) => GraphValue::Int(*n as i64),
63        lbug::Value::UInt16(n) => GraphValue::Int(*n as i64),
64        lbug::Value::UInt32(n) => GraphValue::Int(*n as i64),
65        lbug::Value::UInt64(n) => {
66            if *n <= i64::MAX as u64 {
67                GraphValue::Int(*n as i64)
68            } else {
69                GraphValue::String(n.to_string())
70            }
71        }
72        lbug::Value::Float(f) => GraphValue::Float(*f as f64),
73        lbug::Value::Double(f) => GraphValue::Float(*f),
74        lbug::Value::Decimal(d) => GraphValue::String(d.to_string()),
75        lbug::Value::String(s) => GraphValue::String(s.clone()),
76        lbug::Value::Blob(b) => {
77            GraphValue::String(base64::engine::general_purpose::STANDARD.encode(b))
78        }
79        lbug::Value::UUID(u) => GraphValue::String(u.to_string()),
80        lbug::Value::Date(d) => GraphValue::String(d.to_string()),
81        lbug::Value::Timestamp(t)
82        | lbug::Value::TimestampTz(t)
83        | lbug::Value::TimestampNs(t)
84        | lbug::Value::TimestampMs(t)
85        | lbug::Value::TimestampSec(t) => GraphValue::String(t.to_string()),
86        lbug::Value::Interval(d) => GraphValue::String(format!("{d:?}")),
87        lbug::Value::List(_, items) | lbug::Value::Array(_, items) => {
88            GraphValue::List(items.iter().map(from_lbug_value).collect())
89        }
90        lbug::Value::Map(_, entries) => {
91            let mut map = HashMap::new();
92            for (k, v) in entries {
93                let key = match k {
94                    lbug::Value::String(s) => s.clone(),
95                    other => format!("{other:?}"),
96                };
97                map.insert(key, from_lbug_value(v));
98            }
99            GraphValue::Map(map)
100        }
101        lbug::Value::Struct(fields) => {
102            let mut map = HashMap::new();
103            for (key, val) in fields {
104                map.insert(key.clone(), from_lbug_value(val));
105            }
106            GraphValue::Map(map)
107        }
108        lbug::Value::Node(node) => {
109            let id = InternalId {
110                table: node.get_node_id().table_id,
111                offset: node.get_node_id().offset,
112            };
113            let mut properties = HashMap::new();
114            for (key, val) in node.get_properties() {
115                properties.insert(key.clone(), from_lbug_value(val));
116            }
117            GraphValue::Tagged(TaggedValue::Node {
118                id,
119                label: node.get_label_name().clone(),
120                properties,
121            })
122        }
123        lbug::Value::Rel(rel) => {
124            let src = InternalId {
125                table: rel.get_src_node().table_id,
126                offset: rel.get_src_node().offset,
127            };
128            let dst = InternalId {
129                table: rel.get_dst_node().table_id,
130                offset: rel.get_dst_node().offset,
131            };
132            let mut properties = HashMap::new();
133            for (key, val) in rel.get_properties() {
134                properties.insert(key.clone(), from_lbug_value(val));
135            }
136            GraphValue::Tagged(TaggedValue::Rel {
137                id: InternalId { table: 0, offset: 0 }, // RelVal doesn't expose its own ID
138                label: rel.get_label_name().clone(),
139                src,
140                dst,
141                properties,
142            })
143        }
144        lbug::Value::RecursiveRel { nodes, rels } => {
145            GraphValue::Tagged(TaggedValue::Path {
146                nodes: nodes.iter().map(|n| from_lbug_value(&lbug::Value::Node(n.clone()))).collect(),
147                rels: rels.iter().map(|r| from_lbug_value(&lbug::Value::Rel(r.clone()))).collect(),
148            })
149        }
150        lbug::Value::InternalID(id) => GraphValue::Map(HashMap::from([
151            ("table".to_string(), GraphValue::Int(id.table_id as i64)),
152            ("offset".to_string(), GraphValue::Int(id.offset as i64)),
153        ])),
154        lbug::Value::Union { value, .. } => from_lbug_value(value),
155    }
156}
157
158/// Convert a JSON value to a LadybugDB value for parameter binding.
159pub fn to_lbug_value(json: &serde_json::Value) -> Result<lbug::Value, String> {
160    match json {
161        serde_json::Value::Null => Ok(lbug::Value::Null(lbug::LogicalType::Any)),
162        serde_json::Value::Bool(b) => Ok(lbug::Value::Bool(*b)),
163        serde_json::Value::Number(n) => {
164            if let Some(i) = n.as_i64() {
165                Ok(lbug::Value::Int64(i))
166            } else if let Some(u) = n.as_u64() {
167                Ok(lbug::Value::UInt64(u))
168            } else if let Some(f) = n.as_f64() {
169                Ok(lbug::Value::Double(f))
170            } else {
171                Err("Unsupported number type".into())
172            }
173        }
174        serde_json::Value::String(s) => Ok(lbug::Value::String(s.clone())),
175        serde_json::Value::Array(_) => {
176            Err("Arrays not supported as query parameters".into())
177        }
178        serde_json::Value::Object(_) => {
179            Err("Objects not supported as query parameters".into())
180        }
181    }
182}
183
184/// Convert a JSON params object to a Vec of (name, lbug::Value) pairs.
185pub fn json_params_to_lbug(
186    params: &serde_json::Value,
187) -> Result<Vec<(String, lbug::Value)>, String> {
188    let obj = params
189        .as_object()
190        .ok_or_else(|| "params must be a JSON object".to_string())?;
191    obj.iter()
192        .map(|(k, v)| Ok((k.clone(), to_lbug_value(v)?)))
193        .collect()
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_scalar_serialization() {
202        assert_eq!(serde_json::to_string(&GraphValue::Null).unwrap(), "null");
203        assert_eq!(serde_json::to_string(&GraphValue::Bool(true)).unwrap(), "true");
204        assert_eq!(serde_json::to_string(&GraphValue::Int(42)).unwrap(), "42");
205        assert_eq!(serde_json::to_string(&GraphValue::Float(3.14)).unwrap(), "3.14");
206        assert_eq!(
207            serde_json::to_string(&GraphValue::String("hello".into())).unwrap(),
208            "\"hello\""
209        );
210    }
211
212    #[test]
213    fn test_node_serialization() {
214        let node = GraphValue::Tagged(TaggedValue::Node {
215            id: InternalId { table: 0, offset: 5 },
216            label: "Person".into(),
217            properties: HashMap::from([
218                ("name".into(), GraphValue::String("Alice".into())),
219                ("age".into(), GraphValue::Int(30)),
220            ]),
221        });
222        let json: serde_json::Value = serde_json::to_value(&node).unwrap();
223        assert_eq!(json["$type"], "node");
224        assert_eq!(json["label"], "Person");
225        assert_eq!(json["id"]["table"], 0);
226        assert_eq!(json["id"]["offset"], 5);
227    }
228
229    #[test]
230    fn test_rel_serialization() {
231        let rel = GraphValue::Tagged(TaggedValue::Rel {
232            id: InternalId { table: 2, offset: 10 },
233            label: "KNOWS".into(),
234            src: InternalId { table: 0, offset: 5 },
235            dst: InternalId { table: 0, offset: 8 },
236            properties: HashMap::new(),
237        });
238        let json: serde_json::Value = serde_json::to_value(&rel).unwrap();
239        assert_eq!(json["$type"], "rel");
240        assert_eq!(json["label"], "KNOWS");
241        assert_eq!(json["src"]["offset"], 5);
242        assert_eq!(json["dst"]["offset"], 8);
243    }
244
245    #[test]
246    fn test_list_serialization() {
247        let list = GraphValue::List(vec![GraphValue::Int(1), GraphValue::Int(2)]);
248        assert_eq!(serde_json::to_string(&list).unwrap(), "[1,2]");
249    }
250
251    #[test]
252    fn test_to_lbug_null() {
253        let v = to_lbug_value(&serde_json::Value::Null).unwrap();
254        assert!(matches!(v, lbug::Value::Null(_)));
255    }
256
257    #[test]
258    fn test_to_lbug_bool() {
259        assert!(matches!(
260            to_lbug_value(&serde_json::json!(true)).unwrap(),
261            lbug::Value::Bool(true)
262        ));
263        assert!(matches!(
264            to_lbug_value(&serde_json::json!(false)).unwrap(),
265            lbug::Value::Bool(false)
266        ));
267    }
268
269    #[test]
270    fn test_to_lbug_int() {
271        match to_lbug_value(&serde_json::json!(42)).unwrap() {
272            lbug::Value::Int64(n) => assert_eq!(n, 42),
273            other => panic!("expected Int64, got {other:?}"),
274        }
275    }
276
277    #[test]
278    fn test_to_lbug_float() {
279        match to_lbug_value(&serde_json::json!(3.14)).unwrap() {
280            lbug::Value::Double(f) => assert!((f - 3.14).abs() < f64::EPSILON),
281            other => panic!("expected Double, got {other:?}"),
282        }
283    }
284
285    #[test]
286    fn test_to_lbug_string() {
287        match to_lbug_value(&serde_json::json!("hello")).unwrap() {
288            lbug::Value::String(s) => assert_eq!(s, "hello"),
289            other => panic!("expected String, got {other:?}"),
290        }
291    }
292
293    #[test]
294    fn test_to_lbug_array_error() {
295        assert!(to_lbug_value(&serde_json::json!([1, 2])).is_err());
296    }
297
298    #[test]
299    fn test_to_lbug_object_error() {
300        assert!(to_lbug_value(&serde_json::json!({"a": 1})).is_err());
301    }
302
303    #[test]
304    fn test_json_params_to_lbug() {
305        let params = serde_json::json!({"name": "Alice", "age": 30});
306        let result = json_params_to_lbug(&params).unwrap();
307        assert_eq!(result.len(), 2);
308        // Check both params exist (order not guaranteed from JSON object)
309        let names: Vec<&str> = result.iter().map(|(k, _)| k.as_str()).collect();
310        assert!(names.contains(&"name"));
311        assert!(names.contains(&"age"));
312    }
313
314    #[test]
315    fn test_json_params_not_object() {
316        assert!(json_params_to_lbug(&serde_json::json!("not an object")).is_err());
317    }
318}