burn-train 0.20.1

Training crate for the Burn framework
Documentation
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};

use super::{EventProcessorTraining, LearnerEvent};
use async_channel::{Receiver, Sender};

/// Event processor for the training process.
pub struct AsyncProcessorTraining<P: EventProcessorTraining> {
    sender: Sender<Message<P>>,
}

/// Event processor for the model evaluation.
pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
    sender: Sender<EvalMessage<P>>,
}

struct WorkerTraining<P: EventProcessorTraining> {
    processor: P,
    rec: Receiver<Message<P>>,
}

struct WorkerEvaluation<P: EventProcessorEvaluation> {
    processor: P,
    rec: Receiver<EvalMessage<P>>,
}

impl<P: EventProcessorTraining + 'static> WorkerTraining<P> {
    pub fn start(processor: P, rec: Receiver<Message<P>>) {
        let mut worker = Self { processor, rec };

        std::thread::spawn(move || {
            while let Ok(msg) = worker.rec.recv_blocking() {
                match msg {
                    Message::Train(event) => worker.processor.process_train(event),
                    Message::Valid(event) => worker.processor.process_valid(event),
                    Message::Renderer(callback) => {
                        callback.send_blocking(worker.processor.renderer()).unwrap();
                        return;
                    }
                }
            }
        });
    }
}
impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
    pub fn start(processor: P, rec: Receiver<EvalMessage<P>>) {
        let mut worker = Self { processor, rec };

        std::thread::spawn(move || {
            while let Ok(event) = worker.rec.recv_blocking() {
                match event {
                    EvalMessage::Test(event) => worker.processor.process_test(event),
                    EvalMessage::Renderer(sender) => {
                        sender.send_blocking(worker.processor.renderer()).unwrap();
                        return;
                    }
                }
            }
        });
    }
}

impl<P: EventProcessorTraining + 'static> AsyncProcessorTraining<P> {
    /// Create an event processor for training.
    pub fn new(processor: P) -> Self {
        let (sender, rec) = async_channel::bounded(1);

        WorkerTraining::start(processor, rec);

        Self { sender }
    }
}

impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
    /// Create an event processor for model evaluation.
    pub fn new(processor: P) -> Self {
        let (sender, rec) = async_channel::bounded(1);

        WorkerEvaluation::start(processor, rec);

        Self { sender }
    }
}

enum Message<P: EventProcessorTraining> {
    Train(LearnerEvent<P::ItemTrain>),
    Valid(LearnerEvent<P::ItemValid>),
    Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
}

enum EvalMessage<P: EventProcessorEvaluation> {
    Test(EvaluatorEvent<P::ItemTest>),
    Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
}

impl<P: EventProcessorTraining> EventProcessorTraining for AsyncProcessorTraining<P> {
    type ItemTrain = P::ItemTrain;
    type ItemValid = P::ItemValid;

    fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
        self.sender.send_blocking(Message::Train(event)).unwrap();
    }

    fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
        self.sender.send_blocking(Message::Valid(event)).unwrap();
    }

    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
        let (sender, rec) = async_channel::bounded(1);
        self.sender
            .send_blocking(Message::Renderer(sender))
            .unwrap();

        match rec.recv_blocking() {
            Ok(value) => value,
            Err(err) => panic!("{err:?}"),
        }
    }
}

impl<P: EventProcessorEvaluation> EventProcessorEvaluation for AsyncProcessorEvaluation<P> {
    type ItemTest = P::ItemTest;

    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
        self.sender.send_blocking(EvalMessage::Test(event)).unwrap();
    }

    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
        let (sender, rec) = async_channel::bounded(1);
        self.sender
            .send_blocking(EvalMessage::Renderer(sender))
            .unwrap();

        match rec.recv_blocking() {
            Ok(value) => value,
            Err(err) => panic!("{err:?}"),
        }
    }
}