Skip to main content

cognee_database/ops/
data.rs

1use chrono::{DateTime, Utc};
2use cognee_models::{Data, Dataset};
3use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
4use sea_orm::{
5    ActiveModelTrait, ActiveValue::Set, ColumnTrait, DatabaseConnection, EntityTrait,
6    IntoActiveModel, PaginatorTrait, QueryFilter,
7};
8use tracing::{Span, instrument};
9use uuid::Uuid;
10
11use crate::conversions::map_sea_err;
12use crate::database_system_label;
13use crate::entities::{data, dataset, dataset_data};
14use crate::types::DatabaseError;
15use crate::uuid_hex;
16
17#[instrument(
18    name = "cognee.db.relational.data.create_data",
19    level = "info",
20    skip_all,
21    fields(cognee.db.system = tracing::field::Empty),
22    err,
23)]
24pub async fn create_data(db: &DatabaseConnection, d: Data) -> Result<Data, DatabaseError> {
25    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
26    data::ActiveModel::from(&d)
27        .insert(db)
28        .await
29        .map_err(map_sea_err)?;
30    Ok(d)
31}
32
33#[instrument(
34    name = "cognee.db.relational.data.get_data",
35    level = "info",
36    skip_all,
37    fields(
38        cognee.db.system = tracing::field::Empty,
39        cognee.db.row_count = tracing::field::Empty,
40    ),
41    err,
42)]
43pub async fn get_data(db: &DatabaseConnection, id: Uuid) -> Result<Option<Data>, DatabaseError> {
44    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
45    let result = data::Entity::find_by_id(uuid_hex::to_hex(id))
46        .one(db)
47        .await
48        .map_err(map_sea_err)
49        .map(|opt| opt.map(Data::from))?;
50    Span::current().record(
51        COGNEE_DB_ROW_COUNT,
52        if result.is_some() { 1i64 } else { 0i64 },
53    );
54    Ok(result)
55}
56
57#[instrument(
58    name = "cognee.db.relational.data.delete_data",
59    level = "info",
60    skip_all,
61    fields(cognee.db.system = tracing::field::Empty),
62    err,
63)]
64pub async fn delete_data(db: &DatabaseConnection, id: Uuid) -> Result<(), DatabaseError> {
65    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
66    data::Entity::delete_by_id(uuid_hex::to_hex(id))
67        .exec(db)
68        .await
69        .map_err(map_sea_err)?;
70    Ok(())
71}
72
73#[instrument(
74    name = "cognee.db.relational.data.update_data",
75    level = "info",
76    skip_all,
77    fields(cognee.db.system = tracing::field::Empty),
78    err,
79)]
80pub async fn update_data(db: &DatabaseConnection, d: Data) -> Result<Data, DatabaseError> {
81    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
82    let mut model = data::ActiveModel::from(&d);
83    model.updated_at = Set(Some(Utc::now()));
84    model.update(db).await.map_err(map_sea_err)?;
85    Ok(d)
86}
87
88#[instrument(
89    name = "cognee.db.relational.data.count_data_dataset_links",
90    level = "info",
91    skip_all,
92    fields(
93        cognee.db.system = tracing::field::Empty,
94        cognee.db.row_count = tracing::field::Empty,
95    ),
96    err,
97)]
98pub async fn count_data_dataset_links(
99    db: &DatabaseConnection,
100    data_id: Uuid,
101) -> Result<usize, DatabaseError> {
102    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
103    let count: u64 = dataset_data::Entity::find()
104        .filter(dataset_data::Column::DataId.eq(uuid_hex::to_hex(data_id)))
105        .count(db)
106        .await
107        .map_err(map_sea_err)?;
108    Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
109    Ok(count as usize)
110}
111
112/// Update only the `token_count` column for a Data record.
113///
114/// Mirrors the Python `update_document_token_count()` in
115/// `cognee/tasks/documents/extract_chunks_from_documents.py`.
116#[instrument(
117    name = "cognee.db.relational.data.update_data_token_count",
118    level = "info",
119    skip_all,
120    fields(cognee.db.system = tracing::field::Empty),
121    err,
122)]
123pub async fn update_data_token_count(
124    db: &DatabaseConnection,
125    data_id: Uuid,
126    token_count: i64,
127) -> Result<(), DatabaseError> {
128    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
129    let model = data::Entity::find_by_id(uuid_hex::to_hex(data_id))
130        .one(db)
131        .await
132        .map_err(map_sea_err)?
133        .ok_or_else(|| DatabaseError::NotFound(format!("Data {data_id} not found")))?;
134
135    let mut active = model.into_active_model();
136    active.token_count = Set(token_count);
137    active.updated_at = Set(Some(Utc::now()));
138    active.update(db).await.map_err(map_sea_err)?;
139    Ok(())
140}
141
142/// Update `last_accessed` for a batch of Data records identified by their IDs.
143///
144/// This is a no-op when `data_ids` is empty.
145#[instrument(
146    name = "cognee.db.relational.data.update_last_accessed",
147    level = "info",
148    skip_all,
149    fields(cognee.db.system = tracing::field::Empty),
150    err,
151)]
152pub async fn update_last_accessed(
153    db: &DatabaseConnection,
154    data_ids: &[Uuid],
155    timestamp: DateTime<Utc>,
156) -> Result<(), DatabaseError> {
157    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
158    if data_ids.is_empty() {
159        return Ok(());
160    }
161
162    for id in data_ids {
163        let model = data::Entity::find_by_id(uuid_hex::to_hex(*id))
164            .one(db)
165            .await
166            .map_err(map_sea_err)?;
167
168        if let Some(m) = model {
169            let mut active = m.into_active_model();
170            active.last_accessed = Set(Some(timestamp));
171            active.update(db).await.map_err(map_sea_err)?;
172        }
173    }
174
175    Ok(())
176}
177
178/// Clear `pipeline_status` JSON entries keyed by the given `dataset_id`
179/// from all `Data` records linked to that dataset via the `dataset_data`
180/// junction table.
181///
182/// This mirrors the Python cleanup in `delete_dataset.py` lines 33-54.
183/// Must be called **before** the junction rows are removed (before
184/// `detach_data_from_dataset` or `delete_dataset`), since the junction is
185/// needed to find related `Data` records.
186///
187/// Returns the number of `Data` records whose `pipeline_status` was modified.
188#[instrument(
189    name = "cognee.db.relational.data.clear_pipeline_status_for_dataset",
190    level = "info",
191    skip_all,
192    fields(
193        cognee.db.system = tracing::field::Empty,
194        cognee.db.row_count = tracing::field::Empty,
195    ),
196    err,
197)]
198pub async fn clear_pipeline_status_for_dataset(
199    db: &DatabaseConnection,
200    dataset_id: Uuid,
201) -> Result<usize, DatabaseError> {
202    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
203    // Find all data IDs linked to this dataset via the junction table
204    let junction_rows = dataset_data::Entity::find()
205        .filter(dataset_data::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
206        .all(db)
207        .await
208        .map_err(map_sea_err)?;
209
210    let data_ids: Vec<String> = junction_rows.into_iter().map(|j| j.data_id).collect();
211    if data_ids.is_empty() {
212        Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
213        return Ok(0);
214    }
215
216    let dataset_id_str = uuid_hex::to_hex(dataset_id);
217    let mut updated_count = 0usize;
218
219    for data_hex_id in &data_ids {
220        let model = data::Entity::find_by_id(data_hex_id.clone())
221            .one(db)
222            .await
223            .map_err(map_sea_err)?;
224
225        let Some(model) = model else { continue };
226
227        let Some(ref status_json) = model.pipeline_status else {
228            continue;
229        };
230
231        let mut parsed: serde_json::Value = serde_json::from_str(status_json)
232            .unwrap_or(serde_json::Value::Object(Default::default()));
233
234        let serde_json::Value::Object(ref mut top_map) = parsed else {
235            continue;
236        };
237
238        let mut modified = false;
239        for (_pipeline_name, inner) in top_map.iter_mut() {
240            if let serde_json::Value::Object(inner_map) = inner
241                && inner_map.remove(&dataset_id_str).is_some()
242            {
243                modified = true;
244            }
245        }
246
247        if !modified {
248            continue;
249        }
250
251        // Remove pipeline entries whose inner map is now empty
252        top_map.retain(|_, v| !matches!(v, serde_json::Value::Object(m) if m.is_empty()));
253
254        let new_status = if top_map.is_empty() {
255            None
256        } else {
257            Some(serde_json::to_string(&parsed).map_err(|e| {
258                DatabaseError::QueryError(format!("Failed to serialize pipeline_status: {e}"))
259            })?)
260        };
261
262        let mut active = model.into_active_model();
263        active.pipeline_status = Set(new_status);
264        active.updated_at = Set(Some(Utc::now()));
265        active.update(db).await.map_err(map_sea_err)?;
266        updated_count += 1;
267    }
268
269    Span::current().record(COGNEE_DB_ROW_COUNT, updated_count as i64);
270    Ok(updated_count)
271}
272
273/// Clear only the `cognify_pipeline` entry for `dataset_id` from a single
274/// Data record's `pipeline_status` JSON. All other entries are preserved.
275///
276/// Mirrors Python `_forget_data_memory` lines 343-348.
277#[instrument(
278    name = "cognee.db.relational.data.clear_cognify_pipeline_status_for_data",
279    level = "info",
280    skip_all,
281    fields(
282        cognee.db.system = tracing::field::Empty,
283    ),
284    err,
285)]
286pub async fn clear_cognify_pipeline_status_for_data(
287    db: &DatabaseConnection,
288    data_id: Uuid,
289    dataset_id: Uuid,
290) -> Result<(), DatabaseError> {
291    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
292    let model = data::Entity::find_by_id(uuid_hex::to_hex(data_id))
293        .one(db)
294        .await
295        .map_err(map_sea_err)?;
296
297    let Some(model) = model else {
298        return Ok(());
299    };
300
301    let Some(ref status_json) = model.pipeline_status else {
302        return Ok(());
303    };
304
305    let mut parsed: serde_json::Value =
306        serde_json::from_str(status_json).unwrap_or(serde_json::Value::Object(Default::default()));
307
308    let serde_json::Value::Object(ref mut top_map) = parsed else {
309        return Ok(());
310    };
311
312    let dataset_id_str = uuid_hex::to_hex(dataset_id);
313    let Some(inner) = top_map.get_mut("cognify_pipeline") else {
314        return Ok(());
315    };
316    let modified = if let serde_json::Value::Object(inner_map) = inner {
317        inner_map.remove(&dataset_id_str).is_some()
318    } else {
319        false
320    };
321
322    if !modified {
323        return Ok(());
324    }
325
326    // Remove `cognify_pipeline` if its inner map is now empty.
327    top_map.retain(|k, v| {
328        k != "cognify_pipeline" || !matches!(v, serde_json::Value::Object(m) if m.is_empty())
329    });
330
331    let new_status = if top_map.is_empty() {
332        None
333    } else {
334        Some(serde_json::to_string(&parsed).map_err(|e| {
335            DatabaseError::QueryError(format!("Failed to serialize pipeline_status: {e}"))
336        })?)
337    };
338
339    let mut active = model.into_active_model();
340    active.pipeline_status = Set(new_status);
341    active.updated_at = Set(Some(Utc::now()));
342    active.update(db).await.map_err(map_sea_err)?;
343    Ok(())
344}
345
346#[instrument(
347    name = "cognee.db.relational.data.list_datasets_for_data",
348    level = "info",
349    skip_all,
350    fields(
351        cognee.db.system = tracing::field::Empty,
352        cognee.db.row_count = tracing::field::Empty,
353    ),
354    err,
355)]
356pub async fn list_datasets_for_data(
357    db: &DatabaseConnection,
358    data_id: Uuid,
359) -> Result<Vec<Dataset>, DatabaseError> {
360    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
361    let pairs = data::Entity::find_by_id(uuid_hex::to_hex(data_id))
362        .find_with_related(dataset::Entity)
363        .all(db)
364        .await
365        .map_err(map_sea_err)?;
366    let datasets: Vec<Dataset> = pairs
367        .into_iter()
368        .flat_map(|(_, ds_list)| ds_list)
369        .map(Dataset::from)
370        .collect();
371    Span::current().record(COGNEE_DB_ROW_COUNT, datasets.len() as i64);
372    Ok(datasets)
373}