1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
use super::{ParamId, Quantizer};
use crate::{
record::Record,
tensor::backend::{AutodiffBackend, Backend},
};
use alloc::vec::Vec;
pub use burn_derive::Module;
use burn_tensor::{quantization::Calibration, Bool, Int, Tensor};
/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
/// the `alloc` crate.
pub type Devices<B> = Vec<<B as Backend>::Device>;
// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
// We may consider making it public in the future.
macro_rules! module {
(map=$module:ident, ops=$item:expr) => {{
struct Mapper;
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map_float<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D>,
) -> Tensor<B, D> {
let func = $item;
func(tensor)
}
}
let mut mapper = Mapper;
$module.map(&mut mapper)
}};
(visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
struct Visitor<'a, B: Backend> {
state: &'a mut $state_ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit_float<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
let func = $item;
func(tensor, &mut self.state)
}
}
#[allow(clippy::redundant_closure_call)]
let mut state = $init();
let mut visitor = Visitor {
state: &mut state,
backend: core::marker::PhantomData,
};
$module.visit(&mut visitor);
state
}};
}
/// Trait for all neural network modules.
///
/// Modules should be created using the [derive](burn_derive::Module) attribute.
/// This will make your module trainable, savable and loadable via
/// `state` and `load`.
///
/// # Example
///
/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
/// necessary to optimize and train the module on any backend.
///
/// ```no_run
/// // Not necessary when using the burn crate directly.
/// use burn_core as burn;
///
/// use burn::{
/// nn,
/// module::Module,
/// tensor::Tensor,
/// tensor::backend::Backend,
/// };
///
/// #[derive(Module, Debug)]
/// struct MyModule<B: Backend> {
/// my_param: nn::Linear<B>,
/// my_other_field: usize,
/// }
/// ```
pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
/// Type to save and load the module.
type Record: Record<B>;
/// Return all the devices found in the underneath module tree added to the given vector
/// without duplicates.
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
/// Return all the devices found in the underneath module tree without duplicates.
fn devices(&self) -> Devices<B> {
self.collect_devices(Devices::<B>::new())
}
/// Fork the module and all of its sub-modules to the given device.
///
/// # Notes
///
/// This is similar to [to_device](Module::to_device), but it ensures the output module on the
/// new device will have its own autodiff graph.
fn fork(self, device: &B::Device) -> Self;
/// Move the module and all of its sub-modules to the given device.
///
/// # Warnings
///
/// The operation supports autodiff and it will be registered when activated. However, this may
/// not be what you want. The output model will be an intermediary model, meaning that you
/// can't optimize it with gradient descent. If you want to optimize the output network on the
/// target device, use [fork](Module::fork) instead.
fn to_device(self, device: &B::Device) -> Self;
/// Each tensor in the module tree will not require grad.
///
/// # Warnings
///
/// This should not be used for inference, use [valid](AutodiffModule::valid) when using
/// AD modules. This is mostly useful when performing partial finetuning, which is updating only
/// a small fraction of the parameters instead of finetuning all of them.
fn no_grad(self) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
)
}
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize {
module!(
visit_float = self,
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
*state += tensor.shape().num_elements();
},
state = usize,
init = || 0
)
}
/// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
/// Map each tensor parameter in the module with a [mapper](ModuleMapper).
fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
/// Load the module state from a record.
fn load_record(self, record: Self::Record) -> Self;
/// Convert the module into a record containing the state.
fn into_record(self) -> Self::Record;
#[cfg(feature = "std")]
/// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
///
/// List of supported file recorders:
///
/// * [default](crate::record::DefaultFileRecorder)
/// * [bincode](crate::record::BinFileRecorder)
/// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
/// * [json pretty](crate::record::PrettyJsonFileRecorder)
/// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
/// * [named mpk](crate::record::NamedMpkFileRecorder)
/// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
///
/// ## Notes
///
/// The file extension is automatically added depending on the file recorder provided, you
/// don't have to specify it.
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), crate::record::RecorderError>
where
FR: crate::record::FileRecorder<B>,
PB: Into<std::path::PathBuf>,
{
let record = Self::into_record(self);
recorder.record(record, file_path.into())
}
#[cfg(feature = "std")]
/// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
///
/// The recorder should be the same as the one used to save the module, see
/// [save_file](Self::save_file).
///
/// ## Notes
///
/// The file extension is automatically added depending on the file recorder provided, you
/// don't have to specify it.
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &B::Device,
) -> Result<Self, crate::record::RecorderError>
where
FR: crate::record::FileRecorder<B>,
PB: Into<std::path::PathBuf>,
{
let record = recorder.load(file_path.into(), device)?;
Ok(self.load_record(record))
}
/// Quantize the weights of the module.
fn quantize_weights<C: Calibration>(self, quantizer: &mut Quantizer<C>) -> Self {
self.map(quantizer)
}
}
/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a float tensor in the module.
fn visit_float<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D>) {}
/// Visit an int tensor in the module.
fn visit_int<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Int>) {}
/// Visit a bool tensor in the module.
fn visit_bool<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Bool>) {}
}
/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a float tensor in the module.
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor
}
/// Map an int tensor in the module.
fn map_int<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D, Int>,
) -> Tensor<B, D, Int> {
tensor
}
/// Map a bool tensor in the module.
fn map_bool<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D, Bool>,
) -> Tensor<B, D, Bool> {
tensor
}
}
/// Module with auto-differentiation backend.
pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
/// Inner module without auto-differentiation.
type InnerModule: Module<B::InnerBackend>;
/// Get the same module, but on the inner backend without auto-differentiation.
fn valid(&self) -> Self::InnerModule;
}