Skip to main content

rbp_database/
stage.rs

1use super::*;
2use std::sync::Arc;
3use tokio_postgres::Client;
4
5/// Stage defines bulk upload operations for fast training.
6/// Manages staging table lifecycle and batch epoch updates.
7#[async_trait::async_trait]
8pub trait Stage: Send + Sync {
9    async fn stage(&self);
10    async fn merge(&self);
11    async fn stamp(&self, n: usize);
12}
13
14#[async_trait::async_trait]
15impl Stage for Client {
16    async fn stage(&self) {
17        let sql = format!(
18            "DROP   TABLE IF EXISTS {t2};
19             CREATE UNLOGGED TABLE  {t2} (LIKE {t1});",
20            t1 = BLUEPRINT,
21            t2 = STAGING
22        );
23        self.batch_execute(&sql).await.expect("create staging");
24    }
25    async fn merge(&self) {
26        let sql = format!(
27            "INSERT INTO   {t1} (past, present, choices, edge, weight, regret, evalue, counts)
28             SELECT              past, present, choices, edge, weight, regret, evalue, counts FROM {t2}
29             ON CONFLICT  (past, present, choices, edge)
30             DO UPDATE SET
31                 weight = EXCLUDED.weight,
32                 regret = EXCLUDED.regret,
33                 evalue = EXCLUDED.evalue,
34                 counts = EXCLUDED.counts;
35             DROP TABLE    {t2};",
36            t1 = BLUEPRINT,
37            t2 = STAGING
38        );
39        self.batch_execute(&sql).await.expect("upsert blueprint");
40    }
41    async fn stamp(&self, n: usize) {
42        let sql = format!(
43            "UPDATE {t} SET value = value + $1 WHERE key = 'current'",
44            t = EPOCH
45        );
46        self.execute(&sql, &[&(n as i64)])
47            .await
48            .expect("update epoch");
49    }
50}
51
52#[async_trait::async_trait]
53impl Stage for Arc<Client> {
54    async fn stage(&self) {
55        self.as_ref().stage().await
56    }
57    async fn merge(&self) {
58        self.as_ref().merge().await
59    }
60    async fn stamp(&self, n: usize) {
61        self.as_ref().stamp(n).await
62    }
63}