Skip to main content

force_sync/store/pg/
checkpoint.rs

1//! Checkpoint repository helpers for the `PostgreSQL` sync store.
2
3use tokio_postgres::GenericClient;
4
5use crate::error::ForceSyncError;
6
7use super::PgStore;
8
9/// Stored checkpoint state for a capture stream.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct CheckpointState {
12    /// Stream name that owns the checkpoint.
13    pub stream_name: String,
14    /// Monotonic numeric cursor position.
15    pub cursor_position: i64,
16    /// Original cursor string.
17    pub cursor: Option<String>,
18}
19
20async fn advance_checkpoint_if_greater_query<C>(
21    client: &C,
22    stream_name: &str,
23    cursor_position: i64,
24    cursor: &str,
25) -> Result<u64, ForceSyncError>
26where
27    C: GenericClient + Sync + ?Sized,
28{
29    let rows = client
30        .execute(
31            "insert into sync_checkpoint (stream_name, cursor_position, cursor)
32             values ($1, $2, $3)
33             on conflict (stream_name) do update set
34                 cursor_position = excluded.cursor_position,
35                 cursor = excluded.cursor,
36                 updated_at = now()
37             where sync_checkpoint.cursor_position < excluded.cursor_position",
38            &[&stream_name, &cursor_position, &cursor],
39        )
40        .await?;
41
42    Ok(rows)
43}
44
45async fn get_checkpoint_query<C>(
46    client: &C,
47    stream_name: &str,
48) -> Result<Option<CheckpointState>, ForceSyncError>
49where
50    C: GenericClient + Sync + ?Sized,
51{
52    let row = client
53        .query_opt(
54            "select stream_name, cursor_position, cursor
55             from sync_checkpoint
56             where stream_name = $1",
57            &[&stream_name],
58        )
59        .await?;
60
61    Ok(row.map(|row| CheckpointState {
62        stream_name: row.get(0),
63        cursor_position: row.get(1),
64        cursor: row.get(2),
65    }))
66}
67
68impl PgStore {
69    /// Advances a checkpoint only if the new position is greater than the stored one.
70    ///
71    /// Returns the number of affected rows.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the database write fails.
76    pub async fn advance_checkpoint_if_greater(
77        &self,
78        stream_name: impl AsRef<str>,
79        cursor_position: i64,
80        cursor: impl AsRef<str>,
81    ) -> Result<u64, ForceSyncError> {
82        let stream_name = stream_name.as_ref().to_owned();
83        let cursor = cursor.as_ref().to_owned();
84        let client = self.pool().get().await?;
85        advance_checkpoint_if_greater_query(&**client, &stream_name, cursor_position, &cursor).await
86    }
87
88    /// Advances a checkpoint in an existing transaction when the position increases.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if the database write fails.
93    pub async fn advance_checkpoint_if_greater_in_tx<C>(
94        client: &C,
95        stream_name: &str,
96        cursor_position: i64,
97        cursor: &str,
98    ) -> Result<u64, ForceSyncError>
99    where
100        C: GenericClient + Sync + ?Sized,
101    {
102        advance_checkpoint_if_greater_query(client, stream_name, cursor_position, cursor).await
103    }
104
105    /// Loads the stored checkpoint for a stream, if one exists.
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the database query fails.
110    pub async fn get_checkpoint(
111        &self,
112        stream_name: impl AsRef<str>,
113    ) -> Result<Option<CheckpointState>, ForceSyncError> {
114        let client = self.pool().get().await?;
115        get_checkpoint_query(&**client, stream_name.as_ref()).await
116    }
117}