Skip to main content

cognee_database/ops/
graph_storage.rs

1use chrono::{DateTime, Utc};
2use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
3use sea_orm::sea_query::{Alias, Expr, OnConflict, Query};
4use sea_orm::{
5    ColumnTrait, Condition, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter,
6    QueryOrder, QuerySelect,
7};
8use tracing::{Span, instrument};
9use uuid::Uuid;
10
11use crate::conversions::map_sea_err;
12use crate::database_system_label;
13use crate::entities::{edge, node};
14use crate::types::{DatabaseError, GraphEdge, GraphNode};
15use crate::uuid_hex;
16
17/// Max rows per provenance INSERT. A multi-row `insert_many` binds
18/// `rows × columns` parameters in one statement, and SQLite caps that at
19/// `SQLITE_MAX_VARIABLE_NUMBER` (999 on very old builds, 32766 since 3.32).
20/// The node/edge tables have ~10 columns, so 500 rows ≈ 5 000 bound values —
21/// comfortably under SQLite's cap and Postgres' 65 535. Without batching, a
22/// large graph (e.g. a full-length book) overflows the cap and the upsert
23/// fails with "too many SQL variables".
24const PROVENANCE_INSERT_BATCH: usize = 500;
25
26#[instrument(
27    name = "cognee.db.relational.graph_storage.upsert_nodes",
28    level = "info",
29    skip_all,
30    fields(cognee.db.system = tracing::field::Empty),
31    err,
32)]
33pub async fn upsert_nodes(
34    db: &DatabaseConnection,
35    nodes: &[GraphNode],
36) -> Result<(), DatabaseError> {
37    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
38    if nodes.is_empty() {
39        return Ok(());
40    }
41    // Chunk so a single statement never exceeds the DB's bound-variable cap.
42    for batch in nodes.chunks(PROVENANCE_INSERT_BATCH) {
43        let models: Vec<node::ActiveModel> = batch.iter().map(node::ActiveModel::from).collect();
44        node::Entity::insert_many(models)
45            .on_conflict(
46                OnConflict::column(node::Column::Id)
47                    .update_columns([
48                        node::Column::Slug,
49                        node::Column::UserId,
50                        node::Column::DataId,
51                        node::Column::DatasetId,
52                        node::Column::Label,
53                        node::Column::NodeType,
54                        node::Column::IndexedFields,
55                        node::Column::Attributes,
56                    ])
57                    .to_owned(),
58            )
59            .exec(db)
60            .await
61            .map_err(map_sea_err)?;
62    }
63    Ok(())
64}
65
66#[instrument(
67    name = "cognee.db.relational.graph_storage.get_nodes_by_dataset",
68    level = "info",
69    skip_all,
70    fields(
71        cognee.db.system = tracing::field::Empty,
72        cognee.db.row_count = tracing::field::Empty,
73    ),
74    err,
75)]
76pub async fn get_nodes_by_dataset(
77    db: &DatabaseConnection,
78    dataset_id: Uuid,
79) -> Result<Vec<GraphNode>, DatabaseError> {
80    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
81    let rows: Vec<GraphNode> = node::Entity::find()
82        .filter(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
83        .order_by_asc(node::Column::CreatedAt)
84        .all(db)
85        .await
86        .map_err(map_sea_err)?
87        .into_iter()
88        .map(GraphNode::from)
89        .collect();
90    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
91    Ok(rows)
92}
93
94#[instrument(
95    name = "cognee.db.relational.graph_storage.delete_nodes_by_data",
96    level = "info",
97    skip_all,
98    fields(cognee.db.system = tracing::field::Empty),
99    err,
100)]
101pub async fn delete_nodes_by_data(
102    db: &DatabaseConnection,
103    data_id: Uuid,
104) -> Result<(), DatabaseError> {
105    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
106    node::Entity::delete_many()
107        .filter(node::Column::DataId.eq(uuid_hex::to_hex(data_id)))
108        .exec(db)
109        .await
110        .map_err(map_sea_err)?;
111    Ok(())
112}
113
114#[instrument(
115    name = "cognee.db.relational.graph_storage.upsert_edges",
116    level = "info",
117    skip_all,
118    fields(cognee.db.system = tracing::field::Empty),
119    err,
120)]
121pub async fn upsert_edges(
122    db: &DatabaseConnection,
123    edges: &[GraphEdge],
124) -> Result<(), DatabaseError> {
125    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
126    if edges.is_empty() {
127        return Ok(());
128    }
129    // Chunk so a single statement never exceeds the DB's bound-variable cap.
130    for batch in edges.chunks(PROVENANCE_INSERT_BATCH) {
131        let models: Vec<edge::ActiveModel> = batch.iter().map(edge::ActiveModel::from).collect();
132        edge::Entity::insert_many(models)
133            .on_conflict(
134                OnConflict::column(edge::Column::Id)
135                    .update_columns([
136                        edge::Column::Slug,
137                        edge::Column::UserId,
138                        edge::Column::DataId,
139                        edge::Column::DatasetId,
140                        edge::Column::SourceNodeId,
141                        edge::Column::DestinationNodeId,
142                        edge::Column::RelationshipName,
143                        edge::Column::Label,
144                        edge::Column::Attributes,
145                    ])
146                    .to_owned(),
147            )
148            .exec(db)
149            .await
150            .map_err(map_sea_err)?;
151    }
152    Ok(())
153}
154
155#[instrument(
156    name = "cognee.db.relational.graph_storage.get_edges_by_dataset",
157    level = "info",
158    skip_all,
159    fields(
160        cognee.db.system = tracing::field::Empty,
161        cognee.db.row_count = tracing::field::Empty,
162    ),
163    err,
164)]
165pub async fn get_edges_by_dataset(
166    db: &DatabaseConnection,
167    dataset_id: Uuid,
168) -> Result<Vec<GraphEdge>, DatabaseError> {
169    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
170    let rows: Vec<GraphEdge> = edge::Entity::find()
171        .filter(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
172        .order_by_asc(edge::Column::CreatedAt)
173        .all(db)
174        .await
175        .map_err(map_sea_err)?
176        .into_iter()
177        .map(GraphEdge::from)
178        .collect();
179    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
180    Ok(rows)
181}
182
183/// Return edges for `dataset_id` created strictly after `since`, ordered by
184/// `created_at` ascending and limited to `limit` rows. Used by Stage 4 of
185/// `improve()` for incremental graph→session synchronisation.
186///
187/// When `since` is `None`, returns the oldest `limit` edges in the dataset.
188#[instrument(
189    name = "cognee.db.relational.graph_storage.get_edges_since",
190    level = "info",
191    skip_all,
192    fields(
193        cognee.db.system = tracing::field::Empty,
194        cognee.db.row_count = tracing::field::Empty,
195    ),
196    err,
197)]
198pub async fn get_edges_since(
199    db: &DatabaseConnection,
200    dataset_id: Uuid,
201    since: Option<DateTime<Utc>>,
202    limit: u64,
203) -> Result<Vec<GraphEdge>, DatabaseError> {
204    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
205    let mut q = edge::Entity::find()
206        .filter(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
207        .order_by_asc(edge::Column::CreatedAt)
208        .limit(limit);
209    if let Some(ts) = since {
210        q = q.filter(edge::Column::CreatedAt.gt(ts));
211    }
212    let rows: Vec<GraphEdge> = q
213        .all(db)
214        .await
215        .map_err(map_sea_err)?
216        .into_iter()
217        .map(GraphEdge::from)
218        .collect();
219    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
220    Ok(rows)
221}
222
223/// Batch-fetch nodes by their string IDs (hex form). Used by Stage 4 to
224/// resolve edge endpoints to full node metadata for JSON-line rendering.
225#[instrument(
226    name = "cognee.db.relational.graph_storage.get_nodes_by_ids",
227    level = "info",
228    skip_all,
229    fields(
230        cognee.db.system = tracing::field::Empty,
231        cognee.db.row_count = tracing::field::Empty,
232    ),
233    err,
234)]
235pub async fn get_nodes_by_ids(
236    db: &DatabaseConnection,
237    ids: &[String],
238) -> Result<Vec<GraphNode>, DatabaseError> {
239    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
240    if ids.is_empty() {
241        Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
242        return Ok(Vec::new());
243    }
244    let rows: Vec<GraphNode> = node::Entity::find()
245        .filter(node::Column::Id.is_in(ids.to_vec()))
246        .all(db)
247        .await
248        .map_err(map_sea_err)?
249        .into_iter()
250        .map(GraphNode::from)
251        .collect();
252    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
253    Ok(rows)
254}
255
256#[instrument(
257    name = "cognee.db.relational.graph_storage.delete_edges_by_data",
258    level = "info",
259    skip_all,
260    fields(cognee.db.system = tracing::field::Empty),
261    err,
262)]
263pub async fn delete_edges_by_data(
264    db: &DatabaseConnection,
265    data_id: Uuid,
266) -> Result<(), DatabaseError> {
267    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
268    edge::Entity::delete_many()
269        .filter(edge::Column::DataId.eq(uuid_hex::to_hex(data_id)))
270        .exec(db)
271        .await
272        .map_err(map_sea_err)?;
273    Ok(())
274}
275
276// ---------------------------------------------------------------------------
277// Queries scoped by (data_id, dataset_id)
278// ---------------------------------------------------------------------------
279
280/// Get all provenance nodes for a specific `(data_id, dataset_id)` pair.
281#[instrument(
282    name = "cognee.db.relational.graph_storage.get_nodes_by_data",
283    level = "info",
284    skip_all,
285    fields(
286        cognee.db.system = tracing::field::Empty,
287        cognee.db.row_count = tracing::field::Empty,
288    ),
289    err,
290)]
291pub async fn get_nodes_by_data(
292    db: &DatabaseConnection,
293    data_id: Uuid,
294    dataset_id: Uuid,
295) -> Result<Vec<GraphNode>, DatabaseError> {
296    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
297    let rows: Vec<GraphNode> = node::Entity::find()
298        .filter(
299            Condition::all()
300                .add(node::Column::DataId.eq(uuid_hex::to_hex(data_id)))
301                .add(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
302        )
303        .order_by_asc(node::Column::CreatedAt)
304        .all(db)
305        .await
306        .map_err(map_sea_err)?
307        .into_iter()
308        .map(GraphNode::from)
309        .collect();
310    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
311    Ok(rows)
312}
313
314/// Get all provenance edges for a specific `(data_id, dataset_id)` pair.
315#[instrument(
316    name = "cognee.db.relational.graph_storage.get_edges_by_data",
317    level = "info",
318    skip_all,
319    fields(
320        cognee.db.system = tracing::field::Empty,
321        cognee.db.row_count = tracing::field::Empty,
322    ),
323    err,
324)]
325pub async fn get_edges_by_data(
326    db: &DatabaseConnection,
327    data_id: Uuid,
328    dataset_id: Uuid,
329) -> Result<Vec<GraphEdge>, DatabaseError> {
330    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
331    let rows: Vec<GraphEdge> = edge::Entity::find()
332        .filter(
333            Condition::all()
334                .add(edge::Column::DataId.eq(uuid_hex::to_hex(data_id)))
335                .add(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
336        )
337        .order_by_asc(edge::Column::CreatedAt)
338        .all(db)
339        .await
340        .map_err(map_sea_err)?
341        .into_iter()
342        .map(GraphEdge::from)
343        .collect();
344    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
345    Ok(rows)
346}
347
348// ---------------------------------------------------------------------------
349// Dataset-scoped deletion of provenance rows
350// ---------------------------------------------------------------------------
351
352/// Delete all provenance node rows for a given dataset.
353#[instrument(
354    name = "cognee.db.relational.graph_storage.delete_nodes_by_dataset",
355    level = "info",
356    skip_all,
357    fields(cognee.db.system = tracing::field::Empty),
358    err,
359)]
360pub async fn delete_nodes_by_dataset(
361    db: &DatabaseConnection,
362    dataset_id: Uuid,
363) -> Result<(), DatabaseError> {
364    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
365    node::Entity::delete_many()
366        .filter(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
367        .exec(db)
368        .await
369        .map_err(map_sea_err)?;
370    Ok(())
371}
372
373/// Delete all provenance edge rows for a given dataset.
374#[instrument(
375    name = "cognee.db.relational.graph_storage.delete_edges_by_dataset",
376    level = "info",
377    skip_all,
378    fields(cognee.db.system = tracing::field::Empty),
379    err,
380)]
381pub async fn delete_edges_by_dataset(
382    db: &DatabaseConnection,
383    dataset_id: Uuid,
384) -> Result<(), DatabaseError> {
385    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
386    edge::Entity::delete_many()
387        .filter(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
388        .exec(db)
389        .await
390        .map_err(map_sea_err)?;
391    Ok(())
392}
393
394// ---------------------------------------------------------------------------
395// Data-scoped deletion of provenance rows
396// ---------------------------------------------------------------------------
397
398/// Delete provenance node rows for a specific `(data_id, dataset_id)` pair.
399#[instrument(
400    name = "cognee.db.relational.graph_storage.delete_nodes_for_data",
401    level = "info",
402    skip_all,
403    fields(cognee.db.system = tracing::field::Empty),
404    err,
405)]
406pub async fn delete_nodes_for_data(
407    db: &DatabaseConnection,
408    data_id: Uuid,
409    dataset_id: Uuid,
410) -> Result<(), DatabaseError> {
411    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
412    node::Entity::delete_many()
413        .filter(
414            Condition::all()
415                .add(node::Column::DataId.eq(uuid_hex::to_hex(data_id)))
416                .add(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
417        )
418        .exec(db)
419        .await
420        .map_err(map_sea_err)?;
421    Ok(())
422}
423
424/// Delete provenance edge rows for a specific `(data_id, dataset_id)` pair.
425#[instrument(
426    name = "cognee.db.relational.graph_storage.delete_edges_for_data",
427    level = "info",
428    skip_all,
429    fields(cognee.db.system = tracing::field::Empty),
430    err,
431)]
432pub async fn delete_edges_for_data(
433    db: &DatabaseConnection,
434    data_id: Uuid,
435    dataset_id: Uuid,
436) -> Result<(), DatabaseError> {
437    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
438    edge::Entity::delete_many()
439        .filter(
440            Condition::all()
441                .add(edge::Column::DataId.eq(uuid_hex::to_hex(data_id)))
442                .add(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
443        )
444        .exec(db)
445        .await
446        .map_err(map_sea_err)?;
447    Ok(())
448}
449
450// ---------------------------------------------------------------------------
451// Count queries scoped by (data_id, dataset_id)
452// ---------------------------------------------------------------------------
453
454/// Count provenance node rows for a specific `(data_id, dataset_id)` pair.
455#[instrument(
456    name = "cognee.db.relational.graph_storage.count_nodes_for_data",
457    level = "info",
458    skip_all,
459    fields(
460        cognee.db.system = tracing::field::Empty,
461        cognee.db.row_count = tracing::field::Empty,
462    ),
463    err,
464)]
465pub async fn count_nodes_for_data(
466    db: &DatabaseConnection,
467    data_id: Uuid,
468    dataset_id: Uuid,
469) -> Result<usize, DatabaseError> {
470    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
471    let count = node::Entity::find()
472        .filter(
473            Condition::all()
474                .add(node::Column::DataId.eq(uuid_hex::to_hex(data_id)))
475                .add(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
476        )
477        .count(db)
478        .await
479        .map_err(map_sea_err)?;
480    Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
481    Ok(count as usize)
482}
483
484/// Count provenance edge rows for a specific `(data_id, dataset_id)` pair.
485#[instrument(
486    name = "cognee.db.relational.graph_storage.count_edges_for_data",
487    level = "info",
488    skip_all,
489    fields(
490        cognee.db.system = tracing::field::Empty,
491        cognee.db.row_count = tracing::field::Empty,
492    ),
493    err,
494)]
495pub async fn count_edges_for_data(
496    db: &DatabaseConnection,
497    data_id: Uuid,
498    dataset_id: Uuid,
499) -> Result<usize, DatabaseError> {
500    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
501    let count = edge::Entity::find()
502        .filter(
503            Condition::all()
504                .add(edge::Column::DataId.eq(uuid_hex::to_hex(data_id)))
505                .add(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
506        )
507        .count(db)
508        .await
509        .map_err(map_sea_err)?;
510    Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
511    Ok(count as usize)
512}
513
514// ---------------------------------------------------------------------------
515// Unique (non-shared) node/edge queries for safe single-data deletion
516// ---------------------------------------------------------------------------
517
518/// Return nodes belonging to `(data_id, dataset_id)` whose slug does NOT
519/// appear in any other row within the same dataset with a different `data_id`.
520///
521/// This is the Rust equivalent of Python's shared-slug exclusion logic.
522#[instrument(
523    name = "cognee.db.relational.graph_storage.get_unique_nodes_for_data",
524    level = "info",
525    skip_all,
526    fields(
527        cognee.db.system = tracing::field::Empty,
528        cognee.db.row_count = tracing::field::Empty,
529    ),
530    err,
531)]
532pub async fn get_unique_nodes_for_data(
533    db: &DatabaseConnection,
534    data_id: Uuid,
535    dataset_id: Uuid,
536) -> Result<Vec<GraphNode>, DatabaseError> {
537    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
538    let data_hex = uuid_hex::to_hex(data_id);
539    let dataset_hex = uuid_hex::to_hex(dataset_id);
540
541    // One query instead of two: select the (data_id, dataset_id) nodes whose slug
542    // is NOT shared by any other data_id in the same dataset. The correlated
543    // NOT EXISTS replaces the previous fetch-all + fetch-shared-slugs + in-memory
544    // filter.
545    let n2 = Alias::new("n2");
546    let shared = Query::select()
547        .expr(Expr::val(1))
548        .from_as(Alias::new("nodes"), n2.clone())
549        .and_where(Expr::col((n2.clone(), node::Column::DatasetId)).eq(dataset_hex.clone()))
550        .and_where(Expr::col((n2.clone(), node::Column::DataId)).ne(data_hex.clone()))
551        .and_where(
552            Expr::col((n2.clone(), node::Column::Slug))
553                .equals((Alias::new("nodes"), node::Column::Slug)),
554        )
555        .to_owned();
556
557    let rows: Vec<GraphNode> = node::Entity::find()
558        .filter(node::Column::DataId.eq(&data_hex))
559        .filter(node::Column::DatasetId.eq(&dataset_hex))
560        .filter(Expr::exists(shared).not())
561        .all(db)
562        .await
563        .map_err(map_sea_err)?
564        .into_iter()
565        .map(GraphNode::from)
566        .collect();
567    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
568    Ok(rows)
569}
570
571/// Return edges belonging to `(data_id, dataset_id)` whose slug does NOT
572/// appear in any other row within the same dataset with a different `data_id`.
573#[instrument(
574    name = "cognee.db.relational.graph_storage.get_unique_edges_for_data",
575    level = "info",
576    skip_all,
577    fields(
578        cognee.db.system = tracing::field::Empty,
579        cognee.db.row_count = tracing::field::Empty,
580    ),
581    err,
582)]
583pub async fn get_unique_edges_for_data(
584    db: &DatabaseConnection,
585    data_id: Uuid,
586    dataset_id: Uuid,
587) -> Result<Vec<GraphEdge>, DatabaseError> {
588    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
589    let data_hex = uuid_hex::to_hex(data_id);
590    let dataset_hex = uuid_hex::to_hex(dataset_id);
591
592    // One query instead of two (see `get_unique_nodes_for_data`): edges for
593    // (data_id, dataset_id) whose slug is not shared by another data_id in the
594    // same dataset, via a correlated NOT EXISTS.
595    let e2 = Alias::new("e2");
596    let shared = Query::select()
597        .expr(Expr::val(1))
598        .from_as(Alias::new("edges"), e2.clone())
599        .and_where(Expr::col((e2.clone(), edge::Column::DatasetId)).eq(dataset_hex.clone()))
600        .and_where(Expr::col((e2.clone(), edge::Column::DataId)).ne(data_hex.clone()))
601        .and_where(
602            Expr::col((e2.clone(), edge::Column::Slug))
603                .equals((Alias::new("edges"), edge::Column::Slug)),
604        )
605        .to_owned();
606
607    let rows: Vec<GraphEdge> = edge::Entity::find()
608        .filter(edge::Column::DataId.eq(&data_hex))
609        .filter(edge::Column::DatasetId.eq(&dataset_hex))
610        .filter(Expr::exists(shared).not())
611        .all(db)
612        .await
613        .map_err(map_sea_err)?
614        .into_iter()
615        .map(GraphEdge::from)
616        .collect();
617    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
618    Ok(rows)
619}