use super::*;
use crate::ops::Op;
use std::fmt;
#[derive(Clone, Debug)]
pub struct ModelImpl<F, O>
where
F: Fact + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
pub label: Option<String>,
pub nodes: Vec<BaseNode<F, O>>,
nodes_by_name: HashMap<String, usize>,
pub inputs: Vec<OutletId>,
pub outputs: Vec<OutletId>,
pub outlet_labels: HashMap<OutletId, String>,
}
impl<F, O> Default for ModelImpl<F, O>
where
F: Fact + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
fn default() -> ModelImpl<F, O> {
ModelImpl {
label: None,
nodes: vec![],
nodes_by_name: HashMap::new(),
inputs: vec![],
outputs: vec![],
outlet_labels: HashMap::new(),
}
}
}
impl<F, O> ModelImpl<F, O>
where
F: Fact + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
ModelImpl<F, O>: Model,
{
pub fn add_node(
&mut self,
name: impl Into<String>,
op: impl Into<O>,
output_facts: TVec<F>,
) -> TractResult<usize> {
let op = op.into();
let name = name.into();
let id = self.nodes.len();
self.nodes_by_name.insert(name.clone(), id);
let outputs =
output_facts.into_iter().map(|fact| OutletFact { fact, successors: tvec!() }).collect();
let node = BaseNode { id, name, op, inputs: vec![], outputs };
self.nodes.push(node);
Ok(id)
}
pub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()> {
if let Some(previous) = self.nodes[inlet.node].inputs.get(inlet.slot).cloned() {
self.nodes[previous.node].outputs[previous.slot]
.successors
.retain(|&mut succ| succ != inlet);
}
{
let prec = &mut self.nodes[outlet.node];
prec.outputs[outlet.slot].successors.push(inlet);
}
let succ = &mut self.nodes[inlet.node];
if inlet.slot == succ.inputs.len() {
succ.inputs.push(outlet);
} else if inlet.slot < succ.inputs.len() {
succ.inputs[inlet.slot] = outlet;
} else {
bail!("Edges must be added in order and consecutive. Trying to connect input {:?} of node {:?} ", inlet.slot, succ)
}
Ok(())
}
pub fn input_outlets(&self) -> TractResult<&[OutletId]> {
Ok(&self.inputs)
}
pub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()> {
self.inputs = inputs.to_vec();
Ok(())
}
pub fn set_input_names(
&mut self,
inputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut ids = vec![];
for i in inputs.into_iter() {
let node = self.node_by_name(i.as_ref())?;
for o in 0..node.outputs.len() {
ids.push(OutletId::new(node.id, o))
}
}
self.inputs = ids;
Ok(())
}
pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
let input = self.input_outlets()?[ix];
self.outlet_fact(input)
}
pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let input = self.input_outlets()?[ix];
self.outlet_fact_mut(input)
}
pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
let outlet = self.inputs[input];
self.set_outlet_fact(outlet, fact)
}
pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
Ok(&self.outputs)
}
pub fn auto_outputs(&mut self) -> TractResult<()> {
let outputs = self
.nodes
.iter()
.flat_map(|n| {
let id = n.id;
n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
(OutletId::new(id, ix), output_fact.successors.len())
})
})
.filter(|(_f, succs)| *succs == 0)
.map(|(f, _)| f)
.collect();
self.outputs = outputs;
Ok(())
}
pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
self.outputs = outputs.to_vec();
Ok(())
}
pub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let ids: Vec<OutletId> = outputs
.into_iter()
.map(|s| self.node_by_name(s.as_ref()).map(|n| OutletId::new(n.id, 0)))
.collect::<TractResult<_>>()?;
self.outputs = ids;
Ok(())
}
pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
let output = self.output_outlets()?[ix];
self.outlet_fact(output)
}
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let output = self.output_outlets()?[ix];
self.outlet_fact_mut(output)
}
pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
let outlet = self.outputs[output];
self.set_outlet_fact(outlet, fact)
}
pub fn node_names(&self) -> impl Iterator<Item = &str> {
self.nodes.iter().map(|s| &*s.name)
}
pub fn node_by_name<S: AsRef<str>>(&self, name: S) -> TractResult<&BaseNode<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&self.nodes[id])
}
pub fn node_by_name_mut(&mut self, name: &str) -> TractResult<&mut BaseNode<F, O>> {
let id: &usize =
self.nodes_by_name.get(name).ok_or_else(|| format!("Node named {} not found", name))?;
Ok(&mut self.nodes[*id])
}
pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
self.node_mut(id).name = name.to_string();
self.nodes_by_name.insert(name.to_string(), id);
Ok(())
}
pub fn node(&self, id: usize) -> &BaseNode<F, O> {
&self.nodes[id]
}
pub fn node_mut(&mut self, id: usize) -> &mut BaseNode<F, O> {
&mut self.nodes[id]
}
pub fn nodes(&self) -> &[BaseNode<F, O>] {
&*self.nodes
}
pub fn nodes_mut(&mut self) -> &mut [BaseNode<F, O>] {
&mut *self.nodes
}
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.ok_or_else(|| format!("Invalid outlet reference: {:?}", outlet).into())
}
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.ok_or_else(|| format!("Invalid outlet reference: {:?}", outlet).into())
}
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
use itertools::Itertools;
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
Ok(unsafe {
outlets
.iter()
.map(|o| &mut *(&self.nodes[o.node].outputs[o.slot].fact as *const F as *mut F))
.collect()
})
}
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
let outlets = &mut self.nodes[outlet.node].outputs;
if outlets.len() <= outlet.slot {
bail!("Invalid outlet refererence: {:?}", outlet)
}
outlets[outlet.slot].fact = fact;
Ok(())
}
pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
self.outlet_labels.get(&outlet).map(|s| &**s)
}
pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) {
self.outlet_labels.insert(outlet, label);
}
pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
self.outlet_labels.iter().find(|(_k, v)| &**v == label).map(|(k, _v)| *k)
}
pub fn eval_order(&self) -> TractResult<Vec<usize>> {
eval_order(&self)
}
pub fn check_edges(&self) -> TractResult<()> {
for node in self.eval_order()? {
let node = &self.nodes[node];
for (ix, input) in node.inputs.iter().enumerate() {
let prec = &self.nodes[input.node];
if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
bail!(
"Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
node.id,
ix,
prec
)
}
}
for (ix, output) in node.outputs.iter().enumerate() {
for succ in &output.successors {
if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
bail!(
"Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
node.id,
ix,
succ
)
}
}
}
}
Ok(())
}
}
impl<F, O> Model for ModelImpl<F, O>
where
F: Fact + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
fn model_label(&self) -> Option<&str> {
self.label.as_ref().map(|s| &**s)
}
fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
Ok(self
.nodes_by_name
.get(name)
.ok_or_else(|| format!("No node found for name: \"{}\"", name))
.map(|x| *x)?)
}
fn node_name(&self, id: usize) -> &str {
&*self.nodes[id].name
}
fn node_inputs(&self, id: usize) -> &[OutletId] {
&*self.nodes[id].inputs
}
fn node_output_count(&self, id: usize) -> usize {
self.nodes[id].outputs.len()
}
fn nodes_len(&self) -> usize {
self.nodes.len()
}
fn node_format(&self, id: usize) -> String {
format!("{:?}", self.nodes[id])
}
fn eval_order(&self) -> TractResult<Vec<usize>> {
crate::model::eval_order(&self)
}
fn eval_order_for_io(&self, inputs: &[usize], outputs: &[usize]) -> TractResult<Vec<usize>> {
crate::model::order::eval_order_for_nodes(&self.nodes, inputs, outputs)
}
fn input_outlets(&self) -> &[OutletId] {
&*self.inputs
}
fn output_outlets(&self) -> &[OutletId] {
&*self.outputs
}
fn node_op(&self, id: usize) -> &dyn Op {
self.nodes[id].op.as_ref()
}
fn outlet_typedfact(&self, outlet: OutletId) -> TractResult<TypedFact> {
self.outlet_fact(outlet)?.to_typed_fact()
}
fn outlet_fact_format(&self, outlet: OutletId) -> String {
format!("{:?}", self.outlet_fact(outlet).unwrap())
}
fn outlet_label(&self, id: OutletId) -> Option<&str> {
self.outlet_label(id)
}
fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
&self.nodes[outlet.node].outputs[outlet.slot].successors
}
}