Skip to main content

burn_derive/
lib.rs

1#![warn(missing_docs)]
2
3//! The derive crate of Burn.
4
5#[macro_use]
6extern crate derive_new;
7
8use proc_macro::TokenStream;
9
10pub(crate) mod config;
11pub(crate) mod module;
12pub(crate) mod record;
13pub(crate) mod shared;
14
15/// Derive macro for the `Module` trait.
16///
17/// # Sub-modules
18///
19/// By default, the macro automatically detects sub-modules and parameters as module types.
20///
21/// Any field not recognized as a module type is assumed to be a non-module
22/// and is skipped by the module system (not persistent, not visited).
23///
24/// ## Generics
25///
26/// Generic type parameters (e.g., `field: M`) are assumed to be sub-modules by default.
27/// If a generic field represents some other runtime state or configuration, you can use
28/// the `#[module(skip)]` attribute to provide a hint.
29///
30/// # Field Attributes
31///
32/// ## `#[module(skip)]`
33///
34/// Explicitly marks a field to be ignored by the module derive.
35///
36/// Skipped fields are not parameters, not modules, and are not persistent.
37/// This is equivalent to the deprecated `Ignored<T>` wrapper.
38///
39/// ### Requirements
40///
41/// The field must implement: `Debug + Clone + Send`.
42///
43/// # Example
44///
45/// ```ignore
46/// #[derive(Module, Debug)]
47/// pub struct MyModule<B: Backend, M, N: NonModuleTrait> {
48///     /// A normal parameter.
49///     weights: Param<Tensor<B, 2>>,
50///     /// A field configured at runtime.
51///     dropout_prob: f64,
52///     /// A field that is recomputed at runtime.
53///     cached_mask: Option<Tensor<B, 2>>,
54///     /// A field that contains some debug state.
55///     debug_state: String,
56///     /// Treated as a module (default for generics).
57///     inner: M,
58///     /// Hint required: this generic is NOT a module.
59///     #[module(skip)]
60///     other: N,
61/// }
62/// ```
63#[proc_macro_derive(Module, attributes(module))]
64pub fn module_derive(input: TokenStream) -> TokenStream {
65    let input = syn::parse(input).unwrap();
66    module::derive_impl(&input)
67}
68
69/// Derive macro for the record.
70#[proc_macro_derive(Record)]
71pub fn record_derive(input: TokenStream) -> TokenStream {
72    let input = syn::parse(input).unwrap();
73    record::derive_impl(&input)
74}
75
76/// Derive macro for the config.
77#[proc_macro_derive(Config, attributes(config))]
78pub fn config_derive(input: TokenStream) -> TokenStream {
79    let item = syn::parse(input).unwrap();
80    config::derive_impl(&item)
81}