cognee_database/pipelines/
sea_orm_impl.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::Utc;
6use sea_orm::{
7 ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, QueryOrder,
8 QuerySelect, RelationTrait,
9};
10use serde_json::json;
11use uuid::Uuid;
12
13use crate::conversions::{domain_status_to_entity, entity_status_to_domain};
14use crate::entities::{dataset, pipeline_run, pipeline_run_payload_field};
15use crate::types::{DatabaseError, PipelineRun, PipelineRunStatus};
16use crate::uuid_hex;
17
18use super::repository::{PipelineRunRepository, PipelineRunRow, PipelineRunWithAttributionRow};
19
20pub struct SeaOrmPipelineRunRepository {
27 db: Arc<DatabaseConnection>,
28}
29
30impl SeaOrmPipelineRunRepository {
31 pub fn new(db: Arc<DatabaseConnection>) -> Self {
33 Self { db }
34 }
35}
36
37#[async_trait]
38impl PipelineRunRepository for SeaOrmPipelineRunRepository {
39 async fn log_pipeline_run(
40 &self,
41 pipeline_run_id: Uuid,
42 pipeline_id: Uuid,
43 pipeline_name: &str,
44 dataset_id: Option<Uuid>,
45 status: PipelineRunStatus,
46 run_info: Option<serde_json::Value>,
47 ) -> Result<Uuid, DatabaseError> {
48 let row_id = Uuid::new_v4();
49
50 let active = pipeline_run::ActiveModel {
53 id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(row_id)),
54 created_at: sea_orm::ActiveValue::Set(Utc::now()),
55 status: sea_orm::ActiveValue::Set(domain_status_to_entity(status)),
56 pipeline_run_id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(pipeline_run_id)),
57 pipeline_name: sea_orm::ActiveValue::Set(pipeline_name.to_string()),
58 pipeline_id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(pipeline_id)),
59 dataset_id: sea_orm::ActiveValue::Set(uuid_hex::to_hex_opt(dataset_id)),
60 run_info: sea_orm::ActiveValue::Set(run_info),
61 };
62
63 active.insert(self.db.as_ref()).await.map_err(|e| {
64 DatabaseError::QueryError(format!("log_pipeline_run insert failed: {e}"))
65 })?;
66
67 Ok(row_id)
68 }
69
70 async fn latest_status(
71 &self,
72 dataset_ids: &[Uuid],
73 pipeline_name: &str,
74 ) -> Result<HashMap<Uuid, PipelineRunStatus>, DatabaseError> {
75 if dataset_ids.is_empty() {
76 return Ok(HashMap::new());
77 }
78
79 let hex_ids: Vec<String> = dataset_ids.iter().map(|id| uuid_hex::to_hex(*id)).collect();
80
81 let rows = pipeline_run::Entity::find()
84 .filter(pipeline_run::Column::PipelineName.eq(pipeline_name))
85 .filter(pipeline_run::Column::DatasetId.is_in(hex_ids))
86 .order_by_desc(pipeline_run::Column::CreatedAt)
87 .all(self.db.as_ref())
88 .await
89 .map_err(|e| DatabaseError::QueryError(format!("latest_status query failed: {e}")))?;
90
91 let mut result: HashMap<Uuid, PipelineRunStatus> = HashMap::new();
92 for row in rows {
93 let run: PipelineRun = row.into();
94 if let Some(did) = run.dataset_id {
101 result.entry(did).or_insert(run.status);
102 }
103 }
104
105 Ok(result)
106 }
107
108 async fn list_recent(
109 &self,
110 dataset_id: Option<Uuid>,
111 limit: u32,
112 ) -> Result<Vec<PipelineRunRow>, DatabaseError> {
113 let mut query = pipeline_run::Entity::find()
114 .order_by_desc(pipeline_run::Column::CreatedAt)
115 .limit(u64::from(limit));
116
117 if let Some(did) = dataset_id {
118 query = query.filter(pipeline_run::Column::DatasetId.eq(uuid_hex::to_hex(did)));
119 }
120
121 let rows = query
122 .all(self.db.as_ref())
123 .await
124 .map_err(|e| DatabaseError::QueryError(format!("list_recent query failed: {e}")))?;
125
126 Ok(rows.into_iter().map(PipelineRun::from).collect())
127 }
128
129 async fn list_recent_with_attribution(
130 &self,
131 dataset_id: Option<Uuid>,
132 limit: u32,
133 ) -> Result<Vec<PipelineRunWithAttributionRow>, DatabaseError> {
134 use sea_orm::JoinType;
135
136 let mut query = pipeline_run::Entity::find()
145 .select_only()
146 .column(pipeline_run::Column::Id)
147 .column(pipeline_run::Column::CreatedAt)
148 .column(pipeline_run::Column::Status)
149 .column(pipeline_run::Column::PipelineRunId)
150 .column(pipeline_run::Column::PipelineName)
151 .column(pipeline_run::Column::PipelineId)
152 .column(pipeline_run::Column::DatasetId)
153 .column_as(dataset::Column::Name, "dataset_name")
154 .column_as(dataset::Column::OwnerId, "dataset_owner_id")
155 .join(JoinType::LeftJoin, pipeline_run::Relation::Dataset.def())
156 .order_by_desc(pipeline_run::Column::CreatedAt)
157 .limit(u64::from(limit));
158
159 if let Some(did) = dataset_id {
160 query = query.filter(pipeline_run::Column::DatasetId.eq(uuid_hex::to_hex(did)));
161 }
162
163 let raw = query
164 .into_tuple::<(
165 String,
166 chrono::DateTime<Utc>,
167 pipeline_run::PipelineRunStatus,
168 String,
169 String,
170 String,
171 Option<String>,
172 Option<String>,
173 Option<String>,
174 )>()
175 .all(self.db.as_ref())
176 .await
177 .map_err(|e| {
178 DatabaseError::QueryError(format!("list_recent_with_attribution query failed: {e}"))
179 })?;
180
181 let mut rows = Vec::with_capacity(raw.len());
182 for (
183 id_hex,
184 created_at,
185 status,
186 pipeline_run_hex,
187 pipeline_name,
188 pipeline_id_hex,
189 dataset_id_hex,
190 dataset_name,
191 owner_id_hex,
192 ) in raw
193 {
194 let dataset_uuid = dataset_id_hex
199 .as_deref()
200 .and_then(|s| uuid_hex::from_hex(s).ok());
201 let owner_uuid = owner_id_hex
202 .as_deref()
203 .and_then(|s| uuid_hex::from_hex(s).ok());
204 let (dataset_id_field, dataset_name_field) = if dataset_name.is_some() {
207 (dataset_uuid, dataset_name)
208 } else {
209 (dataset_uuid, None)
210 };
211
212 rows.push(PipelineRunWithAttributionRow {
213 id: uuid_hex::from_hex(&id_hex)
214 .map_err(|e| DatabaseError::QueryError(format!("invalid id hex: {e}")))?,
215 created_at,
216 status: entity_status_to_domain(status),
217 pipeline_run_id: uuid_hex::from_hex(&pipeline_run_hex).map_err(|e| {
218 DatabaseError::QueryError(format!("invalid pipeline_run_id hex: {e}"))
219 })?,
220 pipeline_name,
221 pipeline_id: uuid_hex::from_hex(&pipeline_id_hex).map_err(|e| {
222 DatabaseError::QueryError(format!("invalid pipeline_id hex: {e}"))
223 })?,
224 dataset_id: dataset_id_field,
225 dataset_name: dataset_name_field,
226 owner_id: owner_uuid,
227 owner_email: None,
228 });
229 }
230 Ok(rows)
231 }
232
233 async fn reset_orphans(&self, reason: &str) -> Result<u64, DatabaseError> {
234 let all_rows = pipeline_run::Entity::find()
243 .order_by_desc(pipeline_run::Column::CreatedAt)
244 .all(self.db.as_ref())
245 .await
246 .map_err(|e| DatabaseError::QueryError(format!("reset_orphans fetch failed: {e}")))?;
247
248 let mut latest_per_run: HashMap<String, pipeline_run::Model> = HashMap::new();
250 for row in all_rows {
251 latest_per_run
252 .entry(row.pipeline_run_id.clone())
253 .or_insert(row);
254 }
255
256 let orphan_ids: Vec<String> = latest_per_run
258 .into_values()
259 .filter(|row| {
260 matches!(
261 row.status,
262 pipeline_run::PipelineRunStatus::Initiated
263 | pipeline_run::PipelineRunStatus::Started
264 )
265 })
266 .map(|row| row.id)
267 .collect();
268
269 if orphan_ids.is_empty() {
270 return Ok(0);
271 }
272
273 let reason_info = json!({"reason": reason});
275 let mut count = 0u64;
276 for orphan_id in &orphan_ids {
277 let orphan_opt = pipeline_run::Entity::find_by_id(orphan_id.clone())
279 .one(self.db.as_ref())
280 .await
281 .map_err(|e| {
282 DatabaseError::QueryError(format!("reset_orphans fetch orphan failed: {e}"))
283 })?;
284
285 if let Some(orphan) = orphan_opt {
286 let new_id = Uuid::new_v4();
287 let active = pipeline_run::ActiveModel {
288 id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(new_id)),
289 created_at: sea_orm::ActiveValue::Set(Utc::now()),
290 status: sea_orm::ActiveValue::Set(pipeline_run::PipelineRunStatus::Errored),
291 pipeline_run_id: sea_orm::ActiveValue::Set(orphan.pipeline_run_id),
292 pipeline_name: sea_orm::ActiveValue::Set(orphan.pipeline_name),
293 pipeline_id: sea_orm::ActiveValue::Set(orphan.pipeline_id),
294 dataset_id: sea_orm::ActiveValue::Set(orphan.dataset_id),
295 run_info: sea_orm::ActiveValue::Set(Some(reason_info.clone())),
296 };
297 active.insert(self.db.as_ref()).await.map_err(|e| {
298 DatabaseError::QueryError(format!("reset_orphans insert failed: {e}"))
299 })?;
300 count += 1;
301 }
302 }
303
304 Ok(count)
305 }
306
307 async fn set_payload_field(
308 &self,
309 run_id: Uuid,
310 key: &str,
311 value: serde_json::Value,
312 ) -> Result<(), DatabaseError> {
313 use sea_orm::sea_query::OnConflict;
314
315 let now = Utc::now();
316 let model = pipeline_run_payload_field::ActiveModel {
317 pipeline_run_id: sea_orm::ActiveValue::Set(uuid_hex::to_hex(run_id)),
318 key: sea_orm::ActiveValue::Set(key.to_owned()),
319 value: sea_orm::ActiveValue::Set(value),
320 created_at: sea_orm::ActiveValue::Set(now),
321 updated_at: sea_orm::ActiveValue::Set(now),
322 };
323
324 pipeline_run_payload_field::Entity::insert(model)
325 .on_conflict(
326 OnConflict::columns([
327 pipeline_run_payload_field::Column::PipelineRunId,
328 pipeline_run_payload_field::Column::Key,
329 ])
330 .update_columns([
331 pipeline_run_payload_field::Column::Value,
332 pipeline_run_payload_field::Column::UpdatedAt,
333 ])
334 .to_owned(),
335 )
336 .exec(self.db.as_ref())
337 .await
338 .map_err(|e| {
339 DatabaseError::QueryError(format!("set_payload_field upsert failed: {e}"))
340 })?;
341 Ok(())
342 }
343
344 async fn get_payload(
345 &self,
346 run_id: Uuid,
347 ) -> Result<serde_json::Map<String, serde_json::Value>, DatabaseError> {
348 let rows = pipeline_run_payload_field::Entity::find()
349 .filter(pipeline_run_payload_field::Column::PipelineRunId.eq(uuid_hex::to_hex(run_id)))
350 .all(self.db.as_ref())
351 .await
352 .map_err(|e| DatabaseError::QueryError(format!("get_payload query failed: {e}")))?;
353
354 Ok(rows.into_iter().map(|m| (m.key, m.value)).collect())
355 }
356
357 async fn get_pipeline_run(
358 &self,
359 pipeline_run_id: Uuid,
360 ) -> Result<Option<PipelineRun>, DatabaseError> {
361 let row = pipeline_run::Entity::find()
362 .filter(pipeline_run::Column::PipelineRunId.eq(uuid_hex::to_hex(pipeline_run_id)))
363 .order_by_desc(pipeline_run::Column::CreatedAt)
364 .one(self.db.as_ref())
365 .await
366 .map_err(|e| {
367 DatabaseError::QueryError(format!("get_pipeline_run query failed: {e}"))
368 })?;
369 Ok(row.map(PipelineRun::from))
370 }
371
372 async fn get_pipeline_run_by_dataset(
373 &self,
374 dataset_id: Uuid,
375 pipeline_name: &str,
376 ) -> Result<Option<PipelineRun>, DatabaseError> {
377 let row = pipeline_run::Entity::find()
382 .filter(pipeline_run::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
383 .filter(pipeline_run::Column::PipelineName.eq(pipeline_name))
384 .order_by_desc(pipeline_run::Column::CreatedAt)
385 .one(self.db.as_ref())
386 .await
387 .map_err(|e| {
388 DatabaseError::QueryError(format!("get_pipeline_run_by_dataset query failed: {e}"))
389 })?;
390 Ok(row.map(PipelineRun::from))
391 }
392
393 async fn get_pipeline_runs_by_dataset(
394 &self,
395 dataset_id: Uuid,
396 ) -> Result<Vec<PipelineRun>, DatabaseError> {
397 let rows = pipeline_run::Entity::find()
402 .filter(pipeline_run::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
403 .order_by_desc(pipeline_run::Column::CreatedAt)
404 .all(self.db.as_ref())
405 .await
406 .map_err(|e| {
407 DatabaseError::QueryError(format!("get_pipeline_runs_by_dataset query failed: {e}"))
408 })?;
409
410 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
411 let mut out = Vec::new();
412 for row in rows {
413 if seen.insert(row.pipeline_name.clone()) {
414 out.push(PipelineRun::from(row));
415 }
416 }
417 Ok(out)
418 }
419}