use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use alloc::{format, string::ToString, vec::Vec};
use burn_tensor::{
backend::{AutodiffBackend, Backend},
ops::Device,
};
use core::fmt::Debug;
impl<T, B> Module<B> for Option<T>
where
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = Option<T::Record>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
if let Some(module) = self {
module.visit(visitor)
}
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
self.map(|module| module.map(mapper))
}
fn load_record(self, record: Self::Record) -> Self {
let is_constant = self.num_params() == 0;
if is_constant {
return self;
}
self.zip(record)
.map(|(module, record)| module.load_record(record))
}
fn into_record(self) -> Self::Record {
self.map(Module::into_record)
}
fn to_device(self, device: &Device<B>) -> Self {
self.map(|module| module.to_device(device))
}
fn fork(self, device: &Device<B>) -> Self {
self.map(|module| module.fork(device))
}
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
if let Some(module) = self.as_ref() {
devices = module.collect_devices(devices);
}
devices
}
}
impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
fn content(&self, content: Content) -> Option<Content> {
match self {
Some(module) => content.add_single(module).optional(),
None => content.add_single("None").optional(),
}
}
}
impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
impl<T, B> AutodiffModule<B> for Option<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
B: AutodiffBackend,
{
type InnerModule = Option<T::InnerModule>;
fn valid(&self) -> Self::InnerModule {
self.as_ref().map(|module| module.valid())
}
}
impl<T, B> Module<B> for Vec<T>
where
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = Vec<T::Record>;
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
num_params += module.num_params();
}
num_params
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
for (i, module) in self.iter().enumerate() {
let index_str = alloc::format!("{}", i);
visitor.enter_module(&index_str, "Vec");
module.visit(visitor);
visitor.exit_module(&index_str, "Vec");
}
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
self.into_iter()
.enumerate()
.map(|(i, module)| {
let index_str = alloc::format!("{}", i);
mapper.enter_module(&index_str, "Vec");
let mapped = module.map(mapper);
mapper.exit_module(&index_str, "Vec");
mapped
})
.collect()
}
fn into_record(self) -> Self::Record {
self.into_iter().map(Module::into_record).collect()
}
fn load_record(self, record: Self::Record) -> Self {
assert_eq!(
self.len(),
record.len(),
r#"[Load Record Error] The vec record does not the same length as the module.
Make sure you module initialization is compatible with the record being loaded.
"#,
);
self.into_iter()
.zip(record)
.map(|(module, record)| module.load_record(record))
.collect()
}
fn to_device(self, device: &Device<B>) -> Self {
self.into_iter()
.map(|module| module.to_device(device))
.collect()
}
fn fork(self, device: &Device<B>) -> Self {
self.into_iter().map(|module| module.fork(device)).collect()
}
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
for module in self.iter() {
devices = module.collect_devices(devices);
}
devices
}
}
impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{i}");
acc.add(&index, module)
})
.set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
.optional()
}
}
impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
impl<T, B> AutodiffModule<B> for Vec<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
B: AutodiffBackend,
{
type InnerModule = Vec<T::InnerModule>;
fn valid(&self) -> Self::InnerModule {
self.iter().map(|module| module.valid()).collect()
}
}
impl<const N: usize, T, B> Module<B> for [T; N]
where
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = [T::Record; N];
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
for module in self.iter() {
devices = module.collect_devices(devices);
}
devices
}
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
num_params += module.num_params();
}
num_params
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
for (i, module) in self.iter().enumerate() {
let index_str = alloc::format!("{}", i);
visitor.enter_module(&index_str, "Array");
module.visit(visitor);
visitor.exit_module(&index_str, "Array");
}
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let mut result = Vec::with_capacity(N);
for (i, module) in IntoIterator::into_iter(self).enumerate() {
let index_str = alloc::format!("{}", i);
mapper.enter_module(&index_str, "Array");
let mapped = module.map(mapper);
mapper.exit_module(&index_str, "Array");
result.push(mapped);
}
result
.try_into()
.unwrap_or_else(|v: Vec<T>| panic!("Expected array of length {}, got {}", N, v.len()))
}
fn load_record(self, record: Self::Record) -> Self {
self.into_iter()
.zip(record)
.map(|(module, record)| module.load_record(record))
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
fn into_record(self) -> Self::Record {
self.map(Module::into_record)
}
fn to_device(self, device: &Device<B>) -> Self {
self.map(|module| module.to_device(device))
}
fn fork(self, device: &Device<B>) -> Self {
self.map(|module| module.fork(device))
}
}
impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{i}");
acc.add(&index, module)
})
.set_top_level_type(format!("[0..{}]", self.len()).as_str())
.optional()
}
}
impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
where
T: AutodiffModule<B> + Debug + Send + Clone,
T::InnerModule: Debug,
B: AutodiffBackend,
{
type InnerModule = [T::InnerModule; N];
fn valid(&self) -> Self::InnerModule {
self.clone().map(|module| module.valid())
}
}
macro_rules! impl_module_tuple {
([$($l:ident),*][$($i:tt),*]) => {
impl<B, $($l,)*> Module<B> for ($($l,)*)
where
B: Backend,
$($l: Module<B> + Debug + Send + Clone,)*
{
type Record = ($($l::Record),*);
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
$(devices = self.$i.collect_devices(devices);)*
devices
}
fn fork(self, device: &Device<B>) -> Self {
($(self.$i.fork(device),)*)
}
fn to_device(self, device: &Device<B>) -> Self {
($(self.$i.to_device(device),)*)
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
$(
let index_str = $i.to_string();
visitor.enter_module(&index_str, "Tuple");
self.$i.visit(visitor);
visitor.exit_module(&index_str, "Tuple");
)*
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
($(
{
let index_str = $i.to_string();
mapper.enter_module(&index_str, "Tuple");
let mapped = self.$i.map(mapper);
mapper.exit_module(&index_str, "Tuple");
mapped
}
,)*)
}
fn load_record(self, record: Self::Record) -> Self {
($(self.$i.load_record(record.$i),)*)
}
fn into_record(self) -> Self::Record {
($(self.$i.into_record(),)*)
}
}
impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
where
B: AutodiffBackend,
$($l: AutodiffModule<B> + Debug + Send + Clone,)*
{
type InnerModule = ($($l::InnerModule,)*);
fn valid(&self) -> Self::InnerModule {
($(self.$i.valid(),)*)
}
}
impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
where
$($l: ModuleDisplay,)*
{
fn content(&self, content: Content) -> Option<Content> {
let content = content
$(.add(&format!("{}", $i), &self.$i))*
.set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
content.optional()
}
}
impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
};
}
impl_module_tuple!([L0, L1][0, 1]);
impl_module_tuple!([L0, L1, L2][0, 1, 2]);
impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn dont_override_constant_module_when_loading_record() {
let module = Some(42);
let record = Module::<TestBackend>::into_record(module);
let loaded = Module::<TestBackend>::load_record(module, record);
assert_eq!(loaded, module);
}
#[test]
fn dont_override_constant_module_when_loading_none_record() {
let module = Some(42);
let record = None;
let loaded = Module::<TestBackend>::load_record(module, record);
assert_eq!(loaded, module);
}
}