use crate::domain::ContractSet;
use crate::object::ObjectMeta;
use crate::op::{LayerBehavior, Operator};
use crate::{Error, Result};
#[derive(Debug, Clone, PartialEq)]
pub struct SemanticNode {
pub id: usize,
pub op_name: String,
pub inputs: Vec<usize>,
pub output_ids: Vec<usize>,
pub outputs: Vec<ObjectMeta>,
pub required_contracts: ContractSet,
pub provided_contracts: ContractSet,
pub layer_behavior: Vec<LayerBehavior>,
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct SemanticGraph {
nodes: Vec<SemanticNode>,
values: Vec<ObjectMeta>,
}
impl SemanticGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_input(&mut self, meta: ObjectMeta) -> usize {
let id = self.values.len();
self.values.push(meta);
id
}
pub fn add_op<O: Operator>(&mut self, op: O, inputs: &[usize]) -> Result<Vec<usize>> {
let input_meta = inputs
.iter()
.map(|id| {
self.values
.get(*id)
.cloned()
.ok_or_else(|| Error::ir(format!("unknown value id {id}")))
})
.collect::<Result<Vec<_>>>()?;
let outputs = op.infer(&input_meta)?;
let output_ids = outputs
.iter()
.map(|output| {
let id = self.values.len();
self.values.push(output.clone());
id
})
.collect::<Vec<_>>();
let node = SemanticNode {
id: self.nodes.len(),
op_name: op.name().to_string(),
inputs: inputs.to_vec(),
output_ids: output_ids.clone(),
outputs,
required_contracts: op.required_contracts(),
provided_contracts: op.provided_contracts(),
layer_behavior: op.layer_behavior(),
};
self.nodes.push(node);
Ok(output_ids)
}
pub fn value(&self, id: usize) -> Option<&ObjectMeta> {
self.values.get(id)
}
pub fn nodes(&self) -> &[SemanticNode] {
&self.nodes
}
pub fn users_of_value(&self, value_id: usize) -> Vec<usize> {
self.nodes
.iter()
.filter(|node| node.inputs.contains(&value_id))
.map(|node| node.id)
.collect()
}
pub fn single_consumer_of_value(&self, value_id: usize) -> Option<usize> {
let users = self.users_of_value(value_id);
if users.len() == 1 {
Some(users[0])
} else {
None
}
}
pub fn output_ids_of_node(&self, node_id: usize) -> Option<&[usize]> {
self.nodes
.get(node_id)
.map(|node| node.output_ids.as_slice())
}
}