use super::{State, StateNamed};
use crate::optim::Optimizer;
use crate::tensor::backend::{ADBackend, Backend};
pub use burn_derive::Module;
pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
type Backend: Backend;
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
fn to_device(&mut self, device: <Self::Backend as Backend>::Device);
fn load(&mut self, state: &State<<Self::Backend as Backend>::Elem>)
-> Result<(), LoadingError>;
fn state(&self) -> State<<Self::Backend as Backend>::Elem>;
fn detach(&mut self);
fn num_params(&self) -> usize;
fn update_params<O: Optimizer<Backend = Self::Backend>>(
&mut self,
grads: &<Self::Backend as ADBackend>::Gradients,
optim: &mut O,
) where
Self::Backend: ADBackend;
fn load_optim_state<O: Optimizer<Backend = Self::Backend>>(
&self,
optim: &mut O,
state_optim: &StateNamed<<Self::Backend as Backend>::Elem>,
) where
Self::Backend: ADBackend;
fn register_optim_state<O: Optimizer<Backend = Self::Backend>>(
&self,
optim: &O,
state_optim: &mut StateNamed<<Self::Backend as Backend>::Elem>,
) where
Self::Backend: ADBackend;
}
pub trait ADModule:
Module<Backend = Self::ADBackend> + Send + Sync + std::fmt::Debug + std::fmt::Display
{
type ADBackend: ADBackend;
type InnerModule: Module<Backend = <Self::ADBackend as ADBackend>::InnerBackend>;
fn inner(&self) -> Self::InnerModule;
}
pub trait Forward<In, Out> {
fn forward(&self, input: In) -> Out;
}
#[derive(new, Debug)]
pub struct LoadingError {
message: String,
}
impl std::fmt::Display for LoadingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(format!("Loading error: {}", self.message).as_str())
}
}
impl std::error::Error for LoadingError {}