use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PaddingMode {
#[default]
Valid,
Same,
Custom(usize, usize, usize, usize),
}
impl PaddingMode {
pub fn uniform(padding: usize) -> Self {
PaddingMode::Custom(padding, padding, padding, padding)
}
pub fn conv1d(left: usize, right: usize) -> Self {
PaddingMode::Custom(left, right, 0, 0)
}
pub fn conv2d(top: usize, bottom: usize, left: usize, right: usize) -> Self {
PaddingMode::Custom(top, bottom, left, right)
}
pub fn name(&self) -> &'static str {
match self {
PaddingMode::Valid => "valid",
PaddingMode::Same => "same",
PaddingMode::Custom(..) => "custom",
}
}
}
pub trait ConvOps<R: Runtime> {
fn conv1d(
&self,
input: &Tensor<R>,
weight: &Tensor<R>,
bias: Option<&Tensor<R>>,
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
) -> Result<Tensor<R>> {
let _ = (input, weight, bias, stride, padding, dilation, groups);
Err(Error::NotImplemented {
feature: "ConvOps::conv1d",
})
}
fn conv2d(
&self,
input: &Tensor<R>,
weight: &Tensor<R>,
bias: Option<&Tensor<R>>,
stride: (usize, usize),
padding: PaddingMode,
dilation: (usize, usize),
groups: usize,
) -> Result<Tensor<R>> {
let _ = (input, weight, bias, stride, padding, dilation, groups);
Err(Error::NotImplemented {
feature: "ConvOps::conv2d",
})
}
fn depthwise_conv2d(
&self,
input: &Tensor<R>,
weight: &Tensor<R>,
bias: Option<&Tensor<R>>,
stride: (usize, usize),
padding: PaddingMode,
dilation: (usize, usize),
) -> Result<Tensor<R>> {
let _ = (input, weight, bias, stride, padding, dilation);
Err(Error::NotImplemented {
feature: "ConvOps::depthwise_conv2d",
})
}
}