Skip to main content

wesichain_checkpoint_postgres/
lib.rs

1use std::convert::TryFrom;
2
3use wesichain_checkpoint_sql::error::CheckpointSqlError;
4use wesichain_checkpoint_sql::migrations::run_migrations;
5use wesichain_checkpoint_sql::ops::{
6    load_latest_checkpoint, save_checkpoint_with_projections_and_queue,
7};
8use wesichain_core::checkpoint::{Checkpoint, Checkpointer};
9use wesichain_core::state::{GraphState, StateSchema};
10use wesichain_core::WesichainError;
11
12#[derive(Debug, Clone)]
13pub struct PostgresCheckpointer {
14    pool: sqlx::PgPool,
15    enable_projections: bool,
16}
17
18#[derive(Debug, Clone)]
19pub struct PostgresCheckpointerBuilder {
20    database_url: String,
21    max_connections: u32,
22    min_connections: u32,
23    enable_projections: bool,
24}
25
26impl PostgresCheckpointer {
27    pub fn builder(database_url: impl Into<String>) -> PostgresCheckpointerBuilder {
28        PostgresCheckpointerBuilder {
29            database_url: database_url.into(),
30            max_connections: 5,
31            min_connections: 0,
32            enable_projections: false,
33        }
34    }
35
36    pub fn projections_enabled(&self) -> bool {
37        self.enable_projections
38    }
39}
40
41impl PostgresCheckpointerBuilder {
42    pub fn max_connections(mut self, max_connections: u32) -> Self {
43        self.max_connections = max_connections;
44        self
45    }
46
47    pub fn min_connections(mut self, min_connections: u32) -> Self {
48        self.min_connections = min_connections;
49        self
50    }
51
52    pub fn enable_projections(mut self, enable_projections: bool) -> Self {
53        self.enable_projections = enable_projections;
54        self
55    }
56
57    pub async fn build(self) -> Result<PostgresCheckpointer, CheckpointSqlError> {
58        let pool = sqlx::postgres::PgPoolOptions::new()
59            .max_connections(self.max_connections)
60            .min_connections(self.min_connections)
61            .connect(&self.database_url)
62            .await
63            .map_err(CheckpointSqlError::Connection)?;
64
65        run_migrations(&pool).await?;
66
67        Ok(PostgresCheckpointer {
68            pool,
69            enable_projections: self.enable_projections,
70        })
71    }
72}
73
74fn graph_checkpoint_error(message: impl Into<String>) -> WesichainError {
75    WesichainError::CheckpointFailed(message.into())
76}
77
78fn map_sql_error(error: CheckpointSqlError) -> WesichainError {
79    graph_checkpoint_error(error.to_string())
80}
81
82impl<S: StateSchema> Checkpointer<S> for PostgresCheckpointer {
83    fn save<'life0, 'life1, 'async_trait>(
84        &'life0 self,
85        checkpoint: &'life1 Checkpoint<S>,
86    ) -> core::pin::Pin<
87        Box<dyn core::future::Future<Output = Result<(), WesichainError>> + Send + 'async_trait>,
88    >
89    where
90        'life0: 'async_trait,
91        'life1: 'async_trait,
92        Self: 'async_trait,
93    {
94        Box::pin(async move {
95            let step = i64::try_from(checkpoint.step)
96                .map_err(|_| graph_checkpoint_error("checkpoint step does not fit into i64"))?;
97
98            save_checkpoint_with_projections_and_queue(
99                &self.pool,
100                &checkpoint.thread_id,
101                &checkpoint.node,
102                step,
103                &checkpoint.created_at,
104                &checkpoint.state,
105                &checkpoint.queue,
106                self.enable_projections,
107            )
108            .await
109            .map_err(map_sql_error)?;
110
111            Ok(())
112        })
113    }
114
115    fn load<'life0, 'life1, 'async_trait>(
116        &'life0 self,
117        thread_id: &'life1 str,
118    ) -> core::pin::Pin<
119        Box<
120            dyn core::future::Future<Output = Result<Option<Checkpoint<S>>, WesichainError>>
121                + Send
122                + 'async_trait,
123        >,
124    >
125    where
126        'life0: 'async_trait,
127        'life1: 'async_trait,
128        Self: 'async_trait,
129    {
130        Box::pin(async move {
131            let stored = load_latest_checkpoint(&self.pool, thread_id)
132                .await
133                .map_err(map_sql_error)?;
134
135            let Some(stored) = stored else {
136                return Ok(None);
137            };
138
139            let step_i64 = stored
140                .step
141                .ok_or_else(|| graph_checkpoint_error("checkpoint step is missing"))?;
142            let step = u64::try_from(step_i64)
143                .map_err(|_| graph_checkpoint_error("checkpoint step is negative"))?;
144
145            let node = stored
146                .node
147                .ok_or_else(|| graph_checkpoint_error("checkpoint node is missing"))?;
148
149            let state: GraphState<S> =
150                serde_json::from_value(stored.state_json).map_err(|error| {
151                    graph_checkpoint_error(format!(
152                        "failed to deserialize checkpoint state: {error}"
153                    ))
154                })?;
155
156            let queue: Vec<(String, u64)> =
157                serde_json::from_value(stored.queue_json).map_err(|error| {
158                    graph_checkpoint_error(format!(
159                        "failed to deserialize checkpoint queue: {error}"
160                    ))
161                })?;
162
163            Ok(Some(Checkpoint {
164                thread_id: stored.thread_id,
165                state,
166                step,
167                node,
168                queue,
169                created_at: stored.created_at,
170            }))
171        })
172    }
173}