use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use crate::{Module, ModuleConfig, Parameter};
pub trait ModuleBuilder<T> {
fn build(self) -> Result<T>;
fn training(self, training: bool) -> Self;
fn device(self, device: DeviceType) -> Self;
fn config<F>(self, config_fn: F) -> Self
where
F: FnOnce(&mut ModuleConfig);
}
pub trait ModuleComposition {
fn then<Other: Module + 'static>(self, other: Other) -> ComposedModule<Self, Other>
where
Self: Sized + 'static;
fn parallel<Other: Module + 'static>(self, other: Other) -> ParallelModule<Self, Other>
where
Self: Sized + 'static;
fn residual(self) -> ResidualModule<Self>
where
Self: Sized + 'static;
fn conditional<F>(self, condition_fn: F) -> ConditionalModule<Self, F>
where
Self: Sized + 'static,
F: Fn() -> bool + Send + Sync;
}
pub struct ComposedModule<First, Second> {
first: First,
second: Second,
}
impl<First: Module, Second: Module> Module for ComposedModule<First, Second> {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let intermediate = self.first.forward(input)?;
self.second.forward(&intermediate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = self.first.parameters();
let second_params = self.second.parameters();
for (name, param) in second_params {
params.insert(format!("second.{}", name), param);
}
params
}
fn training(&self) -> bool {
self.first.training() && self.second.training()
}
fn set_training(&mut self, training: bool) {
self.first.set_training(training);
self.second.set_training(training);
}
}
pub struct ParallelModule<First, Second> {
first: First,
second: Second,
}
impl<First: Module, Second: Module> Module for ParallelModule<First, Second> {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let first_output = self.first.forward(input)?;
let _second_output = self.second.forward(input)?;
Ok(first_output)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.first.parameters() {
params.insert(format!("first.{name}"), param);
}
for (name, param) in self.second.parameters() {
params.insert(format!("second.{name}"), param);
}
params
}
fn training(&self) -> bool {
self.first.training() && self.second.training()
}
fn set_training(&mut self, training: bool) {
self.first.set_training(training);
self.second.set_training(training);
}
}
pub struct ResidualModule<M> {
module: M,
}
impl<M: Module> Module for ResidualModule<M> {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let output = self.module.forward(input)?;
Ok(output)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.module.parameters()
}
fn training(&self) -> bool {
self.module.training()
}
fn set_training(&mut self, training: bool) {
self.module.set_training(training);
}
}
pub struct ConditionalModule<M, F> {
module: M,
condition_fn: F,
}
impl<M: Module, F: Fn() -> bool + Send + Sync> Module for ConditionalModule<M, F> {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if (self.condition_fn)() {
self.module.forward(input)
} else {
Ok(input.clone())
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.module.parameters()
}
fn training(&self) -> bool {
self.module.training()
}
fn set_training(&mut self, training: bool) {
self.module.set_training(training);
}
}
impl<T: Module + 'static> ModuleComposition for T {
fn then<Other: Module + 'static>(self, other: Other) -> ComposedModule<Self, Other> {
ComposedModule {
first: self,
second: other,
}
}
fn parallel<Other: Module + 'static>(self, other: Other) -> ParallelModule<Self, Other> {
ParallelModule {
first: self,
second: other,
}
}
fn residual(self) -> ResidualModule<Self> {
ResidualModule { module: self }
}
fn conditional<F>(self, condition_fn: F) -> ConditionalModule<Self, F>
where
F: Fn() -> bool + Send + Sync,
{
ConditionalModule {
module: self,
condition_fn,
}
}
}