use super::{Param, ParamId, Quantizer};
use crate::{
record::Record,
tensor::backend::{AutodiffBackend, Backend},
};
use alloc::{string::String, vec::Vec};
pub use burn_derive::Module;
use burn_tensor::{Bool, Int, Tensor, ops::Device};
pub type Devices<B> = Vec<Device<B>>;
macro_rules! module {
(map=$module:ident, ops=$item:expr) => {{
struct Mapper;
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map_float<const D: usize>(
&mut self,
param: Param<Tensor<B, D>>,
) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
let func = $item;
let tensor = func(tensor);
Param::from_mapped_value(id, tensor, mapper)
}
}
let mut mapper = Mapper;
$module.map(&mut mapper)
}};
(visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
struct Visitor<'a, B: Backend> {
state: &'a mut $state_ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
let func = $item;
func(¶m.val(), &mut self.state)
}
}
#[allow(clippy::redundant_closure_call)]
let mut state = $init();
let mut visitor = Visitor {
state: &mut state,
backend: core::marker::PhantomData,
};
$module.visit(&mut visitor);
state
}};
}
pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
type Record: Record<B>;
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
fn devices(&self) -> Devices<B> {
self.collect_devices(Devices::<B>::new())
}
fn fork(self, device: &B::Device) -> Self;
fn to_device(self, device: &B::Device) -> Self;
fn no_grad(self) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
)
}
fn num_params(&self) -> usize {
module!(
visit_float = self,
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
*state += tensor.shape().num_elements();
},
state = usize,
init = || 0
)
}
fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
fn load_record(self, record: Self::Record) -> Self;
fn into_record(self) -> Self::Record;
#[cfg(feature = "std")]
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), crate::record::RecorderError>
where
FR: crate::record::FileRecorder<B>,
PB: Into<std::path::PathBuf>,
{
let record = Self::into_record(self);
recorder.record(record, file_path.into())
}
#[cfg(feature = "std")]
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &B::Device,
) -> Result<Self, crate::record::RecorderError>
where
FR: crate::record::FileRecorder<B>,
PB: Into<std::path::PathBuf>,
{
let record = recorder.load(file_path.into(), device)?;
Ok(self.load_record(record))
}
fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
self.map(quantizer)
}
}
pub trait ModuleVisitor<B: Backend> {
#[allow(unused_variables)]
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}
#[allow(unused_variables)]
fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}
#[allow(unused_variables)]
fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}
#[allow(unused_variables)]
fn enter_module(&mut self, name: &str, container_type: &str) {}
#[allow(unused_variables)]
fn exit_module(&mut self, name: &str, container_type: &str) {}
#[allow(unused_variables)]
fn visit_float_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D>,
) {
}
#[allow(unused_variables)]
fn visit_int_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D, Int>,
) {
}
#[allow(unused_variables)]
fn visit_bool_with_path<const D: usize>(
&mut self,
path: &[String],
id: ParamId,
tensor: &Tensor<B, D, Bool>,
) {
}
}
pub trait ModuleMapper<B: Backend> {
#[allow(unused_variables)]
fn enter_module(&mut self, name: &str, container_type: &str) {}
#[allow(unused_variables)]
fn exit_module(&mut self, name: &str, container_type: &str) {}
#[allow(unused_variables)]
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}
#[allow(unused_variables)]
fn map_int<const D: usize>(
&mut self,
param: Param<Tensor<B, D, Int>>,
) -> Param<Tensor<B, D, Int>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}
#[allow(unused_variables)]
fn map_bool<const D: usize>(
&mut self,
param: Param<Tensor<B, D, Bool>>,
) -> Param<Tensor<B, D, Bool>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}
}
pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
type InnerModule: Module<B::InnerBackend>;
fn valid(&self) -> Self::InnerModule;
}