use crate::{ItemLazy, renderer::MetricsRenderer};
use burn_core::module::AutodiffModule;
use burn_core::tensor::backend::AutodiffBackend;
use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
pub struct TrainOutput<TO> {
pub grads: GradientsParams,
pub item: TO,
}
impl<TO> TrainOutput<TO> {
pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
module: &M,
grads: B::Gradients,
item: TO,
) -> Self {
let grads = GradientsParams::from_grads(grads, module);
Self { grads, item }
}
}
pub trait TrainStep {
type Input: Send + 'static;
type Output: ItemLazy + 'static;
fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>;
fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
where
B: AutodiffBackend,
O: Optimizer<Self, B>,
Self: AutodiffModule<B>,
{
optim.step(lr, self, grads)
}
fn optimize_multi<B, O>(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self
where
B: AutodiffBackend,
O: Optimizer<Self, B>,
Self: AutodiffModule<B>,
{
optim.step_multi(lr, self, grads)
}
}
pub trait InferenceStep {
type Input: Send + 'static;
type Output: ItemLazy + 'static;
fn step(&self, item: Self::Input) -> Self::Output;
}
pub struct LearningResult<M> {
pub model: M,
pub renderer: Box<dyn MetricsRenderer>,
}