use std::any::Any;
use std::borrow::Cow;
use std::convert::Infallible;
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display};
use rten_gemm::PackedBMatrix;
use rten_tensor::errors::DimensionError;
use rten_tensor::{Layout, Storage, TensorBase};
use smallvec::SmallVec;
use crate::BufferPool;
use crate::graph::{CaptureEnv, Graph, RunError, RunOptions};
use crate::infer_shapes::InferShapes;
use crate::timing::Profiler;
use crate::value::{DataType, DataTypeOf, TryFromValueError, Value, ValueType, ValueView};
use crate::weight_cache::WeightCache;
pub enum PrepackedInput {
FloatBMatrix(PackedBMatrix<f32>),
Int8BMatrix(PackedBMatrix<i8>),
}
impl PrepackedInput {
fn dtype(&self) -> DataType {
match self {
Self::FloatBMatrix(_) => DataType::Float,
Self::Int8BMatrix(_) => DataType::Int8,
}
}
}
macro_rules! impl_prepacked_input_conversions {
($type:ty, $variant:ident) => {
impl From<PackedBMatrix<$type>> for PrepackedInput {
fn from(value: PackedBMatrix<$type>) -> Self {
PrepackedInput::$variant(value)
}
}
impl<'a> TryFrom<&'a PrepackedInput> for &'a PackedBMatrix<$type> {
type Error = TryFromValueError;
fn try_from(ppi: &'a PrepackedInput) -> Result<Self, Self::Error> {
match ppi {
PrepackedInput::$variant(packed) => Ok(packed),
_ => Err(TryFromValueError::WrongType {
actual: ValueType::Tensor(ppi.dtype()),
expected: ValueType::Tensor(<$type as DataTypeOf>::dtype_of()),
}),
}
}
}
};
}
impl_prepacked_input_conversions!(f32, FloatBMatrix);
impl_prepacked_input_conversions!(i8, Int8BMatrix);
pub trait IntoOpResult {
fn into_op_result(self) -> Result<OutputList, OpError>;
}
impl IntoOpResult for Result<Value, OpError> {
fn into_op_result(self) -> Result<OutputList, OpError> {
self.map(|out| [out].into())
}
}
impl IntoOpResult for Value {
fn into_op_result(self) -> Result<OutputList, OpError> {
Ok([self].into())
}
}
impl<S: Storage, L: Layout> IntoOpResult for TensorBase<S, L>
where
Value: From<TensorBase<S, L>>,
{
fn into_op_result(self) -> Result<OutputList, OpError> {
let output: Value = self.into();
Ok([output].into())
}
}
impl<S: Storage, L: Layout> IntoOpResult for Result<TensorBase<S, L>, OpError>
where
Value: From<TensorBase<S, L>>,
{
fn into_op_result(self) -> Result<OutputList, OpError> {
self.map(|tensor| [tensor.into()].into())
}
}
impl<T> IntoOpResult for Result<Vec<T>, OpError>
where
Value: From<T>,
{
fn into_op_result(self) -> Result<OutputList, OpError> {
self.map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
}
}
#[derive(Eq, PartialEq, Debug)]
pub enum OpError {
CastFailed(TryFromValueError),
InputCastFailed {
index: usize,
error: TryFromValueError,
},
UnsupportedType,
IncompatibleInputShapes(&'static str),
MissingInputs,
InvalidValue(&'static str),
UnsupportedValue(&'static str),
}
impl OpError {
pub fn with_input_index(self, index: usize) -> OpError {
match self {
Self::CastFailed(error) => OpError::InputCastFailed { index, error },
Self::InputCastFailed { error, .. } => OpError::InputCastFailed { index, error },
other => other,
}
}
}
impl From<DimensionError> for OpError {
fn from(val: DimensionError) -> OpError {
OpError::CastFailed(val.into())
}
}
impl From<TryFromValueError> for OpError {
fn from(val: TryFromValueError) -> OpError {
OpError::CastFailed(val)
}
}
impl From<Infallible> for OpError {
fn from(x: Infallible) -> OpError {
match x {}
}
}
impl Display for OpError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OpError::CastFailed(err) => write!(f, "{}", err),
OpError::InputCastFailed { index, error } => {
write!(f, "conversion error for input {}: {}", index, error)
}
OpError::IncompatibleInputShapes(details) => {
write!(f, "incompatible input shapes: {}", details)
}
OpError::MissingInputs => write!(f, "required inputs were missing"),
OpError::InvalidValue(details) => {
write!(f, "input or attribute has invalid value: {}", details)
}
OpError::UnsupportedValue(details) => {
write!(f, "unsupported input or attribute value: {}", details)
}
OpError::UnsupportedType => {
write!(f, "unsupported input type")
}
}
}
}
impl Error for OpError {}
macro_rules! static_dims {
($tensor:ident, $ndim:literal, $dim_names:literal) => {{
use rten_tensor::prelude::*;
if $tensor.ndim() != $ndim {
Err(OpError::InvalidValue(concat!(
stringify!($tensor),
" must have ",
stringify!($ndim),
" dims (",
$dim_names,
")"
)))
} else {
Ok($tensor.nd_view::<$ndim>())
}
}};
($tensor:ident, $ndim:literal) => {{
use rten_tensor::prelude::*;
if $tensor.ndim() != $ndim {
Err(OpError::InvalidValue(concat!(
stringify!($tensor),
" must have ",
stringify!($ndim),
" dims"
)))
} else {
Ok($tensor.nd_view::<$ndim>())
}
}};
($tensor:ident?, $ndim: expr) => {
if let Some($tensor) = $tensor.as_ref() {
Some(static_dims!($tensor, $ndim))
} else {
None
}
};
}
pub(crate) use static_dims;
pub struct OpRunContext<'a, 'i> {
pool: &'a BufferPool,
inputs: &'a InputList<'i>,
n_outputs: Option<u32>,
name: Option<&'a str>,
}
impl<'a, 'i> OpRunContext<'a, 'i> {
pub fn new(pool: &'a BufferPool, inputs: &'a InputList<'i>) -> Self {
OpRunContext {
pool,
inputs,
n_outputs: None,
name: None,
}
}
pub fn with_new_inputs<'b, 'il>(&self, inputs: &'b InputList<'il>) -> OpRunContext<'b, 'il>
where
'a: 'b,
{
OpRunContext { inputs, ..*self }
}
pub fn pool(&self) -> &BufferPool {
self.pool
}
pub fn inputs(&self) -> &InputList<'i> {
self.inputs
}
pub fn set_num_outputs(&mut self, n: u32) {
self.n_outputs = Some(n);
}
pub fn num_outputs(&self) -> Option<u32> {
self.n_outputs
}
pub fn set_name(&mut self, name: Option<&'a str>) {
self.name = name;
}
pub fn name(&self) -> Option<&str> {
self.name
}
}
#[derive(Copy, Clone)]
pub enum OutputType {
Fixed(ValueType),
CopyFromInput(u32),
ElementTypeOfInputSequence(u32),
SequenceWithElementTypeOfInput(u32),
}
pub type OutputTypeList = SmallVec<[OutputType; 1]>;
pub struct OutputTypesContext {
pub num_outputs: usize,
}
pub type OutputList = SmallVec<[Value; 1]>;
pub trait Operator: Any + Debug {
fn name(&self) -> &str;
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError>;
fn max_inputs(&self) -> Option<usize>;
fn output_types(&self, ctx: &OutputTypesContext) -> Option<OutputTypeList>;
fn can_run_in_place(&self) -> bool {
false
}
fn is_commutative(&self) -> bool {
false
}
fn is_deterministic(&self) -> bool {
true
}
fn run_in_place(
&self,
#[allow(unused)] input: Value,
#[allow(unused)] ctx: &OpRunContext,
) -> Result<Value, OpError> {
Err(OpError::InvalidValue("In-place execution not supported"))
}
fn prepack_inputs(&self) -> SmallVec<[usize; 1]> {
SmallVec::new()
}
fn prepack(
&self,
#[allow(unused)] index: usize,
#[allow(unused)] input: ValueView,
) -> Option<PrepackedInput> {
None
}
fn as_subgraph_op(&self) -> Option<&dyn SubgraphOperator> {
None
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
None
}
}
impl dyn Operator {
pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
(self as &dyn Any).downcast_ref()
}
}
impl dyn Operator + Send + Sync {
pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
(self as &dyn Any).downcast_ref()
}
}
pub trait SubgraphOperator: Operator {
fn subgraphs(&self) -> SmallVec<[&Graph; 2]>;
fn run_subgraph<'a>(
&'a self,
ctx: &OpRunContext,
#[allow(unused)] captures: CaptureEnv,
#[allow(unused)] weight_cache: Option<&[WeightCache]>,
#[allow(unused)] profiler: Option<&mut Profiler<'a>>,
#[allow(unused)] run_opts: Option<RunOptions>,
) -> Result<OutputList, RunError>;
}
#[cfg(test)]
pub trait OperatorExt: Operator {
fn run_simple<'a, I: Into<InputList<'a>>, O: TryFrom<Value>>(
&self,
inputs: I,
) -> Result<O, OpError>
where
OpError: From<<O as TryFrom<Value>>::Error>,
{
let pool = BufferPool::new();
let inputs = inputs.into();
let ctx = OpRunContext::new(&pool, &inputs);
let mut outputs = self.run(&ctx)?;
Ok(outputs.remove(0).try_into()?)
}
fn run_simple_in_place<'a, M: Into<Value>, I: Into<InputList<'a>>, O: TryFrom<Value>>(
&self,
mut_input: M,
inputs: I,
) -> Result<O, OpError>
where
OpError: From<<O as TryFrom<Value>>::Error>,
{
let pool = BufferPool::new();
let inputs = inputs.into();
let ctx = OpRunContext::new(&pool, &inputs);
let output = self.run_in_place(mut_input.into(), &ctx)?;
let typed_output = output.try_into()?;
Ok(typed_output)
}
}
#[cfg(test)]
impl<O: ?Sized + Operator> OperatorExt for O {}
#[derive(Clone)]
pub struct InputList<'a> {
inputs: Cow<'a, [Option<ValueView<'a>>]>,
get_prepacked: Option<&'a dyn Fn(usize) -> Option<&'a PrepackedInput>>,
first_input_omitted: bool,
}
impl<'a> InputList<'a> {
pub fn new() -> InputList<'a> {
InputList {
inputs: Cow::Owned(vec![]),
get_prepacked: None,
first_input_omitted: false,
}
}
pub fn with_first_input_omitted(mut self, offset: bool) -> Self {
self.first_input_omitted = offset;
self
}
pub fn len(&self) -> usize {
self.inputs.len()
}
pub fn is_empty(&self) -> bool {
self.inputs.is_empty()
}
pub fn push<I: Into<ValueView<'a>>>(&mut self, inp: I) {
self.inputs.to_mut().push(Some(inp.into()))
}
pub fn push_optional<I: Into<ValueView<'a>>>(&mut self, inp: Option<I>) {
self.inputs.to_mut().push(inp.map(|inp| inp.into()))
}
pub fn from(inputs: &[ValueView<'a>]) -> InputList<'a> {
InputList {
inputs: inputs.iter().cloned().map(Some).collect(),
get_prepacked: None,
first_input_omitted: false,
}
}
pub fn from_optional(inputs: &'a [Option<ValueView<'a>>]) -> InputList<'a> {
InputList {
inputs: Cow::Borrowed(inputs),
get_prepacked: None,
first_input_omitted: false,
}
}
pub fn with_prepacked(
mut self,
lookup: &'a dyn Fn(usize) -> Option<&'a PrepackedInput>,
) -> Self {
self.get_prepacked = Some(lookup);
self
}
pub fn get(&self, index: usize) -> Option<ValueView<'a>> {
self.inputs.get(index).cloned().flatten()
}
pub fn get_prepacked(&self, index: usize) -> Option<&'a PrepackedInput> {
self.get_prepacked.and_then(|gp| gp(index))
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut ValueView<'a>> {
self.inputs.to_mut().get_mut(index)?.as_mut()
}
pub fn get_as<T>(&self, index: usize) -> Result<Option<T>, OpError>
where
T: TryFrom<ValueView<'a>, Error = TryFromValueError>,
{
self.get(index)
.map(|input| {
input.try_into().map_err(|error| OpError::InputCastFailed {
index: self.to_real_index(index),
error,
})
})
.transpose()
}
pub fn require(&self, index: usize) -> Result<ValueView<'a>, OpError> {
self.get(index).ok_or(OpError::MissingInputs)
}
pub fn require_as<T>(&self, index: usize) -> Result<T, OpError>
where
T: TryFrom<ValueView<'a>, Error = TryFromValueError>,
{
self.require(index).and_then(|input| {
input.try_into().map_err(|error| OpError::InputCastFailed {
index: self.to_real_index(index),
error,
})
})
}
pub fn iter<'b>(&'b self) -> impl Iterator<Item = Option<ValueView<'a>>> + 'b {
self.inputs.iter().cloned()
}
fn to_real_index(&self, index: usize) -> usize {
if self.first_input_omitted {
index + 1
} else {
index
}
}
}
impl Default for InputList<'_> {
fn default() -> Self {
Self::new()
}
}
impl<'a, I: Into<ValueView<'a>>> From<I> for InputList<'a> {
fn from(val: I) -> InputList<'a> {
InputList::from(&[val.into()])
}
}
impl<'a> From<()> for InputList<'a> {
fn from(_: ()) -> InputList<'a> {
Self::default()
}
}
impl<'a, I1: Into<ValueView<'a>>> From<(I1,)> for InputList<'a> {
fn from((a,): (I1,)) -> InputList<'a> {
InputList::from(&[a.into()])
}
}
impl<'a, I1: Into<ValueView<'a>>, I2: Into<ValueView<'a>>> From<(I1, I2)> for InputList<'a> {
fn from((a, b): (I1, I2)) -> InputList<'a> {
InputList::from(&[a.into(), b.into()])
}
}
impl<'a, I1: Into<ValueView<'a>>, I2: Into<ValueView<'a>>, I3: Into<ValueView<'a>>>
From<(I1, I2, I3)> for InputList<'a>
{
fn from((a, b, c): (I1, I2, I3)) -> InputList<'a> {
InputList::from(&[a.into(), b.into(), c.into()])
}
}
impl<'a> Extend<ValueView<'a>> for InputList<'a> {
fn extend<T>(&mut self, iter: T)
where
T: IntoIterator<Item = ValueView<'a>>,
{
for item in iter {
self.push(item);
}
}
}
impl<'a> Extend<Option<ValueView<'a>>> for InputList<'a> {
fn extend<T>(&mut self, iter: T)
where
T: IntoIterator<Item = Option<ValueView<'a>>>,
{
for item in iter {
self.push_optional(item);
}
}
}
impl<'a, A> FromIterator<A> for InputList<'a>
where
InputList<'a>: Extend<A>,
{
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = A>,
{
let mut list = InputList::new();
list.extend(iter);
list
}
}
#[cfg(test)]
mod tests {
use rten_tensor::prelude::*;
use rten_tensor::{Tensor, TensorView};
use crate::operator::{InputList, OpError, Operator};
use crate::ops::{Add, Sub};
#[test]
fn test_input_list_first_input_omitted() {
let tensor = Tensor::<f32>::zeros(&[2, 2]);
let inputs = InputList::from(&[tensor.view().into()]).with_first_input_omitted(false);
let err = inputs.require_as::<TensorView<i32>>(0).err().unwrap();
assert!(matches!(err, OpError::InputCastFailed { index: 0, .. }));
let inputs = InputList::from(&[tensor.view().into()]).with_first_input_omitted(true);
let err = inputs.require_as::<TensorView<i32>>(0).err().unwrap();
assert!(matches!(err, OpError::InputCastFailed { index: 1, .. }));
}
#[test]
fn test_downcast_operator() {
let add_op = Add {};
let sub_op = Sub {};
let add_op_dyn: &dyn Operator = &add_op;
let sub_op_dyn: &dyn Operator = &sub_op;
assert!(add_op_dyn.downcast_ref::<Add>().is_some());
assert!(sub_op_dyn.downcast_ref::<Sub>().is_some());
}
}