use crate::bench::classification_metrics::ClassificationMetrics;
use crate::bench::core::error::ProfilerError;
use crate::bench::core::profiler::Profiler;
use crate::bench::core::train_metrics::TrainMetrics;
use crate::model::core::base::BaseModel;
use crate::model::core::classification_model::ClassificationModel;
use crate::optim::core::optimizer::Optimizer;
use std::marker::PhantomData;
use std::time::Instant;
pub struct ClassificationProfiler<Model, Opt, Input, Output> {
_phantom: std::marker::PhantomData<(Model, Opt, Input, Output)>,
}
impl<Model, Opt, Input, Output> ClassificationProfiler<Model, Opt, Input, Output> {
pub fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
}
impl<Model, Opt, Input, Output> Default for ClassificationProfiler<Model, Opt, Input, Output> {
fn default() -> Self {
Self::new()
}
}
impl<Model, Opt, Input, Output> Profiler<Model, Opt, Input, Output>
for ClassificationProfiler<Model, Opt, Input, Output>
where
Model: BaseModel<Input, Output> + ClassificationModel<Input, Output>,
Opt: Optimizer<Input, Output, Model>,
{
type EvalMetrics = ClassificationMetrics;
fn train(
&self,
model: &mut Model,
optimizer: &mut Opt,
x: &Input,
y: &Output,
) -> Result<(TrainMetrics, Self::EvalMetrics), ProfilerError> {
let tick = Instant::now();
optimizer.fit(model, x, y)?;
let tock = Instant::now();
let elapsed = tock.duration_since(tick).as_secs_f64();
let train_metrics = TrainMetrics::new(elapsed);
let eval_metrics = model.compute_metrics(x, y)?;
Ok((train_metrics, eval_metrics))
}
fn profile_evaluation(
&self,
model: &mut Model,
x: &Input,
y: &Output,
) -> Result<Self::EvalMetrics, ProfilerError> {
let eval_metrics = model.compute_metrics(x, y)?;
Ok(eval_metrics)
}
}