1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
use alloc::vec::Vec;
use super::ParamId;
use crate::{
    record::Record,
    tensor::backend::{ADBackend, Backend},
};
pub use burn_derive::Module;
use burn_tensor::Tensor;
// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
// We may consider making it public in the future.
macro_rules! module {
    (map=$module:ident, ops=$item:expr) => {{
        struct Mapper;
        impl<B: Backend> ModuleMapper<B> for Mapper {
            fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
                let func = $item;
                func(tensor)
            }
        }
        let mut mapper = Mapper;
        $module.map(&mut mapper)
    }};
    (map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{
        struct Mapper<'a, B: Backend> {
            capture: &'a $ty,
            backend: core::marker::PhantomData<B>,
        }
        impl<'a, B: Backend> ModuleMapper<B> for Mapper<'a, B> {
            fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
                let func = $item;
                func(tensor, self.capture)
            }
        }
        let mut mapper = Mapper {
            capture: $capture,
            backend: core::marker::PhantomData::default(),
        };
        $module.map(&mut mapper)
    }};
    (visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
        struct Visitor<'a, B: Backend> {
            state: &'a mut $state_ty,
            backend: core::marker::PhantomData<B>,
        }
        impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
            fn visit<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
                let func = $item;
                func(tensor, &mut self.state)
            }
        }
        let mut state = $init();
        let mut visitor = Visitor {
            state: &mut state,
            backend: core::marker::PhantomData::default(),
        };
        $module.visit(&mut visitor);
        state
    }};
}
/// Trait for all neural network modules.
///
/// Modules should be created using the [derive](burn_derive::Module) attribute.
/// This will make your module trainable, savable and loadable via
/// [state](Module::state) and [load](Module::load).
///
/// # Example
///
/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
/// necessary to optimize and train the module on any backend.
///
/// ```rust
/// // Not necessary when using the burn crate directly.
/// use burn_core as burn;
///
/// use burn::{
///     nn,
///     module::Module,
///     tensor::Tensor,
///     tensor::backend::Backend,
/// };
///
/// #[derive(Module, Debug)]
/// struct MyModule<B: Backend> {
///   my_param: nn::Linear<B>,
///   my_other_field: usize,
/// }
/// ```
pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
    /// Type to save and load the module.
    type Record: Record;
    /// Get the device list of the module and all of its sub-modules.
    fn devices(&self) -> Vec<B::Device> {
        module!(
            visit = self,
            ops = |tensor: &Tensor<B, D>, state: &mut Vec<B::Device>| {
                let device = tensor.device();
                if !state.contains(&device) {
                    state.push(device);
                }
            },
            state = Vec<B::Device>,
            init = Vec::new
        )
    }
    /// Fork the module and all of its sub-modules to the given device.
    ///
    /// # Notes
    ///
    /// This is similar to [to_device](Module::to_device), but it ensures the module will
    /// have its own autodiff graph.
    fn fork(self, device: &B::Device) -> Self {
        module!(
            map = self,
            ops = |tensor: Tensor<B, D>, device: &B::Device| {
                let is_require_grad = tensor.is_require_grad();
                let mut tensor = tensor.to_device(device).detach();
                if is_require_grad {
                    tensor = tensor.require_grad();
                }
                tensor
            },
            capture = { device: B::Device }
        )
    }
    /// Move the module and all of its sub-modules to the given device.
    ///
    /// # Warnings
    ///
    /// The device operations will be registered in the autodiff graph. Therefore, be sure to call
    /// backward only one time even if you have the same module on multiple devices. If you want to
    /// call backward multiple times, look into using [fork](Module::fork) instead.
    fn to_device(self, device: &B::Device) -> Self {
        module!(
            map = self,
            ops = |tensor: Tensor<B, D>, device: &B::Device| tensor.to_device(device),
            capture = { device: B::Device }
        )
    }
    /// Each tensor in the module tree will not require grad.
    ///
    /// # Warnings
    ///
    /// This should not be used for inference, use [valid](ADModule::valid) when using
    /// AD modules. This is mostly useful when performing partial finetuning, which is updating only
    /// a small fraction of the parameters instead of finetuning all of them.
    fn no_grad(self) -> Self {
        module!(
            map = self,
            ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
        )
    }
    /// Get the number of parameters the module has, including all of its sub-modules.
    fn num_params(&self) -> usize {
        module!(
            visit = self,
            ops = |tensor: &Tensor<B, D>, state: &mut usize| {
                *state += tensor.shape().num_elements();
            },
            state = usize,
            init = || 0
        )
    }
    /// Visit each tensor in the module with a [visitor](ModuleVisitor).
    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
    /// Map each tensor in the module with a [mapper](ModuleMapper).
    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
    /// Load the module state from a record.
    fn load_record(self, record: Self::Record) -> Self;
    /// Convert the module into a record containing the state.
    fn into_record(self) -> Self::Record;
}
pub trait ModuleVisitor<B: Backend> {
    fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
}
pub trait ModuleMapper<B: Backend> {
    fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
}
/// Module with auto-differentiation backend.
pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
    type InnerModule: Module<B::InnerBackend>;
    /// Get the same module, but on the inner backend without auto-differentiation.
    fn valid(&self) -> Self::InnerModule;
}