Skip to main content

cognee_database/pipelines/
sea_orm_impl.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::Utc;
6use sea_orm::{
7    ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, QueryOrder,
8    QuerySelect, RelationTrait,
9};
10use serde_json::json;
11use uuid::Uuid;
12
13use crate::conversions::{domain_status_to_entity, entity_status_to_domain};
14use crate::entities::{dataset, pipeline_run, pipeline_run_payload_field};
15use crate::types::{DatabaseError, PipelineRun, PipelineRunStatus};
16use crate::uuid_hex;
17
18use super::repository::{PipelineRunRepository, PipelineRunRow, PipelineRunWithAttributionRow};
19
20/// SeaORM-backed implementation of [`PipelineRunRepository`].
21///
22/// Wraps a shared `DatabaseConnection`. All methods write or query the
23/// `pipeline_runs` table using the "new row per status transition" pattern,
24/// matching both Python's writing pattern and the cross-SDK audit trail
25/// requirement.
26pub struct SeaOrmPipelineRunRepository {
27    db: Arc<DatabaseConnection>,
28}
29
30impl SeaOrmPipelineRunRepository {
31    /// Create a new repository backed by the given database connection.
32    pub fn new(db: Arc<DatabaseConnection>) -> Self {
33        Self { db }
34    }
35}
36
37#[async_trait]
38impl PipelineRunRepository for SeaOrmPipelineRunRepository {
39    async fn log_pipeline_run(
40        &self,
41        pipeline_run_id: Uuid,
42        pipeline_id: Uuid,
43        pipeline_name: &str,
44        dataset_id: Option<Uuid>,
45        status: PipelineRunStatus,
46        run_info: Option<serde_json::Value>,
47    ) -> Result<Uuid, DatabaseError> {
48        let row_id = Uuid::new_v4();
49
50        // `dataset_id` is nullable post-08-01; ad-hoc runs without a dataset
51        // persist with `NULL` in the column rather than being silently dropped.
52        let active = pipeline_run::ActiveModel {
53            id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(row_id)),
54            created_at: sea_orm::ActiveValue::Set(Utc::now()),
55            status: sea_orm::ActiveValue::Set(domain_status_to_entity(status)),
56            pipeline_run_id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(pipeline_run_id)),
57            pipeline_name: sea_orm::ActiveValue::Set(pipeline_name.to_string()),
58            pipeline_id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(pipeline_id)),
59            dataset_id: sea_orm::ActiveValue::Set(uuid_hex::to_hex_opt(dataset_id)),
60            run_info: sea_orm::ActiveValue::Set(run_info),
61        };
62
63        active.insert(self.db.as_ref()).await.map_err(|e| {
64            DatabaseError::QueryError(format!("log_pipeline_run insert failed: {e}"))
65        })?;
66
67        Ok(row_id)
68    }
69
70    async fn latest_status(
71        &self,
72        dataset_ids: &[Uuid],
73        pipeline_name: &str,
74    ) -> Result<HashMap<Uuid, PipelineRunStatus>, DatabaseError> {
75        if dataset_ids.is_empty() {
76            return Ok(HashMap::new());
77        }
78
79        let hex_ids: Vec<String> = dataset_ids.iter().map(|id| uuid_hex::to_hex(*id)).collect();
80
81        // Fetch all matching rows, ordered by created_at DESC.
82        // We then pick the first (most recent) per dataset_id.
83        let rows = pipeline_run::Entity::find()
84            .filter(pipeline_run::Column::PipelineName.eq(pipeline_name))
85            .filter(pipeline_run::Column::DatasetId.is_in(hex_ids))
86            .order_by_desc(pipeline_run::Column::CreatedAt)
87            .all(self.db.as_ref())
88            .await
89            .map_err(|e| DatabaseError::QueryError(format!("latest_status query failed: {e}")))?;
90
91        let mut result: HashMap<Uuid, PipelineRunStatus> = HashMap::new();
92        for row in rows {
93            let run: PipelineRun = row.into();
94            // Only keep the first (most recent) entry per dataset_id.
95            // Ad-hoc rows (dataset_id = None) are not surfaced by
96            // latest_status — they don't belong to any dataset bucket the
97            // caller asked about (the input filter was already keyed by
98            // dataset_id, so a None row would not have matched the IN clause
99            // anyway; we filter defensively here for clarity).
100            if let Some(did) = run.dataset_id {
101                result.entry(did).or_insert(run.status);
102            }
103        }
104
105        Ok(result)
106    }
107
108    async fn list_recent(
109        &self,
110        dataset_id: Option<Uuid>,
111        limit: u32,
112    ) -> Result<Vec<PipelineRunRow>, DatabaseError> {
113        let mut query = pipeline_run::Entity::find()
114            .order_by_desc(pipeline_run::Column::CreatedAt)
115            .limit(u64::from(limit));
116
117        if let Some(did) = dataset_id {
118            query = query.filter(pipeline_run::Column::DatasetId.eq(uuid_hex::to_hex(did)));
119        }
120
121        let rows = query
122            .all(self.db.as_ref())
123            .await
124            .map_err(|e| DatabaseError::QueryError(format!("list_recent query failed: {e}")))?;
125
126        Ok(rows.into_iter().map(PipelineRun::from).collect())
127    }
128
129    async fn list_recent_with_attribution(
130        &self,
131        dataset_id: Option<Uuid>,
132        limit: u32,
133    ) -> Result<Vec<PipelineRunWithAttributionRow>, DatabaseError> {
134        use sea_orm::JoinType;
135
136        // SeaORM JOIN — uses the relationships defined on the entities. We
137        // perform a single LEFT JOIN to `datasets`. Owner-email attribution
138        // requires the `users` table which now lives in the closed
139        // `cognee-access-control` crate; OSS callers receive `owner_email =
140        // None` and are expected to resolve emails out-of-band (or via the
141        // closed `cognee-access-control::auth::UserAuthRepository`). The
142        // dataset/owner_id columns continue to flow through this query so
143        // downstream UIs can render attribution without the email.
144        let mut query = pipeline_run::Entity::find()
145            .select_only()
146            .column(pipeline_run::Column::Id)
147            .column(pipeline_run::Column::CreatedAt)
148            .column(pipeline_run::Column::Status)
149            .column(pipeline_run::Column::PipelineRunId)
150            .column(pipeline_run::Column::PipelineName)
151            .column(pipeline_run::Column::PipelineId)
152            .column(pipeline_run::Column::DatasetId)
153            .column_as(dataset::Column::Name, "dataset_name")
154            .column_as(dataset::Column::OwnerId, "dataset_owner_id")
155            .join(JoinType::LeftJoin, pipeline_run::Relation::Dataset.def())
156            .order_by_desc(pipeline_run::Column::CreatedAt)
157            .limit(u64::from(limit));
158
159        if let Some(did) = dataset_id {
160            query = query.filter(pipeline_run::Column::DatasetId.eq(uuid_hex::to_hex(did)));
161        }
162
163        let raw = query
164            .into_tuple::<(
165                String,
166                chrono::DateTime<Utc>,
167                pipeline_run::PipelineRunStatus,
168                String,
169                String,
170                String,
171                Option<String>,
172                Option<String>,
173                Option<String>,
174            )>()
175            .all(self.db.as_ref())
176            .await
177            .map_err(|e| {
178                DatabaseError::QueryError(format!("list_recent_with_attribution query failed: {e}"))
179            })?;
180
181        let mut rows = Vec::with_capacity(raw.len());
182        for (
183            id_hex,
184            created_at,
185            status,
186            pipeline_run_hex,
187            pipeline_name,
188            pipeline_id_hex,
189            dataset_id_hex,
190            dataset_name,
191            owner_id_hex,
192        ) in raw
193        {
194            // `dataset_id` is nullable post-08-01 — the column may genuinely
195            // be NULL (ad-hoc run without a dataset), and the LEFT JOIN may
196            // also yield NULL when the referenced dataset has been deleted.
197            // Both cases collapse to `None` in the projection.
198            let dataset_uuid = dataset_id_hex
199                .as_deref()
200                .and_then(|s| uuid_hex::from_hex(s).ok());
201            let owner_uuid = owner_id_hex
202                .as_deref()
203                .and_then(|s| uuid_hex::from_hex(s).ok());
204            // Determine dataset attribution presence: when dataset_name is
205            // None the LEFT JOIN didn't match (orphan or NULL dataset_id).
206            let (dataset_id_field, dataset_name_field) = if dataset_name.is_some() {
207                (dataset_uuid, dataset_name)
208            } else {
209                (dataset_uuid, None)
210            };
211
212            rows.push(PipelineRunWithAttributionRow {
213                id: uuid_hex::from_hex(&id_hex)
214                    .map_err(|e| DatabaseError::QueryError(format!("invalid id hex: {e}")))?,
215                created_at,
216                status: entity_status_to_domain(status),
217                pipeline_run_id: uuid_hex::from_hex(&pipeline_run_hex).map_err(|e| {
218                    DatabaseError::QueryError(format!("invalid pipeline_run_id hex: {e}"))
219                })?,
220                pipeline_name,
221                pipeline_id: uuid_hex::from_hex(&pipeline_id_hex).map_err(|e| {
222                    DatabaseError::QueryError(format!("invalid pipeline_id hex: {e}"))
223                })?,
224                dataset_id: dataset_id_field,
225                dataset_name: dataset_name_field,
226                owner_id: owner_uuid,
227                owner_email: None,
228            });
229        }
230        Ok(rows)
231    }
232
233    async fn reset_orphans(&self, reason: &str) -> Result<u64, DatabaseError> {
234        // Find all pipeline_run_ids that have INITIATED or STARTED status
235        // and do NOT have a more recent COMPLETED or ERRORED row with the same
236        // pipeline_run_id. We implement this by fetching the latest row per
237        // pipeline_run_id and checking its status.
238        //
239        // Strategy: fetch all rows ordered by (pipeline_run_id, created_at DESC),
240        // then for each unique pipeline_run_id, check if the latest row is stuck.
241
242        let all_rows = pipeline_run::Entity::find()
243            .order_by_desc(pipeline_run::Column::CreatedAt)
244            .all(self.db.as_ref())
245            .await
246            .map_err(|e| DatabaseError::QueryError(format!("reset_orphans fetch failed: {e}")))?;
247
248        // Collect the latest row per pipeline_run_id.
249        let mut latest_per_run: HashMap<String, pipeline_run::Model> = HashMap::new();
250        for row in all_rows {
251            latest_per_run
252                .entry(row.pipeline_run_id.clone())
253                .or_insert(row);
254        }
255
256        // Find rows that are stuck in INITIATED or STARTED.
257        let orphan_ids: Vec<String> = latest_per_run
258            .into_values()
259            .filter(|row| {
260                matches!(
261                    row.status,
262                    pipeline_run::PipelineRunStatus::Initiated
263                        | pipeline_run::PipelineRunStatus::Started
264                )
265            })
266            .map(|row| row.id)
267            .collect();
268
269        if orphan_ids.is_empty() {
270            return Ok(0);
271        }
272
273        // Write new ERRORED rows for each orphan (new-row-per-transition pattern).
274        let reason_info = json!({"reason": reason});
275        let mut count = 0u64;
276        for orphan_id in &orphan_ids {
277            // Fetch the orphan row to get all its fields.
278            let orphan_opt = pipeline_run::Entity::find_by_id(orphan_id.clone())
279                .one(self.db.as_ref())
280                .await
281                .map_err(|e| {
282                    DatabaseError::QueryError(format!("reset_orphans fetch orphan failed: {e}"))
283                })?;
284
285            if let Some(orphan) = orphan_opt {
286                let new_id = Uuid::new_v4();
287                let active = pipeline_run::ActiveModel {
288                    id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(new_id)),
289                    created_at: sea_orm::ActiveValue::Set(Utc::now()),
290                    status: sea_orm::ActiveValue::Set(pipeline_run::PipelineRunStatus::Errored),
291                    pipeline_run_id: sea_orm::ActiveValue::Set(orphan.pipeline_run_id),
292                    pipeline_name: sea_orm::ActiveValue::Set(orphan.pipeline_name),
293                    pipeline_id: sea_orm::ActiveValue::Set(orphan.pipeline_id),
294                    dataset_id: sea_orm::ActiveValue::Set(orphan.dataset_id),
295                    run_info: sea_orm::ActiveValue::Set(Some(reason_info.clone())),
296                };
297                active.insert(self.db.as_ref()).await.map_err(|e| {
298                    DatabaseError::QueryError(format!("reset_orphans insert failed: {e}"))
299                })?;
300                count += 1;
301            }
302        }
303
304        Ok(count)
305    }
306
307    async fn set_payload_field(
308        &self,
309        run_id: Uuid,
310        key: &str,
311        value: serde_json::Value,
312    ) -> Result<(), DatabaseError> {
313        use sea_orm::sea_query::OnConflict;
314
315        let now = Utc::now();
316        let model = pipeline_run_payload_field::ActiveModel {
317            pipeline_run_id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(run_id)),
318            key: sea_orm::ActiveValue::Set(key.to_owned()),
319            value: sea_orm::ActiveValue::Set(value),
320            created_at: sea_orm::ActiveValue::Set(now),
321            updated_at: sea_orm::ActiveValue::Set(now),
322        };
323
324        pipeline_run_payload_field::Entity::insert(model)
325            .on_conflict(
326                OnConflict::columns([
327                    pipeline_run_payload_field::Column::PipelineRunId,
328                    pipeline_run_payload_field::Column::Key,
329                ])
330                .update_columns([
331                    pipeline_run_payload_field::Column::Value,
332                    pipeline_run_payload_field::Column::UpdatedAt,
333                ])
334                .to_owned(),
335            )
336            .exec(self.db.as_ref())
337            .await
338            .map_err(|e| {
339                DatabaseError::QueryError(format!("set_payload_field upsert failed: {e}"))
340            })?;
341        Ok(())
342    }
343
344    async fn get_payload(
345        &self,
346        run_id: Uuid,
347    ) -> Result<serde_json::Map<String, serde_json::Value>, DatabaseError> {
348        let rows = pipeline_run_payload_field::Entity::find()
349            .filter(pipeline_run_payload_field::Column::PipelineRunId.eq(uuid_hex::to_hex(run_id)))
350            .all(self.db.as_ref())
351            .await
352            .map_err(|e| DatabaseError::QueryError(format!("get_payload query failed: {e}")))?;
353
354        Ok(rows.into_iter().map(|m| (m.key, m.value)).collect())
355    }
356
357    async fn get_pipeline_run(
358        &self,
359        pipeline_run_id: Uuid,
360    ) -> Result<Option<PipelineRun>, DatabaseError> {
361        let row = pipeline_run::Entity::find()
362            .filter(pipeline_run::Column::PipelineRunId.eq(uuid_hex::to_hex(pipeline_run_id)))
363            .order_by_desc(pipeline_run::Column::CreatedAt)
364            .one(self.db.as_ref())
365            .await
366            .map_err(|e| {
367                DatabaseError::QueryError(format!("get_pipeline_run query failed: {e}"))
368            })?;
369        Ok(row.map(PipelineRun::from))
370    }
371
372    async fn get_pipeline_run_by_dataset(
373        &self,
374        dataset_id: Uuid,
375        pipeline_name: &str,
376    ) -> Result<Option<PipelineRun>, DatabaseError> {
377        // `dataset_id` is the function parameter (non-nullable `Uuid`); we
378        // match the column against the hex string. Per decision 4 the column
379        // is `Option<String>` post-08-01 but a literal `eq(...)` only matches
380        // non-NULL rows — exactly what we want here.
381        let row = pipeline_run::Entity::find()
382            .filter(pipeline_run::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
383            .filter(pipeline_run::Column::PipelineName.eq(pipeline_name))
384            .order_by_desc(pipeline_run::Column::CreatedAt)
385            .one(self.db.as_ref())
386            .await
387            .map_err(|e| {
388                DatabaseError::QueryError(format!("get_pipeline_run_by_dataset query failed: {e}"))
389            })?;
390        Ok(row.map(PipelineRun::from))
391    }
392
393    async fn get_pipeline_runs_by_dataset(
394        &self,
395        dataset_id: Uuid,
396    ) -> Result<Vec<PipelineRun>, DatabaseError> {
397        // Fetch every row for `dataset_id`, newest first, then collapse to
398        // one entry per distinct `pipeline_name` (keeping the first / newest
399        // seen). Matches Python's behaviour where the helper groups by
400        // pipeline_name and picks the latest row.
401        let rows = pipeline_run::Entity::find()
402            .filter(pipeline_run::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
403            .order_by_desc(pipeline_run::Column::CreatedAt)
404            .all(self.db.as_ref())
405            .await
406            .map_err(|e| {
407                DatabaseError::QueryError(format!("get_pipeline_runs_by_dataset query failed: {e}"))
408            })?;
409
410        let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
411        let mut out = Vec::new();
412        for row in rows {
413            if seen.insert(row.pipeline_name.clone()) {
414                out.push(PipelineRun::from(row));
415            }
416        }
417        Ok(out)
418    }
419}