1use std::sync::Arc;
3use tokio_postgres::Client;
4
5#[async_trait::async_trait]
8pub trait Trainer: Send + Sync + Sized {
9 fn client(&self) -> &Arc<Client>;
11 async fn sync(self);
13 async fn step(&mut self);
15 async fn epoch(&self) -> usize;
17 async fn summary(&self) -> String;
19 async fn checkpoint(&self) -> Option<String>;
21
22 async fn train(mut self) {
23 log::info!("training blueprint");
24 log::info!("press 'Q + ↵' to stop gracefully");
25 loop {
26 self.step().await;
27 self.checkpoint().await.map(|s| log::info!("{}", s));
28 if rbp_core::interrupted() {
29 log::info!("{}", self.summary().await);
30 break;
31 }
32 }
33 self.sync().await;
34 }
35}