Skip to main content

cognee_database/ops/
checkpoint.rs

1//! Checkpoint store abstraction for Stage 4 of `improve()`.
2//!
3//! Provides a generic key/timestamp storage interface used by
4//! `sync_graph_to_session` to track the high-water mark of edges that have
5//! already been merged into a session's graph context.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
10use sea_orm::sea_query::OnConflict;
11use sea_orm::{ActiveValue, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter};
12use tracing::{Span, instrument};
13
14use crate::conversions::map_sea_err;
15use crate::database_system_label;
16use crate::entities::graph_sync_checkpoint;
17use crate::types::DatabaseError;
18
19/// Abstraction over persistent timestamp checkpoints keyed by string.
20///
21/// Analogous to Python's cache-engine interface used for
22/// `graph_sync_checkpoint:{user_id}:{dataset_id}:{session_id}` keys.
23#[async_trait]
24pub trait CheckpointStore: Send + Sync {
25    /// Read the timestamp stored under `key`, or `None` if missing.
26    async fn load(&self, key: &str) -> Result<Option<DateTime<Utc>>, DatabaseError>;
27
28    /// Write `ts` under `key`, overwriting any previous value.
29    async fn save(&self, key: &str, ts: DateTime<Utc>) -> Result<(), DatabaseError>;
30}
31
32/// Load the checkpoint timestamp for a key from the `graph_sync_checkpoints`
33/// table, or `None` if the key does not exist.
34#[instrument(
35    name = "cognee.db.relational.checkpoint.load_checkpoint",
36    level = "info",
37    skip_all,
38    fields(
39        cognee.db.system = tracing::field::Empty,
40        cognee.db.row_count = tracing::field::Empty,
41    ),
42    err,
43)]
44pub async fn load_checkpoint(
45    db: &DatabaseConnection,
46    key: &str,
47) -> Result<Option<DateTime<Utc>>, DatabaseError> {
48    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
49    let row = graph_sync_checkpoint::Entity::find()
50        .filter(graph_sync_checkpoint::Column::Key.eq(key))
51        .one(db)
52        .await
53        .map_err(map_sea_err)?;
54    let result = row.map(|m| m.ts);
55    Span::current().record(
56        COGNEE_DB_ROW_COUNT,
57        if result.is_some() { 1i64 } else { 0i64 },
58    );
59    Ok(result)
60}
61
62/// Persist `ts` under `key` in the `graph_sync_checkpoints` table. Inserts
63/// a new row or updates the existing one (upsert on the primary key).
64#[instrument(
65    name = "cognee.db.relational.checkpoint.save_checkpoint",
66    level = "info",
67    skip_all,
68    fields(cognee.db.system = tracing::field::Empty),
69    err,
70)]
71pub async fn save_checkpoint(
72    db: &DatabaseConnection,
73    key: &str,
74    ts: DateTime<Utc>,
75) -> Result<(), DatabaseError> {
76    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
77    let model = graph_sync_checkpoint::ActiveModel {
78        key: ActiveValue::Set(key.to_string()),
79        ts: ActiveValue::Set(ts),
80    };
81    graph_sync_checkpoint::Entity::insert(model)
82        .on_conflict(
83            OnConflict::column(graph_sync_checkpoint::Column::Key)
84                .update_column(graph_sync_checkpoint::Column::Ts)
85                .to_owned(),
86        )
87        .exec(db)
88        .await
89        .map_err(map_sea_err)?;
90    Ok(())
91}
92
93/// SeaORM-backed implementation of [`CheckpointStore`] that writes to the
94/// `graph_sync_checkpoints` table.
95pub struct SeaOrmCheckpointStore {
96    db: std::sync::Arc<DatabaseConnection>,
97}
98
99impl SeaOrmCheckpointStore {
100    pub fn new(db: std::sync::Arc<DatabaseConnection>) -> Self {
101        Self { db }
102    }
103}
104
105#[async_trait]
106impl CheckpointStore for SeaOrmCheckpointStore {
107    async fn load(&self, key: &str) -> Result<Option<DateTime<Utc>>, DatabaseError> {
108        load_checkpoint(self.db.as_ref(), key).await
109    }
110
111    async fn save(&self, key: &str, ts: DateTime<Utc>) -> Result<(), DatabaseError> {
112        save_checkpoint(self.db.as_ref(), key, ts).await
113    }
114}