1use cognee_models::{Data, Dataset};
2use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
3use sea_orm::sea_query::OnConflict;
4use sea_orm::{
5 ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter,
6 QueryOrder,
7};
8use tracing::{Span, instrument};
9use uuid::Uuid;
10
11use crate::conversions::{ignore_do_nothing, make_dataset_data_active, map_sea_err};
12use crate::database_system_label;
13use crate::entities::{data, dataset, dataset_data};
14use crate::types::DatabaseError;
15use crate::uuid_hex;
16
17#[instrument(
18 name = "cognee.db.relational.datasets.create_dataset",
19 level = "info",
20 skip_all,
21 fields(cognee.db.system = tracing::field::Empty),
22 err,
23)]
24pub async fn create_dataset(
25 db: &DatabaseConnection,
26 ds: Dataset,
27) -> Result<Dataset, DatabaseError> {
28 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
29 dataset::ActiveModel::from(&ds)
30 .insert(db)
31 .await
32 .map_err(map_sea_err)?;
33 Ok(ds)
34}
35
36#[instrument(
37 name = "cognee.db.relational.datasets.get_dataset",
38 level = "info",
39 skip_all,
40 fields(
41 cognee.db.system = tracing::field::Empty,
42 cognee.db.row_count = tracing::field::Empty,
43 ),
44 err,
45)]
46pub async fn get_dataset(
47 db: &DatabaseConnection,
48 id: Uuid,
49) -> Result<Option<Dataset>, DatabaseError> {
50 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
51 let result = dataset::Entity::find_by_id(uuid_hex::to_hex(id))
52 .one(db)
53 .await
54 .map_err(map_sea_err)
55 .map(|opt| opt.map(Dataset::from))?;
56 Span::current().record(
57 COGNEE_DB_ROW_COUNT,
58 if result.is_some() { 1i64 } else { 0i64 },
59 );
60 Ok(result)
61}
62
63#[instrument(
64 name = "cognee.db.relational.datasets.get_dataset_by_name",
65 level = "info",
66 skip_all,
67 fields(
68 cognee.db.system = tracing::field::Empty,
69 cognee.db.row_count = tracing::field::Empty,
70 ),
71 err,
72)]
73pub async fn get_dataset_by_name(
74 db: &DatabaseConnection,
75 name: &str,
76 owner_id: Uuid,
77 tenant_id: Option<Uuid>,
78) -> Result<Option<Dataset>, DatabaseError> {
79 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
80 let mut q = dataset::Entity::find().filter(
81 dataset::Column::Name
82 .eq(name)
83 .and(dataset::Column::OwnerId.eq(uuid_hex::to_hex(owner_id))),
84 );
85 if let Some(tid) = tenant_id {
86 q = q.filter(dataset::Column::TenantId.eq(uuid_hex::to_hex(tid)));
87 }
88 let result = q
89 .one(db)
90 .await
91 .map_err(map_sea_err)
92 .map(|opt| opt.map(Dataset::from))?;
93 Span::current().record(
94 COGNEE_DB_ROW_COUNT,
95 if result.is_some() { 1i64 } else { 0i64 },
96 );
97 Ok(result)
98}
99
100#[instrument(
101 name = "cognee.db.relational.datasets.list_datasets_by_owner",
102 level = "info",
103 skip_all,
104 fields(
105 cognee.db.system = tracing::field::Empty,
106 cognee.db.row_count = tracing::field::Empty,
107 ),
108 err,
109)]
110pub async fn list_datasets_by_owner(
111 db: &DatabaseConnection,
112 owner_id: Uuid,
113) -> Result<Vec<Dataset>, DatabaseError> {
114 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
115 let rows: Vec<Dataset> = dataset::Entity::find()
116 .filter(dataset::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
117 .order_by_asc(dataset::Column::CreatedAt)
118 .all(db)
119 .await
120 .map_err(map_sea_err)?
121 .into_iter()
122 .map(Dataset::from)
123 .collect();
124 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
125 Ok(rows)
126}
127
128#[instrument(
129 name = "cognee.db.relational.datasets.list_datasets",
130 level = "info",
131 skip_all,
132 fields(
133 cognee.db.system = tracing::field::Empty,
134 cognee.db.row_count = tracing::field::Empty,
135 ),
136 err,
137)]
138pub async fn list_datasets(db: &DatabaseConnection) -> Result<Vec<Dataset>, DatabaseError> {
139 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
140 let rows: Vec<Dataset> = dataset::Entity::find()
141 .order_by_asc(dataset::Column::CreatedAt)
142 .all(db)
143 .await
144 .map_err(map_sea_err)?
145 .into_iter()
146 .map(Dataset::from)
147 .collect();
148 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
149 Ok(rows)
150}
151
152#[instrument(
153 name = "cognee.db.relational.datasets.delete_dataset",
154 level = "info",
155 skip_all,
156 fields(cognee.db.system = tracing::field::Empty),
157 err,
158)]
159pub async fn delete_dataset(db: &DatabaseConnection, id: Uuid) -> Result<(), DatabaseError> {
160 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
161 dataset::Entity::delete_by_id(uuid_hex::to_hex(id))
162 .exec(db)
163 .await
164 .map_err(map_sea_err)?;
165 Ok(())
166}
167
168#[instrument(
169 name = "cognee.db.relational.datasets.attach_data_to_dataset",
170 level = "info",
171 skip_all,
172 fields(cognee.db.system = tracing::field::Empty),
173 err,
174)]
175pub async fn attach_data_to_dataset(
176 db: &DatabaseConnection,
177 dataset_id: Uuid,
178 data_id: Uuid,
179) -> Result<(), DatabaseError> {
180 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
181 let model = make_dataset_data_active(dataset_id, data_id);
182 let res = dataset_data::Entity::insert(model)
183 .on_conflict(
184 OnConflict::columns([
185 dataset_data::Column::DatasetId,
186 dataset_data::Column::DataId,
187 ])
188 .do_nothing()
189 .to_owned(),
190 )
191 .exec(db)
192 .await
193 .map_err(map_sea_err)
194 .map(|_| ());
195 ignore_do_nothing(res)
196}
197
198#[instrument(
199 name = "cognee.db.relational.datasets.detach_data_from_dataset",
200 level = "info",
201 skip_all,
202 fields(cognee.db.system = tracing::field::Empty),
203 err,
204)]
205pub async fn detach_data_from_dataset(
206 db: &DatabaseConnection,
207 dataset_id: Uuid,
208 data_id: Uuid,
209) -> Result<(), DatabaseError> {
210 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
211 dataset_data::Entity::delete_many()
212 .filter(
213 dataset_data::Column::DatasetId
214 .eq(uuid_hex::to_hex(dataset_id))
215 .and(dataset_data::Column::DataId.eq(uuid_hex::to_hex(data_id))),
216 )
217 .exec(db)
218 .await
219 .map_err(map_sea_err)?;
220 Ok(())
221}
222
223#[instrument(
227 name = "cognee.db.relational.datasets.count_dataset_data",
228 level = "info",
229 skip_all,
230 fields(
231 cognee.db.system = tracing::field::Empty,
232 cognee.db.row_count = tracing::field::Empty,
233 ),
234 err,
235)]
236pub async fn count_dataset_data(
237 db: &DatabaseConnection,
238 dataset_id: Uuid,
239) -> Result<usize, DatabaseError> {
240 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
241 let count: u64 = dataset_data::Entity::find()
242 .filter(dataset_data::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
243 .count(db)
244 .await
245 .map_err(map_sea_err)?;
246 Span::current().record(COGNEE_DB_ROW_COUNT, count as i64);
247 Ok(count as usize)
248}
249
250#[instrument(
251 name = "cognee.db.relational.datasets.get_dataset_data",
252 level = "info",
253 skip_all,
254 fields(
255 cognee.db.system = tracing::field::Empty,
256 cognee.db.row_count = tracing::field::Empty,
257 ),
258 err,
259)]
260pub async fn get_dataset_data(
261 db: &DatabaseConnection,
262 dataset_id: Uuid,
263) -> Result<Vec<Data>, DatabaseError> {
264 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
265 let pairs = dataset::Entity::find_by_id(uuid_hex::to_hex(dataset_id))
266 .find_with_related(data::Entity)
267 .all(db)
268 .await
269 .map_err(map_sea_err)?;
270 let rows: Vec<Data> = pairs
271 .into_iter()
272 .flat_map(|(_, data_list)| data_list)
273 .map(Data::from)
274 .collect();
275 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
276 Ok(rows)
277}