burn_train/lib.rs
1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A library for training neural networks using the burn crate.
5
6#[macro_use]
7extern crate derive_new;
8
9/// The checkpoint module.
10pub mod checkpoint;
11
12pub(crate) mod components;
13
14/// Renderer modules to display metrics and training information.
15pub mod renderer;
16
17/// The logger module.
18pub mod logger;
19
20/// The metric module.
21pub mod metric;
22
23pub use metric::processor::*;
24
25mod learner;
26
27pub use learner::*;
28
29mod evaluator;
30
31pub use evaluator::*;
32
33pub use components::*;
34
35#[cfg(test)]
36pub(crate) type TestBackend = burn_flex::Flex;
37
38#[cfg(test)]
39pub(crate) mod tests {
40 use crate::TestBackend;
41 use burn_core::{prelude::Tensor, tensor::Bool};
42 use std::default::Default;
43
44 pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
45
46 /// Probability of tp before adding errors
47 pub const THRESHOLD: f64 = 0.5;
48
49 #[derive(Debug, Default)]
50 pub enum ClassificationType {
51 #[default]
52 Binary,
53 Multiclass,
54 Multilabel,
55 }
56
57 /// Sample x Class shaped matrix for use in
58 /// classification metrics testing
59 pub fn dummy_classification_input(
60 classification_type: &ClassificationType,
61 ) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
62 match classification_type {
63 ClassificationType::Binary => {
64 (
65 Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
66 // targets
67 Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
68 // predictions @ threshold=0.5
69 // [[0], [0], [1], [0], [1]]
70 )
71 }
72 ClassificationType::Multiclass => {
73 (
74 Tensor::from_data(
75 [
76 [0.2, 0.8, 0.0],
77 [0.3, 0.6, 0.1],
78 [0.7, 0.25, 0.05],
79 [0.1, 0.15, 0.8],
80 [0.9, 0.03, 0.07],
81 ],
82 &Default::default(),
83 ),
84 Tensor::from_data(
85 // targets
86 [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
87 // predictions @ top_k=1
88 // [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
89 // predictions @ top_k=2
90 // [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]]
91 &Default::default(),
92 ),
93 )
94 }
95 ClassificationType::Multilabel => {
96 (
97 Tensor::from_data(
98 [
99 [0.1, 0.7, 0.6],
100 [0.3, 0.9, 0.05],
101 [0.8, 0.9, 0.4],
102 [0.7, 0.5, 0.9],
103 [1.0, 0.3, 0.2],
104 ],
105 &Default::default(),
106 ),
107 // targets
108 Tensor::from_data(
109 [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
110 // predictions @ threshold=0.5
111 // [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
112 &Default::default(),
113 ),
114 )
115 }
116 }
117 }
118}