acme_tensor/ops/
op.rs

1/*
2    Appellation: kinds <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::ops::kinds::reshape::*;
6use crate::shape::{Axis, Shape};
7use crate::tensor::TensorBase;
8use acme::prelude::{BinaryOp, UnaryOp};
9use num::Complex;
10
11pub type BoxTensor<T = f64> = Box<TensorBase<T>>;
12
13#[derive(Clone, Debug, Eq, Hash, PartialEq)]
14#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
15#[non_exhaustive]
16pub enum TensorExpr<A, B = A> {
17    Binary(BoxTensor<A>, BoxTensor<B>, BinaryOp),
18    BinaryScalar(BoxTensor<A>, B, BinaryOp),
19    Unary(BoxTensor<A>, UnaryOp),
20    Matmul(BoxTensor<A>, BoxTensor<B>),
21    Sigmoid(BoxTensor<A>),
22    Shape(ReshapeExpr<A>),
23}
24
25impl<A, B> TensorExpr<A, B> {
26    pub fn binary(lhs: TensorBase<A>, rhs: TensorBase<B>, op: BinaryOp) -> Self {
27        Self::Binary(lhs.boxed(), rhs.boxed(), op)
28    }
29
30    pub fn binary_scalar(lhs: TensorBase<A>, rhs: B, op: BinaryOp) -> Self {
31        Self::BinaryScalar(Box::new(lhs), rhs, op)
32    }
33
34    pub fn binary_scalar_c(
35        lhs: TensorBase<A>,
36        rhs: Complex<A>,
37        op: BinaryOp,
38    ) -> TensorExpr<A, Complex<A>> {
39        TensorExpr::BinaryScalar(Box::new(lhs), rhs, op)
40    }
41
42    pub fn broadcast(tensor: TensorBase<A>, shape: Shape) -> Self {
43        Self::shape(ReshapeExpr::broadcast(tensor, shape))
44    }
45
46    pub fn matmul(lhs: TensorBase<A>, rhs: TensorBase<B>) -> Self {
47        Self::Matmul(Box::new(lhs), Box::new(rhs))
48    }
49
50    pub fn reshape(tensor: TensorBase<A>, shape: Shape) -> Self {
51        Self::shape(ReshapeExpr::reshape(tensor, shape))
52    }
53
54    pub fn shape(expr: ReshapeExpr<A>) -> Self {
55        Self::Shape(expr)
56    }
57
58    pub fn sigmoid(tensor: TensorBase<A>) -> Self {
59        Self::Sigmoid(Box::new(tensor))
60    }
61
62    pub fn swap_axes(tensor: TensorBase<A>, swap: Axis, with: Axis) -> Self {
63        Self::shape(ReshapeExpr::swap_axes(tensor, swap, with))
64    }
65
66    pub fn transpose(tensor: TensorBase<A>) -> Self {
67        Self::Shape(ReshapeExpr::transpose(tensor))
68    }
69
70    pub fn unary(tensor: TensorBase<A>, op: UnaryOp) -> Self {
71        Self::Unary(Box::new(tensor), op)
72    }
73
74    pub fn lhs(self) -> Option<TensorBase<A>> {
75        match self {
76            Self::Binary(lhs, _, _) => Some(*lhs),
77            Self::BinaryScalar(lhs, _, _) => Some(*lhs),
78            Self::Unary(lhs, _) => Some(*lhs),
79            Self::Matmul(lhs, _) => Some(*lhs),
80            _ => None,
81        }
82    }
83
84    pub fn rhs(self) -> Option<TensorBase<B>> {
85        match self {
86            Self::Binary(_, rhs, _) => Some(*rhs),
87            Self::BinaryScalar(_, scalar, _) => Some(TensorBase::from_scalar(scalar)),
88            Self::Matmul(_, rhs) => Some(*rhs),
89            _ => None,
90        }
91    }
92    pub fn view(&self) -> TensorExpr<&A, &B> {
93        match self {
94            TensorExpr::Binary(lhs, rhs, op) => TensorExpr::binary(lhs.view(), rhs.view(), *op),
95            TensorExpr::BinaryScalar(lhs, rhs, op) => {
96                TensorExpr::binary_scalar(lhs.view(), rhs, *op)
97            }
98            TensorExpr::Unary(tensor, op) => TensorExpr::unary(tensor.view(), *op),
99            TensorExpr::Matmul(lhs, rhs) => TensorExpr::matmul(lhs.view(), rhs.view()),
100            TensorExpr::Sigmoid(tensor) => TensorExpr::sigmoid(tensor.view()),
101            TensorExpr::Shape(inner) => TensorExpr::Shape(inner.view()),
102        }
103    }
104    pub fn view_mut(&mut self) -> TensorExpr<&mut A, &mut B> {
105        match self {
106            TensorExpr::Binary(lhs, rhs, op) => {
107                TensorExpr::binary(lhs.view_mut(), rhs.view_mut(), *op)
108            }
109
110            TensorExpr::BinaryScalar(lhs, rhs, op) => {
111                TensorExpr::binary_scalar(lhs.view_mut(), rhs, *op)
112            }
113            TensorExpr::Unary(tensor, op) => TensorExpr::unary(tensor.view_mut(), *op),
114            TensorExpr::Matmul(lhs, rhs) => TensorExpr::matmul(lhs.view_mut(), rhs.view_mut()),
115            TensorExpr::Sigmoid(tensor) => TensorExpr::sigmoid(tensor.view_mut()),
116            TensorExpr::Shape(inner) => TensorExpr::Shape(inner.view_mut()),
117        }
118    }
119}