Skip to main content

burn_train/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A library for training neural networks using the burn crate.
5
6#[macro_use]
7extern crate derive_new;
8
9/// The checkpoint module.
10pub mod checkpoint;
11
12pub(crate) mod components;
13
14/// Renderer modules to display metrics and training information.
15pub mod renderer;
16
17/// The logger module.
18pub mod logger;
19
20/// The metric module.
21pub mod metric;
22
23pub use metric::processor::*;
24
25mod learner;
26
27pub use learner::*;
28
29mod evaluator;
30
31pub use evaluator::*;
32
33pub use components::*;
34
35#[cfg(test)]
36pub(crate) type TestBackend = burn_flex::Flex;
37
38#[cfg(test)]
39pub(crate) mod tests {
40    use crate::TestBackend;
41    use burn_core::{prelude::Tensor, tensor::Bool};
42    use std::default::Default;
43
44    pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
45
46    /// Probability of tp before adding errors
47    pub const THRESHOLD: f64 = 0.5;
48
49    #[derive(Debug, Default)]
50    pub enum ClassificationType {
51        #[default]
52        Binary,
53        Multiclass,
54        Multilabel,
55    }
56
57    /// Sample x Class shaped matrix for use in
58    /// classification metrics testing
59    pub fn dummy_classification_input(
60        classification_type: &ClassificationType,
61    ) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
62        match classification_type {
63            ClassificationType::Binary => {
64                (
65                    Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
66                    // targets
67                    Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
68                    // predictions @ threshold=0.5
69                    //                     [[0], [0], [1], [0], [1]]
70                )
71            }
72            ClassificationType::Multiclass => {
73                (
74                    Tensor::from_data(
75                        [
76                            [0.2, 0.8, 0.0],
77                            [0.3, 0.6, 0.1],
78                            [0.7, 0.25, 0.05],
79                            [0.1, 0.15, 0.8],
80                            [0.9, 0.03, 0.07],
81                        ],
82                        &Default::default(),
83                    ),
84                    Tensor::from_data(
85                        // targets
86                        [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
87                        // predictions @ top_k=1
88                        //   [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0,  0]]
89                        // predictions @ top_k=2
90                        //   [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0,  1]]
91                        &Default::default(),
92                    ),
93                )
94            }
95            ClassificationType::Multilabel => {
96                (
97                    Tensor::from_data(
98                        [
99                            [0.1, 0.7, 0.6],
100                            [0.3, 0.9, 0.05],
101                            [0.8, 0.9, 0.4],
102                            [0.7, 0.5, 0.9],
103                            [1.0, 0.3, 0.2],
104                        ],
105                        &Default::default(),
106                    ),
107                    // targets
108                    Tensor::from_data(
109                        [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
110                        // predictions @ threshold=0.5
111                        //   [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
112                        &Default::default(),
113                    ),
114                )
115            }
116        }
117    }
118}