use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor};
use alloc::vec::Vec;
use burn_tensor::backend::{ADBackend, Backend};
use core::fmt::Debug;
impl<T, B> Module<B> for Option<T>
where
    T: Module<B> + Debug + Send + Sync + Clone,
    B: Backend,
{
    type Record = Option<T::Record>;
    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
        if let Some(module) = self {
            module.visit(visitor)
        }
    }
    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
        self.map(|module| module.map(mapper))
    }
    fn load_record(self, record: Self::Record) -> Self {
        self.zip(record)
            .map(|(module, record)| module.load_record(record))
    }
    fn into_record(self) -> Self::Record {
        self.map(Module::into_record)
    }
}
impl<T, B> ADModule<B> for Option<T>
where
    T: ADModule<B> + Debug + Send + Sync + Clone,
    B: ADBackend,
{
    type InnerModule = Option<T::InnerModule>;
    fn valid(&self) -> Self::InnerModule {
        self.as_ref().map(|module| module.valid())
    }
}
impl<T, B> Module<B> for Vec<T>
where
    T: Module<B> + Debug + Send + Sync + Clone,
    B: Backend,
{
    type Record = Vec<T::Record>;
    fn num_params(&self) -> usize {
        let mut num_params = 0;
        for module in self.iter() {
            num_params += module.num_params();
        }
        num_params
    }
    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
        self.iter().for_each(|module| {
            module.visit(visitor);
        });
    }
    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
        self.into_iter().map(|module| module.map(mapper)).collect()
    }
    fn into_record(self) -> Self::Record {
        self.into_iter().map(Module::into_record).collect()
    }
    fn load_record(self, record: Self::Record) -> Self {
        self.into_iter()
            .zip(record.into_iter())
            .map(|(module, record)| module.load_record(record))
            .collect()
    }
}
impl<T, B> ADModule<B> for Vec<T>
where
    T: ADModule<B> + Debug + Send + Sync + Clone,
    B: ADBackend,
{
    type InnerModule = Vec<T::InnerModule>;
    fn valid(&self) -> Self::InnerModule {
        self.iter().map(|module| module.valid()).collect()
    }
}
impl<const N: usize, T, B> Module<B> for [T; N]
where
    T: Module<B> + Debug + Send + Sync + Clone + Copy,
    T::Record: Debug,
    B: Backend,
{
    type Record = [T::Record; N];
    fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
        let mut devices = Vec::new();
        for module in self.iter() {
            devices.append(&mut module.devices());
        }
        devices
    }
    fn num_params(&self) -> usize {
        let mut num_params = 0;
        for module in self.iter() {
            num_params += module.num_params();
        }
        num_params
    }
    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
        self.iter().for_each(|module| {
            module.visit(visitor);
        });
    }
    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
        self.map(|module| module.map(mapper))
    }
    fn load_record(self, record: Self::Record) -> Self {
        self.into_iter()
            .zip(record)
            .map(|(module, record)| module.load_record(record))
            .collect::<Vec<_>>()
            .try_into()
            .unwrap()
    }
    fn into_record(self) -> Self::Record {
        self.map(Module::into_record)
    }
}
impl<const N: usize, T, B> ADModule<B> for [T; N]
where
    T: ADModule<B> + Debug + Send + Sync + Clone + Copy,
    T::InnerModule: Copy + Debug,
    <T::InnerModule as Module<B::InnerBackend>>::Record: Debug,
    <T as Module<B>>::Record: Debug,
    B: ADBackend,
{
    type InnerModule = [T::InnerModule; N];
    fn valid(&self) -> Self::InnerModule {
        self.map(|module| module.valid())
    }
}