burn_core/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "135"]
5
6//! The core crate of Burn.
7
8#[macro_use]
9extern crate derive_new;
10
11/// Re-export serde for proc macros.
12pub use serde;
13
14/// The configuration module.
15pub mod config;
16
17/// Data module.
18#[cfg(feature = "std")]
19pub mod data;
20
21/// Module for the neural network module.
22pub mod module;
23
24/// Module for the recorder.
25pub mod record;
26
27/// Module for the tensor.
28pub mod tensor;
29// Tensor at root: `burn::Tensor`
30pub use tensor::Tensor;
31
32/// Module for visual operations
33#[cfg(feature = "vision")]
34pub mod vision;
35
36extern crate alloc;
37
38/// Backend for test cases
39#[cfg(all(
40    test,
41    not(feature = "test-tch"),
42    not(feature = "test-wgpu"),
43    not(feature = "test-cuda"),
44    not(feature = "test-rocm")
45))]
46pub type TestBackend = burn_ndarray::NdArray<f32>;
47
48#[cfg(all(test, feature = "test-tch"))]
49/// Backend for test cases
50pub type TestBackend = burn_tch::LibTorch<f32>;
51
52#[cfg(all(test, feature = "test-wgpu"))]
53/// Backend for test cases
54pub type TestBackend = burn_wgpu::Wgpu;
55
56#[cfg(all(test, feature = "test-cuda"))]
57/// Backend for test cases
58pub type TestBackend = burn_cuda::Cuda;
59
60#[cfg(all(test, feature = "test-rocm"))]
61/// Backend for test cases
62pub type TestBackend = burn_rocm::Rocm;
63
64/// Backend for autodiff test cases
65#[cfg(test)]
66pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
67
68#[cfg(all(test, feature = "test-memory-checks"))]
69mod tests {
70    burn_fusion::memory_checks!();
71}
72
73#[cfg(test)]
74mod test_utils {
75    use crate as burn;
76    use crate::module::Module;
77    use crate::module::Param;
78    use burn_tensor::Tensor;
79    use burn_tensor::backend::Backend;
80
81    /// Simple linear module.
82    #[derive(Module, Debug)]
83    pub struct SimpleLinear<B: Backend> {
84        pub weight: Param<Tensor<B, 2>>,
85        pub bias: Option<Param<Tensor<B, 1>>>,
86    }
87
88    impl<B: Backend> SimpleLinear<B> {
89        pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
90            let weight = Tensor::random(
91                [out_features, in_features],
92                burn_tensor::Distribution::Default,
93                device,
94            );
95            let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
96
97            Self {
98                weight: Param::from_tensor(weight),
99                bias: Some(Param::from_tensor(bias)),
100            }
101        }
102    }
103}
104
105pub mod prelude {
106    //! Structs and macros used by most projects. Add `use
107    //! burn::prelude::*` to your code to quickly get started with
108    //! Burn.
109    pub use crate::{
110        config::Config,
111        module::Module,
112        tensor::{
113            Bool, Device, ElementConversion, Float, Int, Shape, SliceArg, Tensor, TensorData,
114            backend::Backend, cast::ToElement, s,
115        },
116    };
117    pub use burn_common::device::Device as DeviceOps;
118}