use crate::{client::FusionClient, stream::execution::Operation, Fusion, FusionBackend};
use burn_tensor::{
ops::{
conv::{
calculate_conv_output_size, calculate_conv_transpose_output_size,
calculate_pool_output_size,
},
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions,
MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices,
ModuleOps,
},
repr::*,
Element,
};
use std::marker::PhantomData;
macro_rules! make_ops {
($name:ident, $desc:ty, $fn:expr) => {
#[derive(new)]
struct $name<B: FusionBackend> {
desc: $desc,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
#[allow(clippy::redundant_closure_call)]
$fn(self.desc, handles)
}
}
};
}
impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
fn conv1d(
x: FloatTensor<Self, 3>,
weight: FloatTensor<Self, 3>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<1>,
) -> FloatTensor<Self, 3> {
make_ops!(
Conv1dOps,
Conv1dDescription,
|desc: Conv1dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&desc.x);
let weight = handles.get_float_tensor::<B, 3>(&desc.weight);
let bias = desc
.bias
.as_ref()
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
let output = B::conv1d(x, weight, bias, desc.options.into());
handles.register_float_tensor::<B, 3>(&desc.out.id, output);
}
);
let size = calculate_conv_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.dilation[0],
x.shape[2],
);
let stream_1 = x.stream;
let stream_2 = weight.stream;
let stream_3 = bias.as_ref().map(|b| b.stream);
let shape = vec![x.shape[0], weight.shape[0], size];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let description = Conv1dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
let streams = match stream_3 {
Some(stream_3) => vec![stream_1, stream_2, stream_3],
None => vec![stream_1, stream_2],
};
out.client.clone().register(
streams,
OperationDescription::Module(ModuleOperationDescription::Conv1d(description.clone())),
Conv1dOps::<B>::new(description),
);
out
}
fn conv2d(
x: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<2>,
) -> FloatTensor<Self, 4> {
make_ops!(
Conv2dOps,
Conv2dDescription,
|args: Conv2dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let weight = handles.get_float_tensor::<B, 4>(&args.weight);
let bias = args
.bias
.as_ref()
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
let output = B::conv2d(x, weight, bias, args.options.clone().into());
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let size_0 = calculate_conv_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.dilation[1],
x.shape[3],
);
let stream_1 = x.stream;
let stream_2 = weight.stream;
let stream_3 = bias.as_ref().map(|b| b.stream);
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = Conv2dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
let streams = match stream_3 {
Some(stream_3) => vec![stream_1, stream_2, stream_3],
None => vec![stream_1, stream_2],
};
out.client.register(
streams,
OperationDescription::Module(ModuleOperationDescription::Conv2d(desc.clone())),
Conv2dOps::<B>::new(desc),
);
out
}
fn conv3d(
x: FloatTensor<Self, 5>,
weight: FloatTensor<Self, 5>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<3>,
) -> FloatTensor<Self, 5> {
make_ops!(
Conv3dOps,
Conv3dDescription,
|args: Conv3dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 5>(&args.x);
let weight = handles.get_float_tensor::<B, 5>(&args.weight);
let bias = args
.bias
.as_ref()
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
let output = B::conv3d(x, weight, bias, args.options.clone().into());
handles.register_float_tensor::<B, 5>(&args.out.id, output);
}
);
let size_0 = calculate_conv_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.dilation[1],
x.shape[3],
);
let size_2 = calculate_conv_output_size(
weight.shape[4],
options.stride[2],
options.padding[2],
options.dilation[2],
x.shape[4],
);
let stream_1 = x.stream;
let stream_2 = weight.stream;
let stream_3 = bias.as_ref().map(|b| b.stream);
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1, size_2];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = Conv3dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
let streams = match stream_3 {
Some(stream_3) => vec![stream_1, stream_2, stream_3],
None => vec![stream_1, stream_2],
};
out.client.register(
streams,
OperationDescription::Module(ModuleOperationDescription::Conv3d(desc.clone())),
Conv3dOps::<B>::new(desc),
);
out
}
fn conv_transpose1d(
x: FloatTensor<Self, 3>,
weight: FloatTensor<Self, 3>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self, 3> {
make_ops!(
ConvTranspose1dOps,
ConvTranspose1dDescription,
|args: ConvTranspose1dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&args.x);
let weight = handles.get_float_tensor::<B, 3>(&args.weight);
let bias = args
.bias
.as_ref()
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into());
handles.register_float_tensor::<B, 3>(&args.out.id, output);
}
);
let size = calculate_conv_transpose_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.padding_out[0],
options.dilation[0],
x.shape[2],
);
let stream_1 = x.stream;
let stream_2 = weight.stream;
let stream_3 = bias.as_ref().map(|b| b.stream);
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = ConvTranspose1dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
let streams = match stream_3 {
Some(stream_3) => vec![stream_1, stream_2, stream_3],
None => vec![stream_1, stream_2],
};
out.client.register(
streams,
OperationDescription::Module(ModuleOperationDescription::ConvTranspose1d(desc.clone())),
ConvTranspose1dOps::<B>::new(desc),
);
out
}
fn conv_transpose2d(
x: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<Self, 4> {
make_ops!(
ConvTranspose2dOps,
ConvTranspose2dDescription,
|args: ConvTranspose2dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let weight = handles.get_float_tensor::<B, 4>(&args.weight);
let bias = args
.bias
.as_ref()
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into());
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let size_0 = calculate_conv_transpose_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.padding_out[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_transpose_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.padding_out[1],
options.dilation[1],
x.shape[3],
);
let stream_1 = x.stream;
let stream_2 = weight.stream;
let stream_3 = bias.as_ref().map(|b| b.stream);
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = ConvTranspose2dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
let streams = match stream_3 {
Some(stream_3) => vec![stream_1, stream_2, stream_3],
None => vec![stream_1, stream_2],
};
out.client.register(
streams,
OperationDescription::Module(ModuleOperationDescription::ConvTranspose2d(desc.clone())),
ConvTranspose2dOps::<B>::new(desc),
);
out
}
fn conv_transpose3d(
x: FloatTensor<Self, 5>,
weight: FloatTensor<Self, 5>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<Self, 5> {
make_ops!(
ConvTranspose3dOps,
ConvTranspose3dDescription,
|args: ConvTranspose3dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 5>(&args.x);
let weight = handles.get_float_tensor::<B, 5>(&args.weight);
let bias = args
.bias
.as_ref()
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into());
handles.register_float_tensor::<B, 5>(&args.out.id, output);
}
);
let size_0 = calculate_conv_transpose_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.padding_out[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_transpose_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.padding_out[1],
options.dilation[1],
x.shape[3],
);
let size_2 = calculate_conv_transpose_output_size(
weight.shape[4],
options.stride[2],
options.padding[2],
options.padding_out[2],
options.dilation[2],
x.shape[4],
);
let stream_1 = x.stream;
let stream_2 = weight.stream;
let stream_3 = bias.as_ref().map(|b| b.stream);
let shape = vec![
x.shape[0],
weight.shape[1] * options.groups,
size_0,
size_1,
size_2,
];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = ConvTranspose3dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
let streams = match stream_3 {
Some(stream_3) => vec![stream_1, stream_2, stream_3],
None => vec![stream_1, stream_2],
};
out.client.register(
streams,
OperationDescription::Module(ModuleOperationDescription::ConvTranspose3d(desc.clone())),
ConvTranspose3dOps::<B>::new(desc),
);
out
}
fn avg_pool1d(
x: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
) -> FloatTensor<Self, 3> {
make_ops!(
AvgPool1dOps,
AvgPool1dDescription,
|args: AvgPool1dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&args.x);
let output = B::avg_pool1d(
x,
args.kernel_size,
args.stride,
args.padding,
args.count_include_pad,
);
handles.register_float_tensor::<B, 3>(&args.out.id, output);
}
);
let stream = x.stream;
let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]);
let shape = vec![x.shape[0], x.shape[1], size];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = AvgPool1dDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
count_include_pad,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::AvgPool1d(desc.clone())),
AvgPool1dOps::<B>::new(desc),
);
out
}
fn avg_pool2d(
x: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
) -> FloatTensor<Self, 4> {
make_ops!(
AvgPool2dOps,
AvgPool2dDescription,
|args: AvgPool2dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let output = B::avg_pool2d(
x,
args.kernel_size,
args.stride,
args.padding,
args.count_include_pad,
);
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let size_0 =
calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]);
let size_1 =
calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]);
let stream = x.stream;
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = AvgPool2dDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
count_include_pad,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::AvgPool2d(desc.clone())),
AvgPool2dOps::<B>::new(desc),
);
out
}
fn avg_pool1d_backward(
x: FloatTensor<Self, 3>,
grad: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
) -> FloatTensor<Self, 3> {
make_ops!(
AvgPool1dBackwardOps,
AvgPool1dBackwardDescription,
|args: AvgPool1dBackwardDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&args.x);
let grad = handles.get_float_tensor::<B, 3>(&args.grad);
let output = B::avg_pool1d_backward(
x,
grad,
args.kernel_size,
args.stride,
args.padding,
args.count_include_pad,
);
handles.register_float_tensor::<B, 3>(&args.out.id, output);
}
);
let stream_1 = x.stream;
let stream_2 = grad.stream;
let out = x
.client
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
let desc = AvgPool1dBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
kernel_size,
stride,
padding,
count_include_pad,
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(ModuleOperationDescription::AvgPool1dBackward(
desc.clone(),
)),
AvgPool1dBackwardOps::<B>::new(desc),
);
out
}
fn avg_pool2d_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
) -> FloatTensor<Self, 4> {
make_ops!(
AvgPool2dBackwardOps,
AvgPool2dBackwardDescription,
|args: AvgPool2dBackwardDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let grad = handles.get_float_tensor::<B, 4>(&args.grad);
let output = B::avg_pool2d_backward(
x,
grad,
args.kernel_size,
args.stride,
args.padding,
args.count_include_pad,
);
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let stream_1 = x.stream;
let stream_2 = grad.stream;
let out = x
.client
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
let desc = AvgPool2dBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
kernel_size,
stride,
padding,
count_include_pad,
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(ModuleOperationDescription::AvgPool2dBackward(
desc.clone(),
)),
AvgPool2dBackwardOps::<B>::new(desc),
);
out
}
fn max_pool1d(
x: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> FloatTensor<Self, 3> {
make_ops!(
MaxPool1dOps,
MaxPool1dDescription,
|args: MaxPool1dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&args.x);
let output = B::max_pool1d(
x,
args.kernel_size,
args.stride,
args.padding,
args.dilation,
);
handles.register_float_tensor::<B, 3>(&args.out.id, output);
}
);
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
let stream = x.stream;
let shape = vec![x.shape[0], x.shape[1], size];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = MaxPool1dDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::MaxPool1d(desc.clone())),
MaxPool1dOps::<B>::new(desc),
);
out
}
fn max_pool2d(
x: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> FloatTensor<Self, 4> {
make_ops!(
MaxPool2dOps,
MaxPool2dDescription,
|args: MaxPool2dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let output = B::max_pool2d(
x,
args.kernel_size,
args.stride,
args.padding,
args.dilation,
);
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape[3],
);
let stream = x.stream;
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = MaxPool2dDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::MaxPool2d(desc.clone())),
MaxPool2dOps::<B>::new(desc),
);
out
}
fn max_pool1d_with_indices(
x: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<Self> {
make_ops!(
MaxPool1dWithIndicesOps,
MaxPool1dWithIndicesDescription,
|args: MaxPool1dWithIndicesDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&args.x);
let output = B::max_pool1d_with_indices(
x,
args.kernel_size,
args.stride,
args.padding,
args.dilation,
);
handles.register_float_tensor::<B, 3>(&args.out.id, output.output);
handles.register_int_tensor::<B, 3>(&args.out_indices.id, output.indices);
}
);
let stream = x.stream;
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
let shape = vec![x.shape[0], x.shape[1], size];
let out = x
.client
.tensor_uninitialized(shape.clone(), B::FloatElem::dtype());
let out_indices = x.client.tensor_uninitialized(shape, B::IntElem::dtype());
let desc = MaxPool1dWithIndicesDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
out_indices: out_indices.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndices(
desc.clone(),
)),
MaxPool1dWithIndicesOps::<B>::new(desc),
);
MaxPool1dWithIndices::new(out, out_indices)
}
fn max_pool2d_with_indices(
x: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<Self> {
make_ops!(
MaxPool2dWithIndicesOps,
MaxPool2dWithIndicesDescription,
|args: MaxPool2dWithIndicesDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let output = B::max_pool2d_with_indices(
x,
args.kernel_size,
args.stride,
args.padding,
args.dilation,
);
handles.register_float_tensor::<B, 4>(&args.out.id, output.output);
handles.register_int_tensor::<B, 4>(&args.out_indices.id, output.indices);
}
);
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape[3],
);
let stream = x.stream;
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let out = x
.client
.tensor_uninitialized(shape.clone(), B::FloatElem::dtype());
let out_indices = x.client.tensor_uninitialized(shape, B::IntElem::dtype());
let desc = MaxPool2dWithIndicesDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
out_indices: out_indices.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndices(
desc.clone(),
)),
MaxPool2dWithIndicesOps::<B>::new(desc),
);
MaxPool2dWithIndices::new(out, out_indices)
}
fn max_pool1d_with_indices_backward(
x: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
output_grad: FloatTensor<Self, 3>,
indices: IntTensor<Self, 3>,
) -> MaxPool1dBackward<Self> {
make_ops!(
MaxPool1dWithIndicesBackwardOps,
MaxPool1dWithIndicesBackwardDescription,
|args: MaxPool1dWithIndicesBackwardDescription,
handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&args.x);
let grad = handles.get_float_tensor::<B, 3>(&args.grad);
let indices = handles.get_int_tensor::<B, 3>(&args.indices);
let output = B::max_pool1d_with_indices_backward(
x,
args.kernel_size,
args.stride,
args.padding,
args.dilation,
grad,
indices,
);
handles.register_float_tensor::<B, 3>(&args.out.id, output.x_grad);
}
);
let stream_1 = x.stream;
let stream_2 = output_grad.stream;
let stream_3 = indices.stream;
let out = x
.client
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
let desc = MaxPool1dWithIndicesBackwardDescription {
x: x.into_description(),
grad: output_grad.into_description(),
indices: indices.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndicesBackward(
desc.clone(),
)),
MaxPool1dWithIndicesBackwardOps::<B>::new(desc),
);
MaxPool1dBackward::new(out)
}
fn max_pool2d_with_indices_backward(
x: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: FloatTensor<Self, 4>,
indices: IntTensor<Self, 4>,
) -> MaxPool2dBackward<Self> {
make_ops!(
MaxPool2dWithIndicesBackwardOps,
MaxPool2dWithIndicesBackwardDescription,
|args: MaxPool2dWithIndicesBackwardDescription,
handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let grad = handles.get_float_tensor::<B, 4>(&args.grad);
let indices = handles.get_int_tensor::<B, 4>(&args.indices);
let output = B::max_pool2d_with_indices_backward(
x,
args.kernel_size,
args.stride,
args.padding,
args.dilation,
grad,
indices,
);
handles.register_float_tensor::<B, 4>(&args.out.id, output.x_grad);
}
);
let stream_1 = x.stream;
let stream_2 = output_grad.stream;
let stream_3 = indices.stream;
let out = x
.client
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
let desc = MaxPool2dWithIndicesBackwardDescription {
x: x.into_description(),
grad: output_grad.into_description(),
indices: indices.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndicesBackward(
desc.clone(),
)),
MaxPool2dWithIndicesBackwardOps::<B>::new(desc),
);
MaxPool2dBackward::new(out)
}
fn adaptive_avg_pool1d(x: FloatTensor<Self, 3>, output_size: usize) -> FloatTensor<Self, 3> {
make_ops!(
AdaptiveAvgPool1dOps,
AdaptiveAvgPool1dDescription,
|args: AdaptiveAvgPool1dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&args.x);
let output = B::adaptive_avg_pool1d(x, args.output_size);
handles.register_float_tensor::<B, 3>(&args.out.id, output);
}
);
let stream = x.stream;
let shape = vec![x.shape[0], x.shape[1], output_size];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = AdaptiveAvgPool1dDescription {
x: x.into_description(),
output_size,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1d(
desc.clone(),
)),
AdaptiveAvgPool1dOps::<B>::new(desc),
);
out
}
fn adaptive_avg_pool2d(
x: FloatTensor<Self, 4>,
output_size: [usize; 2],
) -> FloatTensor<Self, 4> {
make_ops!(
AdaptiveAvgPool2dOps,
AdaptiveAvgPool2dDescription,
|args: AdaptiveAvgPool2dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let output = B::adaptive_avg_pool2d(x, args.output_size);
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let stream = x.stream;
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = AdaptiveAvgPool2dDescription {
x: x.into_description(),
output_size,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2d(
desc.clone(),
)),
AdaptiveAvgPool2dOps::<B>::new(desc),
);
out
}
fn adaptive_avg_pool1d_backward(
x: FloatTensor<Self, 3>,
grad: FloatTensor<Self, 3>,
) -> FloatTensor<Self, 3> {
make_ops!(
AdaptiveAvgPool1dBackwardOps,
AdaptiveAvgPool1dBackwardDescription,
|args: AdaptiveAvgPool1dBackwardDescription,
handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 3>(&args.x);
let grad = handles.get_float_tensor::<B, 3>(&args.grad);
let output = B::adaptive_avg_pool1d_backward(x, grad);
handles.register_float_tensor::<B, 3>(&args.out.id, output);
}
);
let stream_1 = x.stream;
let stream_2 = grad.stream;
let out = x
.client
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
let desc = AdaptiveAvgPool1dBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1dBackward(
desc.clone(),
)),
AdaptiveAvgPool1dBackwardOps::<B>::new(desc),
);
out
}
fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
) -> FloatTensor<Self, 4> {
make_ops!(
AdaptiveAvgPool2dBackwardOps,
AdaptiveAvgPool2dBackwardDescription,
|args: AdaptiveAvgPool2dBackwardDescription,
handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let grad = handles.get_float_tensor::<B, 4>(&args.grad);
let output = B::adaptive_avg_pool2d_backward(x, grad);
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let stream_1 = x.stream;
let stream_2 = grad.stream;
let out = x
.client
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
let desc = AdaptiveAvgPool2dBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2dBackward(
desc.clone(),
)),
AdaptiveAvgPool2dBackwardOps::<B>::new(desc),
);
out
}
fn interpolate(
x: FloatTensor<Self, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self, 4> {
make_ops!(
InterpolateOps,
InterpolateDescription,
|args: InterpolateDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let output = B::interpolate(x, args.output_size, args.options.clone().into());
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let stream = x.stream;
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = InterpolateDescription {
x: x.into_description(),
output_size,
options: options.into(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Module(ModuleOperationDescription::Interpolate(desc.clone())),
InterpolateOps::<B>::new(desc),
);
out
}
fn interpolate_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self, 4> {
make_ops!(
InterpolateBackwardOps,
InterpolateBackwardDescription,
|args: InterpolateBackwardDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let grad = handles.get_float_tensor::<B, 4>(&args.grad);
let output =
B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let stream_1 = x.stream;
let stream_2 = grad.stream;
let out = x
.client
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
let desc = InterpolateBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
output_size,
options: options.into(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(ModuleOperationDescription::InterpolateBackward(
desc.clone(),
)),
InterpolateBackwardOps::<B>::new(desc),
);
out
}
}