1use chrono::{DateTime, Utc};
2use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
3use sea_orm::sea_query::OnConflict;
4use sea_orm::{
5 ColumnTrait, Condition, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter,
6 QueryOrder, QuerySelect,
7};
8use tracing::{Span, instrument};
9use uuid::Uuid;
10
11use crate::conversions::map_sea_err;
12use crate::database_system_label;
13use crate::entities::{edge, node};
14use crate::types::{DatabaseError, GraphEdge, GraphNode};
15use crate::uuid_hex;
16
17const PROVENANCE_INSERT_BATCH: usize = 500;
25
26#[instrument(
27 name = "cognee.db.relational.graph_storage.upsert_nodes",
28 level = "info",
29 skip_all,
30 fields(cognee.db.system = tracing::field::Empty),
31 err,
32)]
33pub async fn upsert_nodes(
34 db: &DatabaseConnection,
35 nodes: &[GraphNode],
36) -> Result<(), DatabaseError> {
37 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
38 if nodes.is_empty() {
39 return Ok(());
40 }
41 for batch in nodes.chunks(PROVENANCE_INSERT_BATCH) {
43 let models: Vec<node::ActiveModel> = batch.iter().map(node::ActiveModel::from).collect();
44 node::Entity::insert_many(models)
45 .on_conflict(
46 OnConflict::column(node::Column::Id)
47 .update_columns([
48 node::Column::Slug,
49 node::Column::UserId,
50 node::Column::DataId,
51 node::Column::DatasetId,
52 node::Column::Label,
53 node::Column::NodeType,
54 node::Column::IndexedFields,
55 node::Column::Attributes,
56 ])
57 .to_owned(),
58 )
59 .exec(db)
60 .await
61 .map_err(map_sea_err)?;
62 }
63 Ok(())
64}
65
66#[instrument(
67 name = "cognee.db.relational.graph_storage.get_nodes_by_dataset",
68 level = "info",
69 skip_all,
70 fields(
71 cognee.db.system = tracing::field::Empty,
72 cognee.db.row_count = tracing::field::Empty,
73 ),
74 err,
75)]
76pub async fn get_nodes_by_dataset(
77 db: &DatabaseConnection,
78 dataset_id: Uuid,
79) -> Result<Vec<GraphNode>, DatabaseError> {
80 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
81 let rows: Vec<GraphNode> = node::Entity::find()
82 .filter(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
83 .order_by_asc(node::Column::CreatedAt)
84 .all(db)
85 .await
86 .map_err(map_sea_err)?
87 .into_iter()
88 .map(GraphNode::from)
89 .collect();
90 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
91 Ok(rows)
92}
93
94#[instrument(
95 name = "cognee.db.relational.graph_storage.delete_nodes_by_data",
96 level = "info",
97 skip_all,
98 fields(cognee.db.system = tracing::field::Empty),
99 err,
100)]
101pub async fn delete_nodes_by_data(
102 db: &DatabaseConnection,
103 data_id: Uuid,
104) -> Result<(), DatabaseError> {
105 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
106 node::Entity::delete_many()
107 .filter(node::Column::DataId.eq(uuid_hex::to_hex(data_id)))
108 .exec(db)
109 .await
110 .map_err(map_sea_err)?;
111 Ok(())
112}
113
114#[instrument(
115 name = "cognee.db.relational.graph_storage.upsert_edges",
116 level = "info",
117 skip_all,
118 fields(cognee.db.system = tracing::field::Empty),
119 err,
120)]
121pub async fn upsert_edges(
122 db: &DatabaseConnection,
123 edges: &[GraphEdge],
124) -> Result<(), DatabaseError> {
125 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
126 if edges.is_empty() {
127 return Ok(());
128 }
129 for batch in edges.chunks(PROVENANCE_INSERT_BATCH) {
131 let models: Vec<edge::ActiveModel> = batch.iter().map(edge::ActiveModel::from).collect();
132 edge::Entity::insert_many(models)
133 .on_conflict(
134 OnConflict::column(edge::Column::Id)
135 .update_columns([
136 edge::Column::Slug,
137 edge::Column::UserId,
138 edge::Column::DataId,
139 edge::Column::DatasetId,
140 edge::Column::SourceNodeId,
141 edge::Column::DestinationNodeId,
142 edge::Column::RelationshipName,
143 edge::Column::Label,
144 edge::Column::Attributes,
145 ])
146 .to_owned(),
147 )
148 .exec(db)
149 .await
150 .map_err(map_sea_err)?;
151 }
152 Ok(())
153}
154
155#[instrument(
156 name = "cognee.db.relational.graph_storage.get_edges_by_dataset",
157 level = "info",
158 skip_all,
159 fields(
160 cognee.db.system = tracing::field::Empty,
161 cognee.db.row_count = tracing::field::Empty,
162 ),
163 err,
164)]
165pub async fn get_edges_by_dataset(
166 db: &DatabaseConnection,
167 dataset_id: Uuid,
168) -> Result<Vec<GraphEdge>, DatabaseError> {
169 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
170 let rows: Vec<GraphEdge> = edge::Entity::find()
171 .filter(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
172 .order_by_asc(edge::Column::CreatedAt)
173 .all(db)
174 .await
175 .map_err(map_sea_err)?
176 .into_iter()
177 .map(GraphEdge::from)
178 .collect();
179 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
180 Ok(rows)
181}
182
183#[instrument(
189 name = "cognee.db.relational.graph_storage.get_edges_since",
190 level = "info",
191 skip_all,
192 fields(
193 cognee.db.system = tracing::field::Empty,
194 cognee.db.row_count = tracing::field::Empty,
195 ),
196 err,
197)]
198pub async fn get_edges_since(
199 db: &DatabaseConnection,
200 dataset_id: Uuid,
201 since: Option<DateTime<Utc>>,
202 limit: u64,
203) -> Result<Vec<GraphEdge>, DatabaseError> {
204 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
205 let mut q = edge::Entity::find()
206 .filter(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
207 .order_by_asc(edge::Column::CreatedAt)
208 .limit(limit);
209 if let Some(ts) = since {
210 q = q.filter(edge::Column::CreatedAt.gt(ts));
211 }
212 let rows: Vec<GraphEdge> = q
213 .all(db)
214 .await
215 .map_err(map_sea_err)?
216 .into_iter()
217 .map(GraphEdge::from)
218 .collect();
219 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
220 Ok(rows)
221}
222
223#[instrument(
226 name = "cognee.db.relational.graph_storage.get_nodes_by_ids",
227 level = "info",
228 skip_all,
229 fields(
230 cognee.db.system = tracing::field::Empty,
231 cognee.db.row_count = tracing::field::Empty,
232 ),
233 err,
234)]
235pub async fn get_nodes_by_ids(
236 db: &DatabaseConnection,
237 ids: &[String],
238) -> Result<Vec<GraphNode>, DatabaseError> {
239 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
240 if ids.is_empty() {
241 Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
242 return Ok(Vec::new());
243 }
244 let rows: Vec<GraphNode> = node::Entity::find()
245 .filter(node::Column::Id.is_in(ids.to_vec()))
246 .all(db)
247 .await
248 .map_err(map_sea_err)?
249 .into_iter()
250 .map(GraphNode::from)
251 .collect();
252 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
253 Ok(rows)
254}
255
256#[instrument(
257 name = "cognee.db.relational.graph_storage.delete_edges_by_data",
258 level = "info",
259 skip_all,
260 fields(cognee.db.system = tracing::field::Empty),
261 err,
262)]
263pub async fn delete_edges_by_data(
264 db: &DatabaseConnection,
265 data_id: Uuid,
266) -> Result<(), DatabaseError> {
267 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
268 edge::Entity::delete_many()
269 .filter(edge::Column::DataId.eq(uuid_hex::to_hex(data_id)))
270 .exec(db)
271 .await
272 .map_err(map_sea_err)?;
273 Ok(())
274}
275
276#[instrument(
282 name = "cognee.db.relational.graph_storage.get_nodes_by_data",
283 level = "info",
284 skip_all,
285 fields(
286 cognee.db.system = tracing::field::Empty,
287 cognee.db.row_count = tracing::field::Empty,
288 ),
289 err,
290)]
291pub async fn get_nodes_by_data(
292 db: &DatabaseConnection,
293 data_id: Uuid,
294 dataset_id: Uuid,
295) -> Result<Vec<GraphNode>, DatabaseError> {
296 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
297 let rows: Vec<GraphNode> = node::Entity::find()
298 .filter(
299 Condition::all()
300 .add(node::Column::DataId.eq(uuid_hex::to_hex(data_id)))
301 .add(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
302 )
303 .order_by_asc(node::Column::CreatedAt)
304 .all(db)
305 .await
306 .map_err(map_sea_err)?
307 .into_iter()
308 .map(GraphNode::from)
309 .collect();
310 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
311 Ok(rows)
312}
313
314#[instrument(
316 name = "cognee.db.relational.graph_storage.get_edges_by_data",
317 level = "info",
318 skip_all,
319 fields(
320 cognee.db.system = tracing::field::Empty,
321 cognee.db.row_count = tracing::field::Empty,
322 ),
323 err,
324)]
325pub async fn get_edges_by_data(
326 db: &DatabaseConnection,
327 data_id: Uuid,
328 dataset_id: Uuid,
329) -> Result<Vec<GraphEdge>, DatabaseError> {
330 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
331 let rows: Vec<GraphEdge> = edge::Entity::find()
332 .filter(
333 Condition::all()
334 .add(edge::Column::DataId.eq(uuid_hex::to_hex(data_id)))
335 .add(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
336 )
337 .order_by_asc(edge::Column::CreatedAt)
338 .all(db)
339 .await
340 .map_err(map_sea_err)?
341 .into_iter()
342 .map(GraphEdge::from)
343 .collect();
344 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
345 Ok(rows)
346}
347
348#[instrument(
354 name = "cognee.db.relational.graph_storage.delete_nodes_by_dataset",
355 level = "info",
356 skip_all,
357 fields(cognee.db.system = tracing::field::Empty),
358 err,
359)]
360pub async fn delete_nodes_by_dataset(
361 db: &DatabaseConnection,
362 dataset_id: Uuid,
363) -> Result<(), DatabaseError> {
364 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
365 node::Entity::delete_many()
366 .filter(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
367 .exec(db)
368 .await
369 .map_err(map_sea_err)?;
370 Ok(())
371}
372
373#[instrument(
375 name = "cognee.db.relational.graph_storage.delete_edges_by_dataset",
376 level = "info",
377 skip_all,
378 fields(cognee.db.system = tracing::field::Empty),
379 err,
380)]
381pub async fn delete_edges_by_dataset(
382 db: &DatabaseConnection,
383 dataset_id: Uuid,
384) -> Result<(), DatabaseError> {
385 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
386 edge::Entity::delete_many()
387 .filter(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
388 .exec(db)
389 .await
390 .map_err(map_sea_err)?;
391 Ok(())
392}
393
394#[instrument(
400 name = "cognee.db.relational.graph_storage.delete_nodes_for_data",
401 level = "info",
402 skip_all,
403 fields(cognee.db.system = tracing::field::Empty),
404 err,
405)]
406pub async fn delete_nodes_for_data(
407 db: &DatabaseConnection,
408 data_id: Uuid,
409 dataset_id: Uuid,
410) -> Result<(), DatabaseError> {
411 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
412 node::Entity::delete_many()
413 .filter(
414 Condition::all()
415 .add(node::Column::DataId.eq(uuid_hex::to_hex(data_id)))
416 .add(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
417 )
418 .exec(db)
419 .await
420 .map_err(map_sea_err)?;
421 Ok(())
422}
423
424#[instrument(
426 name = "cognee.db.relational.graph_storage.delete_edges_for_data",
427 level = "info",
428 skip_all,
429 fields(cognee.db.system = tracing::field::Empty),
430 err,
431)]
432pub async fn delete_edges_for_data(
433 db: &DatabaseConnection,
434 data_id: Uuid,
435 dataset_id: Uuid,
436) -> Result<(), DatabaseError> {
437 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
438 edge::Entity::delete_many()
439 .filter(
440 Condition::all()
441 .add(edge::Column::DataId.eq(uuid_hex::to_hex(data_id)))
442 .add(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
443 )
444 .exec(db)
445 .await
446 .map_err(map_sea_err)?;
447 Ok(())
448}
449
450#[instrument(
456 name = "cognee.db.relational.graph_storage.count_nodes_for_data",
457 level = "info",
458 skip_all,
459 fields(
460 cognee.db.system = tracing::field::Empty,
461 cognee.db.row_count = tracing::field::Empty,
462 ),
463 err,
464)]
465pub async fn count_nodes_for_data(
466 db: &DatabaseConnection,
467 data_id: Uuid,
468 dataset_id: Uuid,
469) -> Result<usize, DatabaseError> {
470 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
471 let count = node::Entity::find()
472 .filter(
473 Condition::all()
474 .add(node::Column::DataId.eq(uuid_hex::to_hex(data_id)))
475 .add(node::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
476 )
477 .count(db)
478 .await
479 .map_err(map_sea_err)?;
480 Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
481 Ok(count as usize)
482}
483
484#[instrument(
486 name = "cognee.db.relational.graph_storage.count_edges_for_data",
487 level = "info",
488 skip_all,
489 fields(
490 cognee.db.system = tracing::field::Empty,
491 cognee.db.row_count = tracing::field::Empty,
492 ),
493 err,
494)]
495pub async fn count_edges_for_data(
496 db: &DatabaseConnection,
497 data_id: Uuid,
498 dataset_id: Uuid,
499) -> Result<usize, DatabaseError> {
500 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
501 let count = edge::Entity::find()
502 .filter(
503 Condition::all()
504 .add(edge::Column::DataId.eq(uuid_hex::to_hex(data_id)))
505 .add(edge::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id))),
506 )
507 .count(db)
508 .await
509 .map_err(map_sea_err)?;
510 Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
511 Ok(count as usize)
512}
513
514#[instrument(
523 name = "cognee.db.relational.graph_storage.get_unique_nodes_for_data",
524 level = "info",
525 skip_all,
526 fields(
527 cognee.db.system = tracing::field::Empty,
528 cognee.db.row_count = tracing::field::Empty,
529 ),
530 err,
531)]
532pub async fn get_unique_nodes_for_data(
533 db: &DatabaseConnection,
534 data_id: Uuid,
535 dataset_id: Uuid,
536) -> Result<Vec<GraphNode>, DatabaseError> {
537 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
538 let data_hex = uuid_hex::to_hex(data_id);
539 let dataset_hex = uuid_hex::to_hex(dataset_id);
540
541 let all_nodes = node::Entity::find()
543 .filter(
544 Condition::all()
545 .add(node::Column::DataId.eq(&data_hex))
546 .add(node::Column::DatasetId.eq(&dataset_hex)),
547 )
548 .all(db)
549 .await
550 .map_err(map_sea_err)?;
551
552 if all_nodes.is_empty() {
553 Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
554 return Ok(vec![]);
555 }
556
557 let shared_slugs: Vec<String> = node::Entity::find()
559 .filter(
560 Condition::all()
561 .add(node::Column::DatasetId.eq(&dataset_hex))
562 .add(node::Column::DataId.ne(&data_hex)),
563 )
564 .column(node::Column::Slug)
565 .all(db)
566 .await
567 .map_err(map_sea_err)?
568 .into_iter()
569 .map(|m| m.slug)
570 .collect();
571
572 let shared_set: std::collections::HashSet<&str> =
573 shared_slugs.iter().map(|s| s.as_str()).collect();
574
575 let rows: Vec<GraphNode> = all_nodes
576 .into_iter()
577 .filter(|n| !shared_set.contains(n.slug.as_str()))
578 .map(GraphNode::from)
579 .collect();
580 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
581 Ok(rows)
582}
583
584#[instrument(
587 name = "cognee.db.relational.graph_storage.get_unique_edges_for_data",
588 level = "info",
589 skip_all,
590 fields(
591 cognee.db.system = tracing::field::Empty,
592 cognee.db.row_count = tracing::field::Empty,
593 ),
594 err,
595)]
596pub async fn get_unique_edges_for_data(
597 db: &DatabaseConnection,
598 data_id: Uuid,
599 dataset_id: Uuid,
600) -> Result<Vec<GraphEdge>, DatabaseError> {
601 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
602 let data_hex = uuid_hex::to_hex(data_id);
603 let dataset_hex = uuid_hex::to_hex(dataset_id);
604
605 let all_edges = edge::Entity::find()
607 .filter(
608 Condition::all()
609 .add(edge::Column::DataId.eq(&data_hex))
610 .add(edge::Column::DatasetId.eq(&dataset_hex)),
611 )
612 .all(db)
613 .await
614 .map_err(map_sea_err)?;
615
616 if all_edges.is_empty() {
617 Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
618 return Ok(vec![]);
619 }
620
621 let shared_slugs: Vec<String> = edge::Entity::find()
623 .filter(
624 Condition::all()
625 .add(edge::Column::DatasetId.eq(&dataset_hex))
626 .add(edge::Column::DataId.ne(&data_hex)),
627 )
628 .column(edge::Column::Slug)
629 .all(db)
630 .await
631 .map_err(map_sea_err)?
632 .into_iter()
633 .map(|m| m.slug)
634 .collect();
635
636 let shared_set: std::collections::HashSet<&str> =
637 shared_slugs.iter().map(|s| s.as_str()).collect();
638
639 let rows: Vec<GraphEdge> = all_edges
640 .into_iter()
641 .filter(|e| !shared_set.contains(e.slug.as_str()))
642 .map(GraphEdge::from)
643 .collect();
644 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
645 Ok(rows)
646}