Skip to main content

cognee_database/ops/
datasets.rs

1use cognee_models::{Data, Dataset};
2use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
3use sea_orm::sea_query::OnConflict;
4use sea_orm::{
5    ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter,
6    QueryOrder,
7};
8use tracing::{Span, instrument};
9use uuid::Uuid;
10
11use crate::conversions::{ignore_do_nothing, make_dataset_data_active, map_sea_err};
12use crate::database_system_label;
13use crate::entities::{data, dataset, dataset_data};
14use crate::types::DatabaseError;
15use crate::uuid_hex;
16
17#[instrument(
18    name = "cognee.db.relational.datasets.create_dataset",
19    level = "info",
20    skip_all,
21    fields(cognee.db.system = tracing::field::Empty),
22    err,
23)]
24pub async fn create_dataset(
25    db: &DatabaseConnection,
26    ds: Dataset,
27) -> Result<Dataset, DatabaseError> {
28    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
29    dataset::ActiveModel::from(&ds)
30        .insert(db)
31        .await
32        .map_err(map_sea_err)?;
33    Ok(ds)
34}
35
36#[instrument(
37    name = "cognee.db.relational.datasets.get_dataset",
38    level = "info",
39    skip_all,
40    fields(
41        cognee.db.system = tracing::field::Empty,
42        cognee.db.row_count = tracing::field::Empty,
43    ),
44    err,
45)]
46pub async fn get_dataset(
47    db: &DatabaseConnection,
48    id: Uuid,
49) -> Result<Option<Dataset>, DatabaseError> {
50    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
51    let result = dataset::Entity::find_by_id(uuid_hex::to_hex(id))
52        .one(db)
53        .await
54        .map_err(map_sea_err)
55        .map(|opt| opt.map(Dataset::from))?;
56    Span::current().record(
57        COGNEE_DB_ROW_COUNT,
58        if result.is_some() { 1i64 } else { 0i64 },
59    );
60    Ok(result)
61}
62
63#[instrument(
64    name = "cognee.db.relational.datasets.get_dataset_by_name",
65    level = "info",
66    skip_all,
67    fields(
68        cognee.db.system = tracing::field::Empty,
69        cognee.db.row_count = tracing::field::Empty,
70    ),
71    err,
72)]
73pub async fn get_dataset_by_name(
74    db: &DatabaseConnection,
75    name: &str,
76    owner_id: Uuid,
77    tenant_id: Option<Uuid>,
78) -> Result<Option<Dataset>, DatabaseError> {
79    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
80    let mut q = dataset::Entity::find().filter(
81        dataset::Column::Name
82            .eq(name)
83            .and(dataset::Column::OwnerId.eq(uuid_hex::to_hex(owner_id))),
84    );
85    if let Some(tid) = tenant_id {
86        q = q.filter(dataset::Column::TenantId.eq(uuid_hex::to_hex(tid)));
87    }
88    let result = q
89        .one(db)
90        .await
91        .map_err(map_sea_err)
92        .map(|opt| opt.map(Dataset::from))?;
93    Span::current().record(
94        COGNEE_DB_ROW_COUNT,
95        if result.is_some() { 1i64 } else { 0i64 },
96    );
97    Ok(result)
98}
99
100#[instrument(
101    name = "cognee.db.relational.datasets.list_datasets_by_owner",
102    level = "info",
103    skip_all,
104    fields(
105        cognee.db.system = tracing::field::Empty,
106        cognee.db.row_count = tracing::field::Empty,
107    ),
108    err,
109)]
110pub async fn list_datasets_by_owner(
111    db: &DatabaseConnection,
112    owner_id: Uuid,
113) -> Result<Vec<Dataset>, DatabaseError> {
114    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
115    let rows: Vec<Dataset> = dataset::Entity::find()
116        .filter(dataset::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
117        .order_by_asc(dataset::Column::CreatedAt)
118        .all(db)
119        .await
120        .map_err(map_sea_err)?
121        .into_iter()
122        .map(Dataset::from)
123        .collect();
124    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
125    Ok(rows)
126}
127
128#[instrument(
129    name = "cognee.db.relational.datasets.list_datasets",
130    level = "info",
131    skip_all,
132    fields(
133        cognee.db.system = tracing::field::Empty,
134        cognee.db.row_count = tracing::field::Empty,
135    ),
136    err,
137)]
138pub async fn list_datasets(db: &DatabaseConnection) -> Result<Vec<Dataset>, DatabaseError> {
139    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
140    let rows: Vec<Dataset> = dataset::Entity::find()
141        .order_by_asc(dataset::Column::CreatedAt)
142        .all(db)
143        .await
144        .map_err(map_sea_err)?
145        .into_iter()
146        .map(Dataset::from)
147        .collect();
148    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
149    Ok(rows)
150}
151
152#[instrument(
153    name = "cognee.db.relational.datasets.delete_dataset",
154    level = "info",
155    skip_all,
156    fields(cognee.db.system = tracing::field::Empty),
157    err,
158)]
159pub async fn delete_dataset(db: &DatabaseConnection, id: Uuid) -> Result<(), DatabaseError> {
160    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
161    dataset::Entity::delete_by_id(uuid_hex::to_hex(id))
162        .exec(db)
163        .await
164        .map_err(map_sea_err)?;
165    Ok(())
166}
167
168#[instrument(
169    name = "cognee.db.relational.datasets.attach_data_to_dataset",
170    level = "info",
171    skip_all,
172    fields(cognee.db.system = tracing::field::Empty),
173    err,
174)]
175pub async fn attach_data_to_dataset(
176    db: &DatabaseConnection,
177    dataset_id: Uuid,
178    data_id: Uuid,
179) -> Result<(), DatabaseError> {
180    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
181    let model = make_dataset_data_active(dataset_id, data_id);
182    let res = dataset_data::Entity::insert(model)
183        .on_conflict(
184            OnConflict::columns([
185                dataset_data::Column::DatasetId,
186                dataset_data::Column::DataId,
187            ])
188            .do_nothing()
189            .to_owned(),
190        )
191        .exec(db)
192        .await
193        .map_err(map_sea_err)
194        .map(|_| ());
195    ignore_do_nothing(res)
196}
197
198#[instrument(
199    name = "cognee.db.relational.datasets.detach_data_from_dataset",
200    level = "info",
201    skip_all,
202    fields(cognee.db.system = tracing::field::Empty),
203    err,
204)]
205pub async fn detach_data_from_dataset(
206    db: &DatabaseConnection,
207    dataset_id: Uuid,
208    data_id: Uuid,
209) -> Result<(), DatabaseError> {
210    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
211    dataset_data::Entity::delete_many()
212        .filter(
213            dataset_data::Column::DatasetId
214                .eq(uuid_hex::to_hex(dataset_id))
215                .and(dataset_data::Column::DataId.eq(uuid_hex::to_hex(data_id))),
216        )
217        .exec(db)
218        .await
219        .map_err(map_sea_err)?;
220    Ok(())
221}
222
223/// Count the number of data items linked to a dataset without loading them.
224///
225/// Uses `SELECT COUNT(*)` on the `dataset_data` junction table for efficiency.
226#[instrument(
227    name = "cognee.db.relational.datasets.count_dataset_data",
228    level = "info",
229    skip_all,
230    fields(
231        cognee.db.system = tracing::field::Empty,
232        cognee.db.row_count = tracing::field::Empty,
233    ),
234    err,
235)]
236pub async fn count_dataset_data(
237    db: &DatabaseConnection,
238    dataset_id: Uuid,
239) -> Result<usize, DatabaseError> {
240    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
241    let count: u64 = dataset_data::Entity::find()
242        .filter(dataset_data::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
243        .count(db)
244        .await
245        .map_err(map_sea_err)?;
246    Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
247    Ok(count as usize)
248}
249
250#[instrument(
251    name = "cognee.db.relational.datasets.get_dataset_data",
252    level = "info",
253    skip_all,
254    fields(
255        cognee.db.system = tracing::field::Empty,
256        cognee.db.row_count = tracing::field::Empty,
257    ),
258    err,
259)]
260pub async fn get_dataset_data(
261    db: &DatabaseConnection,
262    dataset_id: Uuid,
263) -> Result<Vec<Data>, DatabaseError> {
264    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
265    let pairs = dataset::Entity::find_by_id(uuid_hex::to_hex(dataset_id))
266        .find_with_related(data::Entity)
267        .all(db)
268        .await
269        .map_err(map_sea_err)?;
270    let rows: Vec<Data> = pairs
271        .into_iter()
272        .flat_map(|(_, data_list)| data_list)
273        .map(Data::from)
274        .collect();
275    Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
276    Ok(rows)
277}