pub mod cnn;
pub mod rnn;
pub mod serialization;
pub mod training;
pub mod transformer_models;
pub mod sequential;
pub mod high_level;
pub mod sequential_basic;
#[cfg(test)]
pub mod sequential_tests;
use crate::autograd::Variable;
use crate::nn::Module;
use num_traits::Float;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelMode {
Train,
Eval,
}
pub trait Model<T>: Module<T>
where
T: Float + 'static + Send + Sync + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
fn train(&mut self);
fn eval(&mut self);
fn mode(&self) -> ModelMode;
fn config(&self) -> HashMap<String, String>;
fn summary(&self) -> String;
}
pub trait ModelBuilder<T>
where
T: Float + 'static + Send + Sync + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
type Model: Model<T>;
fn build(self) -> Self::Model;
}
pub use cnn::{CNNBuilder, ResNet, ResNetBuilder, CNN};
pub use rnn::{LSTMModel, LSTMModelBuilder, RNNModel, RNNModelBuilder};
pub use serialization::{ModelLoader, ModelSaver, SerializationFormat};
pub use training::{Trainer, TrainingConfig, TrainingResult};
pub use transformer_models::{BERTBuilder, TransformerModel, TransformerModelBuilder, BERT};
pub use high_level::{FitConfig, HighLevelModel, TrainingHistory};
pub use sequential::{Sequential, SequentialBuilder};
pub use sequential_basic::{BasicSequential, BasicSequentialBuilder};
#[derive(Debug)]
pub struct InferenceEngine<T: Float + Send + Sync + 'static> {
_phantom: std::marker::PhantomData<T>,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
InferenceEngine<T>
{
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
pub fn predict<M: Model<T>>(&self, model: &M, input: &Variable<T>) -> Variable<T> {
model.forward(input)
}
}
#[derive(Debug, Clone)]
pub struct Metrics {
pub accuracy: f64,
pub precision: f64,
pub recall: f64,
pub f1_score: f64,
pub loss: f64,
}
impl Metrics {
pub fn new() -> Self {
Self {
accuracy: 0.0,
precision: 0.0,
recall: 0.0,
f1_score: 0.0,
loss: 0.0,
}
}
pub fn with_values(
accuracy: f64,
precision: f64,
recall: f64,
f1_score: f64,
loss: f64,
) -> Self {
Self {
accuracy,
precision,
recall,
f1_score,
loss,
}
}
}