use super::edge::ProvenanceEdge;
use super::node::{NodeId, ProvenanceNode};
use crate::monitor::inference::path::DecisionPath;
use crate::monitor::inference::trace::DecisionTrace;
use std::collections::HashMap;
pub struct ProvenanceGraph {
nodes: HashMap<NodeId, ProvenanceNode>,
edges: Vec<ProvenanceEdge>,
forward: HashMap<NodeId, Vec<usize>>,
backward: HashMap<NodeId, Vec<usize>>,
next_id: NodeId,
}
impl ProvenanceGraph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: Vec::new(),
forward: HashMap::new(),
backward: HashMap::new(),
next_id: 0,
}
}
pub fn add_node(&mut self, node: ProvenanceNode) -> NodeId {
let id = self.next_id;
self.next_id += 1;
self.nodes.insert(id, node);
id
}
pub fn add_edge(&mut self, edge: ProvenanceEdge) {
let edge_idx = self.edges.len();
self.forward.entry(edge.from).or_default().push(edge_idx);
self.backward.entry(edge.to).or_default().push(edge_idx);
self.edges.push(edge);
}
pub fn get_node(&self, id: NodeId) -> Option<&ProvenanceNode> {
self.nodes.get(&id)
}
pub fn nodes(&self) -> &HashMap<NodeId, ProvenanceNode> {
&self.nodes
}
pub fn edges(&self) -> &[ProvenanceEdge] {
&self.edges
}
pub fn incoming_edges(&self, id: NodeId) -> Vec<&ProvenanceEdge> {
self.backward
.get(&id)
.map(|indices| indices.iter().map(|&i| &self.edges[i]).collect())
.unwrap_or_default()
}
pub fn outgoing_edges(&self, id: NodeId) -> Vec<&ProvenanceEdge> {
self.forward
.get(&id)
.map(|indices| indices.iter().map(|&i| &self.edges[i]).collect())
.unwrap_or_default()
}
pub fn predecessors(&self, id: NodeId) -> Vec<NodeId> {
self.incoming_edges(id).into_iter().map(|e| e.from).collect()
}
pub fn successors(&self, id: NodeId) -> Vec<NodeId> {
self.outgoing_edges(id).into_iter().map(|e| e.to).collect()
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn add_inference<P: DecisionPath>(
&mut self,
trace: &DecisionTrace<P>,
model_id: &str,
model_version: &str,
) -> NodeId {
self.add_node(ProvenanceNode::Inference {
model_id: model_id.to_string(),
model_version: model_version.to_string(),
confidence: trace.confidence(),
output: trace.output,
})
}
}
impl Default for ProvenanceGraph {
fn default() -> Self {
Self::new()
}
}