use num_traits::NumCast;
use crate::{
shapes::{Dtype, Shape},
tensor::Tensor,
tensor_ops::Device,
};
use super::{ModuleVisitor, ScalarOptions, TensorCollection, TensorOptions};
#[derive(Debug)]
pub struct RecursiveWalker<'a, M, F> {
pub m: M,
pub f: &'a mut F,
}
pub trait TensorVisitor<E: Dtype, D: Device<E>> {
type Viewer: TensorViewer;
type Err;
type E2: Dtype;
type D2: Device<Self::E2>;
#[allow(clippy::type_complexity)]
fn visit<S: Shape>(
&mut self,
opts: TensorOptions<S, E, D>,
t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>,
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>;
fn visit_scalar<N: NumCast>(
&mut self,
opts: ScalarOptions<N>,
_h: <Self::Viewer as TensorViewer>::View<'_, N>,
) -> Result<Option<N>, Self::Err> {
Ok(Some(opts.default))
}
}
pub trait TensorViewer: 'static {
type View<'a, Mod: 'a>
where
Self: 'a;
fn view_field<'a, Mod, Field, GetRef, GetMut>(
module: &'a mut Self::View<'_, Mod>,
name: &str,
get_ref: &mut GetRef,
get_mut: &mut GetMut,
) -> Self::View<'a, Field>
where
GetRef: FnMut(&Mod) -> &Field,
GetMut: FnMut(&mut Mod) -> &mut Field;
}
pub trait ModuleFields<M: TensorCollection<E, D>, E: Dtype, D: Device<E>> {
type Options<E2: Dtype, D2: Device<E2>>;
type Output<E2: Dtype, D2: Device<E2>>;
fn visit_fields<V: ModuleVisitor<M, E, D>>(
self,
visitor: &mut V,
) -> Result<Self::Options<V::E2, V::D2>, V::Err>;
fn handle_options<E2: Dtype, D2: Device<E2>>(
options: Self::Options<E2, D2>,
) -> Option<Self::Output<E2, D2>>;
}
pub struct ModuleField<'a, F1, F2, Mod, Field>
where
F1: FnMut(&Mod) -> &Field,
F2: FnMut(&mut Mod) -> &mut Field,
{
pub(super) name: &'a str,
pub(super) get_ref: F1,
pub(super) get_mut: F2,
pub(super) m: std::marker::PhantomData<Mod>,
pub(super) f: std::marker::PhantomData<Field>,
}
pub struct TensorField<'a, F1, F2, Mod, S: Shape, E: Dtype, D: Device<E>>
where
F1: FnMut(&Mod) -> &Tensor<S, E, D>,
F2: FnMut(&mut Mod) -> &mut Tensor<S, E, D>,
{
pub(super) name: &'a str,
pub(super) get_ref: F1,
pub(super) get_mut: F2,
pub(super) options: TensorOptions<S, E, D>,
pub(super) m: std::marker::PhantomData<Mod>,
}
pub struct ScalarField<'a, F1, F2, Mod, N>
where
N: NumCast,
F1: FnMut(&Mod) -> &N,
F2: FnMut(&mut Mod) -> &mut N,
{
pub(super) name: &'a str,
pub(super) get_ref: F1,
pub(super) get_mut: F2,
pub(super) options: ScalarOptions<N>,
pub(super) m: std::marker::PhantomData<Mod>,
}
#[derive(Debug)]
pub enum ViewTensorRef {}
#[derive(Debug)]
pub enum ViewTensorMut {}
#[derive(Debug)]
pub enum ViewTensorName {}