use async_trait::async_trait;
use chrono::{DateTime, Utc};
use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
use sea_orm::sea_query::OnConflict;
use sea_orm::{ActiveValue, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter};
use tracing::{Span, instrument};
use crate::conversions::map_sea_err;
use crate::database_system_label;
use crate::entities::graph_sync_checkpoint;
use crate::types::DatabaseError;
#[async_trait]
pub trait CheckpointStore: Send + Sync {
async fn load(&self, key: &str) -> Result<Option<DateTime<Utc>>, DatabaseError>;
async fn save(&self, key: &str, ts: DateTime<Utc>) -> Result<(), DatabaseError>;
}
#[instrument(
name = "cognee.db.relational.checkpoint.load_checkpoint",
level = "info",
skip_all,
fields(
cognee.db.system = tracing::field::Empty,
cognee.db.row_count = tracing::field::Empty,
),
err,
)]
pub async fn load_checkpoint(
db: &DatabaseConnection,
key: &str,
) -> Result<Option<DateTime<Utc>>, DatabaseError> {
Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
let row = graph_sync_checkpoint::Entity::find()
.filter(graph_sync_checkpoint::Column::Key.eq(key))
.one(db)
.await
.map_err(map_sea_err)?;
let result = row.map(|m| m.ts);
Span::current().record(
COGNEE_DB_ROW_COUNT,
if result.is_some() { 1i64 } else { 0i64 },
);
Ok(result)
}
#[instrument(
name = "cognee.db.relational.checkpoint.save_checkpoint",
level = "info",
skip_all,
fields(cognee.db.system = tracing::field::Empty),
err,
)]
pub async fn save_checkpoint(
db: &DatabaseConnection,
key: &str,
ts: DateTime<Utc>,
) -> Result<(), DatabaseError> {
Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
let model = graph_sync_checkpoint::ActiveModel {
key: ActiveValue::Set(key.to_string()),
ts: ActiveValue::Set(ts),
};
graph_sync_checkpoint::Entity::insert(model)
.on_conflict(
OnConflict::column(graph_sync_checkpoint::Column::Key)
.update_column(graph_sync_checkpoint::Column::Ts)
.to_owned(),
)
.exec(db)
.await
.map_err(map_sea_err)?;
Ok(())
}
pub struct SeaOrmCheckpointStore {
db: std::sync::Arc<DatabaseConnection>,
}
impl SeaOrmCheckpointStore {
pub fn new(db: std::sync::Arc<DatabaseConnection>) -> Self {
Self { db }
}
}
#[async_trait]
impl CheckpointStore for SeaOrmCheckpointStore {
async fn load(&self, key: &str) -> Result<Option<DateTime<Utc>>, DatabaseError> {
load_checkpoint(self.db.as_ref(), key).await
}
async fn save(&self, key: &str, ts: DateTime<Utc>) -> Result<(), DatabaseError> {
save_checkpoint(self.db.as_ref(), key, ts).await
}
}