use super::*;
use crate::internal::*;
use crate::ops::Op;
use tract_itertools::Itertools;
use std::fmt;
use std::fmt::{Debug, Display};
use std::hash::Hash;
#[derive(Debug, Clone, Educe)]
#[educe(Hash)]
pub struct Node<F: Fact + Hash, O: Hash> {
pub id: usize,
pub name: String,
pub inputs: Vec<OutletId>,
#[cfg_attr(feature = "serialize", serde(skip))]
pub op: O,
pub outputs: TVec<Outlet<F>>,
}
impl<F: Fact + Hash, O: Hash + std::fmt::Display> fmt::Display for Node<F, O> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "#{} \"{}\" {}", self.id, self.name, self.op)
}
}
impl<F, NodeOp> Node<F, NodeOp>
where
F: Fact + Hash,
NodeOp: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + AsMut<dyn Op> + Hash,
{
pub fn op(&self) -> &dyn Op {
self.op.as_ref()
}
pub fn op_as<O: Op>(&self) -> Option<&O> {
self.op().downcast_ref::<O>()
}
pub fn op_as_mut<O: Op>(&mut self) -> Option<&mut O> {
self.op.as_mut().downcast_mut::<O>()
}
pub fn op_is<O: Op>(&self) -> bool {
self.op_as::<O>().is_some()
}
pub fn same_as(&self, other: &Node<F, NodeOp>) -> bool {
self.inputs == other.inputs && self.op().same_as(other.op())
}
}
#[derive(Clone, Default, Educe)]
#[educe(Hash)]
pub struct Outlet<F: Fact + Hash> {
pub fact: F,
pub successors: TVec<InletId>,
}
impl<F: Fact + Hash> fmt::Debug for Outlet<F> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(
fmt,
"{:?} {}",
self.fact,
self.successors.iter().map(|o| format!("{o:?}")).join(" ")
)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, new)]
pub struct OutletId {
pub node: usize,
pub slot: usize,
}
impl fmt::Debug for OutletId {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}/{}>", self.node, self.slot)
}
}
impl From<usize> for OutletId {
fn from(node: usize) -> OutletId {
OutletId::new(node, 0)
}
}
impl From<(usize, usize)> for OutletId {
fn from(pair: (usize, usize)) -> OutletId {
OutletId::new(pair.0, pair.1)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, new, Ord, PartialOrd)]
pub struct InletId {
pub node: usize,
pub slot: usize,
}
impl fmt::Debug for InletId {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, ">{}/{}", self.node, self.slot)
}
}