use crate::nn::tensors::Tensor;
use crate::nn::tensors::WithGrad;
use crate::nn::TensorFloat;
#[cfg(feature = "alloc")]
use alloc::boxed::Box;
#[cfg(not(feature = "alloc"))]
use box_closure::{Align8, OpaqueFn};
#[must_use]
#[cfg(feature = "dyntensor")]
pub fn matmul<'a>(
a: &'a WithGrad<Tensor<TensorFloat>>,
b: &'a WithGrad<Tensor<TensorFloat>>,
) -> (
Tensor<TensorFloat>,
Box<dyn Fn(Tensor<TensorFloat>) -> (Tensor<TensorFloat>, Tensor<TensorFloat>) + 'a>,
) {
let a_val = a.get_value();
let b_val = b.get_value();
let out = a_val.matmul(b_val);
let back = move |grad: Tensor<TensorFloat>| {
let b_t = b_val.transpose();
let a_t = a_val.transpose();
let grad_wrt_a = grad.matmul(&b_t);
let grad_wrt_b = a_t.matmul(&grad);
(grad_wrt_a, grad_wrt_b)
};
(out, Box::new(back))
}
#[must_use]
#[cfg(all(feature = "alloc", not(feature = "dyntensor")))]
pub fn matmul<'a, const A: usize, const B: usize, const OUT: usize, const D: usize>(
a: &'a WithGrad<Tensor<TensorFloat, A, D>>,
b: &'a WithGrad<Tensor<TensorFloat, B, D>>,
) -> (
Tensor<TensorFloat, OUT, D>,
Box<
dyn Fn(
Tensor<TensorFloat, OUT, D>,
) -> (Tensor<TensorFloat, A, D>, Tensor<TensorFloat, B, D>)
+ 'a,
>,
) {
let a_val = a.get_value();
let b_val = b.get_value();
let out = a_val.matmul(b_val);
let back = move |grad: Tensor<TensorFloat, OUT, D>| {
let b_t = b_val.transpose();
let a_t = a_val.transpose();
let grad_wrt_a = grad.matmul(&b_t); let grad_wrt_b = a_t.matmul(&grad);
(grad_wrt_a, grad_wrt_b)
};
(out, Box::new(back))
}
#[must_use]
#[cfg(not(feature = "alloc"))]
pub fn matmul<'a, const A: usize, const B: usize, const OUT: usize, const D: usize>(
a: &'a WithGrad<Tensor<TensorFloat, A, D>>,
b: &'a WithGrad<Tensor<TensorFloat, B, D>>,
) -> (
Tensor<TensorFloat, OUT, D>,
OpaqueFn<
'a,
Tensor<TensorFloat, OUT, D>,
(Tensor<TensorFloat, A, D>, Tensor<TensorFloat, B, D>),
Align8<64>,
>,
) {
let a_val = a.get_value();
let b_val = b.get_value();
let out = a_val.matmul(b_val);
let back = move |grad: Tensor<TensorFloat, OUT, D>| {
let b_t = b_val.transpose();
let a_t = a_val.transpose();
let grad_wrt_a = grad.matmul(&b_t); let grad_wrt_b = a_t.matmul(&grad);
(grad_wrt_a, grad_wrt_b)
};
(out, OpaqueFn::new(back))
}