use crate::manual::tensors::Tensor;
use crate::manual::tensors::WithGrad;
#[cfg(feature = "alloc")]
use alloc::boxed::Box;
#[cfg(not(feature = "alloc"))]
use box_closure::{Align32, OpaqueFn};
#[cfg(feature = "dyntensor")]
use tensor_optim::TensorOps;
#[must_use]
#[cfg(feature = "dyntensor")]
pub fn matmul<
'a,
T: Clone
+ Default
+ Copy
+ core::ops::Add<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::AddAssign
+ 'a,
>(
a: &WithGrad<Tensor<T>>,
b: &WithGrad<Tensor<T>>,
) -> (
Tensor<T>,
Box<dyn Fn(Tensor<T>) -> (Tensor<T>, Tensor<T>) + 'a>,
) {
assert_eq!(a.get_value().shape().len(), 2, "`A` must be 2D for matmul");
assert_eq!(b.get_value().shape().len(), 2, "`B` must be 2D for matmul");
assert_eq!(
a.get_value().shape()[1],
b.get_value().shape()[0],
"inner dimensions must match for matmul",
);
let a_val = a.get_value().clone();
let b_val = b.get_value().clone();
let out = a_val.matmul(&b_val);
let back = move |grad: Tensor<T>| {
let b_t = b_val.transpose();
let a_t = a_val.transpose();
let grad_wrt_a = grad.matmul(&b_t);
let grad_wrt_b = a_t.matmul(&grad);
(grad_wrt_a, grad_wrt_b)
};
(out, Box::new(back))
}
#[must_use]
#[cfg(all(feature = "alloc", not(feature = "dyntensor")))]
pub fn matmul<'a, T, const A: usize, const B: usize, const OUT: usize>(
a: &'a WithGrad<Tensor<T, A, 2>>,
b: &'a WithGrad<Tensor<T, B, 2>>,
) -> (
Tensor<T, OUT, 2>,
Box<dyn Fn(Tensor<T, OUT, 2>) -> (Tensor<T, A, 2>, Tensor<T, B, 2>)>,
)
where
T: Copy
+ core::ops::Add<Output = T>
+ core::ops::Mul<Output = T>
+ Default
+ core::ops::AddAssign
+ 'static,
{
let a_val = a.get_value().clone();
let b_val = b.get_value().clone();
let out = a_val.matmul(&b_val);
let back = move |grad: Tensor<T, OUT, 2>| {
let b_t = b_val.transpose();
let a_t = a_val.transpose();
let grad_wrt_a = grad.matmul(&b_t); let grad_wrt_b = a_t.matmul(&grad);
(grad_wrt_a, grad_wrt_b)
};
(out, Box::new(back))
}
#[must_use]
#[cfg(not(feature = "alloc"))]
pub fn matmul<'a, T, const A: usize, const B: usize, const OUT: usize>(
a: &'a WithGrad<Tensor<T, A, 2>>,
b: &'a WithGrad<Tensor<T, B, 2>>,
) -> (
Tensor<T, OUT, 2>,
OpaqueFn<'a, Tensor<T, OUT, 2>, (Tensor<T, A, 2>, Tensor<T, B, 2>), Align32<256>>,
)
where
T: Copy
+ core::ops::Add<Output = T>
+ core::ops::Mul<Output = T>
+ Default
+ core::ops::AddAssign,
{
let a_val = a.get_value().clone();
let b_val = b.get_value().clone();
let out = a_val.matmul(&b_val);
let back = move |grad: Tensor<T, OUT, 2>| {
let b_t = b_val.transpose();
let a_t = a_val.transpose();
let grad_wrt_a = grad.matmul(&b_t); let grad_wrt_b = a_t.matmul(&grad);
(grad_wrt_a, grad_wrt_b)
};
(out, OpaqueFn::new(back))
}