Skip to main content

cognee_database/sync/
sea_orm_impl.rs

1//! SeaORM-backed [`SyncOperationRepository`] implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use chrono::Utc;
7use sea_orm::{
8    ActiveModelTrait, ActiveValue::Set, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter,
9    QueryOrder,
10};
11use uuid::Uuid;
12
13use crate::entities::sync_operation;
14use crate::types::DatabaseError;
15use crate::uuid_hex;
16
17use super::repository::{SyncOperationRepository, SyncOperationRow, SyncOperationStatus};
18
19/// SeaORM impl of [`SyncOperationRepository`]. Cheap to clone (interior `Arc`).
20#[derive(Clone)]
21pub struct SeaOrmSyncOperationRepository {
22    db: Arc<DatabaseConnection>,
23}
24
25impl SeaOrmSyncOperationRepository {
26    /// Build a new repository wrapping the supplied connection.
27    pub fn new(db: Arc<DatabaseConnection>) -> Self {
28        Self { db }
29    }
30}
31
32fn parse_uuid_list(json: Option<&serde_json::Value>) -> Vec<Uuid> {
33    match json {
34        Some(serde_json::Value::Array(items)) => items
35            .iter()
36            .filter_map(|v| v.as_str())
37            .filter_map(|s| Uuid::parse_str(s).ok())
38            .collect(),
39        _ => Vec::new(),
40    }
41}
42
43fn parse_string_list(json: Option<&serde_json::Value>) -> Vec<String> {
44    match json {
45        Some(serde_json::Value::Array(items)) => items
46            .iter()
47            .filter_map(|v| v.as_str().map(|s| s.to_string()))
48            .collect(),
49        _ => Vec::new(),
50    }
51}
52
53fn row_from_model(m: sync_operation::Model) -> Result<SyncOperationRow, DatabaseError> {
54    let id = uuid_hex::from_hex(&m.id)
55        .map_err(|e| DatabaseError::QueryError(format!("invalid sync_operations.id: {e}")))?;
56    let user_id = uuid_hex::from_hex(&m.user_id)
57        .map_err(|e| DatabaseError::QueryError(format!("invalid sync_operations.user_id: {e}")))?;
58    let dataset_ids = parse_uuid_list(m.dataset_ids.as_ref());
59    let dataset_names = parse_string_list(m.dataset_names.as_ref());
60    Ok(SyncOperationRow {
61        id,
62        run_id: m.run_id,
63        status: m.status,
64        progress_percentage: m.progress_percentage.max(0) as u32,
65        dataset_ids,
66        dataset_names,
67        user_id,
68        created_at: m.created_at,
69        started_at: m.started_at,
70        completed_at: m.completed_at,
71        total_records_to_sync: m.total_records_to_sync,
72        total_records_to_download: m.total_records_to_download,
73        total_records_to_upload: m.total_records_to_upload,
74        records_downloaded: m.records_downloaded,
75        records_uploaded: m.records_uploaded,
76        bytes_downloaded: m.bytes_downloaded,
77        bytes_uploaded: m.bytes_uploaded,
78        dataset_sync_hashes: m.dataset_sync_hashes,
79        error_message: m.error_message,
80        retry_count: m.retry_count,
81    })
82}
83
84async fn fetch_by_run_id(
85    db: &DatabaseConnection,
86    run_id: &str,
87) -> Result<Option<sync_operation::Model>, DatabaseError> {
88    sync_operation::Entity::find()
89        .filter(sync_operation::Column::RunId.eq(run_id))
90        .one(db)
91        .await
92        .map_err(|e| DatabaseError::QueryError(format!("sync_operations lookup failed: {e}")))
93}
94
95#[async_trait]
96impl SyncOperationRepository for SeaOrmSyncOperationRepository {
97    async fn create_operation(
98        &self,
99        run_id: &str,
100        dataset_ids: &[Uuid],
101        dataset_names: &[String],
102        user_id: Uuid,
103    ) -> Result<(), DatabaseError> {
104        let row_id = uuid_hex::to_hex(Uuid::new_v4());
105        let dataset_ids_json = serde_json::Value::Array(
106            dataset_ids
107                .iter()
108                .map(|u| serde_json::Value::String(u.to_string()))
109                .collect(),
110        );
111        let dataset_names_json = serde_json::Value::Array(
112            dataset_names
113                .iter()
114                .map(|s| serde_json::Value::String(s.clone()))
115                .collect(),
116        );
117        let am = sync_operation::ActiveModel {
118            id: Set(row_id),
119            run_id: Set(run_id.to_string()),
120            status: Set(SyncOperationStatus::Started.as_str().to_string()),
121            progress_percentage: Set(0),
122            dataset_ids: Set(Some(dataset_ids_json)),
123            dataset_names: Set(Some(dataset_names_json)),
124            user_id: Set(uuid_hex::to_hex(user_id)),
125            created_at: Set(Utc::now()),
126            started_at: Set(None),
127            completed_at: Set(None),
128            total_records_to_sync: Set(None),
129            total_records_to_download: Set(None),
130            total_records_to_upload: Set(None),
131            records_downloaded: Set(0),
132            records_uploaded: Set(0),
133            bytes_downloaded: Set(0),
134            bytes_uploaded: Set(0),
135            dataset_sync_hashes: Set(None),
136            error_message: Set(None),
137            retry_count: Set(0),
138        };
139        sync_operation::Entity::insert(am)
140            .exec(self.db.as_ref())
141            .await
142            .map_err(|e| {
143                DatabaseError::QueryError(format!("create_operation insert failed: {e}"))
144            })?;
145        Ok(())
146    }
147
148    async fn mark_started(&self, run_id: &str) -> Result<(), DatabaseError> {
149        let Some(row) = fetch_by_run_id(self.db.as_ref(), run_id).await? else {
150            return Err(DatabaseError::NotFound(format!(
151                "sync_operations row not found: {run_id}"
152            )));
153        };
154        let mut am: sync_operation::ActiveModel = row.into();
155        am.status = Set(SyncOperationStatus::InProgress.as_str().to_string());
156        am.started_at = Set(Some(Utc::now()));
157        am.update(self.db.as_ref())
158            .await
159            .map_err(|e| DatabaseError::QueryError(format!("mark_started update failed: {e}")))?;
160        Ok(())
161    }
162
163    async fn mark_completed(
164        &self,
165        run_id: &str,
166        records_uploaded: i32,
167        records_downloaded: i32,
168        bytes_uploaded: i64,
169        bytes_downloaded: i64,
170        dataset_sync_hashes: Option<serde_json::Value>,
171    ) -> Result<(), DatabaseError> {
172        let Some(row) = fetch_by_run_id(self.db.as_ref(), run_id).await? else {
173            return Err(DatabaseError::NotFound(format!(
174                "sync_operations row not found: {run_id}"
175            )));
176        };
177        let mut am: sync_operation::ActiveModel = row.into();
178        am.status = Set(SyncOperationStatus::Completed.as_str().to_string());
179        am.progress_percentage = Set(100);
180        am.completed_at = Set(Some(Utc::now()));
181        am.records_uploaded = Set(records_uploaded);
182        am.records_downloaded = Set(records_downloaded);
183        am.bytes_uploaded = Set(bytes_uploaded);
184        am.bytes_downloaded = Set(bytes_downloaded);
185        am.dataset_sync_hashes = Set(dataset_sync_hashes);
186        am.update(self.db.as_ref())
187            .await
188            .map_err(|e| DatabaseError::QueryError(format!("mark_completed update failed: {e}")))?;
189        Ok(())
190    }
191
192    async fn mark_failed(&self, run_id: &str, error_message: &str) -> Result<(), DatabaseError> {
193        let Some(row) = fetch_by_run_id(self.db.as_ref(), run_id).await? else {
194            return Err(DatabaseError::NotFound(format!(
195                "sync_operations row not found: {run_id}"
196            )));
197        };
198        let mut am: sync_operation::ActiveModel = row.into();
199        am.status = Set(SyncOperationStatus::Failed.as_str().to_string());
200        am.completed_at = Set(Some(Utc::now()));
201        am.error_message = Set(Some(error_message.to_string()));
202        am.update(self.db.as_ref())
203            .await
204            .map_err(|e| DatabaseError::QueryError(format!("mark_failed update failed: {e}")))?;
205        Ok(())
206    }
207
208    async fn update_progress(&self, run_id: &str, percent: u32) -> Result<(), DatabaseError> {
209        let Some(row) = fetch_by_run_id(self.db.as_ref(), run_id).await? else {
210            return Err(DatabaseError::NotFound(format!(
211                "sync_operations row not found: {run_id}"
212            )));
213        };
214        let mut am: sync_operation::ActiveModel = row.into();
215        am.progress_percentage = Set(percent.min(100) as i32);
216        am.update(self.db.as_ref()).await.map_err(|e| {
217            DatabaseError::QueryError(format!("update_progress update failed: {e}"))
218        })?;
219        Ok(())
220    }
221
222    async fn running_for_user(
223        &self,
224        user_id: Uuid,
225    ) -> Result<Vec<SyncOperationRow>, DatabaseError> {
226        let user_hex = uuid_hex::to_hex(user_id);
227        let rows = sync_operation::Entity::find()
228            .filter(sync_operation::Column::UserId.eq(user_hex))
229            .filter(sync_operation::Column::Status.is_in([
230                SyncOperationStatus::Started.as_str(),
231                SyncOperationStatus::InProgress.as_str(),
232            ]))
233            .order_by_desc(sync_operation::Column::CreatedAt)
234            .all(self.db.as_ref())
235            .await
236            .map_err(|e| DatabaseError::QueryError(format!("running_for_user failed: {e}")))?;
237
238        let mut out = Vec::with_capacity(rows.len());
239        for row in rows {
240            out.push(row_from_model(row)?);
241        }
242        Ok(out)
243    }
244
245    async fn get_by_run_id(&self, run_id: &str) -> Result<Option<SyncOperationRow>, DatabaseError> {
246        match fetch_by_run_id(self.db.as_ref(), run_id).await? {
247            Some(model) => Ok(Some(row_from_model(model)?)),
248            None => Ok(None),
249        }
250    }
251}