use std::collections::HashMap;
use serde_json::Value;
use crate::context::{Error, Path, PathResult};
pub struct Graph {
nodes: HashMap<String, Value>,
}
impl Graph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
}
}
pub fn add_node<S: Into<String>>(&mut self, name: S, value: Value) -> &mut Self {
self.nodes.insert(name.into(), value);
self
}
pub fn get_node(&self, name: &str) -> Option<&Value> {
self.nodes.get(name)
}
pub fn remove_node(&mut self, name: &str) -> Option<Value> {
self.nodes.remove(name)
}
pub fn node_names(&self) -> Vec<&str> {
self.nodes.keys().map(|s| s.as_str()).collect()
}
pub fn query<S: Into<String>>(&self, expr: S) -> Result<PathResult, Error> {
let root = self.virtual_root();
Path::collect(root, expr)
}
pub fn query_node<S: Into<String>>(&self, node: S, expr: S) -> Result<PathResult, Error> {
let name = node.into();
match self.nodes.get(&name) {
Some(value) => Path::collect(value.clone(), expr),
None => Err(Error::Eval(format!("node '{}' not found in graph", name))),
}
}
pub fn virtual_root(&self) -> Value {
let mut map = serde_json::Map::new();
for (name, value) in &self.nodes {
map.insert(name.clone(), value.clone());
}
Value::Object(map)
}
pub fn message<S: Into<String>>(&self, schema: S) -> Result<PathResult, Error> {
self.query(schema)
}
}
impl Default for Graph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_graph() -> Graph {
let mut g = Graph::new();
g.add_node(
"orders",
json!([
{"id": 1, "customer_id": 10, "price": 9.99, "is_gratis": false},
{"id": 2, "customer_id": 20, "price": 4.50, "is_gratis": false},
{"id": 3, "customer_id": 10, "price": 0.00, "is_gratis": true},
]),
)
.add_node(
"customers",
json!([
{"id": 10, "name": "Alice"},
{"id": 20, "name": "Bob"},
]),
);
g
}
#[test]
fn test_node_count() {
let g = make_graph();
assert_eq!(g.node_names().len(), 2);
}
#[test]
fn test_simple_query() {
let g = make_graph();
let mut r = g.query(">/orders/#len").expect("query failed");
let count: i64 = r.from_index(0).unwrap();
assert_eq!(count, 3);
}
#[test]
fn test_sum_across_node() {
let g = make_graph();
let mut r = g
.query(">/orders/#filter('is_gratis' == false)/..price/#sum")
.expect("query failed");
let total: f64 = r.from_index(0).unwrap();
assert!((total - 14.49).abs() < 0.001);
}
#[test]
fn test_message_schema() {
let g = make_graph();
let schema = r#">{"total": >/orders/#filter('is_gratis' == false)/..price/#sum, "count": >/orders/#len}"#;
let r = g.message(schema).expect("message failed");
assert_eq!(r.0.len(), 1);
let obj = r.0[0].as_object().unwrap();
assert!(obj.contains_key("total"));
assert!(obj.contains_key("count"));
}
#[test]
fn test_group_by_across_node() {
let g = make_graph();
let mut r = g
.query(">/orders/#group_by('customer_id')")
.expect("query failed");
let grouped = r.from_index::<serde_json::Value>(0).unwrap();
let obj = grouped.as_object().unwrap();
assert_eq!(obj.len(), 2);
}
#[test]
fn test_query_node() {
let g = make_graph();
let mut r = g.query_node("customers", ">/[0]/name").expect("query_node failed");
let name: String = r.from_index(0).unwrap();
assert_eq!(name, "Alice");
}
#[test]
fn test_virtual_root_shape() {
let g = make_graph();
let root = g.virtual_root();
let obj = root.as_object().unwrap();
assert!(obj.contains_key("orders"));
assert!(obj.contains_key("customers"));
}
}