#[macro_export]
macro_rules! non_differentiable {
($($path:tt)+) => {
$crate::__non_differentiable!(begin $($path)+);
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __non_differentiable {
(begin < $($rest:tt)*) => {
$crate::__non_differentiable!(generics () () $($rest)*);
};
(begin $first:tt $($rest:tt)*) => {
$crate::__non_differentiable!(path () ($first) $($rest)*);
};
(generics ($($generics:tt)*) () > $($rest:tt)*) => {
$crate::__non_differentiable!(path ($($generics)*) () $($rest)*);
};
(generics ($($generics:tt)*) ($($brackets:tt)*) < $($rest:tt)*) => {
$crate::__non_differentiable!(generics ($($generics)* <) ($($brackets)* <) $($rest)*);
};
(generics ($($generics:tt)*) (< $($brackets:tt)*) > $($rest:tt)*) => {
$crate::__non_differentiable!(generics ($($generics)* >) ($($brackets)*) $($rest)*);
};
(generics ($($generics:tt)*) ($($brackets:tt)*) $first:tt $($rest:tt)*) => {
$crate::__non_differentiable!(generics ($($generics)* $first) ($($brackets)*) $($rest)*);
};
(path ($($generics:tt)*) ($($path:tt)*) where $($rest:tt)*) => {
$crate::__non_differentiable!(impl ($($generics)*) ($($path)*) ($($rest)*));
};
(path ($($generics:tt)*) ($($path:tt)*)) => {
$crate::__non_differentiable!(impl ($($generics)*) ($($path)*) ());
};
(path ($($generics:tt)*) ($($path:tt)*) $first:tt $($rest:tt)*) => {
$crate::__non_differentiable!(path ($($generics)*) ($($path)* $first) $($rest)*);
};
(impl ($($generics:tt)*) ($($path:tt)*) ($($bound:tt)*)) => {
impl<$($generics)*> $crate::Differentiable for $($path)* where $($bound)* {
type Tensors = ();
type Gradient = ();
fn tensors(&self) -> Self::Tensors {}
fn grad(_: &Self::Tensors, _: &HashMap<usize, Tensor>) -> Self::Gradient {}
fn grad_map(_: &Self::Tensors, _: Self::Gradient, _: &mut HashMap<usize, Tensor>) {}
}
};
}
#[macro_export]
macro_rules! differentiable_module {
($m:ident) => {
impl $crate::Differentiable for $m {
type Tensors = HashMap<usize, Tensor>;
type Gradient = HashMap<usize, Tensor>;
fn tensors(&self) -> Self::Tensors {
$crate::Module::parameters(self)
}
fn grad(tensors: &Self::Tensors, grad_map: &HashMap<usize, Tensor>) -> Self::Gradient {
tensors
.keys()
.map(|id| (*id, grad_map.get(id).unwrap().clone()))
.collect()
}
fn grad_map(
tensors: &Self::Tensors,
grad: Self::Gradient,
out: &mut HashMap<usize, Tensor>,
) {
for id in tensors.keys() {
out.insert(*id, grad.get(id).unwrap().clone());
}
}
}
impl $crate::DifferentiableModule for $m {}
};
}