1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "135"]
5
6#[macro_use]
9extern crate derive_new;
10
11pub use serde;
13
14pub mod config;
16
17#[cfg(feature = "std")]
19pub mod data;
20
21pub mod module;
23
24pub mod record;
26
27pub mod tensor;
29pub use tensor::Tensor;
31
32extern crate alloc;
33
34#[cfg(all(
36 test,
37 not(feature = "test-tch"),
38 not(feature = "test-wgpu"),
39 not(feature = "test-cuda"),
40 not(feature = "test-rocm")
41))]
42pub type TestBackend = burn_flex::Flex;
43
44#[cfg(all(test, feature = "test-tch"))]
45pub type TestBackend = burn_tch::LibTorch<f32>;
47
48#[cfg(all(test, feature = "test-wgpu"))]
49pub type TestBackend = burn_wgpu::Wgpu;
51
52#[cfg(all(test, feature = "test-cuda"))]
53pub type TestBackend = burn_cuda::Cuda;
55
56#[cfg(all(test, feature = "test-rocm"))]
57pub type TestBackend = burn_rocm::Rocm;
59
60#[cfg(test)]
62pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
63
64#[cfg(all(test, feature = "test-memory-checks"))]
65mod tests {
66 burn_fusion::memory_checks!();
67}
68
69#[cfg(test)]
70mod test_utils {
71 use crate as burn;
72 use crate::module::Module;
73 use crate::module::Param;
74 use burn_tensor::Tensor;
75 use burn_tensor::backend::Backend;
76
77 #[derive(Module, Debug)]
79 pub struct SimpleLinear<B: Backend> {
80 pub weight: Param<Tensor<B, 2>>,
81 pub bias: Option<Param<Tensor<B, 1>>>,
82 }
83
84 impl<B: Backend> SimpleLinear<B> {
85 pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
86 let weight = Tensor::random(
87 [out_features, in_features],
88 burn_tensor::Distribution::Default,
89 device,
90 );
91 let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
92
93 Self {
94 weight: Param::from_tensor(weight),
95 bias: Some(Param::from_tensor(bias)),
96 }
97 }
98 }
99}
100
101pub mod prelude {
102 pub use crate::{
106 config::Config,
107 module::Module,
108 tensor::{
109 Bool, Device, ElementConversion, Float, Int, Shape, SliceArg, Tensor, TensorData,
110 backend::Backend, cast::ToElement, s,
111 },
112 };
113 pub use burn_std::device::Device as DeviceOps;
114}