use crate::autograd::Var;
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{BinaryOps, ConvOps, PaddingMode, ReduceOps, ScalarOps, TensorOps};
use crate::runtime::{Runtime, RuntimeClient};
use std::sync::Arc;
use super::conv_common::compute_padding;
pub fn var_conv1d<R, C>(
input: &Var<R>,
weight: &Var<R>,
bias: Option<&Var<R>>,
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
client: &C,
) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + ConvOps<R> + TensorOps<R> + ReduceOps<R> + BinaryOps<R> + ScalarOps<R>,
R::Client: ConvOps<R> + TensorOps<R> + ReduceOps<R> + BinaryOps<R> + ScalarOps<R>,
{
let output = client.conv1d(
input.tensor(),
weight.tensor(),
bias.map(|b| b.tensor()),
stride,
padding,
dilation,
groups,
)?;
let needs_grad =
input.requires_grad() || weight.requires_grad() || bias.is_some_and(|b| b.requires_grad());
if needs_grad {
let grad_fn = Conv1dBackward::<R>::new(
input.id(),
weight.id(),
bias.map(|b| b.id()),
input.tensor().clone(),
weight.tensor().clone(),
input.tensor().shape().to_vec(),
stride,
padding,
dilation,
groups,
input.grad_fn().cloned(),
weight.grad_fn().cloned(),
bias.and_then(|b| b.grad_fn().cloned()),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub struct Conv1dBackward<R: Runtime> {
input_ids: Vec<crate::tensor::TensorId>,
saved_input: crate::tensor::Tensor<R>,
saved_weight: crate::tensor::Tensor<R>,
input_shape: Vec<usize>,
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
input_grad_fn: Option<Arc<dyn crate::autograd::GradFn<R>>>,
weight_grad_fn: Option<Arc<dyn crate::autograd::GradFn<R>>>,
bias_grad_fn: Option<Arc<dyn crate::autograd::GradFn<R>>>,
}
impl<R: Runtime> Conv1dBackward<R> {
#[allow(clippy::too_many_arguments)]
pub fn new(
input_id: crate::tensor::TensorId,
weight_id: crate::tensor::TensorId,
bias_id: Option<crate::tensor::TensorId>,
input: crate::tensor::Tensor<R>,
weight: crate::tensor::Tensor<R>,
input_shape: Vec<usize>,
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
input_grad_fn: Option<Arc<dyn crate::autograd::GradFn<R>>>,
weight_grad_fn: Option<Arc<dyn crate::autograd::GradFn<R>>>,
bias_grad_fn: Option<Arc<dyn crate::autograd::GradFn<R>>>,
) -> Self {
let mut ids = vec![input_id, weight_id];
if let Some(bid) = bias_id {
ids.push(bid);
}
Self {
input_ids: ids,
saved_input: input,
saved_weight: weight,
input_shape,
stride,
padding,
dilation,
groups,
input_grad_fn,
weight_grad_fn,
bias_grad_fn,
}
}
}
fn conv1d_input_backward<R, C>(
client: &C,
grad_output: &crate::tensor::Tensor<R>,
weight: &crate::tensor::Tensor<R>,
input_shape: &[usize],
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
) -> Result<crate::tensor::Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + BinaryOps<R> + ReduceOps<R> + ScalarOps<R>,
{
let batch = input_shape[0];
let _c_in = input_shape[1];
let input_len = input_shape[2];
let c_out = weight.shape()[0];
let c_in_per_group = weight.shape()[1];
let kernel_size = weight.shape()[2];
let output_len = grad_output.shape()[2];
let c_out_per_group = c_out / groups;
let (pad_left, _pad_right) = compute_padding(padding, kernel_size, dilation);
let device = grad_output.device();
let dtype = grad_output.dtype();
let mut d_input = crate::tensor::Tensor::<R>::zeros(input_shape, dtype, device);
for k in 0..kernel_size {
let weight_k = weight.narrow(2, k, 1)?;
let weight_k = weight_k.squeeze(Some(2));
for o in 0..output_len {
let i_pos = o * stride + k * dilation;
if i_pos >= pad_left && i_pos < pad_left + input_len {
let i = i_pos - pad_left;
let grad_o = grad_output.narrow(2, o, 1)?;
let grad_o = grad_o.squeeze(Some(2));
for g in 0..groups {
let c_in_start = g * c_in_per_group;
let c_out_start = g * c_out_per_group;
let grad_g = grad_o.narrow(1, c_out_start, c_out_per_group)?;
let weight_g = weight_k.narrow(0, c_out_start, c_out_per_group)?;
let contrib_g = client.matmul(&grad_g, &weight_g.transpose(0, 1)?)?;
let contrib_g_3d = contrib_g.reshape(&[batch, c_in_per_group, 1])?;
let mut d_input_at_i = d_input.narrow(2, i, 1)?;
let d_input_group = d_input_at_i.narrow(1, c_in_start, c_in_per_group)?;
let updated_group = client.add(&d_input_group, &contrib_g_3d)?;
d_input_at_i =
client.slice_assign(&d_input_at_i, &updated_group, 1, c_in_start)?;
d_input = client.slice_assign(&d_input, &d_input_at_i, 2, i)?;
}
}
}
}
Ok(d_input)
}
fn conv1d_weight_backward<R, C>(
client: &C,
grad_output: &crate::tensor::Tensor<R>,
input: &crate::tensor::Tensor<R>,
weight_shape: &[usize],
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
) -> Result<crate::tensor::Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + BinaryOps<R> + ReduceOps<R> + ScalarOps<R>,
{
let _batch = input.shape()[0];
let _c_in = input.shape()[1];
let input_len = input.shape()[2];
let c_out = weight_shape[0];
let c_in_per_group = weight_shape[1];
let kernel_size = weight_shape[2];
let output_len = grad_output.shape()[2];
let c_out_per_group = c_out / groups;
let (pad_left, _pad_right) = compute_padding(padding, kernel_size, dilation);
let device = grad_output.device();
let dtype = grad_output.dtype();
let mut d_weight = crate::tensor::Tensor::<R>::zeros(weight_shape, dtype, device);
for o in 0..output_len {
for k in 0..kernel_size {
let i_pos = o * stride + k * dilation;
if i_pos >= pad_left && i_pos < pad_left + input_len {
let i = i_pos - pad_left;
let input_i = input.narrow(2, i, 1)?;
let input_i = input_i.squeeze(Some(2));
let grad_o = grad_output.narrow(2, o, 1)?;
let grad_o = grad_o.squeeze(Some(2));
for g in 0..groups {
let c_in_start = g * c_in_per_group;
let c_out_start = g * c_out_per_group;
let input_g = input_i.narrow(1, c_in_start, c_in_per_group)?;
let grad_g = grad_o.narrow(1, c_out_start, c_out_per_group)?;
let contrib_2d = client.matmul(&grad_g.transpose(0, 1)?, &input_g)?;
let contrib_3d = contrib_2d.reshape(&[c_out_per_group, c_in_per_group, 1])?;
let mut d_weight_at_k = d_weight.narrow(2, k, 1)?;
let d_weight_group = d_weight_at_k.narrow(0, c_out_start, c_out_per_group)?;
let updated_group = client.add(&d_weight_group, &contrib_3d)?;
d_weight_at_k =
client.slice_assign(&d_weight_at_k, &updated_group, 0, c_out_start)?;
d_weight = client.slice_assign(&d_weight, &d_weight_at_k, 2, k)?;
}
}
}
}
Ok(d_weight)
}
impl<R: Runtime<DType = DType>> crate::autograd::GradFn<R> for Conv1dBackward<R>
where
R::Client: ConvOps<R> + TensorOps<R> + ReduceOps<R> + BinaryOps<R> + ScalarOps<R>,
{
fn backward(
&self,
grad_output: &crate::tensor::Tensor<R>,
) -> Result<Vec<Option<crate::tensor::Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let d_input = conv1d_input_backward::<R, _>(
&client,
grad_output,
&self.saved_weight,
&self.input_shape,
self.stride,
self.padding,
self.dilation,
self.groups,
)?;
let d_weight = conv1d_weight_backward::<R, _>(
&client,
grad_output,
&self.saved_input,
self.saved_weight.shape(),
self.stride,
self.padding,
self.dilation,
self.groups,
)?;
let d_bias = if self.input_ids.len() > 2 {
let summed = client.sum(grad_output, &[0, 2], false)?;
Some(summed)
} else {
None
};
Ok(vec![Some(d_input), Some(d_weight), d_bias])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R>
+ ConvOps<R>
+ TensorOps<R>
+ ReduceOps<R>
+ BinaryOps<R>
+ ScalarOps<R>,
{
let grads = self.backward(grad_output.tensor())?;
Ok(grads
.into_iter()
.map(|g| g.map(|t| Var::new(t, true)))
.collect())
}
fn inputs(&self) -> &[crate::tensor::TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn crate::autograd::GradFn<R>>>> {
let mut fns = vec![self.input_grad_fn.clone(), self.weight_grad_fn.clone()];
if self.input_ids.len() > 2 {
fns.push(self.bias_grad_fn.clone());
}
fns
}
fn saved_tensors(&self) -> &[crate::tensor::Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"Conv1dBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::backward;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
use crate::tensor::Tensor;
#[test]
fn test_var_conv1d_forward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 1, 3], &device),
false,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1, 1, 1], &device),
false,
);
let output =
var_conv1d(&input, &weight, None, 1, PaddingMode::Valid, 1, 1, &client).unwrap();
let data: Vec<f32> = output.tensor().to_vec();
assert_eq!(data, vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_var_conv1d_backward_input() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 1, 3], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1, 1, 1], &device),
true,
);
let output =
var_conv1d(&input, &weight, None, 1, PaddingMode::Valid, 1, 1, &client).unwrap();
let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let d_input: Vec<f32> = grads.get(input.id()).unwrap().to_vec();
assert_eq!(d_input, vec![2.0, 2.0, 2.0]);
let d_weight: Vec<f32> = grads.get(weight.id()).unwrap().to_vec();
assert!((d_weight[0] - 6.0).abs() < 1e-5);
}
#[test]
fn test_var_conv1d_backward_with_bias() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[1, 1, 2], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1, 1, 1], &device),
true,
);
let bias = Var::new(
Tensor::<CpuRuntime>::from_slice(&[10.0f32], &[1], &device),
true,
);
let output = var_conv1d(
&input,
&weight,
Some(&bias),
1,
PaddingMode::Valid,
1,
1,
&client,
)
.unwrap();
let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let d_bias: Vec<f32> = grads.get(bias.id()).unwrap().to_vec();
assert!((d_bias[0] - 2.0).abs() < 1e-5);
}
#[test]
fn test_var_conv1d_kernel3() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device),
true,
);
let output =
var_conv1d(&input, &weight, None, 1, PaddingMode::Valid, 1, 1, &client).unwrap();
let data: Vec<f32> = output.tensor().to_vec();
assert_eq!(data, vec![6.0, 9.0, 12.0]);
let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let d_input: Vec<f32> = grads.get(input.id()).unwrap().to_vec();
assert_eq!(d_input, vec![1.0, 2.0, 3.0, 2.0, 1.0]);
}
}