Skip to main content

burn_nn/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "256"]
5
6//! Burn neural network module.
7
8/// Loss module
9pub mod loss;
10
11/// Neural network modules implementations.
12pub mod modules;
13pub use modules::*;
14
15pub mod activation;
16pub use activation::{
17    gelu::*, glu::*, hard_sigmoid::*, leaky_relu::*, prelu::*, relu::*, sigmoid::*, softplus::*,
18    swiglu::*, tanh::*,
19};
20
21mod padding;
22pub use padding::*;
23
24// For backward compat, `burn::nn::Initializer`
25pub use burn_core::module::Initializer;
26
27extern crate alloc;
28
29/// Backend for test cases
30#[cfg(all(
31    test,
32    not(feature = "test-tch"),
33    not(feature = "test-wgpu"),
34    not(feature = "test-cuda"),
35    not(feature = "test-rocm")
36))]
37pub type TestBackend = burn_ndarray::NdArray<f32>;
38
39#[cfg(all(test, feature = "test-tch"))]
40/// Backend for test cases
41pub type TestBackend = burn_tch::LibTorch<f32>;
42
43#[cfg(all(test, feature = "test-wgpu"))]
44/// Backend for test cases
45pub type TestBackend = burn_wgpu::Wgpu;
46
47#[cfg(all(test, feature = "test-cuda"))]
48/// Backend for test cases
49pub type TestBackend = burn_cuda::Cuda;
50
51#[cfg(all(test, feature = "test-rocm"))]
52/// Backend for test cases
53pub type TestBackend = burn_rocm::Rocm;
54
55/// Backend for autodiff test cases
56#[cfg(test)]
57pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
58
59#[cfg(all(test, feature = "test-memory-checks"))]
60mod tests {
61    burn_fusion::memory_checks!();
62}