1use chrono::{DateTime, Utc};
2use cognee_models::{Data, Dataset};
3use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
4use sea_orm::{
5 ActiveModelTrait, ActiveValue::Set, ColumnTrait, DatabaseConnection, EntityTrait,
6 IntoActiveModel, PaginatorTrait, QueryFilter, sea_query::Expr,
7};
8use tracing::{Span, instrument};
9use uuid::Uuid;
10
11use crate::conversions::map_sea_err;
12use crate::database_system_label;
13use crate::entities::{data, dataset, dataset_data};
14use crate::types::DatabaseError;
15use crate::uuid_hex;
16
17#[instrument(
18 name = "cognee.db.relational.data.create_data",
19 level = "info",
20 skip_all,
21 fields(cognee.db.system = tracing::field::Empty),
22 err,
23)]
24pub async fn create_data(db: &DatabaseConnection, d: Data) -> Result<Data, DatabaseError> {
25 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
26 data::ActiveModel::from(&d)
27 .insert(db)
28 .await
29 .map_err(map_sea_err)?;
30 Ok(d)
31}
32
33#[instrument(
34 name = "cognee.db.relational.data.get_data",
35 level = "info",
36 skip_all,
37 fields(
38 cognee.db.system = tracing::field::Empty,
39 cognee.db.row_count = tracing::field::Empty,
40 ),
41 err,
42)]
43pub async fn get_data(db: &DatabaseConnection, id: Uuid) -> Result<Option<Data>, DatabaseError> {
44 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
45 let result = data::Entity::find_by_id(uuid_hex::to_hex(id))
46 .one(db)
47 .await
48 .map_err(map_sea_err)
49 .map(|opt| opt.map(Data::from))?;
50 Span::current().record(
51 COGNEE_DB_ROW_COUNT,
52 if result.is_some() { 1i64 } else { 0i64 },
53 );
54 Ok(result)
55}
56
57#[instrument(
58 name = "cognee.db.relational.data.delete_data",
59 level = "info",
60 skip_all,
61 fields(cognee.db.system = tracing::field::Empty),
62 err,
63)]
64pub async fn delete_data(db: &DatabaseConnection, id: Uuid) -> Result<(), DatabaseError> {
65 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
66 data::Entity::delete_by_id(uuid_hex::to_hex(id))
67 .exec(db)
68 .await
69 .map_err(map_sea_err)?;
70 Ok(())
71}
72
73#[instrument(
74 name = "cognee.db.relational.data.update_data",
75 level = "info",
76 skip_all,
77 fields(cognee.db.system = tracing::field::Empty),
78 err,
79)]
80pub async fn update_data(db: &DatabaseConnection, d: Data) -> Result<Data, DatabaseError> {
81 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
82 let mut model = data::ActiveModel::from(&d);
83 model.updated_at = Set(Some(Utc::now()));
84 model.update(db).await.map_err(map_sea_err)?;
85 Ok(d)
86}
87
88#[instrument(
89 name = "cognee.db.relational.data.count_data_dataset_links",
90 level = "info",
91 skip_all,
92 fields(
93 cognee.db.system = tracing::field::Empty,
94 cognee.db.row_count = tracing::field::Empty,
95 ),
96 err,
97)]
98pub async fn count_data_dataset_links(
99 db: &DatabaseConnection,
100 data_id: Uuid,
101) -> Result<usize, DatabaseError> {
102 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
103 let count: u64 = dataset_data::Entity::find()
104 .filter(dataset_data::Column::DataId.eq(uuid_hex::to_hex(data_id)))
105 .count(db)
106 .await
107 .map_err(map_sea_err)?;
108 Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
109 Ok(count as usize)
110}
111
112#[instrument(
117 name = "cognee.db.relational.data.update_data_token_count",
118 level = "info",
119 skip_all,
120 fields(cognee.db.system = tracing::field::Empty),
121 err,
122)]
123pub async fn update_data_token_count(
124 db: &DatabaseConnection,
125 data_id: Uuid,
126 token_count: i64,
127) -> Result<(), DatabaseError> {
128 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
129 let result = data::Entity::update_many()
132 .col_expr(data::Column::TokenCount, Expr::value(token_count))
133 .col_expr(data::Column::UpdatedAt, Expr::value(Some(Utc::now())))
134 .filter(data::Column::Id.eq(uuid_hex::to_hex(data_id)))
135 .exec(db)
136 .await
137 .map_err(map_sea_err)?;
138
139 if result.rows_affected == 0 {
140 return Err(DatabaseError::NotFound(format!("Data {data_id} not found")));
141 }
142 Ok(())
143}
144
145#[instrument(
149 name = "cognee.db.relational.data.update_last_accessed",
150 level = "info",
151 skip_all,
152 fields(cognee.db.system = tracing::field::Empty),
153 err,
154)]
155pub async fn update_last_accessed(
156 db: &DatabaseConnection,
157 data_ids: &[Uuid],
158 timestamp: DateTime<Utc>,
159) -> Result<(), DatabaseError> {
160 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
161 if data_ids.is_empty() {
162 return Ok(());
163 }
164
165 let hex_ids: Vec<_> = data_ids.iter().map(|id| uuid_hex::to_hex(*id)).collect();
167 data::Entity::update_many()
168 .col_expr(data::Column::LastAccessed, Expr::value(Some(timestamp)))
169 .filter(data::Column::Id.is_in(hex_ids))
170 .exec(db)
171 .await
172 .map_err(map_sea_err)?;
173
174 Ok(())
175}
176
177#[instrument(
188 name = "cognee.db.relational.data.clear_pipeline_status_for_dataset",
189 level = "info",
190 skip_all,
191 fields(
192 cognee.db.system = tracing::field::Empty,
193 cognee.db.row_count = tracing::field::Empty,
194 ),
195 err,
196)]
197pub async fn clear_pipeline_status_for_dataset(
198 db: &DatabaseConnection,
199 dataset_id: Uuid,
200) -> Result<usize, DatabaseError> {
201 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
202 let junction_rows = dataset_data::Entity::find()
204 .filter(dataset_data::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
205 .all(db)
206 .await
207 .map_err(map_sea_err)?;
208
209 let data_ids: Vec<String> = junction_rows.into_iter().map(|j| j.data_id).collect();
210 if data_ids.is_empty() {
211 Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
212 return Ok(0);
213 }
214
215 let dataset_id_str = uuid_hex::to_hex(dataset_id);
216 let mut updated_count = 0usize;
217
218 const READ_CHUNK: usize = 500;
225 let mut models = Vec::with_capacity(data_ids.len());
226 for chunk in data_ids.chunks(READ_CHUNK) {
227 models.extend(
228 data::Entity::find()
229 .filter(data::Column::Id.is_in(chunk.to_vec()))
230 .all(db)
231 .await
232 .map_err(map_sea_err)?,
233 );
234 }
235
236 for model in models {
237 let Some(ref status_json) = model.pipeline_status else {
238 continue;
239 };
240
241 let mut parsed: serde_json::Value = serde_json::from_str(status_json)
242 .unwrap_or(serde_json::Value::Object(Default::default()));
243
244 let serde_json::Value::Object(ref mut top_map) = parsed else {
245 continue;
246 };
247
248 let mut modified = false;
249 for (_pipeline_name, inner) in top_map.iter_mut() {
250 if let serde_json::Value::Object(inner_map) = inner
251 && inner_map.remove(&dataset_id_str).is_some()
252 {
253 modified = true;
254 }
255 }
256
257 if !modified {
258 continue;
259 }
260
261 top_map.retain(|_, v| !matches!(v, serde_json::Value::Object(m) if m.is_empty()));
263
264 let new_status = if top_map.is_empty() {
265 None
266 } else {
267 Some(serde_json::to_string(&parsed).map_err(|e| {
268 DatabaseError::QueryError(format!("Failed to serialize pipeline_status: {e}"))
269 })?)
270 };
271
272 let mut active = model.into_active_model();
273 active.pipeline_status = Set(new_status);
274 active.updated_at = Set(Some(Utc::now()));
275 active.update(db).await.map_err(map_sea_err)?;
276 updated_count += 1;
277 }
278
279 Span::current().record(COGNEE_DB_ROW_COUNT, updated_count as i64);
280 Ok(updated_count)
281}
282
283#[instrument(
288 name = "cognee.db.relational.data.clear_cognify_pipeline_status_for_data",
289 level = "info",
290 skip_all,
291 fields(
292 cognee.db.system = tracing::field::Empty,
293 ),
294 err,
295)]
296pub async fn clear_cognify_pipeline_status_for_data(
297 db: &DatabaseConnection,
298 data_id: Uuid,
299 dataset_id: Uuid,
300) -> Result<(), DatabaseError> {
301 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
302 let model = data::Entity::find_by_id(uuid_hex::to_hex(data_id))
303 .one(db)
304 .await
305 .map_err(map_sea_err)?;
306
307 let Some(model) = model else {
308 return Ok(());
309 };
310
311 let Some(ref status_json) = model.pipeline_status else {
312 return Ok(());
313 };
314
315 let mut parsed: serde_json::Value =
316 serde_json::from_str(status_json).unwrap_or(serde_json::Value::Object(Default::default()));
317
318 let serde_json::Value::Object(ref mut top_map) = parsed else {
319 return Ok(());
320 };
321
322 let dataset_id_str = uuid_hex::to_hex(dataset_id);
323 let Some(inner) = top_map.get_mut("cognify_pipeline") else {
324 return Ok(());
325 };
326 let modified = if let serde_json::Value::Object(inner_map) = inner {
327 inner_map.remove(&dataset_id_str).is_some()
328 } else {
329 false
330 };
331
332 if !modified {
333 return Ok(());
334 }
335
336 top_map.retain(|k, v| {
338 k != "cognify_pipeline" || !matches!(v, serde_json::Value::Object(m) if m.is_empty())
339 });
340
341 let new_status = if top_map.is_empty() {
342 None
343 } else {
344 Some(serde_json::to_string(&parsed).map_err(|e| {
345 DatabaseError::QueryError(format!("Failed to serialize pipeline_status: {e}"))
346 })?)
347 };
348
349 let mut active = model.into_active_model();
350 active.pipeline_status = Set(new_status);
351 active.updated_at = Set(Some(Utc::now()));
352 active.update(db).await.map_err(map_sea_err)?;
353 Ok(())
354}
355
356#[instrument(
357 name = "cognee.db.relational.data.list_datasets_for_data",
358 level = "info",
359 skip_all,
360 fields(
361 cognee.db.system = tracing::field::Empty,
362 cognee.db.row_count = tracing::field::Empty,
363 ),
364 err,
365)]
366pub async fn list_datasets_for_data(
367 db: &DatabaseConnection,
368 data_id: Uuid,
369) -> Result<Vec<Dataset>, DatabaseError> {
370 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
371 let pairs = data::Entity::find_by_id(uuid_hex::to_hex(data_id))
372 .find_with_related(dataset::Entity)
373 .all(db)
374 .await
375 .map_err(map_sea_err)?;
376 let datasets: Vec<Dataset> = pairs
377 .into_iter()
378 .flat_map(|(_, ds_list)| ds_list)
379 .map(Dataset::from)
380 .collect();
381 Span::current().record(COGNEE_DB_ROW_COUNT, datasets.len() as i64);
382 Ok(datasets)
383}