use alloc::vec;
use crate::tensor::FloatTensor;
use crate::{Backend, TensorMetadata};
use burn_std::Shape;
pub(crate) fn linear<B: Backend>(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
) -> FloatTensor<B> {
let x_ndims = x.shape().num_dims();
let weight = unsqueeze_leading::<B>(weight, x_ndims);
let output = B::float_matmul(x, weight);
match bias {
Some(bias) => {
let bias = unsqueeze_leading::<B>(bias, x_ndims);
B::float_add(output, bias)
}
None => output,
}
}
fn unsqueeze_leading<B: Backend>(tensor: FloatTensor<B>, target_ndims: usize) -> FloatTensor<B> {
let shape = tensor.shape();
let ndims = shape.num_dims();
if ndims >= target_ndims {
return tensor;
}
let mut new_dims = vec![1usize; target_ndims - ndims];
for i in 0..ndims {
new_dims.push(shape[i]);
}
B::float_reshape(tensor, Shape::from(new_dims))
}
pub(crate) fn linear_x_backward<B: Backend>(
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
) -> FloatTensor<B> {
let weight = B::float_swap_dims(weight, 0, 1);
let grad_ndims = output_grad.shape().num_dims();
let weight = unsqueeze_leading::<B>(weight, grad_ndims);
B::float_matmul(output_grad, weight)
}
pub(crate) fn linear_weight_backward<B: Backend>(
x: FloatTensor<B>,
output_grad: FloatTensor<B>,
) -> FloatTensor<B> {
let ndims = x.shape().num_dims();
let x = B::float_swap_dims(x, ndims - 2, ndims - 1);
let mut grad = B::float_matmul(x, output_grad);
let ndims = grad.shape().num_dims();
if ndims > 2 {
for dim in 0..ndims - 2 {
grad = B::float_sum_dim(grad, dim);
}
let shape = grad.shape();
let d0 = shape[ndims - 2];
let d1 = shape[ndims - 1];
B::float_reshape(grad, Shape::new([d0, d1]))
} else {
grad
}
}
pub(crate) fn linear_bias_backward<B: Backend>(output_grad: FloatTensor<B>) -> FloatTensor<B> {
let ndims = output_grad.shape().num_dims();
let mut grad = output_grad;
for dim in 0..ndims - 1 {
grad = B::float_sum_dim(grad, dim);
}
let shape = grad.shape();
let d_output = shape[ndims - 1];
B::float_reshape(grad, Shape::new([d_output]))
}