burn_train/lib.rs
1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A library for training neural networks using the burn crate.
5
6#[macro_use]
7extern crate derive_new;
8
9/// The checkpoint module.
10pub mod checkpoint;
11
12pub(crate) mod components;
13
14/// Renderer modules to display metrics and training information.
15pub mod renderer;
16
17/// The logger module.
18pub mod logger;
19
20/// The metric module.
21pub mod metric;
22
23pub use metric::processor::*;
24
25mod learner;
26
27pub use learner::*;
28
29mod evaluator;
30
31pub use evaluator::*;
32
33pub use components::LearnerComponentTypes;
34
35#[cfg(test)]
36pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
37
38#[cfg(test)]
39pub(crate) mod tests {
40 use crate::TestBackend;
41 use burn_core::{prelude::Tensor, tensor::Bool};
42 use std::default::Default;
43
44 /// Probability of tp before adding errors
45 pub const THRESHOLD: f64 = 0.5;
46
47 #[derive(Debug, Default)]
48 pub enum ClassificationType {
49 #[default]
50 Binary,
51 Multiclass,
52 Multilabel,
53 }
54
55 /// Sample x Class shaped matrix for use in
56 /// classification metrics testing
57 pub fn dummy_classification_input(
58 classification_type: &ClassificationType,
59 ) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
60 match classification_type {
61 ClassificationType::Binary => {
62 (
63 Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
64 // targets
65 Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
66 // predictions @ threshold=0.5
67 // [[0], [0], [1], [0], [1]]
68 )
69 }
70 ClassificationType::Multiclass => {
71 (
72 Tensor::from_data(
73 [
74 [0.2, 0.8, 0.0],
75 [0.3, 0.6, 0.1],
76 [0.7, 0.25, 0.05],
77 [0.1, 0.15, 0.8],
78 [0.9, 0.03, 0.07],
79 ],
80 &Default::default(),
81 ),
82 Tensor::from_data(
83 // targets
84 [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
85 // predictions @ top_k=1
86 // [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
87 // predictions @ top_k=2
88 // [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]]
89 &Default::default(),
90 ),
91 )
92 }
93 ClassificationType::Multilabel => {
94 (
95 Tensor::from_data(
96 [
97 [0.1, 0.7, 0.6],
98 [0.3, 0.9, 0.05],
99 [0.8, 0.9, 0.4],
100 [0.7, 0.5, 0.9],
101 [1.0, 0.3, 0.2],
102 ],
103 &Default::default(),
104 ),
105 // targets
106 Tensor::from_data(
107 [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
108 // predictions @ threshold=0.5
109 // [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
110 &Default::default(),
111 ),
112 )
113 }
114 }
115 }
116}