use crate::error::Result;
use crate::intern::Sym;
use crate::prediction::Prediction;
use crate::types::Inputs;
use std::future::Future;
pub trait Module: Send + Sync {
type ForwardFut<'a>: Future<Output = Result<Prediction<'a>>> + Send + 'a
where
Self: 'a;
fn forward<'a>(&'a self, inputs: Inputs<'a>) -> Self::ForwardFut<'a>;
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
fn id(&self) -> Sym {
crate::intern::sym(self.name())
}
}
#[derive(Debug, Clone)]
pub struct BaseModule {
pub name: String,
}
impl BaseModule {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
}
pub struct FnModule<F> {
f: F,
name: &'static str,
}
impl<F> FnModule<F> {
pub const fn new(name: &'static str, f: F) -> Self {
Self { f, name }
}
}
impl<F> Clone for FnModule<F>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
name: self.name,
}
}
}
pub trait ModuleExt: Module + Sized {
fn then<M: Module>(self, next: M) -> ChainedModule<Self, M> {
ChainedModule {
first: self,
second: next,
}
}
fn map<F, O>(self, f: F) -> MappedModule<Self, F>
where
F: Fn(Prediction<'_>) -> O + Send + Sync,
{
MappedModule { inner: self, f }
}
}
impl<M: Module> ModuleExt for M {}
pub struct ChainedModule<A, B> {
first: A,
second: B,
}
impl<A, B> Clone for ChainedModule<A, B>
where
A: Clone,
B: Clone,
{
fn clone(&self) -> Self {
Self {
first: self.first.clone(),
second: self.second.clone(),
}
}
}
pub struct MappedModule<M, F> {
inner: M,
f: F,
}
impl<M, F> Clone for MappedModule<M, F>
where
M: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
f: self.f.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Inputs;
struct TestModule;
impl Module for TestModule {
type ForwardFut<'a> = std::future::Ready<Result<Prediction<'a>>>;
fn forward<'a>(&'a self, _inputs: Inputs<'a>) -> Self::ForwardFut<'a> {
std::future::ready(Ok(Prediction::new()))
}
}
#[tokio::test]
async fn test_module_execution() {
let module = TestModule;
let inputs = Inputs::new();
let result = module.forward(inputs).await;
assert!(result.is_ok());
}
#[test]
fn test_module_name() {
let module = TestModule;
assert!(module.name().contains("TestModule"));
}
#[test]
fn test_base_module() {
let base = BaseModule::new("my_module");
assert_eq!(base.name, "my_module");
}
}