Skip to main content

burn_core/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "135"]
5
6//! The core crate of Burn.
7
8#[macro_use]
9extern crate derive_new;
10
11/// Re-export serde for proc macros.
12pub use serde;
13
14/// The configuration module.
15pub mod config;
16
17/// Data module.
18#[cfg(feature = "std")]
19pub mod data;
20
21/// Module for the neural network module.
22pub mod module;
23
24/// Module for the recorder.
25pub mod record;
26
27/// Module for the tensor.
28pub mod tensor;
29// Tensor at root: `burn::Tensor`
30pub use tensor::Tensor;
31
32extern crate alloc;
33
34/// Backend for test cases
35#[cfg(all(
36    test,
37    not(feature = "test-tch"),
38    not(feature = "test-wgpu"),
39    not(feature = "test-cuda"),
40    not(feature = "test-rocm")
41))]
42pub type TestBackend = burn_flex::Flex;
43
44#[cfg(all(test, feature = "test-tch"))]
45/// Backend for test cases
46pub type TestBackend = burn_tch::LibTorch<f32>;
47
48#[cfg(all(test, feature = "test-wgpu"))]
49/// Backend for test cases
50pub type TestBackend = burn_wgpu::Wgpu;
51
52#[cfg(all(test, feature = "test-cuda"))]
53/// Backend for test cases
54pub type TestBackend = burn_cuda::Cuda;
55
56#[cfg(all(test, feature = "test-rocm"))]
57/// Backend for test cases
58pub type TestBackend = burn_rocm::Rocm;
59
60/// Backend for autodiff test cases
61#[cfg(test)]
62pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
63
64#[cfg(all(test, feature = "test-memory-checks"))]
65mod tests {
66    burn_fusion::memory_checks!();
67}
68
69#[cfg(test)]
70mod test_utils {
71    use crate as burn;
72    use crate::module::Module;
73    use crate::module::Param;
74    use burn_tensor::Tensor;
75    use burn_tensor::backend::Backend;
76
77    /// Simple linear module.
78    #[derive(Module, Debug)]
79    pub struct SimpleLinear<B: Backend> {
80        pub weight: Param<Tensor<B, 2>>,
81        pub bias: Option<Param<Tensor<B, 1>>>,
82    }
83
84    impl<B: Backend> SimpleLinear<B> {
85        pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
86            let weight = Tensor::random(
87                [out_features, in_features],
88                burn_tensor::Distribution::Default,
89                device,
90            );
91            let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
92
93            Self {
94                weight: Param::from_tensor(weight),
95                bias: Some(Param::from_tensor(bias)),
96            }
97        }
98    }
99}
100
101pub mod prelude {
102    //! Structs and macros used by most projects. Add `use
103    //! burn::prelude::*` to your code to quickly get started with
104    //! Burn.
105    pub use crate::{
106        config::Config,
107        module::Module,
108        tensor::{
109            Bool, Device, ElementConversion, Float, Int, Shape, SliceArg, Tensor, TensorData,
110            backend::Backend, cast::ToElement, s,
111        },
112    };
113    pub use burn_std::device::Device as DeviceOps;
114}