Skip to main content

cognee_database/ops/
data.rs

1use chrono::{DateTime, Utc};
2use cognee_models::{Data, Dataset};
3use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
4use sea_orm::{
5    ActiveModelTrait, ActiveValue::Set, ColumnTrait, DatabaseConnection, EntityTrait,
6    IntoActiveModel, PaginatorTrait, QueryFilter, sea_query::Expr,
7};
8use tracing::{Span, instrument};
9use uuid::Uuid;
10
11use crate::conversions::map_sea_err;
12use crate::database_system_label;
13use crate::entities::{data, dataset, dataset_data};
14use crate::types::DatabaseError;
15use crate::uuid_hex;
16
17#[instrument(
18    name = "cognee.db.relational.data.create_data",
19    level = "info",
20    skip_all,
21    fields(cognee.db.system = tracing::field::Empty),
22    err,
23)]
24pub async fn create_data(db: &DatabaseConnection, d: Data) -> Result<Data, DatabaseError> {
25    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
26    data::ActiveModel::from(&d)
27        .insert(db)
28        .await
29        .map_err(map_sea_err)?;
30    Ok(d)
31}
32
33#[instrument(
34    name = "cognee.db.relational.data.get_data",
35    level = "info",
36    skip_all,
37    fields(
38        cognee.db.system = tracing::field::Empty,
39        cognee.db.row_count = tracing::field::Empty,
40    ),
41    err,
42)]
43pub async fn get_data(db: &DatabaseConnection, id: Uuid) -> Result<Option<Data>, DatabaseError> {
44    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
45    let result = data::Entity::find_by_id(uuid_hex::to_hex(id))
46        .one(db)
47        .await
48        .map_err(map_sea_err)
49        .map(|opt| opt.map(Data::from))?;
50    Span::current().record(
51        COGNEE_DB_ROW_COUNT,
52        if result.is_some() { 1i64 } else { 0i64 },
53    );
54    Ok(result)
55}
56
57#[instrument(
58    name = "cognee.db.relational.data.delete_data",
59    level = "info",
60    skip_all,
61    fields(cognee.db.system = tracing::field::Empty),
62    err,
63)]
64pub async fn delete_data(db: &DatabaseConnection, id: Uuid) -> Result<(), DatabaseError> {
65    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
66    data::Entity::delete_by_id(uuid_hex::to_hex(id))
67        .exec(db)
68        .await
69        .map_err(map_sea_err)?;
70    Ok(())
71}
72
73#[instrument(
74    name = "cognee.db.relational.data.update_data",
75    level = "info",
76    skip_all,
77    fields(cognee.db.system = tracing::field::Empty),
78    err,
79)]
80pub async fn update_data(db: &DatabaseConnection, d: Data) -> Result<Data, DatabaseError> {
81    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
82    let mut model = data::ActiveModel::from(&d);
83    model.updated_at = Set(Some(Utc::now()));
84    model.update(db).await.map_err(map_sea_err)?;
85    Ok(d)
86}
87
88#[instrument(
89    name = "cognee.db.relational.data.count_data_dataset_links",
90    level = "info",
91    skip_all,
92    fields(
93        cognee.db.system = tracing::field::Empty,
94        cognee.db.row_count = tracing::field::Empty,
95    ),
96    err,
97)]
98pub async fn count_data_dataset_links(
99    db: &DatabaseConnection,
100    data_id: Uuid,
101) -> Result<usize, DatabaseError> {
102    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
103    let count: u64 = dataset_data::Entity::find()
104        .filter(dataset_data::Column::DataId.eq(uuid_hex::to_hex(data_id)))
105        .count(db)
106        .await
107        .map_err(map_sea_err)?;
108    Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
109    Ok(count as usize)
110}
111
112/// Update only the `token_count` column for a Data record.
113///
114/// Mirrors the Python `update_document_token_count()` in
115/// `cognee/tasks/documents/extract_chunks_from_documents.py`.
116#[instrument(
117    name = "cognee.db.relational.data.update_data_token_count",
118    level = "info",
119    skip_all,
120    fields(cognee.db.system = tracing::field::Empty),
121    err,
122)]
123pub async fn update_data_token_count(
124    db: &DatabaseConnection,
125    data_id: Uuid,
126    token_count: i64,
127) -> Result<(), DatabaseError> {
128    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
129    // Single UPDATE instead of find-then-update (2 round-trips -> 1). The
130    // rows-affected count preserves the previous NotFound behaviour.
131    let result = data::Entity::update_many()
132        .col_expr(data::Column::TokenCount, Expr::value(token_count))
133        .col_expr(data::Column::UpdatedAt, Expr::value(Some(Utc::now())))
134        .filter(data::Column::Id.eq(uuid_hex::to_hex(data_id)))
135        .exec(db)
136        .await
137        .map_err(map_sea_err)?;
138
139    if result.rows_affected == 0 {
140        return Err(DatabaseError::NotFound(format!("Data {data_id} not found")));
141    }
142    Ok(())
143}
144
145/// Update `last_accessed` for a batch of Data records identified by their IDs.
146///
147/// This is a no-op when `data_ids` is empty.
148#[instrument(
149    name = "cognee.db.relational.data.update_last_accessed",
150    level = "info",
151    skip_all,
152    fields(cognee.db.system = tracing::field::Empty),
153    err,
154)]
155pub async fn update_last_accessed(
156    db: &DatabaseConnection,
157    data_ids: &[Uuid],
158    timestamp: DateTime<Utc>,
159) -> Result<(), DatabaseError> {
160    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
161    if data_ids.is_empty() {
162        return Ok(());
163    }
164
165    // Single UPDATE ... WHERE id IN (...) instead of N×(find + update) round-trips.
166    let hex_ids: Vec<_> = data_ids.iter().map(|id| uuid_hex::to_hex(*id)).collect();
167    data::Entity::update_many()
168        .col_expr(data::Column::LastAccessed, Expr::value(Some(timestamp)))
169        .filter(data::Column::Id.is_in(hex_ids))
170        .exec(db)
171        .await
172        .map_err(map_sea_err)?;
173
174    Ok(())
175}
176
177/// Clear `pipeline_status` JSON entries keyed by the given `dataset_id`
178/// from all `Data` records linked to that dataset via the `dataset_data`
179/// junction table.
180///
181/// This mirrors the Python cleanup in `delete_dataset.py` lines 33-54.
182/// Must be called **before** the junction rows are removed (before
183/// `detach_data_from_dataset` or `delete_dataset`), since the junction is
184/// needed to find related `Data` records.
185///
186/// Returns the number of `Data` records whose `pipeline_status` was modified.
187#[instrument(
188    name = "cognee.db.relational.data.clear_pipeline_status_for_dataset",
189    level = "info",
190    skip_all,
191    fields(
192        cognee.db.system = tracing::field::Empty,
193        cognee.db.row_count = tracing::field::Empty,
194    ),
195    err,
196)]
197pub async fn clear_pipeline_status_for_dataset(
198    db: &DatabaseConnection,
199    dataset_id: Uuid,
200) -> Result<usize, DatabaseError> {
201    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
202    // Find all data IDs linked to this dataset via the junction table
203    let junction_rows = dataset_data::Entity::find()
204        .filter(dataset_data::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
205        .all(db)
206        .await
207        .map_err(map_sea_err)?;
208
209    let data_ids: Vec<String> = junction_rows.into_iter().map(|j| j.data_id).collect();
210    if data_ids.is_empty() {
211        Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
212        return Ok(0);
213    }
214
215    let dataset_id_str = uuid_hex::to_hex(dataset_id);
216    let mut updated_count = 0usize;
217
218    // Read the linked Data rows in chunks instead of one find per id (N reads).
219    // Chunking keeps each `IN (...)` under the driver's bound-variable cap
220    // (SQLite ~32766, Postgres 65535) — mirroring `PROVENANCE_INSERT_BATCH` in
221    // graph_storage — so a dataset with more items than the cap can't overflow.
222    // Updates stay per-row because each row's pipeline_status JSON is mutated
223    // independently, and only rows that actually change are written back.
224    const READ_CHUNK: usize = 500;
225    let mut models = Vec::with_capacity(data_ids.len());
226    for chunk in data_ids.chunks(READ_CHUNK) {
227        models.extend(
228            data::Entity::find()
229                .filter(data::Column::Id.is_in(chunk.to_vec()))
230                .all(db)
231                .await
232                .map_err(map_sea_err)?,
233        );
234    }
235
236    for model in models {
237        let Some(ref status_json) = model.pipeline_status else {
238            continue;
239        };
240
241        let mut parsed: serde_json::Value = serde_json::from_str(status_json)
242            .unwrap_or(serde_json::Value::Object(Default::default()));
243
244        let serde_json::Value::Object(ref mut top_map) = parsed else {
245            continue;
246        };
247
248        let mut modified = false;
249        for (_pipeline_name, inner) in top_map.iter_mut() {
250            if let serde_json::Value::Object(inner_map) = inner
251                && inner_map.remove(&dataset_id_str).is_some()
252            {
253                modified = true;
254            }
255        }
256
257        if !modified {
258            continue;
259        }
260
261        // Remove pipeline entries whose inner map is now empty
262        top_map.retain(|_, v| !matches!(v, serde_json::Value::Object(m) if m.is_empty()));
263
264        let new_status = if top_map.is_empty() {
265            None
266        } else {
267            Some(serde_json::to_string(&parsed).map_err(|e| {
268                DatabaseError::QueryError(format!("Failed to serialize pipeline_status: {e}"))
269            })?)
270        };
271
272        let mut active = model.into_active_model();
273        active.pipeline_status = Set(new_status);
274        active.updated_at = Set(Some(Utc::now()));
275        active.update(db).await.map_err(map_sea_err)?;
276        updated_count += 1;
277    }
278
279    Span::current().record(COGNEE_DB_ROW_COUNT, updated_count as i64);
280    Ok(updated_count)
281}
282
283/// Clear only the `cognify_pipeline` entry for `dataset_id` from a single
284/// Data record's `pipeline_status` JSON. All other entries are preserved.
285///
286/// Mirrors Python `_forget_data_memory` lines 343-348.
287#[instrument(
288    name = "cognee.db.relational.data.clear_cognify_pipeline_status_for_data",
289    level = "info",
290    skip_all,
291    fields(
292        cognee.db.system = tracing::field::Empty,
293    ),
294    err,
295)]
296pub async fn clear_cognify_pipeline_status_for_data(
297    db: &DatabaseConnection,
298    data_id: Uuid,
299    dataset_id: Uuid,
300) -> Result<(), DatabaseError> {
301    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
302    let model = data::Entity::find_by_id(uuid_hex::to_hex(data_id))
303        .one(db)
304        .await
305        .map_err(map_sea_err)?;
306
307    let Some(model) = model else {
308        return Ok(());
309    };
310
311    let Some(ref status_json) = model.pipeline_status else {
312        return Ok(());
313    };
314
315    let mut parsed: serde_json::Value =
316        serde_json::from_str(status_json).unwrap_or(serde_json::Value::Object(Default::default()));
317
318    let serde_json::Value::Object(ref mut top_map) = parsed else {
319        return Ok(());
320    };
321
322    let dataset_id_str = uuid_hex::to_hex(dataset_id);
323    let Some(inner) = top_map.get_mut("cognify_pipeline") else {
324        return Ok(());
325    };
326    let modified = if let serde_json::Value::Object(inner_map) = inner {
327        inner_map.remove(&dataset_id_str).is_some()
328    } else {
329        false
330    };
331
332    if !modified {
333        return Ok(());
334    }
335
336    // Remove `cognify_pipeline` if its inner map is now empty.
337    top_map.retain(|k, v| {
338        k != "cognify_pipeline" || !matches!(v, serde_json::Value::Object(m) if m.is_empty())
339    });
340
341    let new_status = if top_map.is_empty() {
342        None
343    } else {
344        Some(serde_json::to_string(&parsed).map_err(|e| {
345            DatabaseError::QueryError(format!("Failed to serialize pipeline_status: {e}"))
346        })?)
347    };
348
349    let mut active = model.into_active_model();
350    active.pipeline_status = Set(new_status);
351    active.updated_at = Set(Some(Utc::now()));
352    active.update(db).await.map_err(map_sea_err)?;
353    Ok(())
354}
355
356#[instrument(
357    name = "cognee.db.relational.data.list_datasets_for_data",
358    level = "info",
359    skip_all,
360    fields(
361        cognee.db.system = tracing::field::Empty,
362        cognee.db.row_count = tracing::field::Empty,
363    ),
364    err,
365)]
366pub async fn list_datasets_for_data(
367    db: &DatabaseConnection,
368    data_id: Uuid,
369) -> Result<Vec<Dataset>, DatabaseError> {
370    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
371    let pairs = data::Entity::find_by_id(uuid_hex::to_hex(data_id))
372        .find_with_related(dataset::Entity)
373        .all(db)
374        .await
375        .map_err(map_sea_err)?;
376    let datasets: Vec<Dataset> = pairs
377        .into_iter()
378        .flat_map(|(_, ds_list)| ds_list)
379        .map(Dataset::from)
380        .collect();
381    Span::current().record(COGNEE_DB_ROW_COUNT, datasets.len() as i64);
382    Ok(datasets)
383}