burn_train/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A library for training neural networks using the burn crate.
5
6#[macro_use]
7extern crate derive_new;
8
9/// The checkpoint module.
10pub mod checkpoint;
11
12pub(crate) mod components;
13
14/// Renderer modules to display metrics and training information.
15pub mod renderer;
16
17/// The logger module.
18pub mod logger;
19
20/// The metric module.
21pub mod metric;
22
23pub use metric::processor::*;
24
25mod learner;
26
27pub use learner::*;
28
29mod evaluator;
30
31pub use evaluator::*;
32
33pub use components::LearnerComponentTypes;
34
35#[cfg(test)]
36pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
37
38#[cfg(test)]
39pub(crate) mod tests {
40    use crate::TestBackend;
41    use burn_core::{prelude::Tensor, tensor::Bool};
42    use std::default::Default;
43
44    /// Probability of tp before adding errors
45    pub const THRESHOLD: f64 = 0.5;
46
47    #[derive(Debug, Default)]
48    pub enum ClassificationType {
49        #[default]
50        Binary,
51        Multiclass,
52        Multilabel,
53    }
54
55    /// Sample x Class shaped matrix for use in
56    /// classification metrics testing
57    pub fn dummy_classification_input(
58        classification_type: &ClassificationType,
59    ) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
60        match classification_type {
61            ClassificationType::Binary => {
62                (
63                    Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
64                    // targets
65                    Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
66                    // predictions @ threshold=0.5
67                    //                     [[0], [0], [1], [0], [1]]
68                )
69            }
70            ClassificationType::Multiclass => {
71                (
72                    Tensor::from_data(
73                        [
74                            [0.2, 0.8, 0.0],
75                            [0.3, 0.6, 0.1],
76                            [0.7, 0.25, 0.05],
77                            [0.1, 0.15, 0.8],
78                            [0.9, 0.03, 0.07],
79                        ],
80                        &Default::default(),
81                    ),
82                    Tensor::from_data(
83                        // targets
84                        [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
85                        // predictions @ top_k=1
86                        //   [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0,  0]]
87                        // predictions @ top_k=2
88                        //   [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0,  1]]
89                        &Default::default(),
90                    ),
91                )
92            }
93            ClassificationType::Multilabel => {
94                (
95                    Tensor::from_data(
96                        [
97                            [0.1, 0.7, 0.6],
98                            [0.3, 0.9, 0.05],
99                            [0.8, 0.9, 0.4],
100                            [0.7, 0.5, 0.9],
101                            [1.0, 0.3, 0.2],
102                        ],
103                        &Default::default(),
104                    ),
105                    // targets
106                    Tensor::from_data(
107                        [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
108                        // predictions @ threshold=0.5
109                        //   [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
110                        &Default::default(),
111                    ),
112                )
113            }
114        }
115    }
116}