#![allow(clippy::type_complexity)]
use num_traits::NumCast;
use crate::{
shapes::{ConstShape, Dtype, Shape},
tensor::{OneFillStorage, Tensor, ZeroFillStorage},
tensor_ops::Device,
};
use super::{ModuleField, ModuleFields, ScalarField, TensorField};
pub trait TensorCollection<E: Dtype, D: Device<E>>: Sized {
type To<E2: Dtype, D2: Device<E2>>;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>;
fn module<F1, F2, Field>(
name: &str,
get_ref: F1,
get_mut: F2,
) -> ModuleField<F1, F2, Self, Field>
where
F1: FnMut(&Self) -> &Field,
F2: FnMut(&mut Self) -> &mut Field,
Field: TensorCollection<E, D>,
{
ModuleField {
name,
get_ref,
get_mut,
m: Default::default(),
f: Default::default(),
}
}
fn tensor<F1, F2, S>(
name: &str,
get_ref: F1,
get_mut: F2,
options: TensorOptions<S, E, D>,
) -> TensorField<F1, F2, Self, S, E, D>
where
F1: FnMut(&Self) -> &Tensor<S, E, D>,
F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>,
S: Shape,
{
TensorField {
name,
get_ref,
get_mut,
options,
m: Default::default(),
}
}
fn scalar<F1, F2, N>(
name: &str,
get_ref: F1,
get_mut: F2,
options: ScalarOptions<N>,
) -> ScalarField<F1, F2, Self, N>
where
F1: FnMut(&Self) -> &N,
F2: FnMut(&mut Self) -> &mut N,
N: NumCast,
{
ScalarField {
name,
get_ref,
get_mut,
options,
m: Default::default(),
}
}
}
pub trait ModuleVisitor<T: TensorCollection<E, D>, E: Dtype, D: Device<E>>: Sized {
type Err;
type E2: Dtype;
type D2: Device<Self::E2>;
fn visit_module<Field, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut,
) -> Result<Option<Field::To<Self::E2, Self::D2>>, Self::Err>
where
GetRef: FnMut(&T) -> &Field,
GetMut: FnMut(&mut T) -> &mut Field,
Field: TensorCollection<E, D>;
fn visit_tensor<S: Shape, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut,
opts: TensorOptions<S, E, D>,
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>
where
GetRef: FnMut(&T) -> &Tensor<S, E, D>,
GetMut: FnMut(&mut T) -> &mut Tensor<S, E, D>;
fn visit_scalar<N, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut,
opts: ScalarOptions<N>,
) -> Result<Option<N>, Self::Err>
where
N: NumCast,
GetRef: FnMut(&T) -> &N,
GetMut: FnMut(&mut T) -> &mut N;
fn visit_fields<M: ModuleFields<T, E, D>>(
&mut self,
fields: M,
builder: impl FnOnce(M::Output<Self::E2, Self::D2>) -> T::To<Self::E2, Self::D2>,
) -> Result<Option<T::To<Self::E2, Self::D2>>, Self::Err>;
}
impl<S: ConstShape, E: Dtype, D: Device<E>> TensorCollection<E, D> for Tensor<S, E, D> {
type To<E2: Dtype, D2: Device<E2>> = Tensor<S, E2, D2>;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_tensor(
"",
|s| s,
|s| s,
TensorOptions {
do_gradient_update: true,
reset: |_| Ok(()),
shape: Default::default(),
},
)
}
}
#[non_exhaustive]
pub struct TensorOptions<S: Shape, E: Dtype, D: Device<E>> {
pub do_gradient_update: bool,
pub reset: fn(&'_ mut Tensor<S, E, D>) -> Result<(), D::Err>,
pub shape: S,
}
impl<S: Shape, E: Dtype, D: Device<E>> TensorOptions<S, E, D> {
pub fn reset_to_zeros() -> Self
where
D: ZeroFillStorage<E>,
S: ConstShape,
{
TensorOptions {
do_gradient_update: true,
reset: |t| t.try_fill_with_zeros(),
shape: S::default(),
}
}
pub fn reset_to_ones() -> Self
where
D: OneFillStorage<E>,
S: ConstShape,
{
TensorOptions {
do_gradient_update: true,
reset: |t| t.try_fill_with_ones(),
shape: S::default(),
}
}
pub fn reset_with(reset: fn(&mut Tensor<S, E, D>) -> Result<(), D::Err>) -> Self
where
S: ConstShape,
{
TensorOptions {
do_gradient_update: true,
reset,
shape: S::default(),
}
}
pub fn detached(reset: fn(&mut Tensor<S, E, D>) -> Result<(), D::Err>) -> Self
where
S: ConstShape,
{
TensorOptions {
do_gradient_update: false,
reset,
shape: S::default(),
}
}
}
#[non_exhaustive]
pub struct ScalarOptions<N: NumCast> {
pub default: N,
}
impl<N: NumCast> ScalarOptions<N> {
pub fn from_default(default: N) -> Self {
Self { default }
}
}