use std::{
marker::PhantomData,
ops::{Add, Div, Mul, Neg, Sub},
sync::{Arc, RwLock, RwLockReadGuard},
};
use crate::{
device::Dev,
graph::{BinaryOpType, Graph, GraphTensorId, Op, UnaryOpType},
DType, Shape, R1, R2, R3,
};
use super::contiguous_strides;
#[derive(Clone)]
pub struct GraphTensor<S: Shape, T: DType, D: Dev> {
id: GraphTensorId,
graph: Arc<RwLock<Graph<T>>>,
strides: Vec<usize>,
_ghost: PhantomData<(S, T, D)>,
}
impl<const B: usize, const M: usize, const K: usize, T: DType, D: Dev>
GraphTensor<R3<B, M, K>, T, D>
{
#[must_use]
pub fn matmul<const N: usize>(
self,
rhs: GraphTensor<R3<B, K, N>, T, D>,
) -> GraphTensor<R3<B, M, N>, T, D> {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<R3<B, M, N>>(
Op::MatMul {
l_id: self.id(),
r_id: rhs.id(),
o_id: None,
k: K,
alpha: T::ZERO,
beta: T::ONE,
},
&self.strides,
&id,
);
GraphTensor {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}
#[must_use]
pub fn matmul_axpby<const N: usize>(
self,
rhs: GraphTensor<R3<B, K, N>, T, D>,
out: GraphTensor<R3<B, M, N>, T, D>,
alpha: T,
beta: T,
) -> GraphTensor<R3<B, M, N>, T, D> {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<R3<B, M, N>>(
Op::MatMul {
l_id: self.id(),
r_id: rhs.id(),
o_id: Some(out.id()),
k: K,
alpha,
beta,
},
&self.strides,
&id,
);
GraphTensor {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}
}
impl<S: Shape, T: DType, D: Dev> GraphTensor<S, T, D> {
#[must_use]
pub fn fill(graph: &mut Graph<T>, v: T) -> Self {
let id = graph.next_id();
let strides = contiguous_strides(&S::shape());
graph.add_op::<S>(Op::Fill { v }, &strides, &id);
Self {
id,
graph: Arc::new(RwLock::new(graph.clone())),
strides,
_ghost: PhantomData,
}
}
#[must_use]
pub fn zeros(graph: &mut Graph<T>) -> Self {
Self::fill(graph, T::ZERO)
}
#[must_use]
pub fn ones(graph: &mut Graph<T>) -> Self {
Self::fill(graph, T::ONE)
}
#[must_use]
pub fn sqrt(self) -> GraphTensor<S, T, D> {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<S>(
Op::UnaryOp {
v_id: self.id(),
operator: UnaryOpType::Sqrt,
},
&self.strides,
&id,
);
Self {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}
#[must_use]
pub fn rand(graph: &mut Graph<T>) -> Self {
let id = graph.next_id();
let strides = contiguous_strides(&S::shape());
graph.add_op::<S>(Op::Rand, &strides, &id);
GraphTensor {
id,
graph: Arc::new(RwLock::new(graph.clone())),
strides,
_ghost: PhantomData,
}
}
#[must_use]
pub fn randn(graph: &mut Graph<T>, mean: T, std: T) -> Self {
let id = graph.next_id();
let strides = contiguous_strides(&S::shape());
graph.add_op::<S>(Op::Randn { mean, std }, &strides, &id);
GraphTensor {
id,
graph: Arc::new(RwLock::new(graph.clone())),
strides,
_ghost: PhantomData,
}
}
}
impl<S: Shape, T: DType, D: Dev> GraphTensor<S, T, D> {
pub fn graph(&self) -> RwLockReadGuard<Graph<T>> {
self.graph.read().unwrap()
}
pub fn id(&self) -> GraphTensorId {
self.id.clone()
}
}
impl<const A: usize, T: DType, D: Dev> GraphTensor<R1<A>, T, D> {
#[must_use]
pub fn arange(graph: &mut Graph<T>, start: T, stop: T) -> Self {
let id = graph.next_id();
let step = (stop.to_f64() - start.to_f64()) / (A as f64);
let strides = contiguous_strides(&[A]);
graph.add_op::<R1<A>>(
Op::Arange {
start,
step: T::from_f64(step),
stop,
},
&strides,
&id,
);
Self {
id,
graph: Arc::new(RwLock::new(graph.clone())),
strides,
_ghost: PhantomData,
}
}
}
impl<T: DType, const A: usize, const B: usize, D: Dev> GraphTensor<R2<A, B>, T, D> {
pub fn t(&self) -> GraphTensor<R2<B, A>, T, D> {
let mut new_strides = self.strides.clone();
new_strides.swap(0, 1);
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<R2<B, A>>(
Op::Permute {
v_id: self.id.clone(),
},
&new_strides,
&id,
);
GraphTensor {
id,
graph: self.graph.clone(),
strides: new_strides,
_ghost: PhantomData,
}
}
}
impl<T: DType, const A: usize, const B: usize, const C: usize, D: Dev>
GraphTensor<R3<A, B, C>, T, D>
{
pub fn t(&self) -> GraphTensor<R3<A, C, B>, T, D> {
let mut new_strides = self.strides.clone();
new_strides.swap(1, 2);
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<R3<A, C, B>>(
Op::Permute {
v_id: self.id.clone(),
},
&new_strides,
&id,
);
GraphTensor {
id,
graph: self.graph.clone(),
strides: new_strides,
_ghost: PhantomData,
}
}
}
macro_rules! graphtensor_binop {
($trait:ident, $fn_name:ident) => {
impl<S: Shape, T: DType, D: Dev> $trait for GraphTensor<S, T, D> {
type Output = GraphTensor<S, T, D>;
fn $fn_name(self, rhs: Self) -> Self::Output {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<S>(
Op::BinaryOp {
l_id: self.id(),
r_id: rhs.id(),
operator: BinaryOpType::$trait,
},
&self.strides,
&id,
);
Self {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}
}
};
}
graphtensor_binop!(Add, add);
graphtensor_binop!(Div, div);
graphtensor_binop!(Mul, mul);
graphtensor_binop!(Sub, sub);
impl<S: Shape, T: DType + Neg<Output = T>, D: Dev> Neg for GraphTensor<S, T, D> {
type Output = GraphTensor<S, T, D>;
fn neg(self) -> Self::Output {
let id = self.graph.write().unwrap().next_id();
self.graph.write().unwrap().add_op::<S>(
Op::UnaryOp {
v_id: self.id(),
operator: UnaryOpType::Neg,
},
&self.strides,
&id,
);
Self {
id,
graph: self.graph.clone(),
strides: self.strides.clone(),
_ghost: PhantomData,
}
}
}