mod chain;
mod ff;
mod map;
mod seq;
#[cfg(test)]
pub mod testing;
pub use chain::{Chain, ChainConfig};
pub use ff::{Activation, Linear, LinearConfig, Mlp, MlpConfig};
pub use map::BatchMap;
pub use seq::{Gru, GruConfig, Lstm, LstmConfig};
pub type GruMlpConfig = ChainConfig<GruConfig, MlpConfig>;
pub type LstmMlpConfig = ChainConfig<GruConfig, MlpConfig>;
use crate::torch::packed::PackedTensor;
use tch::{Device, Tensor};
pub trait Module {
#[must_use]
fn shallow_clone(&self) -> Self
where
Self: Sized;
#[must_use]
fn clone_to_device(&self, device: Device) -> Self
where
Self: Sized;
fn variables(&self) -> Box<dyn Iterator<Item = &Tensor> + '_>;
fn trainable_variables(&self) -> Box<dyn Iterator<Item = &Tensor> + '_>;
fn has_cudnn_second_derivatives(&self) -> bool {
true }
}
macro_rules! impl_wrapped_module {
($wrapper:ty) => {
impl<T: Module + ?Sized> Module for $wrapper {
fn shallow_clone(&self) -> Self {
unimplemented!()
}
fn clone_to_device(&self, _: Device) -> Self {
unimplemented!()
}
fn variables(&self) -> Box<dyn Iterator<Item = &Tensor> + '_> {
T::variables(self)
}
fn trainable_variables(&self) -> Box<dyn Iterator<Item = &Tensor> + '_> {
T::trainable_variables(self)
}
fn has_cudnn_second_derivatives(&self) -> bool {
T::has_cudnn_second_derivatives(&self)
}
}
};
}
impl_wrapped_module!(&'_ T);
impl_wrapped_module!(Box<T>);
pub trait ModuleExtras<'a> {
type Variables: Iterator<Item = &'a Tensor>;
type TrainableVariables: Iterator<Item = &'a Tensor>;
fn variables(&'a self) -> Self::Variables;
fn trainable_variables(&'a self) -> Self::TrainableVariables;
}
pub trait BuildModule {
type Module: Module;
fn build_module(&self, in_dim: usize, out_dim: usize, device: Device) -> Self::Module;
}
macro_rules! impl_wrapped_build_module {
($wrapper:ty) => {
impl<T: BuildModule + ?Sized> BuildModule for $wrapper {
type Module = T::Module;
fn build_module(&self, in_dim: usize, out_dim: usize, device: Device) -> Self::Module {
T::build_module(self, in_dim, out_dim, device)
}
}
};
}
impl_wrapped_build_module!(&'_ T);
impl_wrapped_build_module!(Box<T>);
pub trait AsModule {
type Module: Module;
fn as_module(&self) -> &Self::Module;
fn as_module_mut(&mut self) -> &mut Self::Module;
fn batch_map<F: Fn(Tensor) -> Tensor>(self, f: F) -> BatchMap<Self, F>
where
Self: Sized,
{
BatchMap::new(self, f)
}
}
impl<T: Module> AsModule for T {
type Module = Self;
fn as_module(&self) -> &Self::Module {
self
}
fn as_module_mut(&mut self) -> &mut Self::Module {
self
}
}
pub trait Forward {
fn forward(&self, input: &Tensor) -> Tensor;
}
macro_rules! impl_wrapped_feed_forward_module {
($wrapper:ty) => {
impl<T: Forward + ?Sized> Forward for $wrapper {
fn forward(&self, input: &Tensor) -> Tensor {
T::forward(self, input)
}
}
};
}
impl_wrapped_feed_forward_module!(&'_ T);
impl_wrapped_feed_forward_module!(Box<T>);
pub trait SeqSerial {
fn seq_serial(&self, inputs: &Tensor, seq_lengths: &[usize]) -> Tensor;
}
macro_rules! impl_wrapped_seq_serial {
($wrapper:ty) => {
impl<T: SeqSerial + ?Sized> SeqSerial for $wrapper {
fn seq_serial(&self, inputs: &Tensor, seq_lengths: &[usize]) -> Tensor {
T::seq_serial(self, inputs, seq_lengths)
}
}
};
}
impl_wrapped_seq_serial!(&'_ T);
impl_wrapped_seq_serial!(Box<T>);
pub trait SeqPacked {
fn seq_packed(&self, inputs: &PackedTensor) -> PackedTensor;
}
macro_rules! impl_wrapped_seq_packed {
($wrapper:ty) => {
impl<T: SeqPacked + ?Sized> SeqPacked for $wrapper {
fn seq_packed(&self, inputs: &PackedTensor) -> PackedTensor {
T::seq_packed(self, inputs)
}
}
};
}
impl_wrapped_seq_packed!(&'_ T);
impl_wrapped_seq_packed!(Box<T>);
pub trait SeqIterative {
type State;
fn initial_state(&self) -> Self::State;
fn step(&self, state: &mut Self::State, input: &Tensor) -> Tensor;
fn iter<I>(&self, inputs: I) -> SeqIterator<&Self, I::IntoIter>
where
I: IntoIterator,
I::Item: AsRef<Tensor>,
{
SeqIterator::new(self, inputs.into_iter())
}
fn into_iter<I>(self, inputs: I) -> SeqIterator<Self, I::IntoIter>
where
I: IntoIterator,
I::Item: AsRef<Tensor>,
Self: Sized,
{
SeqIterator::new(self, inputs.into_iter())
}
}
macro_rules! impl_wrapped_iterative_module {
($wrapper:ty) => {
impl<T: SeqIterative + ?Sized> SeqIterative for $wrapper {
type State = T::State;
fn initial_state(&self) -> Self::State {
T::initial_state(self)
}
fn step(&self, state: &mut Self::State, input: &Tensor) -> Tensor {
T::step(self, state, input)
}
}
};
}
impl_wrapped_iterative_module!(&'_ T);
impl_wrapped_iterative_module!(Box<T>);
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct SeqIterator<M: SeqIterative, I> {
module: M,
state: M::State,
inputs: I,
}
impl<M: SeqIterative, I> SeqIterator<M, I> {
fn new(module: M, inputs: I) -> Self {
Self {
state: module.initial_state(),
module,
inputs,
}
}
}
impl<M, I> Iterator for SeqIterator<M, I>
where
M: SeqIterative,
I: Iterator,
I::Item: AsRef<Tensor>,
{
type Item = Tensor;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
Some(
self.module
.step(&mut self.state, self.inputs.next()?.as_ref()),
)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.inputs.size_hint()
}
fn fold<B, F>(self, init: B, mut f: F) -> B
where
F: FnMut(B, Self::Item) -> B,
{
let module = self.module;
self.inputs
.fold(
(self.state, init),
move |(mut module_state, fold_state), input| {
let new_fold_state =
f(fold_state, module.step(&mut module_state, input.as_ref()));
(module_state, new_fold_state)
},
)
.1
}
}
impl<M, I> ExactSizeIterator for SeqIterator<M, I>
where
M: SeqIterative,
I: ExactSizeIterator,
I::Item: AsRef<Tensor>,
{
fn len(&self) -> usize {
self.inputs.len()
}
}