burn_core/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4#![recursion_limit = "135"]
5
6//! The core crate of Burn.
7
8#[macro_use]
9extern crate derive_new;
10
11/// Re-export serde for proc macros.
12pub use serde;
13
14/// The configuration module.
15pub mod config;
16
17/// Data module.
18#[cfg(feature = "std")]
19pub mod data;
20
21/// Optimizer module.
22pub mod optim;
23
24/// Learning rate scheduler module.
25#[cfg(feature = "std")]
26pub mod lr_scheduler;
27
28/// Gradient clipping module.
29pub mod grad_clipping;
30
31/// Module for the neural network module.
32pub mod module;
33
34/// Neural network module.
35pub mod nn;
36
37/// Module for the recorder.
38pub mod record;
39
40/// Module for the tensor.
41pub mod tensor;
42
43extern crate alloc;
44
45/// Backend for test cases
46#[cfg(all(
47    test,
48    not(feature = "test-tch"),
49    not(feature = "test-wgpu"),
50    not(feature = "test-cuda")
51))]
52pub type TestBackend = burn_ndarray::NdArray<f32>;
53
54#[cfg(all(test, feature = "test-tch"))]
55/// Backend for test cases
56pub type TestBackend = burn_tch::LibTorch<f32>;
57
58#[cfg(all(test, feature = "test-wgpu"))]
59/// Backend for test cases
60pub type TestBackend = burn_wgpu::Wgpu;
61
62#[cfg(all(test, feature = "test-cuda"))]
63/// Backend for test cases
64pub type TestBackend = burn_cuda::Cuda;
65
66/// Backend for autodiff test cases
67#[cfg(test)]
68pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
69
70#[cfg(all(test, feature = "test-memory-checks"))]
71mod tests {
72    burn_fusion::memory_checks!();
73}
74
75/// Type alias for the learning rate.
76///
77/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
78/// can be used for constant learning rate.
79pub type LearningRate = f64; // We could potentially change the type.
80
81pub mod prelude {
82    //! Structs and macros used by most projects. Add `use
83    //! burn::prelude::*` to your code to quickly get started with
84    //! Burn.
85    pub use crate::{
86        config::Config,
87        module::Module,
88        nn,
89        tensor::{
90            Bool, Device, ElementConversion, Float, Int, Shape, Tensor, TensorData,
91            backend::Backend, s,
92        },
93    };
94}