Skip to main content

rbp_autotrain/
fast.rs

1//! Fast in-memory training session
2use crate::*;
3use rbp_nlhe::Flagship;
4use rbp_database::*;
5use rbp_mccfr::*;
6use rbp_nlhe::NlheProfile;
7use std::sync::Arc;
8use tokio_postgres::Client;
9use tokio_postgres::binary_copy::BinaryCopyInWriter;
10
11/// Fast in-memory training using Pluribus.
12pub struct FastSession {
13    client: Arc<Client>,
14    solver: Flagship,
15}
16
17impl FastSession {
18    pub async fn new(client: Arc<Client>) -> Self {
19        PreTraining::run(&client).await;
20        Self {
21            solver: Flagship::hydrate(client.clone()).await,
22            client,
23        }
24    }
25}
26
27#[async_trait::async_trait]
28impl Trainer for FastSession {
29    fn client(&self) -> &Arc<Client> {
30        &self.client
31    }
32    async fn step(&mut self) {
33        self.solver.step();
34    }
35    async fn epoch(&self) -> usize {
36        self.solver.profile().epochs()
37    }
38    async fn checkpoint(&self) -> Option<String> {
39        self.solver.profile().metrics().and_then(|m| m.checkpoint())
40    }
41    async fn summary(&self) -> String {
42        self.solver
43            .profile()
44            .metrics()
45            .map(|m| m.summary())
46            .unwrap_or_else(|| "training stopped".to_string())
47    }
48    async fn sync(self) {
49        let client = self.client;
50        let epochs = self.solver.profile.epochs();
51        let profile = self.solver.profile;
52        client.stage().await;
53        let copy = format!(
54            "COPY {t} (past, present, choices, edge, weight, regret, evalue, counts) FROM STDIN BINARY",
55            t = rbp_database::STAGING
56        );
57        let writer = BinaryCopyInWriter::new(
58            client.copy_in(&copy).await.expect("copy_in"),
59            NlheProfile::columns(),
60        );
61        futures::pin_mut!(writer);
62        for row in profile.rows() {
63            row.write(writer.as_mut()).await;
64        }
65        writer.finish().await.expect("finish stream");
66        client.merge().await;
67        client.stamp(epochs).await;
68        log::info!("profile sync complete (epoch {})", epochs);
69    }
70}