burn_train/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A library for training neural networks using the burn crate.
5
6#[macro_use]
7extern crate derive_new;
8
9/// The checkpoint module.
10pub mod checkpoint;
11
12pub(crate) mod components;
13
14/// Renderer modules to display metrics and training information.
15pub mod renderer;
16
17/// The logger module.
18pub mod logger;
19
20/// The metric module.
21pub mod metric;
22
23mod learner;
24
25pub use learner::*;
26
27mod evaluator;
28
29pub use evaluator::*;
30
31#[cfg(test)]
32pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
33
34#[cfg(test)]
35pub(crate) mod tests {
36    use crate::TestBackend;
37    use burn_core::{prelude::Tensor, tensor::Bool};
38    use std::default::Default;
39
40    /// Probability of tp before adding errors
41    pub const THRESHOLD: f64 = 0.5;
42
43    #[derive(Debug, Default)]
44    pub enum ClassificationType {
45        #[default]
46        Binary,
47        Multiclass,
48        Multilabel,
49    }
50
51    /// Sample x Class shaped matrix for use in
52    /// classification metrics testing
53    pub fn dummy_classification_input(
54        classification_type: &ClassificationType,
55    ) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
56        match classification_type {
57            ClassificationType::Binary => {
58                (
59                    Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
60                    // targets
61                    Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
62                    // predictions @ threshold=0.5
63                    //                     [[0], [0], [1], [0], [1]]
64                )
65            }
66            ClassificationType::Multiclass => {
67                (
68                    Tensor::from_data(
69                        [
70                            [0.2, 0.8, 0.0],
71                            [0.3, 0.6, 0.1],
72                            [0.7, 0.25, 0.05],
73                            [0.1, 0.15, 0.8],
74                            [0.9, 0.03, 0.07],
75                        ],
76                        &Default::default(),
77                    ),
78                    Tensor::from_data(
79                        // targets
80                        [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
81                        // predictions @ top_k=1
82                        //   [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0,  0]]
83                        // predictions @ top_k=2
84                        //   [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0,  1]]
85                        &Default::default(),
86                    ),
87                )
88            }
89            ClassificationType::Multilabel => {
90                (
91                    Tensor::from_data(
92                        [
93                            [0.1, 0.7, 0.6],
94                            [0.3, 0.9, 0.05],
95                            [0.8, 0.9, 0.4],
96                            [0.7, 0.5, 0.9],
97                            [1.0, 0.3, 0.2],
98                        ],
99                        &Default::default(),
100                    ),
101                    // targets
102                    Tensor::from_data(
103                        [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
104                        // predictions @ threshold=0.5
105                        //   [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
106                        &Default::default(),
107                    ),
108                )
109            }
110        }
111    }
112}