#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#[macro_use]
extern crate derive_new;
pub mod checkpoint;
pub(crate) mod components;
pub mod renderer;
pub mod logger;
pub mod metric;
pub use metric::processor::*;
mod learner;
pub use learner::*;
mod evaluator;
pub use evaluator::*;
pub use components::*;
#[cfg(test)]
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(test)]
pub(crate) mod tests {
use crate::TestBackend;
use burn_core::{prelude::Tensor, tensor::Bool};
use std::default::Default;
pub const THRESHOLD: f64 = 0.5;
#[derive(Debug, Default)]
pub enum ClassificationType {
#[default]
Binary,
Multiclass,
Multilabel,
}
pub fn dummy_classification_input(
classification_type: &ClassificationType,
) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
match classification_type {
ClassificationType::Binary => {
(
Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
)
}
ClassificationType::Multiclass => {
(
Tensor::from_data(
[
[0.2, 0.8, 0.0],
[0.3, 0.6, 0.1],
[0.7, 0.25, 0.05],
[0.1, 0.15, 0.8],
[0.9, 0.03, 0.07],
],
&Default::default(),
),
Tensor::from_data(
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
&Default::default(),
),
)
}
ClassificationType::Multilabel => {
(
Tensor::from_data(
[
[0.1, 0.7, 0.6],
[0.3, 0.9, 0.05],
[0.8, 0.9, 0.4],
[0.7, 0.5, 0.9],
[1.0, 0.3, 0.2],
],
&Default::default(),
),
Tensor::from_data(
[[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
&Default::default(),
),
)
}
}
}
}