use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::conv_common::{validate_conv1d, validate_conv2d, validate_depthwise_conv2d};
use crate::ops::{ConvOps, PaddingMode};
#[cfg(target_arch = "x86_64")]
use crate::runtime::cpu::kernels::simd::conv as simd_conv;
use crate::runtime::cpu::{CpuClient, CpuRuntime, helpers::ensure_contiguous, kernels};
use crate::tensor::Tensor;
macro_rules! dispatch_float_dtype {
($dtype:expr, $T:ident => $body:block, $op:expr) => {
match $dtype {
DType::F32 => {
type $T = f32;
$body
}
DType::F64 => {
type $T = f64;
$body
}
#[cfg(feature = "f16")]
DType::F16 => {
type $T = half::f16;
$body
}
#[cfg(feature = "f16")]
DType::BF16 => {
type $T = half::bf16;
$body
}
#[cfg(feature = "fp8")]
DType::FP8E4M3 => {
type $T = crate::dtype::FP8E4M3;
$body
}
#[cfg(feature = "fp8")]
DType::FP8E5M2 => {
type $T = crate::dtype::FP8E5M2;
$body
}
_ => {
return Err(Error::UnsupportedDType {
dtype: $dtype,
op: $op,
})
}
}
};
}
macro_rules! dispatch_conv {
($dtype:expr, $conv_name:ident, $input_ptr:expr, $weight_ptr:expr, $bias_ptr:expr, $output_ptr:expr, $params:expr) => {
paste::paste! {
match $dtype {
#[cfg(target_arch = "x86_64")]
DType::F32 => unsafe {
simd_conv::[<$conv_name _f32>](
$input_ptr as *const f32,
$weight_ptr as *const f32,
$bias_ptr.map(|p| p as *const f32),
$output_ptr as *mut f32,
$params,
);
},
#[cfg(target_arch = "x86_64")]
DType::F64 => unsafe {
simd_conv::[<$conv_name _f64>](
$input_ptr as *const f64,
$weight_ptr as *const f64,
$bias_ptr.map(|p| p as *const f64),
$output_ptr as *mut f64,
$params,
);
},
_ => {
dispatch_float_dtype!($dtype, T => {
unsafe {
kernels::[<$conv_name _kernel>]::<T>(
$input_ptr as *const T,
$weight_ptr as *const T,
$bias_ptr.map(|p| p as *const T),
$output_ptr as *mut T,
$params,
);
}
}, stringify!($conv_name));
}
}
}
};
}
impl ConvOps<CpuRuntime> for CpuClient {
fn conv1d(
&self,
input: &Tensor<CpuRuntime>,
weight: &Tensor<CpuRuntime>,
bias: Option<&Tensor<CpuRuntime>>,
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
) -> Result<Tensor<CpuRuntime>> {
let dtype = input.dtype();
let params = validate_conv1d(
input.shape(),
weight.shape(),
bias.map(|b| b.shape()),
stride,
padding,
dilation,
groups,
dtype,
weight.dtype(),
bias.map(|b| b.dtype()),
)?;
if params.output_length == 0 || params.batch == 0 {
return Ok(Tensor::<CpuRuntime>::empty(
&[params.batch, params.c_out, params.output_length],
dtype,
&self.device,
));
}
let input = ensure_contiguous(input)?;
let weight = ensure_contiguous(weight)?;
let bias = bias.map(ensure_contiguous).transpose()?;
let output = Tensor::<CpuRuntime>::empty(
&[params.batch, params.c_out, params.output_length],
dtype,
&self.device,
);
let input_ptr = input.ptr();
let weight_ptr = weight.ptr();
let bias_ptr = bias.as_ref().map(|b| b.ptr());
let output_ptr = output.ptr();
dispatch_conv!(
dtype, conv1d, input_ptr, weight_ptr, bias_ptr, output_ptr, params
);
Ok(output)
}
fn conv2d(
&self,
input: &Tensor<CpuRuntime>,
weight: &Tensor<CpuRuntime>,
bias: Option<&Tensor<CpuRuntime>>,
stride: (usize, usize),
padding: PaddingMode,
dilation: (usize, usize),
groups: usize,
) -> Result<Tensor<CpuRuntime>> {
let dtype = input.dtype();
let params = validate_conv2d(
input.shape(),
weight.shape(),
bias.map(|b| b.shape()),
stride,
padding,
dilation,
groups,
dtype,
weight.dtype(),
bias.map(|b| b.dtype()),
)?;
if params.output_h == 0 || params.output_w == 0 || params.batch == 0 {
return Ok(Tensor::<CpuRuntime>::empty(
&[params.batch, params.c_out, params.output_h, params.output_w],
dtype,
&self.device,
));
}
let input = ensure_contiguous(input)?;
let weight = ensure_contiguous(weight)?;
let bias = bias.map(ensure_contiguous).transpose()?;
let output = Tensor::<CpuRuntime>::empty(
&[params.batch, params.c_out, params.output_h, params.output_w],
dtype,
&self.device,
);
let input_ptr = input.ptr();
let weight_ptr = weight.ptr();
let bias_ptr = bias.as_ref().map(|b| b.ptr());
let output_ptr = output.ptr();
dispatch_conv!(
dtype, conv2d, input_ptr, weight_ptr, bias_ptr, output_ptr, params
);
Ok(output)
}
fn depthwise_conv2d(
&self,
input: &Tensor<CpuRuntime>,
weight: &Tensor<CpuRuntime>,
bias: Option<&Tensor<CpuRuntime>>,
stride: (usize, usize),
padding: PaddingMode,
dilation: (usize, usize),
) -> Result<Tensor<CpuRuntime>> {
let dtype = input.dtype();
let params = validate_depthwise_conv2d(
input.shape(),
weight.shape(),
bias.map(|b| b.shape()),
stride,
padding,
dilation,
dtype,
weight.dtype(),
bias.map(|b| b.dtype()),
)?;
if params.output_h == 0 || params.output_w == 0 || params.batch == 0 {
return Ok(Tensor::<CpuRuntime>::empty(
&[params.batch, params.c_out, params.output_h, params.output_w],
dtype,
&self.device,
));
}
let input = ensure_contiguous(input)?;
let weight = ensure_contiguous(weight)?;
let bias = bias.map(ensure_contiguous).transpose()?;
let output = Tensor::<CpuRuntime>::empty(
&[params.batch, params.c_out, params.output_h, params.output_w],
dtype,
&self.device,
);
let input_ptr = input.ptr();
let weight_ptr = weight.ptr();
let bias_ptr = bias.as_ref().map(|b| b.ptr());
let output_ptr = output.ptr();
dispatch_conv!(
dtype,
depthwise_conv2d,
input_ptr,
weight_ptr,
bias_ptr,
output_ptr,
params
);
Ok(output)
}
fn conv_transpose1d(
&self,
input: &Tensor<CpuRuntime>,
weight: &Tensor<CpuRuntime>,
bias: Option<&Tensor<CpuRuntime>>,
stride: usize,
padding: PaddingMode,
output_padding: usize,
dilation: usize,
groups: usize,
) -> Result<Tensor<CpuRuntime>> {
conv_transpose1d_cpu(
&self.device,
input,
weight,
bias,
stride,
padding,
output_padding,
dilation,
groups,
)
}
}
fn conv_transpose1d_cpu(
device: &<CpuRuntime as crate::runtime::Runtime>::Device,
input: &Tensor<CpuRuntime>,
weight: &Tensor<CpuRuntime>,
bias: Option<&Tensor<CpuRuntime>>,
stride: usize,
padding: PaddingMode,
output_padding: usize,
dilation: usize,
groups: usize,
) -> Result<Tensor<CpuRuntime>> {
use crate::ops::conv_common;
let op = "conv_transpose1d";
let dtype = input.dtype();
conv_common::validate_3d_tensor(input.shape(), "input", op)?;
conv_common::validate_3d_tensor(weight.shape(), "weight", op)?;
conv_common::validate_float_dtype(dtype, op)?;
conv_common::validate_same_dtype(dtype, weight.dtype(), op)?;
conv_common::validate_positive(stride, "stride", op)?;
conv_common::validate_positive(dilation, "dilation", op)?;
conv_common::validate_positive(groups, "groups", op)?;
if stride > 0 && output_padding >= stride {
return Err(Error::InvalidArgument {
arg: "output_padding",
reason: format!("must be < stride ({stride}), got {output_padding}"),
});
}
let (b, c_in, l_in) = (input.shape()[0], input.shape()[1], input.shape()[2]);
let (w_c_in, c_out_per_group, kernel) =
(weight.shape()[0], weight.shape()[1], weight.shape()[2]);
if w_c_in != c_in {
return Err(Error::InvalidArgument {
arg: "weight",
reason: format!(
"expected weight shape [{c_in}, C_out/groups, K], got {:?}",
weight.shape()
),
});
}
if c_in % groups != 0 {
return Err(Error::InvalidArgument {
arg: "groups",
reason: format!("C_in ({c_in}) not divisible by groups ({groups})"),
});
}
let c_out = c_out_per_group * groups;
if let Some(b_tensor) = bias {
conv_common::validate_1d_tensor(b_tensor.shape(), "bias", op)?;
conv_common::validate_same_dtype(dtype, b_tensor.dtype(), op)?;
conv_common::validate_bias_length(b_tensor.shape()[0], c_out, op)?;
}
let (pad_left, pad_right) = match padding {
PaddingMode::Valid => (0usize, 0usize),
PaddingMode::Same => {
let desired = l_in * stride;
let unpadded = (l_in - 1) * stride + dilation * (kernel - 1) + output_padding + 1;
if unpadded < desired {
return Err(Error::InvalidArgument {
arg: "padding",
reason: "PaddingMode::Same cannot shrink output below target length".into(),
});
}
let total = unpadded - desired;
(total / 2, total - total / 2)
}
PaddingMode::Custom(left, right, _, _) => (left, right),
};
let unpadded_len = (l_in - 1) * stride + dilation * (kernel - 1) + output_padding + 1;
if pad_left + pad_right > unpadded_len {
return Err(Error::InvalidArgument {
arg: "padding",
reason: format!(
"total padding ({}) exceeds raw output length ({unpadded_len})",
pad_left + pad_right
),
});
}
let l_out = unpadded_len - pad_left - pad_right;
if b == 0 || l_out == 0 {
return Ok(Tensor::<CpuRuntime>::empty(
&[b, c_out, l_out],
dtype,
device,
));
}
let input = ensure_contiguous(input)?;
let weight = ensure_contiguous(weight)?;
let bias = bias.map(ensure_contiguous).transpose()?;
let output = Tensor::<CpuRuntime>::empty(&[b, c_out, l_out], dtype, device);
let input_ptr = input.ptr();
let weight_ptr = weight.ptr();
let bias_ptr = bias.as_ref().map(|b| b.ptr());
let output_ptr = output.ptr();
dispatch_float_dtype!(dtype, T => {
unsafe {
conv_transpose1d_scalar::<T>(
input_ptr as *const T,
weight_ptr as *const T,
bias_ptr.map(|p| p as *const T),
output_ptr as *mut T,
b, c_in, c_out, l_in, l_out, kernel,
stride, dilation, pad_left, groups,
);
}
}, op);
Ok(output)
}
#[allow(clippy::too_many_arguments)]
unsafe fn conv_transpose1d_scalar<T>(
input: *const T,
weight: *const T,
bias: Option<*const T>,
output: *mut T,
batch: usize,
c_in: usize,
c_out: usize,
l_in: usize,
l_out: usize,
kernel: usize,
stride: usize,
dilation: usize,
pad_left: usize,
groups: usize,
) where
T: Copy + Default + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let c_in_per_group = c_in / groups;
let c_out_per_group = c_out / groups;
unsafe {
for i in 0..(batch * c_out * l_out) {
*output.add(i) = T::default();
}
}
let pad_left_i = pad_left as isize;
for b in 0..batch {
for g in 0..groups {
for ci in 0..c_in_per_group {
let c_in_abs = g * c_in_per_group + ci;
let input_row_base = (b * c_in + c_in_abs) * l_in;
for l in 0..l_in {
let x = unsafe { *input.add(input_row_base + l) };
let base_out_pos = l as isize * stride as isize - pad_left_i;
for co in 0..c_out_per_group {
let c_out_abs = g * c_out_per_group + co;
let weight_row_base = (c_in_abs * c_out_per_group + co) * kernel;
let output_row_base = (b * c_out + c_out_abs) * l_out;
for k in 0..kernel {
let out_pos = base_out_pos + (k * dilation) as isize;
if out_pos < 0 || out_pos >= l_out as isize {
continue;
}
let w = unsafe { *weight.add(weight_row_base + k) };
let out_idx = output_row_base + out_pos as usize;
unsafe {
let cur = *output.add(out_idx);
*output.add(out_idx) = cur + w * x;
}
}
}
}
}
}
}
if let Some(bias_ptr) = bias {
for b in 0..batch {
for co in 0..c_out {
let bias_val = unsafe { *bias_ptr.add(co) };
let base = (b * c_out + co) * l_out;
for l in 0..l_out {
unsafe {
let cur = *output.add(base + l);
*output.add(base + l) = cur + bias_val;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::RandomOps;
use crate::runtime::Runtime;
use crate::runtime::cpu::CpuDevice;
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(device, client)
}
#[test]
fn test_conv1d_basic() {
let (device, client) = setup();
let input =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device);
let output = client
.conv1d(&input, &weight, None, 1, PaddingMode::Valid, 1, 1)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 3]);
let data: Vec<f32> = output.to_vec();
assert!((data[0] - 6.0).abs() < 1e-5); assert!((data[1] - 9.0).abs() < 1e-5); assert!((data[2] - 12.0).abs() < 1e-5); }
#[test]
fn test_conv1d_same_padding() {
let (device, client) = setup();
let input =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device);
let output = client
.conv1d(&input, &weight, None, 1, PaddingMode::Same, 1, 1)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 5]);
}
#[test]
fn test_conv1d_with_bias() {
let (device, client) = setup();
let input =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device);
let bias = Tensor::<CpuRuntime>::from_slice(&[10.0f32], &[1], &device);
let output = client
.conv1d(&input, &weight, Some(&bias), 1, PaddingMode::Valid, 1, 1)
.unwrap();
let data: Vec<f32> = output.to_vec();
assert!((data[0] - 16.0).abs() < 1e-5); assert!((data[1] - 19.0).abs() < 1e-5); assert!((data[2] - 22.0).abs() < 1e-5); }
#[test]
fn test_conv1d_stride() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
&[1, 1, 7],
&device,
);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device);
let output = client
.conv1d(&input, &weight, None, 2, PaddingMode::Valid, 1, 1)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 3]);
}
#[test]
fn test_conv2d_basic() {
let (device, client) = setup();
#[rustfmt::skip]
let input_data = [
1.0f32, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
];
let input = Tensor::<CpuRuntime>::from_slice(&input_data, &[1, 1, 3, 3], &device);
let weight =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[1, 1, 2, 2], &device);
let output = client
.conv2d(&input, &weight, None, (1, 1), PaddingMode::Valid, (1, 1), 1)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 2, 2]);
let data: Vec<f32> = output.to_vec();
assert!((data[0] - 12.0).abs() < 1e-5); assert!((data[1] - 16.0).abs() < 1e-5); assert!((data[2] - 24.0).abs() < 1e-5); assert!((data[3] - 28.0).abs() < 1e-5); }
#[test]
fn test_conv2d_same_padding() {
let (device, client) = setup();
#[rustfmt::skip]
let input_data = [
1.0f32, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
];
let input = Tensor::<CpuRuntime>::from_slice(&input_data, &[1, 1, 3, 3], &device);
let weight =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[1, 1, 2, 2], &device);
let output = client
.conv2d(&input, &weight, None, (1, 1), PaddingMode::Same, (1, 1), 1)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 3, 3]);
}
#[test]
fn test_conv2d_multi_channel() {
let (_device, client) = setup();
let input = client.randn(&[1, 2, 4, 4], DType::F32).unwrap();
let weight = client.randn(&[3, 2, 3, 3], DType::F32).unwrap();
let output = client
.conv2d(&input, &weight, None, (1, 1), PaddingMode::Valid, (1, 1), 1)
.unwrap();
assert_eq!(output.shape(), &[1, 3, 2, 2]);
}
#[test]
fn test_conv2d_grouped() {
let (_device, client) = setup();
let input = client.randn(&[1, 4, 4, 4], DType::F32).unwrap();
let weight = client.randn(&[6, 2, 3, 3], DType::F32).unwrap();
let output = client
.conv2d(
&input,
&weight,
None,
(1, 1),
PaddingMode::Valid,
(1, 1),
2, )
.unwrap();
assert_eq!(output.shape(), &[1, 6, 2, 2]);
}
#[test]
fn test_depthwise_conv2d() {
let (device, client) = setup();
#[rustfmt::skip]
let input_data = [
1.0f32, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
9.0, 8.0, 7.0,
6.0, 5.0, 4.0,
3.0, 2.0, 1.0,
];
let input = Tensor::<CpuRuntime>::from_slice(&input_data, &[1, 2, 3, 3], &device);
let weight = Tensor::<CpuRuntime>::from_slice(
&[
1.0f32, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, ],
&[2, 1, 2, 2],
&device,
);
let output = client
.depthwise_conv2d(&input, &weight, None, (1, 1), PaddingMode::Valid, (1, 1))
.unwrap();
assert_eq!(output.shape(), &[1, 2, 2, 2]);
let data: Vec<f32> = output.to_vec();
assert!((data[0] - 12.0).abs() < 1e-5);
assert!((data[1] - 16.0).abs() < 1e-5);
assert!((data[2] - 24.0).abs() < 1e-5);
assert!((data[3] - 28.0).abs() < 1e-5);
assert!((data[4] - 56.0).abs() < 1e-5);
assert!((data[5] - 48.0).abs() < 1e-5);
assert!((data[6] - 32.0).abs() < 1e-5);
assert!((data[7] - 24.0).abs() < 1e-5);
}
#[test]
fn test_depthwise_conv2d_same_padding() {
let (_device, client) = setup();
let input = client.randn(&[2, 32, 28, 28], DType::F32).unwrap();
let weight = client.randn(&[32, 1, 3, 3], DType::F32).unwrap();
let output = client
.depthwise_conv2d(&input, &weight, None, (1, 1), PaddingMode::Same, (1, 1))
.unwrap();
assert_eq!(output.shape(), &[2, 32, 28, 28]);
}
#[test]
fn test_conv2d_dilation() {
let (_device, client) = setup();
let input = client.randn(&[1, 1, 5, 5], DType::F32).unwrap();
let weight = client.randn(&[1, 1, 3, 3], DType::F32).unwrap();
let output = client
.conv2d(
&input,
&weight,
None,
(1, 1),
PaddingMode::Valid,
(2, 2), 1,
)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 1, 1]);
}
#[test]
fn test_conv1d_invalid_dimensions() {
let (_device, client) = setup();
let input = client.randn(&[3, 10], DType::F32).unwrap();
let weight = client.randn(&[16, 3, 3], DType::F32).unwrap();
let result = client.conv1d(&input, &weight, None, 1, PaddingMode::Valid, 1, 1);
assert!(result.is_err());
}
#[test]
fn test_conv2d_invalid_groups() {
let (_device, client) = setup();
let input = client.randn(&[1, 5, 8, 8], DType::F32).unwrap();
let weight = client.randn(&[10, 2, 3, 3], DType::F32).unwrap();
let result = client.conv2d(
&input,
&weight,
None,
(1, 1),
PaddingMode::Valid,
(1, 1),
2, );
assert!(result.is_err());
}
#[test]
fn test_conv2d_batch() {
let (_device, client) = setup();
let input = client.randn(&[4, 3, 8, 8], DType::F32).unwrap();
let weight = client.randn(&[16, 3, 3, 3], DType::F32).unwrap();
let output = client
.conv2d(&input, &weight, None, (1, 1), PaddingMode::Valid, (1, 1), 1)
.unwrap();
assert_eq!(output.shape(), &[4, 16, 6, 6]);
}
#[test]
fn test_conv2d_f64() {
let (_device, client) = setup();
let input = client.randn(&[1, 1, 4, 4], DType::F64).unwrap();
let weight = client.randn(&[1, 1, 2, 2], DType::F64).unwrap();
let output = client
.conv2d(&input, &weight, None, (1, 1), PaddingMode::Valid, (1, 1), 1)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 3, 3]);
assert_eq!(output.dtype(), DType::F64);
}
#[test]
fn conv_transpose1d_stride1_valid_roundtrip() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 1, 3], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device);
let out = client
.conv_transpose1d(&input, &weight, None, 1, PaddingMode::Valid, 0, 1, 1)
.unwrap();
assert_eq!(out.shape(), &[1, 1, 5]);
let d: Vec<f32> = out.to_vec();
assert_eq!(d, vec![1.0, 3.0, 6.0, 5.0, 3.0]);
}
#[test]
fn conv_transpose1d_stride2_upsamples() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[1, 1, 2], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device);
let out = client
.conv_transpose1d(&input, &weight, None, 2, PaddingMode::Valid, 0, 1, 1)
.unwrap();
assert_eq!(out.shape(), &[1, 1, 5]);
let d: Vec<f32> = out.to_vec();
assert_eq!(d, vec![1.0, 1.0, 3.0, 2.0, 2.0]);
}
#[test]
fn conv_transpose1d_bias_is_added_per_channel() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[1, 1, 2], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[1, 2, 1], &device);
let bias = Tensor::<CpuRuntime>::from_slice(&[10.0f32, 100.0], &[2], &device);
let out = client
.conv_transpose1d(&input, &weight, Some(&bias), 1, PaddingMode::Valid, 0, 1, 1)
.unwrap();
assert_eq!(out.shape(), &[1, 2, 2]);
let d: Vec<f32> = out.to_vec();
assert_eq!(d, vec![11.0, 11.0, 101.0, 101.0]);
}
#[test]
fn conv_transpose1d_custom_padding_crops_output() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 1, 3], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device);
let out = client
.conv_transpose1d(
&input,
&weight,
None,
1,
PaddingMode::Custom(1, 1, 0, 0),
0,
1,
1,
)
.unwrap();
assert_eq!(out.shape(), &[1, 1, 3]);
let d: Vec<f32> = out.to_vec();
assert_eq!(d, vec![3.0, 6.0, 5.0]);
}
#[test]
fn conv_transpose1d_multichannel_accumulates_c_in() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[1, 2, 2], &device);
let weight =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 2.0, 2.0], &[2, 1, 2], &device);
let out = client
.conv_transpose1d(&input, &weight, None, 1, PaddingMode::Valid, 0, 1, 1)
.unwrap();
assert_eq!(out.shape(), &[1, 1, 3]);
let d: Vec<f32> = out.to_vec();
assert_eq!(d, vec![1.0, 3.0, 2.0]);
}
#[test]
fn conv_transpose1d_f64_output_dtype() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0], &[1, 1, 2], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 1.0], &[1, 1, 2], &device);
let out = client
.conv_transpose1d(&input, &weight, None, 1, PaddingMode::Valid, 0, 1, 1)
.unwrap();
assert_eq!(out.dtype(), DType::F64);
assert_eq!(out.shape(), &[1, 1, 3]);
let d: Vec<f64> = out.to_vec();
assert_eq!(d, vec![1.0, 3.0, 2.0]);
}
#[test]
fn conv_transpose1d_rejects_bad_group_division() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(&[0.0f32; 3], &[1, 3, 1], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[0.0f32; 6], &[3, 2, 1], &device);
let res = client.conv_transpose1d(&input, &weight, None, 1, PaddingMode::Valid, 0, 1, 2);
assert!(res.is_err());
}
#[test]
fn conv_transpose1d_rejects_output_padding_ge_stride() {
let (device, client) = setup();
let input = Tensor::<CpuRuntime>::from_slice(&[0.0f32; 2], &[1, 1, 2], &device);
let weight = Tensor::<CpuRuntime>::from_slice(&[0.0f32; 2], &[1, 1, 2], &device);
let res = client.conv_transpose1d(&input, &weight, None, 2, PaddingMode::Valid, 2, 1, 1);
assert!(res.is_err());
}
}