use indexmap::IndexMap;
use serde_json::Value;
use super::eval::EvalError;
use super::eval::{Method};
use super::vm::VM;
#[derive(Debug)]
pub enum GraphError {
Eval(EvalError),
NodeNotFound(String),
}
impl std::fmt::Display for GraphError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GraphError::Eval(e) => write!(f, "{}", e),
GraphError::NodeNotFound(n) => write!(f, "node '{}' not found", n),
}
}
}
impl std::error::Error for GraphError {}
impl From<EvalError> for GraphError { fn from(e: EvalError) -> Self { GraphError::Eval(e) } }
pub struct Graph {
nodes: IndexMap<String, Value>,
vm: VM,
}
impl Graph {
pub fn new() -> Self {
Self { nodes: IndexMap::new(), vm: VM::new() }
}
pub fn with_capacity(compile_cap: usize, resolution_cap: usize) -> Self {
Self {
nodes: IndexMap::new(),
vm: VM::with_capacity(compile_cap, resolution_cap),
}
}
pub fn add_node<S: Into<String>>(&mut self, name: S, value: Value) -> &mut Self {
self.nodes.insert(name.into(), value);
self
}
pub fn get_node(&self, name: &str) -> Option<&Value> {
self.nodes.get(name)
}
pub fn remove_node(&mut self, name: &str) -> Option<Value> {
self.nodes.shift_remove(name)
}
pub fn len(&self) -> usize { self.nodes.len() }
pub fn is_empty(&self) -> bool { self.nodes.is_empty() }
pub fn node_names(&self) -> impl Iterator<Item = &str> {
self.nodes.keys().map(|s| s.as_str())
}
pub fn query(&mut self, expr: &str) -> Result<Value, GraphError> {
let root = self.virtual_root();
Ok(self.vm.run_str(expr, &root)?)
}
pub fn query_node(&mut self, node: &str, expr: &str) -> Result<Value, GraphError> {
let value = self.nodes.get(node)
.ok_or_else(|| GraphError::NodeNotFound(node.to_string()))?
.clone();
Ok(self.vm.run_str(expr, &value)?)
}
pub fn register_method(&mut self, name: impl Into<String>, method: impl Method + 'static) {
self.vm.register(name, method);
}
pub fn vm(&self) -> &VM { &self.vm }
pub fn vm_mut(&mut self) -> &mut VM { &mut self.vm }
pub fn virtual_root(&self) -> Value {
let map: serde_json::Map<String, Value> = self.nodes.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Value::Object(map)
}
pub fn message(&mut self, schema: &str) -> Result<Value, GraphError> {
self.query(schema)
}
}
impl Default for Graph {
fn default() -> Self { Self::new() }
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_graph() -> Graph {
let mut g = Graph::new();
g.add_node("orders", json!([
{"id": 1, "customer_id": 10, "price": 9.99, "is_gratis": false},
{"id": 2, "customer_id": 20, "price": 4.50, "is_gratis": false},
{"id": 3, "customer_id": 10, "price": 0.00, "is_gratis": true},
])).add_node("customers", json!([
{"id": 10, "name": "Alice"},
{"id": 20, "name": "Bob"},
]));
g
}
#[test]
fn test_node_count() {
let g = make_graph();
assert_eq!(g.len(), 2);
}
#[test]
fn test_virtual_root_shape() {
let g = make_graph();
let root = g.virtual_root();
let obj = root.as_object().unwrap();
assert!(obj.contains_key("orders"));
assert!(obj.contains_key("customers"));
}
#[test]
fn test_query_len() {
let mut g = make_graph();
let r = g.query("$.orders.len()").unwrap();
assert_eq!(r, json!(3));
}
#[test]
fn test_query_sum() {
let mut g = make_graph();
let r = g.query("$.orders.sum(price)").unwrap();
let total = r.as_f64().unwrap();
assert!((total - 14.49).abs() < 0.001);
}
#[test]
fn test_query_filter_sum() {
let mut g = make_graph();
let r = g.query("$.orders.filter(is_gratis == false).sum(price)").unwrap();
let total = r.as_f64().unwrap();
assert!((total - 14.49).abs() < 0.001);
}
#[test]
fn test_query_node_direct() {
let mut g = make_graph();
let r = g.query_node("orders", "$.len()").unwrap();
assert_eq!(r, json!(3));
}
#[test]
fn test_query_node_not_found() {
let mut g = make_graph();
let result = g.query_node("missing", "$.len()");
assert!(matches!(result, Err(GraphError::NodeNotFound(_))));
}
#[test]
fn test_group_by() {
let mut g = make_graph();
let r = g.query("$.orders.groupBy(customer_id)").unwrap();
let obj = r.as_object().unwrap();
assert_eq!(obj.len(), 2);
}
#[test]
fn test_remove_node() {
let mut g = make_graph();
assert!(g.remove_node("orders").is_some());
assert_eq!(g.len(), 1);
assert!(g.get_node("orders").is_none());
}
#[test]
fn test_compile_cache() {
let mut g = make_graph();
g.query("$.orders.len()").unwrap();
g.query("$.orders.len()").unwrap();
let (cache_size, _) = g.vm().cache_stats();
assert_eq!(cache_size, 1);
}
#[test]
fn test_message_alias() {
let mut g = make_graph();
let r = g.message(r#"{"count": $.orders.len(), "names": $.customers.map(name)}"#).unwrap();
let obj = r.as_object().unwrap();
assert!(obj.contains_key("count"));
assert!(obj.contains_key("names"));
}
#[test]
fn test_custom_method() {
let mut g = make_graph();
g.register_method("double_len", |recv: super::super::eval::value::Val, _args: &[super::super::eval::value::Val]| {
use super::super::eval::value::Val;
let n = match &recv {
Val::Arr(a) => a.len() as i64 * 2,
_ => 0,
};
Ok(Val::Int(n))
});
let r = g.query("$.orders.double_len()").unwrap();
assert_eq!(r, json!(6));
}
}