1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "135"]
5
6#[macro_use]
9extern crate derive_new;
10
11pub use serde;
13
14pub mod config;
16
17#[cfg(feature = "std")]
19pub mod data;
20
21pub mod module;
23
24pub mod record;
26
27pub mod tensor;
29pub use tensor::Tensor;
31
32#[cfg(feature = "vision")]
34pub mod vision;
35
36extern crate alloc;
37
38#[cfg(all(
40 test,
41 not(feature = "test-tch"),
42 not(feature = "test-wgpu"),
43 not(feature = "test-cuda"),
44 not(feature = "test-rocm")
45))]
46pub type TestBackend = burn_ndarray::NdArray<f32>;
47
48#[cfg(all(test, feature = "test-tch"))]
49pub type TestBackend = burn_tch::LibTorch<f32>;
51
52#[cfg(all(test, feature = "test-wgpu"))]
53pub type TestBackend = burn_wgpu::Wgpu;
55
56#[cfg(all(test, feature = "test-cuda"))]
57pub type TestBackend = burn_cuda::Cuda;
59
60#[cfg(all(test, feature = "test-rocm"))]
61pub type TestBackend = burn_rocm::Rocm;
63
64#[cfg(test)]
66pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
67
68#[cfg(all(test, feature = "test-memory-checks"))]
69mod tests {
70 burn_fusion::memory_checks!();
71}
72
73#[cfg(test)]
74mod test_utils {
75 use crate as burn;
76 use crate::module::Module;
77 use crate::module::Param;
78 use burn_tensor::Tensor;
79 use burn_tensor::backend::Backend;
80
81 #[derive(Module, Debug)]
83 pub struct SimpleLinear<B: Backend> {
84 pub weight: Param<Tensor<B, 2>>,
85 pub bias: Option<Param<Tensor<B, 1>>>,
86 }
87
88 impl<B: Backend> SimpleLinear<B> {
89 pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
90 let weight = Tensor::random(
91 [out_features, in_features],
92 burn_tensor::Distribution::Default,
93 device,
94 );
95 let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
96
97 Self {
98 weight: Param::from_tensor(weight),
99 bias: Some(Param::from_tensor(bias)),
100 }
101 }
102 }
103}
104
105pub mod prelude {
106 pub use crate::{
110 config::Config,
111 module::Module,
112 tensor::{
113 Bool, Device, ElementConversion, Float, Int, Shape, SliceArg, Tensor, TensorData,
114 backend::Backend, cast::ToElement, s,
115 },
116 };
117 pub use burn_common::device::Device as DeviceOps;
118}