use crate::{shapes::*, tensor::*, tensor_ops::*};
use super::*;
pub mod builder {
#[derive(Debug)]
pub struct Bias2D<const CHAN: usize>;
}
impl<const C: usize, E: Dtype, D: Device<E>> BuildOnDevice<D, E> for builder::Bias2D<C>
where
Bias2D<C, E, D>: BuildModule<D, E>,
{
type Built = Bias2D<C, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
}
#[derive(Clone, Debug)]
pub struct Bias2D<const C: usize, E: Dtype, D: Storage<E>> {
pub bias: Tensor<Rank1<C>, E, D>,
}
impl<const C: usize, E: Dtype, D: Storage<E>> NonMutableModule for Bias2D<C, E, D> {}
impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for Bias2D<C, E, D> {
type To<E2: Dtype, D2: Device<E2>> = Bias2D<C, E2, D2>;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(
Self::tensor(
"bias",
|s| &s.bias,
|s| &mut s.bias,
TensorOptions::reset_to_zeros(),
),
|bias| Bias2D { bias },
)
}
}
impl<const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>>
Module<Tensor<(Const<C>, H, W), E, D, T>> for Bias2D<C, E, D>
{
type Output = Tensor<(Const<C>, H, W), E, D, T>;
type Error = D::Err;
fn try_forward(
&self,
input: Tensor<(Const<C>, H, W), E, D, T>,
) -> Result<Self::Output, D::Err> {
let s = *input.shape();
input.try_add(self.bias.retaped::<T>().try_broadcast_like(&s)?)
}
}
impl<B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>>
Module<Tensor<(B, Const<C>, H, W), E, D, T>> for Bias2D<C, E, D>
{
type Output = Tensor<(B, Const<C>, H, W), E, D, T>;
type Error = D::Err;
fn try_forward(
&self,
input: Tensor<(B, Const<C>, H, W), E, D, T>,
) -> Result<Self::Output, D::Err> {
let s = *input.shape();
input.try_add(self.bias.retaped::<T>().try_broadcast_like(&s)?)
}
}