use std::borrow::Cow;
use std::fmt;
use std::sync::Arc;
use rten_tensor::prelude::*;
use rten_tensor::{ArcTensor, DynLayout, TensorView};
use super::NodeId;
use crate::constant_storage::ArcTensorView;
use crate::operator::Operator;
use crate::value::{DataType, ValueType, ValueView};
#[derive(Debug)]
pub enum Node {
Operator(OperatorNode),
Constant(Constant),
Value(ValueNode),
}
impl Node {
pub fn name(&self) -> Option<&str> {
match self {
Node::Operator(node) => node.name(),
Node::Constant(constant) => constant.name(),
Node::Value(node) => node.name(),
}
}
pub fn shape(&self) -> Option<Cow<'_, [Dimension]>> {
let dims_from_fixed_shape =
|shape: &[usize]| shape.iter().copied().map(Dimension::Fixed).collect();
match self {
Node::Operator(_) => None,
Node::Constant(node) => Some(Cow::Owned(dims_from_fixed_shape(node.layout().shape()))),
Node::Value(node) => node.shape(),
}
}
pub fn dtype(&self) -> Option<ValueType> {
match self {
Node::Value(node) => node.dtype,
Node::Constant(constant) => Some(ValueType::Tensor(constant.dtype())),
Node::Operator(_) => None,
}
}
pub fn as_operator(&self) -> Option<&OperatorNode> {
match self {
Node::Operator(op) => Some(op),
_ => None,
}
}
pub fn as_constant(&self) -> Option<&Constant> {
match self {
Node::Constant(c) => Some(c),
_ => None,
}
}
}
#[derive(Clone, PartialEq)]
pub enum Dimension {
Fixed(usize),
Symbolic(String),
}
impl From<usize> for Dimension {
fn from(val: usize) -> Dimension {
Dimension::Fixed(val)
}
}
impl From<String> for Dimension {
fn from(name: String) -> Dimension {
Dimension::Symbolic(name)
}
}
impl<'a> From<&'a str> for Dimension {
fn from(name: &'a str) -> Dimension {
Dimension::Symbolic(name.into())
}
}
impl fmt::Debug for Dimension {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Fixed(size) => write!(f, "{}", size),
Self::Symbolic(name) => write!(f, "\"{}\"", name),
}
}
}
#[derive(Debug)]
pub struct OperatorNode {
name: Option<String>,
inputs: Box<[Option<NodeId>]>,
outputs: Box<[Option<NodeId>]>,
operator: Arc<dyn Operator + Send + Sync>,
capture_names: Box<[String]>,
}
impl OperatorNode {
pub fn new(
name: Option<&str>,
input_ids: &[Option<NodeId>],
output_ids: &[Option<NodeId>],
operator: Arc<dyn Operator + Send + Sync>,
) -> Self {
let mut capture_names = Vec::new();
if let Some(subgraph_op) = operator.as_subgraph_op() {
for subgraph in subgraph_op.subgraphs() {
capture_names.extend(subgraph.capture_names().iter().map(|s| s.to_string()));
}
}
OperatorNode {
name: name.map(|s| s.to_owned()),
inputs: input_ids.into(),
outputs: output_ids.into(),
operator,
capture_names: capture_names.into(),
}
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn input_ids(&self) -> &[Option<NodeId>] {
&self.inputs
}
pub fn output_ids(&self) -> &[Option<NodeId>] {
&self.outputs
}
pub fn capture_names(&self) -> impl Iterator<Item = &str> {
self.capture_names.iter().map(|s| s.as_ref())
}
pub fn operator(&self) -> &dyn Operator {
self.operator.as_ref()
}
pub fn clone_operator(&self) -> Arc<dyn Operator + Send + Sync> {
self.operator.clone()
}
pub(super) fn replace_input(&mut self, old_id: NodeId, new_id: NodeId) {
for input_id in self.inputs.iter_mut() {
if *input_id == Some(old_id) {
*input_id = Some(new_id);
}
}
}
}
#[derive(Debug)]
pub struct ValueNode {
name: Option<String>,
shape: Option<Vec<Dimension>>,
dtype: Option<ValueType>,
}
impl ValueNode {
pub fn new(
name: Option<&str>,
shape: Option<Vec<Dimension>>,
dtype: Option<ValueType>,
) -> Self {
ValueNode {
name: name.map(|s| s.to_owned()),
shape,
dtype,
}
}
pub fn ndim(&self) -> Option<usize> {
self.shape.as_ref().map(|s| s.len())
}
pub fn shape(&self) -> Option<Cow<'_, [Dimension]>> {
self.shape.as_deref().map(Cow::Borrowed)
}
pub fn dtype(&self) -> Option<ValueType> {
self.dtype
}
pub fn update_shape(&mut self, shape: Vec<Dimension>) {
self.shape = Some(shape);
}
pub fn update_type(&mut self, dtype: ValueType) {
self.dtype = Some(dtype);
}
fn name(&self) -> Option<&str> {
self.name.as_deref()
}
}
#[derive(Debug)]
pub enum Constant {
Float(ConstantNode<f32>),
Int32(ConstantNode<i32>),
Int8(ConstantNode<i8>),
UInt8(ConstantNode<u8>),
}
impl Constant {
pub fn new<T>(name: Option<&str>, tensor: impl Into<ConstantNodeData<T>>) -> Self
where
Self: From<ConstantNode<T>>,
{
ConstantNode::new(name, tensor.into()).into()
}
pub fn name(&self) -> Option<&str> {
match self {
Constant::Float(f) => f.name.as_deref(),
Constant::Int32(i) => i.name.as_deref(),
Constant::Int8(i) => i.name.as_deref(),
Constant::UInt8(i) => i.name.as_deref(),
}
}
pub fn shape(&self) -> &[usize] {
self.layout().shape()
}
pub fn ndim(&self) -> usize {
self.layout().ndim()
}
pub fn clone_ref(&self) -> Option<Constant> {
match self {
Constant::Float(f) => f.clone_ref().map(Constant::Float),
Constant::Int32(i) => i.clone_ref().map(Constant::Int32),
Constant::Int8(i) => i.clone_ref().map(Constant::Int8),
Constant::UInt8(i) => i.clone_ref().map(Constant::UInt8),
}
}
pub fn layout(&self) -> &DynLayout {
match self {
Constant::Float(f) => f.layout(),
Constant::Int32(i) => i.layout(),
Constant::Int8(i) => i.layout(),
Constant::UInt8(i) => i.layout(),
}
}
pub fn as_view(&self) -> ValueView<'_> {
match self {
Constant::Float(f) => ValueView::FloatTensor(f.view()),
Constant::Int32(i) => ValueView::Int32Tensor(i.view()),
Constant::Int8(i) => ValueView::Int8Tensor(i.view()),
Constant::UInt8(i) => ValueView::UInt8Tensor(i.view()),
}
}
fn dtype(&self) -> DataType {
match self {
Constant::Float(_) => DataType::Float,
Constant::Int32(_) => DataType::Int32,
Constant::Int8(_) => DataType::Int8,
Constant::UInt8(_) => DataType::UInt8,
}
}
}
#[derive(Debug)]
pub struct ConstantNode<T> {
name: Option<String>,
data: ConstantNodeData<T>,
}
impl<T> ConstantNode<T> {
pub fn new(name: Option<&str>, data: ConstantNodeData<T>) -> Self {
ConstantNode {
name: name.map(|s| s.to_owned()),
data,
}
}
pub fn view(&self) -> TensorView<'_, T> {
match &self.data {
ConstantNodeData::ArcSlice(data) => data.view(),
ConstantNodeData::Arc(data) => data.view(),
}
}
fn clone_ref(&self) -> Option<ConstantNode<T>> {
let data = self.data.clone_ref()?;
Some(ConstantNode {
name: self.name.clone(),
data,
})
}
fn layout(&self) -> &DynLayout {
match &self.data {
ConstantNodeData::ArcSlice(data) => data.layout(),
ConstantNodeData::Arc(data) => data.layout(),
}
}
}
macro_rules! impl_constant_node {
($scalar_type:ty, $variant:ident) => {
impl From<ConstantNode<$scalar_type>> for Constant {
fn from(node: ConstantNode<$scalar_type>) -> Constant {
Constant::$variant(node)
}
}
};
}
impl_constant_node!(f32, Float);
impl_constant_node!(i32, Int32);
impl_constant_node!(i8, Int8);
impl_constant_node!(u8, UInt8);
#[derive(Debug)]
pub enum ConstantNodeData<T> {
ArcSlice(ArcTensorView<T>),
Arc(ArcTensor<T>),
}
impl<T> ConstantNodeData<T> {
fn clone_ref(&self) -> Option<ConstantNodeData<T>> {
match self {
ConstantNodeData::ArcSlice(view) => Some(ConstantNodeData::ArcSlice(view.clone())),
ConstantNodeData::Arc(view) => Some(ConstantNodeData::Arc(view.clone())),
}
}
}
impl<T> From<ArcTensorView<T>> for ConstantNodeData<T> {
fn from(val: ArcTensorView<T>) -> ConstantNodeData<T> {
ConstantNodeData::ArcSlice(val)
}
}
impl<T> From<ArcTensor<T>> for ConstantNodeData<T> {
fn from(val: ArcTensor<T>) -> ConstantNodeData<T> {
ConstantNodeData::Arc(val)
}
}
pub trait TypedConstant<T> {
fn as_typed_view(&self) -> Option<TensorView<'_, T>>;
fn as_scalar(&self) -> Option<T>;
fn as_vector(&self) -> Option<&[T]>;
}
macro_rules! impl_typed_constant {
($type:ty, $variant:ident) => {
impl TypedConstant<$type> for Constant {
fn as_typed_view(&self) -> Option<TensorView<'_, $type>> {
match self {
Constant::$variant(tensor) => Some(tensor.view()),
_ => None,
}
}
fn as_scalar(&self) -> Option<$type> {
TypedConstant::as_typed_view(self).and_then(|view| view.item().copied())
}
fn as_vector(&self) -> Option<&[$type]> {
TypedConstant::as_typed_view(self).and_then(|view| {
match (view.ndim(), view.data()) {
(1, Some(vec_data)) => Some(vec_data),
_ => None,
}
})
}
}
};
}
impl_typed_constant!(f32, Float);
impl_typed_constant!(i32, Int32);
impl_typed_constant!(i8, Int8);
impl_typed_constant!(u8, UInt8);