1use crate::*;
3use rbp_nlhe::Flagship;
4use rbp_database::*;
5use rbp_mccfr::*;
6use rbp_nlhe::NlheProfile;
7use std::sync::Arc;
8use tokio_postgres::Client;
9use tokio_postgres::binary_copy::BinaryCopyInWriter;
10
11pub struct FastSession {
13 client: Arc<Client>,
14 solver: Flagship,
15}
16
17impl FastSession {
18 pub async fn new(client: Arc<Client>) -> Self {
19 PreTraining::run(&client).await;
20 Self {
21 solver: Flagship::hydrate(client.clone()).await,
22 client,
23 }
24 }
25}
26
27#[async_trait::async_trait]
28impl Trainer for FastSession {
29 fn client(&self) -> &Arc<Client> {
30 &self.client
31 }
32 async fn step(&mut self) {
33 self.solver.step();
34 }
35 async fn epoch(&self) -> usize {
36 self.solver.profile().epochs()
37 }
38 async fn checkpoint(&self) -> Option<String> {
39 self.solver.profile().metrics().and_then(|m| m.checkpoint())
40 }
41 async fn summary(&self) -> String {
42 self.solver
43 .profile()
44 .metrics()
45 .map(|m| m.summary())
46 .unwrap_or_else(|| "training stopped".to_string())
47 }
48 async fn sync(self) {
49 let client = self.client;
50 let epochs = self.solver.profile.epochs();
51 let profile = self.solver.profile;
52 client.stage().await;
53 let copy = format!(
54 "COPY {t} (past, present, choices, edge, weight, regret, evalue, counts) FROM STDIN BINARY",
55 t = rbp_database::STAGING
56 );
57 let writer = BinaryCopyInWriter::new(
58 client.copy_in(©).await.expect("copy_in"),
59 NlheProfile::columns(),
60 );
61 futures::pin_mut!(writer);
62 for row in profile.rows() {
63 row.write(writer.as_mut()).await;
64 }
65 writer.finish().await.expect("finish stream");
66 client.merge().await;
67 client.stamp(epochs).await;
68 log::info!("profile sync complete (epoch {})", epochs);
69 }
70}