#[cfg(test)]
use super::{assert_almost_equals, new_backward_input, new_input, new_tensor};
use super::{
expect_tensor, expect_tensor_mut, push_mat_mat_gradient, Backward, Data, DotDim, Forward,
Gradient, Overwrite, Tensor,
};
use ndarray::{linalg::general_mat_mul, Ix2};
use std::{
cell::{Cell, Ref, RefCell, RefMut},
fmt::{Debug, Display},
rc::Rc,
};
pub struct MatrixMatrixMulT<Lhs, Rhs>
where
Lhs: Data<Dim = Ix2>,
Rhs: Data<Dim = Ix2>,
{
left: Rc<Lhs>,
right: Rc<Rhs>,
data: RefCell<Tensor<Ix2>>,
computed: Cell<bool>,
}
impl<Lhs, Rhs> MatrixMatrixMulT<Lhs, Rhs>
where
Lhs: Data<Dim = Ix2>,
Rhs: Data<Dim = Ix2>,
{
pub fn new(left: Rc<Lhs>, right: Rc<Rhs>) -> Self {
let shape = DotDim::shape(left.data().raw_dim(), right.data().t().raw_dim());
let data = RefCell::new(Tensor::zeros((shape[0], shape[1])));
Self {
left,
right,
data,
computed: Cell::new(false),
}
}
}
impl<Lhs, Rhs> Data for MatrixMatrixMulT<Lhs, Rhs>
where
Lhs: Data<Dim = Ix2>,
Rhs: Data<Dim = Ix2>,
{
type Dim = Ix2;
fn data(&self) -> Ref<Tensor<Self::Dim>> {
self.data.borrow()
}
fn data_mut(&self) -> RefMut<Tensor<Self::Dim>> {
self.data.borrow_mut()
}
}
impl<Lhs, Rhs> Forward for MatrixMatrixMulT<Lhs, Rhs>
where
Lhs: Data<Dim = Ix2>,
Rhs: Data<Dim = Ix2>,
{
fn forward(&self) {
if self.was_computed() {
return;
}
self.computed.set(true);
general_mat_mul(
1.0,
&*self.left.data(),
&self.right.data().t(),
0.0,
&mut *self.data.borrow_mut(),
);
}
fn was_computed(&self) -> bool {
self.computed.get()
}
fn reset_computation(&self) {
self.computed.set(false);
}
}
impl<Lhs, Rhs> Debug for MatrixMatrixMulT<Lhs, Rhs>
where
Lhs: Data<Dim = Ix2>,
Rhs: Data<Dim = Ix2>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MatrixMatrixMulT")
.field("data", &self.data.borrow())
.field("computed", &self.computed.get())
.finish()
}
}
impl<Lhs, Rhs> Display for MatrixMatrixMulT<Lhs, Rhs>
where
Lhs: Data<Dim = Ix2>,
Rhs: Data<Dim = Ix2>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "{}", &self.data.borrow())
}
}
pub struct MatrixMatrixMulTBackward<LhsD, LhsG, RhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
gradient: RefCell<Option<Tensor<Ix2>>>,
shape: Ix2,
overwrite: Cell<bool>,
left_data: Rc<LhsD>,
left_grad: Rc<LhsG>,
right_data: Rc<RhsD>,
right_grad: Rc<RhsG>,
}
impl<LhsD, LhsG, RhsD, RhsG> MatrixMatrixMulTBackward<LhsD, LhsG, RhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
pub fn new(
left_data: Rc<LhsD>,
left_grad: Rc<LhsG>,
right_data: Rc<RhsD>,
right_grad: Rc<RhsG>,
) -> Self {
let shape = DotDim::shape(
left_grad.gradient().raw_dim(),
right_grad.gradient().t().raw_dim(),
);
Self {
gradient: RefCell::new(Some(Tensor::zeros(shape))),
shape,
overwrite: Cell::new(true),
left_data,
left_grad,
right_data,
right_grad,
}
}
}
impl<LhsD, LhsG, RhsD, RhsG> Gradient for MatrixMatrixMulTBackward<LhsD, LhsG, RhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
type Dim = Ix2;
fn gradient(&self) -> Ref<Tensor<Self::Dim>> {
expect_tensor(&self.gradient)
}
fn gradient_mut(&self) -> RefMut<Tensor<Self::Dim>> {
expect_tensor_mut(&self.gradient)
}
}
impl<LhsD, LhsG, RhsD, RhsG> Overwrite for MatrixMatrixMulTBackward<LhsD, LhsG, RhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn can_overwrite(&self) -> bool {
self.overwrite.get()
}
fn set_overwrite(&self, state: bool) {
self.overwrite.set(state);
}
}
impl<LhsD, LhsG, RhsD, RhsG> Backward for MatrixMatrixMulTBackward<LhsD, LhsG, RhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn backward(&self) {
let gradient = self.gradient();
push_mat_mat_gradient(&*self.left_grad, &gradient, &self.right_data.data());
push_mat_mat_gradient(&*self.right_grad, &gradient.t(), &self.left_data.data());
}
fn no_grad(&self) {
*self.gradient.borrow_mut() = None;
}
fn with_grad(&self) {
*self.gradient.borrow_mut() = Some(Tensor::zeros(self.shape));
}
}
impl<LhsD, LhsG, RhsD, RhsG> Debug for MatrixMatrixMulTBackward<LhsD, LhsG, RhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MatrixMatrixMulTBackward")
.field("gradient", &self.gradient.borrow())
.field("overwrite", &self.overwrite.get())
.finish()
}
}
impl<LhsD, LhsG, RhsD, RhsG> Display for MatrixMatrixMulTBackward<LhsD, LhsG, RhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match &*self.gradient.borrow() {
Some(gradient) => write!(f, "{}", &gradient),
None => write!(f, "None"),
}
}
}
pub struct MatrixMatrixMulTBackwardLeft<LhsG, RhsD>
where
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
{
gradient: RefCell<Option<Tensor<Ix2>>>,
shape: Ix2,
overwrite: Cell<bool>,
left_grad: Rc<LhsG>,
right_data: Rc<RhsD>,
}
impl<LhsG, RhsD> MatrixMatrixMulTBackwardLeft<LhsG, RhsD>
where
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
{
pub fn new(left_grad: Rc<LhsG>, right_data: Rc<RhsD>) -> Self {
let shape = DotDim::shape(
left_grad.gradient().raw_dim(),
right_data.data().t().raw_dim(),
);
Self {
gradient: RefCell::new(Some(Tensor::zeros(shape))),
shape,
overwrite: Cell::new(true),
left_grad,
right_data,
}
}
}
impl<LhsG, RhsD> Gradient for MatrixMatrixMulTBackwardLeft<LhsG, RhsD>
where
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
{
type Dim = Ix2;
fn gradient(&self) -> Ref<Tensor<Self::Dim>> {
expect_tensor(&self.gradient)
}
fn gradient_mut(&self) -> RefMut<Tensor<Self::Dim>> {
expect_tensor_mut(&self.gradient)
}
}
impl<LhsG, RhsD> Overwrite for MatrixMatrixMulTBackwardLeft<LhsG, RhsD>
where
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn can_overwrite(&self) -> bool {
self.overwrite.get()
}
fn set_overwrite(&self, state: bool) {
self.overwrite.set(state);
}
}
impl<LhsG, RhsD> Backward for MatrixMatrixMulTBackwardLeft<LhsG, RhsD>
where
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn backward(&self) {
push_mat_mat_gradient(&*self.left_grad, &self.gradient(), &self.right_data.data());
}
fn no_grad(&self) {
*self.gradient.borrow_mut() = None;
}
fn with_grad(&self) {
*self.gradient.borrow_mut() = Some(Tensor::zeros(self.shape));
}
}
impl<LhsG, RhsD> Debug for MatrixMatrixMulTBackwardLeft<LhsG, RhsD>
where
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MatrixMatrixMulTBackwardLeft")
.field("gradient", &self.gradient.borrow())
.field("overwrite", &self.overwrite.get())
.finish()
}
}
impl<LhsG, RhsD> Display for MatrixMatrixMulTBackwardLeft<LhsG, RhsD>
where
RhsD: Data<Dim = Ix2>,
LhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match &*self.gradient.borrow() {
Some(gradient) => write!(f, "{}", &gradient),
None => write!(f, "None"),
}
}
}
pub struct MatrixMatrixMulTBackwardRight<LhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
gradient: RefCell<Option<Tensor<Ix2>>>,
shape: Ix2,
overwrite: Cell<bool>,
left_data: Rc<LhsD>,
right_grad: Rc<RhsG>,
}
impl<LhsD, RhsG> MatrixMatrixMulTBackwardRight<LhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
pub fn new(left_data: Rc<LhsD>, right_grad: Rc<RhsG>) -> Self {
let shape = DotDim::shape(
left_data.data().raw_dim(),
right_grad.gradient().t().raw_dim(),
);
Self {
gradient: RefCell::new(Some(Tensor::zeros(shape))),
shape,
overwrite: Cell::new(true),
left_data,
right_grad,
}
}
}
impl<LhsD, RhsG> Gradient for MatrixMatrixMulTBackwardRight<LhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
type Dim = Ix2;
fn gradient(&self) -> Ref<Tensor<Self::Dim>> {
expect_tensor(&self.gradient)
}
fn gradient_mut(&self) -> RefMut<Tensor<Self::Dim>> {
expect_tensor_mut(&self.gradient)
}
}
impl<LhsD, RhsG> Overwrite for MatrixMatrixMulTBackwardRight<LhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn can_overwrite(&self) -> bool {
self.overwrite.get()
}
fn set_overwrite(&self, state: bool) {
self.overwrite.set(state);
}
}
impl<LhsD, RhsG> Backward for MatrixMatrixMulTBackwardRight<LhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn backward(&self) {
push_mat_mat_gradient(
&*self.right_grad,
&self.gradient().t(),
&self.left_data.data(),
);
}
fn no_grad(&self) {
*self.gradient.borrow_mut() = None;
}
fn with_grad(&self) {
*self.gradient.borrow_mut() = Some(Tensor::zeros(self.shape));
}
}
impl<LhsD, RhsG> Debug for MatrixMatrixMulTBackwardRight<LhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MatrixMatrixMulTBackwardRight")
.field("gradient", &self.gradient.borrow())
.field("overwrite", &self.overwrite.get())
.finish()
}
}
impl<LhsD, RhsG> Display for MatrixMatrixMulTBackwardRight<LhsD, RhsG>
where
LhsD: Data<Dim = Ix2>,
RhsG: Gradient<Dim = Ix2> + Overwrite,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match &*self.gradient.borrow() {
Some(gradient) => write!(f, "{}", &gradient),
None => write!(f, "None"),
}
}
}
#[cfg(test)]
mod test;