burn_train/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_auto_cfg))]
3
4//! A library for training neural networks using the burn crate.
5
6#[macro_use]
7extern crate derive_new;
8
9/// The checkpoint module.
10pub mod checkpoint;
11
12pub(crate) mod components;
13
14/// Renderer modules to display metrics and training information.
15pub mod renderer;
16
17/// The logger module.
18pub mod logger;
19
20/// The metric module.
21pub mod metric;
22
23mod learner;
24
25pub use learner::*;
26
27#[cfg(test)]
28pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
29
30#[cfg(test)]
31pub(crate) mod tests {
32    use crate::TestBackend;
33    use burn_core::{prelude::Tensor, tensor::Bool};
34    use std::default::Default;
35
36    /// Probability of tp before adding errors
37    pub const THRESHOLD: f64 = 0.5;
38
39    #[derive(Debug, Default)]
40    pub enum ClassificationType {
41        #[default]
42        Binary,
43        Multiclass,
44        Multilabel,
45    }
46
47    /// Sample x Class shaped matrix for use in
48    /// classification metrics testing
49    pub fn dummy_classification_input(
50        classification_type: &ClassificationType,
51    ) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
52        match classification_type {
53            ClassificationType::Binary => {
54                (
55                    Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
56                    // targets
57                    Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
58                    // predictions @ threshold=0.5
59                    //                     [[0], [0], [1], [0], [1]]
60                )
61            }
62            ClassificationType::Multiclass => {
63                (
64                    Tensor::from_data(
65                        [
66                            [0.2, 0.8, 0.0],
67                            [0.3, 0.6, 0.1],
68                            [0.7, 0.25, 0.05],
69                            [0.1, 0.15, 0.8],
70                            [0.9, 0.03, 0.07],
71                        ],
72                        &Default::default(),
73                    ),
74                    Tensor::from_data(
75                        // targets
76                        [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
77                        // predictions @ top_k=1
78                        //   [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0,  0]]
79                        // predictions @ top_k=2
80                        //   [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0,  1]]
81                        &Default::default(),
82                    ),
83                )
84            }
85            ClassificationType::Multilabel => {
86                (
87                    Tensor::from_data(
88                        [
89                            [0.1, 0.7, 0.6],
90                            [0.3, 0.9, 0.05],
91                            [0.8, 0.9, 0.4],
92                            [0.7, 0.5, 0.9],
93                            [1.0, 0.3, 0.2],
94                        ],
95                        &Default::default(),
96                    ),
97                    // targets
98                    Tensor::from_data(
99                        [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
100                        // predictions @ threshold=0.5
101                        //   [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
102                        &Default::default(),
103                    ),
104                )
105            }
106        }
107    }
108}