use std::collections::HashMap;
use panproto_gat::Name;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::metadata::Node;
use crate::value::{FieldPresence, Value};
use crate::wtype::{WInstance, build_env_from_extra_fields, value_to_expr_literal};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct InstanceQuery {
pub anchor: Name,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub predicate: Option<panproto_expr::Expr>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub group_by: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub project: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub limit: Option<usize>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub path: Vec<Name>,
}
#[derive(Debug, Clone)]
pub struct QueryMatch {
pub node_id: u32,
pub anchor: Name,
pub value: Option<FieldPresence>,
pub fields: FxHashMap<String, Value>,
}
#[must_use]
pub fn execute(
query: &InstanceQuery,
instance: &WInstance,
_schema: &panproto_schema::Schema,
) -> Vec<QueryMatch> {
let eval_config = panproto_expr::EvalConfig::default();
let candidates: Vec<u32> = instance
.nodes
.iter()
.filter(|(_, n)| n.anchor == query.anchor)
.map(|(id, _)| *id)
.collect();
let navigated = if query.path.is_empty() {
candidates
} else {
navigate_path(instance, &candidates, &query.path)
};
let filtered = if let Some(ref pred) = query.predicate {
navigated
.into_iter()
.filter(|&id| {
let Some(node) = instance.nodes.get(&id) else {
return false;
};
let env = build_node_env(node, instance);
matches!(
crate::instance_env::eval_with_instance(
pred,
&env,
&eval_config,
instance,
Some(id),
),
Ok(panproto_expr::Literal::Bool(true))
)
})
.collect()
} else {
navigated
};
let limited: Vec<u32> = if let Some(limit) = query.limit {
filtered.into_iter().take(limit).collect()
} else {
filtered
};
let mut results: Vec<QueryMatch> = limited
.into_iter()
.filter_map(|id| {
let node = instance.nodes.get(&id)?;
Some(QueryMatch {
node_id: id,
anchor: node.anchor.clone(),
value: node.value.clone(),
fields: project_fields(&node.extra_fields, query.project.as_ref()),
})
})
.collect();
if let Some(ref group_key) = query.group_by {
results.sort_by(|a, b| {
let va = a.fields.get(group_key).map(value_sort_key);
let vb = b.fields.get(group_key).map(value_sort_key);
va.cmp(&vb)
});
}
results
}
fn navigate_path(instance: &WInstance, start_nodes: &[u32], path: &[Name]) -> Vec<u32> {
let mut current = start_nodes.to_vec();
for edge_kind in path {
let mut next = Vec::new();
for &node_id in ¤t {
for &(src, tgt, ref edge) in &instance.arcs {
if src == node_id && edge.kind == *edge_kind {
next.push(tgt);
}
}
}
current = next;
}
current
}
#[must_use]
pub fn build_node_env(node: &Node, instance: &WInstance) -> panproto_expr::Env {
let mut env = build_env_from_extra_fields(&node.extra_fields);
env = env.extend(
std::sync::Arc::from("_anchor"),
panproto_expr::Literal::Str(node.anchor.as_ref().into()),
);
env = env.extend(
std::sync::Arc::from("_id"),
panproto_expr::Literal::Int(i64::from(node.id)),
);
if let Some(FieldPresence::Present(ref v)) = node.value {
env = env.extend(std::sync::Arc::from("_value"), value_to_expr_literal(v));
}
let children_count = instance
.arcs
.iter()
.filter(|(src, _, _)| *src == node.id)
.count();
#[allow(clippy::cast_possible_wrap)]
{
env = env.extend(
std::sync::Arc::from("_children_count"),
panproto_expr::Literal::Int(children_count as i64),
);
}
env
}
fn value_sort_key(v: &Value) -> String {
match v {
Value::Str(s) => s.clone(),
Value::Int(i) => i.to_string(),
Value::Float(f) => f.to_string(),
Value::Bool(b) => b.to_string(),
Value::Token(t) => t.clone(),
Value::Null => String::new(),
_ => format!("{v:?}"),
}
}
fn project_fields(
fields: &HashMap<String, Value>,
project: Option<&Vec<String>>,
) -> FxHashMap<String, Value> {
project.map_or_else(
|| fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
|keys| {
let mut result = FxHashMap::default();
for key in keys {
if let Some(val) = fields.get(key) {
result.insert(key.clone(), val.clone());
}
}
result
},
)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::cast_possible_truncation)]
mod tests {
use super::*;
use panproto_schema::{Edge, Protocol, SchemaBuilder};
fn make_test_schema() -> panproto_schema::Schema {
let protocol = Protocol::default();
SchemaBuilder::new(&protocol)
.vertex("document", "record", None)
.unwrap()
.vertex("layer", "record", None)
.unwrap()
.vertex("annotation", "record", None)
.unwrap()
.edge("document", "layer", "layers", None)
.unwrap()
.edge("layer", "annotation", "annotations", None)
.unwrap()
.build()
.unwrap()
}
fn make_test_instance() -> WInstance {
let mut nodes = HashMap::new();
nodes.insert(0, Node::new(0, "document"));
let mut ann1 = Node::new(1, "layer");
ann1.extra_fields
.insert("kind".into(), Value::Str("span".into()));
nodes.insert(1, ann1);
let mut ann2 = Node::new(2, "annotation");
ann2.extra_fields
.insert("label".into(), Value::Str("ingredient".into()));
ann2.extra_fields
.insert("confidence".into(), Value::Float(0.9));
nodes.insert(2, ann2);
let mut ann3 = Node::new(3, "annotation");
ann3.extra_fields
.insert("label".into(), Value::Str("step".into()));
ann3.extra_fields
.insert("confidence".into(), Value::Float(0.4));
nodes.insert(3, ann3);
let edge_layer = Edge {
src: Name::from("document"),
tgt: Name::from("layer"),
kind: Name::from("layers"),
name: None,
};
let edge_ann = Edge {
src: Name::from("layer"),
tgt: Name::from("annotation"),
kind: Name::from("annotations"),
name: None,
};
let arcs = vec![
(0, 1, edge_layer),
(1, 2, edge_ann.clone()),
(1, 3, edge_ann),
];
WInstance::new(nodes, arcs, vec![], 0, Name::from("document"))
}
#[test]
fn query_by_anchor() {
let inst = make_test_instance();
let query = InstanceQuery {
anchor: Name::from("annotation"),
..Default::default()
};
let results = execute(&query, &inst, &make_test_schema());
assert_eq!(results.len(), 2);
}
#[test]
fn query_with_predicate() {
let inst = make_test_instance();
let query = InstanceQuery {
anchor: Name::from("annotation"),
predicate: Some(panproto_expr::Expr::Builtin(
panproto_expr::BuiltinOp::Eq,
vec![
panproto_expr::Expr::Var("label".into()),
panproto_expr::Expr::Lit(panproto_expr::Literal::Str("ingredient".into())),
],
)),
..Default::default()
};
let results = execute(&query, &inst, &make_test_schema());
assert_eq!(results.len(), 1);
assert_eq!(
results[0].fields.get("label"),
Some(&Value::Str("ingredient".into()))
);
}
#[test]
fn query_with_path_navigation() {
let inst = make_test_instance();
let query = InstanceQuery {
anchor: Name::from("document"),
path: vec![Name::from("layers"), Name::from("annotations")],
..Default::default()
};
let results = execute(&query, &inst, &make_test_schema());
assert_eq!(results.len(), 2);
}
#[test]
fn query_with_limit() {
let inst = make_test_instance();
let query = InstanceQuery {
anchor: Name::from("annotation"),
limit: Some(1),
..Default::default()
};
let results = execute(&query, &inst, &make_test_schema());
assert_eq!(results.len(), 1);
}
#[test]
fn query_with_projection() {
let inst = make_test_instance();
let query = InstanceQuery {
anchor: Name::from("annotation"),
project: Some(vec!["label".into()]),
..Default::default()
};
let results = execute(&query, &inst, &make_test_schema());
assert_eq!(results.len(), 2);
for r in &results {
assert!(r.fields.contains_key("label"));
assert!(!r.fields.contains_key("confidence"));
}
}
#[test]
fn query_no_match() {
let inst = make_test_instance();
let query = InstanceQuery {
anchor: Name::from("nonexistent"),
..Default::default()
};
let results = execute(&query, &inst, &make_test_schema());
assert!(results.is_empty());
}
#[test]
fn query_with_group_by() {
let mut nodes = HashMap::new();
nodes.insert(0, Node::new(0, "document"));
let mut layer = Node::new(1, "layer");
layer
.extra_fields
.insert("kind".into(), Value::Str("span".into()));
nodes.insert(1, layer);
let categories = ["vegetable", "fruit", "fruit", "vegetable", "grain"];
for (i, cat) in categories.iter().enumerate() {
let id = (i as u32) + 2;
let mut ann = Node::new(id, "annotation");
ann.extra_fields
.insert("category".into(), Value::Str((*cat).into()));
ann.extra_fields
.insert("label".into(), Value::Str(format!("item_{i}")));
nodes.insert(id, ann);
}
let edge_layer = Edge {
src: Name::from("document"),
tgt: Name::from("layer"),
kind: Name::from("layers"),
name: None,
};
let mut arcs = vec![(0, 1, edge_layer)];
for i in 0..categories.len() {
let id = (i as u32) + 2;
arcs.push((
1,
id,
Edge {
src: Name::from("layer"),
tgt: Name::from("annotation"),
kind: Name::from("annotations"),
name: None,
},
));
}
let inst = WInstance::new(nodes, arcs, vec![], 0, Name::from("document"));
let query = InstanceQuery {
anchor: Name::from("annotation"),
group_by: Some("category".into()),
..Default::default()
};
let results = execute(&query, &inst, &make_test_schema());
assert_eq!(results.len(), 5);
let categories_out: Vec<&str> = results
.iter()
.filter_map(|r| match r.fields.get("category") {
Some(Value::Str(s)) => Some(s.as_str()),
_ => None,
})
.collect();
assert_eq!(
categories_out,
vec!["fruit", "fruit", "grain", "vegetable", "vegetable"]
);
}
}