1use indexmap::IndexMap;
10use serde_json::Value;
11
12use super::eval::EvalError;
13use super::eval::{Method};
14use super::vm::VM;
15
16#[derive(Debug)]
19pub enum GraphError {
20 Eval(EvalError),
21 NodeNotFound(String),
22}
23
24impl std::fmt::Display for GraphError {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 GraphError::Eval(e) => write!(f, "{}", e),
28 GraphError::NodeNotFound(n) => write!(f, "node '{}' not found", n),
29 }
30 }
31}
32impl std::error::Error for GraphError {}
33impl From<EvalError> for GraphError { fn from(e: EvalError) -> Self { GraphError::Eval(e) } }
34
35pub struct Graph {
44 nodes: IndexMap<String, Value>,
45 vm: VM,
46}
47
48impl Graph {
49 pub fn new() -> Self {
51 Self { nodes: IndexMap::new(), vm: VM::new() }
52 }
53
54 pub fn with_capacity(compile_cap: usize, resolution_cap: usize) -> Self {
58 Self {
59 nodes: IndexMap::new(),
60 vm: VM::with_capacity(compile_cap, resolution_cap),
61 }
62 }
63
64 pub fn add_node<S: Into<String>>(&mut self, name: S, value: Value) -> &mut Self {
68 self.nodes.insert(name.into(), value);
69 self
70 }
71
72 pub fn get_node(&self, name: &str) -> Option<&Value> {
74 self.nodes.get(name)
75 }
76
77 pub fn remove_node(&mut self, name: &str) -> Option<Value> {
79 self.nodes.shift_remove(name)
80 }
81
82 pub fn len(&self) -> usize { self.nodes.len() }
84
85 pub fn is_empty(&self) -> bool { self.nodes.is_empty() }
87
88 pub fn node_names(&self) -> impl Iterator<Item = &str> {
90 self.nodes.keys().map(|s| s.as_str())
91 }
92
93 pub fn query(&mut self, expr: &str) -> Result<Value, GraphError> {
100 let root = self.virtual_root();
101 Ok(self.vm.run_str(expr, &root)?)
102 }
103
104 pub fn query_node(&mut self, node: &str, expr: &str) -> Result<Value, GraphError> {
106 let value = self.nodes.get(node)
107 .ok_or_else(|| GraphError::NodeNotFound(node.to_string()))?
108 .clone();
109 Ok(self.vm.run_str(expr, &value)?)
110 }
111
112 pub fn register_method(&mut self, name: impl Into<String>, method: impl Method + 'static) {
114 self.vm.register(name, method);
115 }
116
117 pub fn vm(&self) -> &VM { &self.vm }
121
122 pub fn vm_mut(&mut self) -> &mut VM { &mut self.vm }
124
125 pub fn virtual_root(&self) -> Value {
129 let map: serde_json::Map<String, Value> = self.nodes.iter()
130 .map(|(k, v)| (k.clone(), v.clone()))
131 .collect();
132 Value::Object(map)
133 }
134
135 pub fn message(&mut self, schema: &str) -> Result<Value, GraphError> {
141 self.query(schema)
142 }
143}
144
145impl Default for Graph {
146 fn default() -> Self { Self::new() }
147}
148
149#[cfg(test)]
152mod tests {
153 use super::*;
154 use serde_json::json;
155
156 fn make_graph() -> Graph {
157 let mut g = Graph::new();
158 g.add_node("orders", json!([
159 {"id": 1, "customer_id": 10, "price": 9.99, "is_gratis": false},
160 {"id": 2, "customer_id": 20, "price": 4.50, "is_gratis": false},
161 {"id": 3, "customer_id": 10, "price": 0.00, "is_gratis": true},
162 ])).add_node("customers", json!([
163 {"id": 10, "name": "Alice"},
164 {"id": 20, "name": "Bob"},
165 ]));
166 g
167 }
168
169 #[test]
170 fn test_node_count() {
171 let g = make_graph();
172 assert_eq!(g.len(), 2);
173 }
174
175 #[test]
176 fn test_virtual_root_shape() {
177 let g = make_graph();
178 let root = g.virtual_root();
179 let obj = root.as_object().unwrap();
180 assert!(obj.contains_key("orders"));
181 assert!(obj.contains_key("customers"));
182 }
183
184 #[test]
185 fn test_query_len() {
186 let mut g = make_graph();
187 let r = g.query("$.orders.len()").unwrap();
188 assert_eq!(r, json!(3));
189 }
190
191 #[test]
192 fn test_query_sum() {
193 let mut g = make_graph();
194 let r = g.query("$.orders.sum(price)").unwrap();
195 let total = r.as_f64().unwrap();
196 assert!((total - 14.49).abs() < 0.001);
197 }
198
199 #[test]
200 fn test_query_filter_sum() {
201 let mut g = make_graph();
202 let r = g.query("$.orders.filter(is_gratis == false).sum(price)").unwrap();
203 let total = r.as_f64().unwrap();
204 assert!((total - 14.49).abs() < 0.001);
205 }
206
207 #[test]
208 fn test_query_node_direct() {
209 let mut g = make_graph();
210 let r = g.query_node("orders", "$.len()").unwrap();
211 assert_eq!(r, json!(3));
212 }
213
214 #[test]
215 fn test_query_node_not_found() {
216 let mut g = make_graph();
217 let result = g.query_node("missing", "$.len()");
218 assert!(matches!(result, Err(GraphError::NodeNotFound(_))));
219 }
220
221 #[test]
222 fn test_group_by() {
223 let mut g = make_graph();
224 let r = g.query("$.orders.groupBy(customer_id)").unwrap();
225 let obj = r.as_object().unwrap();
226 assert_eq!(obj.len(), 2);
227 }
228
229 #[test]
230 fn test_remove_node() {
231 let mut g = make_graph();
232 assert!(g.remove_node("orders").is_some());
233 assert_eq!(g.len(), 1);
234 assert!(g.get_node("orders").is_none());
235 }
236
237 #[test]
238 fn test_compile_cache() {
239 let mut g = make_graph();
240 g.query("$.orders.len()").unwrap();
242 g.query("$.orders.len()").unwrap();
243 let (cache_size, _) = g.vm().cache_stats();
244 assert_eq!(cache_size, 1);
245 }
246
247 #[test]
248 fn test_message_alias() {
249 let mut g = make_graph();
250 let r = g.message(r#"{"count": $.orders.len(), "names": $.customers.map(name)}"#).unwrap();
251 let obj = r.as_object().unwrap();
252 assert!(obj.contains_key("count"));
253 assert!(obj.contains_key("names"));
254 }
255
256 #[test]
257 fn test_custom_method() {
258 let mut g = make_graph();
259 g.register_method("double_len", |recv: super::super::eval::value::Val, _args: &[super::super::eval::value::Val]| {
260 use super::super::eval::value::Val;
261 let n = match &recv {
262 Val::Arr(a) => a.len() as i64 * 2,
263 _ => 0,
264 };
265 Ok(Val::Int(n))
266 });
267 let r = g.query("$.orders.double_len()").unwrap();
268 assert_eq!(r, json!(6));
269 }
270}