burn_train/lib.rs
1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A library for training neural networks using the burn crate.
5
6#[macro_use]
7extern crate derive_new;
8
9/// The checkpoint module.
10pub mod checkpoint;
11
12pub(crate) mod components;
13
14/// Renderer modules to display metrics and training information.
15pub mod renderer;
16
17/// The logger module.
18pub mod logger;
19
20/// The metric module.
21pub mod metric;
22
23mod learner;
24
25pub use learner::*;
26
27mod evaluator;
28
29pub use evaluator::*;
30
31#[cfg(test)]
32pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
33
34#[cfg(test)]
35pub(crate) mod tests {
36 use crate::TestBackend;
37 use burn_core::{prelude::Tensor, tensor::Bool};
38 use std::default::Default;
39
40 /// Probability of tp before adding errors
41 pub const THRESHOLD: f64 = 0.5;
42
43 #[derive(Debug, Default)]
44 pub enum ClassificationType {
45 #[default]
46 Binary,
47 Multiclass,
48 Multilabel,
49 }
50
51 /// Sample x Class shaped matrix for use in
52 /// classification metrics testing
53 pub fn dummy_classification_input(
54 classification_type: &ClassificationType,
55 ) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
56 match classification_type {
57 ClassificationType::Binary => {
58 (
59 Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
60 // targets
61 Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
62 // predictions @ threshold=0.5
63 // [[0], [0], [1], [0], [1]]
64 )
65 }
66 ClassificationType::Multiclass => {
67 (
68 Tensor::from_data(
69 [
70 [0.2, 0.8, 0.0],
71 [0.3, 0.6, 0.1],
72 [0.7, 0.25, 0.05],
73 [0.1, 0.15, 0.8],
74 [0.9, 0.03, 0.07],
75 ],
76 &Default::default(),
77 ),
78 Tensor::from_data(
79 // targets
80 [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
81 // predictions @ top_k=1
82 // [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
83 // predictions @ top_k=2
84 // [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]]
85 &Default::default(),
86 ),
87 )
88 }
89 ClassificationType::Multilabel => {
90 (
91 Tensor::from_data(
92 [
93 [0.1, 0.7, 0.6],
94 [0.3, 0.9, 0.05],
95 [0.8, 0.9, 0.4],
96 [0.7, 0.5, 0.9],
97 [1.0, 0.3, 0.2],
98 ],
99 &Default::default(),
100 ),
101 // targets
102 Tensor::from_data(
103 [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
104 // predictions @ threshold=0.5
105 // [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
106 &Default::default(),
107 ),
108 )
109 }
110 }
111 }
112}