1use async_trait::async_trait;
2use chrono::Utc;
3use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
4use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, Set, prelude::*};
5use tracing::{Span, instrument};
6use uuid::Uuid;
7
8use crate::conversions::map_sea_err;
9use crate::database_system_label;
10use crate::entities::dataset_configuration;
11use crate::traits::{DatasetConfigDb, DatasetConfiguration, DatasetConfigurationPatch};
12use crate::types::DatabaseError;
13use crate::uuid_hex;
14
15fn model_to_dataset_configuration(
16 m: dataset_configuration::Model,
17) -> Result<DatasetConfiguration, DatabaseError> {
18 Ok(DatasetConfiguration {
19 id: uuid_hex::from_hex(&m.id).map_err(|e| {
20 DatabaseError::QueryError(format!("Invalid dataset configuration id hex: {e}"))
21 })?,
22 dataset_id: uuid_hex::from_hex(&m.dataset_id)
23 .map_err(|e| DatabaseError::QueryError(format!("Invalid dataset_id hex: {e}")))?,
24 graph_schema: m.graph_schema,
25 custom_prompt: m.custom_prompt,
26 created_at: m.created_at,
27 updated_at: m.updated_at,
28 })
29}
30
31#[async_trait]
32impl DatasetConfigDb for DatabaseConnection {
33 #[instrument(
34 name = "cognee.db.relational.dataset_configurations.get_by_dataset_id",
35 level = "info",
36 skip_all,
37 fields(
38 cognee.db.system = tracing::field::Empty,
39 cognee.db.row_count = tracing::field::Empty,
40 ),
41 err,
42 )]
43 async fn get_by_dataset_id(
44 &self,
45 dataset_id: Uuid,
46 ) -> Result<Option<DatasetConfiguration>, DatabaseError> {
47 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
48 let model = dataset_configuration::Entity::find()
49 .filter(dataset_configuration::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
50 .one(self)
51 .await
52 .map_err(map_sea_err)?;
53
54 let result = model.map(model_to_dataset_configuration).transpose()?;
55 Span::current().record(
56 COGNEE_DB_ROW_COUNT,
57 if result.is_some() { 1i64 } else { 0i64 },
58 );
59 Ok(result)
60 }
61
62 #[instrument(
63 name = "cognee.db.relational.dataset_configurations.upsert",
64 level = "info",
65 skip_all,
66 fields(cognee.db.system = tracing::field::Empty),
67 err,
68 )]
69 async fn upsert(
70 &self,
71 dataset_id: Uuid,
72 patch: DatasetConfigurationPatch,
73 ) -> Result<DatasetConfiguration, DatabaseError> {
74 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
75 let now = Utc::now();
76 let existing = dataset_configuration::Entity::find()
77 .filter(dataset_configuration::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
78 .one(self)
79 .await
80 .map_err(map_sea_err)?;
81 let has_existing = existing.is_some();
82
83 let active = if let Some(model) = existing {
84 let mut active: dataset_configuration::ActiveModel = model.into();
85 if let Some(graph_schema) = patch.graph_schema {
86 active.graph_schema = Set(Some(graph_schema));
87 }
88 if let Some(custom_prompt) = patch.custom_prompt {
89 active.custom_prompt = Set(Some(custom_prompt));
90 }
91 active.updated_at = Set(Some(now));
92 active
93 } else {
94 dataset_configuration::ActiveModel {
95 id: Set(uuid_hex::to_hex(Uuid::new_v4())),
96 dataset_id: Set(uuid_hex::to_hex(dataset_id)),
97 graph_schema: Set(patch.graph_schema),
98 custom_prompt: Set(patch.custom_prompt),
99 created_at: Set(now),
100 updated_at: Set(None),
101 }
102 };
103
104 let inserted = if has_existing {
105 active.update(self).await.map_err(map_sea_err)?
106 } else {
107 active.insert(self).await.map_err(map_sea_err)?
108 };
109 let result = model_to_dataset_configuration(inserted)?;
110 Span::current().record(COGNEE_DB_ROW_COUNT, 1i64);
111 Ok(result)
112 }
113}
114
115#[cfg(test)]
116#[allow(
117 clippy::unwrap_used,
118 clippy::expect_used,
119 reason = "test code — panics are acceptable failures"
120)]
121mod tests {
122 use super::*;
123 use crate::entities::{dataset, dataset_configuration};
124 use crate::{connect, initialize};
125 use sea_orm::{EntityTrait, Set};
126 use tempfile::TempDir;
127
128 async fn in_memory_db() -> DatabaseConnection {
129 let temp_dir = TempDir::new().expect("temp dir");
130 let db_path = temp_dir.path().join("dataset_configurations.db");
131 std::fs::File::create(&db_path).expect("create sqlite db file");
132 let db_url = format!("sqlite://{}?mode=rwc", db_path.display());
133 let db = connect(&db_url).await.expect("in-memory SQLite");
134 initialize(&db).await.expect("migrations");
135 std::mem::forget(temp_dir);
136 db
137 }
138
139 async fn seed_dataset(db: &DatabaseConnection, dataset_id: Uuid) {
140 let owner_id = Uuid::new_v4();
141 let now = Utc::now();
142 dataset::Entity::insert(dataset::ActiveModel {
143 id: Set(uuid_hex::to_hex(dataset_id)),
144 name: Set("dataset".to_owned()),
145 owner_id: Set(uuid_hex::to_hex(owner_id)),
146 tenant_id: Set(None),
147 created_at: Set(now),
148 updated_at: Set(None),
149 })
150 .exec(db)
151 .await
152 .expect("insert dataset");
153 }
154
155 #[tokio::test]
156 async fn upsert_inserts_new_row() {
157 let db = in_memory_db().await;
158 let dataset_id = Uuid::new_v4();
159 seed_dataset(&db, dataset_id).await;
160
161 let patch = DatasetConfigurationPatch {
162 graph_schema: Some(serde_json::json!({"type": "object"})),
163 custom_prompt: Some("X".to_owned()),
164 };
165 let saved = db.upsert(dataset_id, patch).await.expect("upsert");
166 assert_eq!(saved.dataset_id, dataset_id);
167 assert_eq!(
168 saved.graph_schema,
169 Some(serde_json::json!({"type": "object"}))
170 );
171 assert_eq!(saved.custom_prompt.as_deref(), Some("X"));
172 assert!(saved.updated_at.is_none());
173
174 let fetched = db
175 .get_by_dataset_id(dataset_id)
176 .await
177 .expect("get")
178 .expect("row");
179 assert_eq!(fetched.graph_schema, saved.graph_schema);
180 assert_eq!(fetched.custom_prompt, saved.custom_prompt);
181 }
182
183 #[tokio::test]
184 async fn upsert_updates_existing_row() {
185 let db = in_memory_db().await;
186 let dataset_id = Uuid::new_v4();
187 seed_dataset(&db, dataset_id).await;
188
189 let first = db
190 .upsert(
191 dataset_id,
192 DatasetConfigurationPatch {
193 graph_schema: Some(serde_json::json!({"type": "object"})),
194 custom_prompt: Some("X".to_owned()),
195 },
196 )
197 .await
198 .expect("first upsert");
199 let second = db
200 .upsert(
201 dataset_id,
202 DatasetConfigurationPatch {
203 graph_schema: Some(serde_json::json!({"new": "shape"})),
204 custom_prompt: None,
205 },
206 )
207 .await
208 .expect("second upsert");
209
210 assert_eq!(
211 second.graph_schema,
212 Some(serde_json::json!({"new": "shape"}))
213 );
214 assert_eq!(second.custom_prompt.as_deref(), Some("X"));
215 assert!(second.updated_at.is_some());
216 assert!(second.updated_at.expect("updated_at").gt(&first.created_at));
217 }
218
219 #[tokio::test]
220 async fn upsert_preserves_existing_field_when_patch_omits_it() {
221 let db = in_memory_db().await;
222 let dataset_id = Uuid::new_v4();
223 seed_dataset(&db, dataset_id).await;
224
225 db.upsert(
226 dataset_id,
227 DatasetConfigurationPatch {
228 graph_schema: Some(serde_json::json!({"type": "object"})),
229 custom_prompt: Some("X".to_owned()),
230 },
231 )
232 .await
233 .expect("seed upsert");
234
235 let updated = db
236 .upsert(
237 dataset_id,
238 DatasetConfigurationPatch {
239 graph_schema: None,
240 custom_prompt: Some("Y".to_owned()),
241 },
242 )
243 .await
244 .expect("second upsert");
245
246 assert_eq!(
247 updated.graph_schema,
248 Some(serde_json::json!({"type": "object"}))
249 );
250 assert_eq!(updated.custom_prompt.as_deref(), Some("Y"));
251 }
252
253 #[tokio::test]
254 async fn unique_constraint_enforced() {
255 let db = in_memory_db().await;
256 let dataset_id = Uuid::new_v4();
257 seed_dataset(&db, dataset_id).await;
258
259 let first_id = Uuid::new_v4();
260 let second_id = Uuid::new_v4();
261 let now = Utc::now();
262
263 dataset_configuration::Entity::insert(dataset_configuration::ActiveModel {
264 id: Set(uuid_hex::to_hex(first_id)),
265 dataset_id: Set(uuid_hex::to_hex(dataset_id)),
266 graph_schema: Set(Some(serde_json::json!({"first": true}))),
267 custom_prompt: Set(Some("X".to_owned())),
268 created_at: Set(now),
269 updated_at: Set(None),
270 })
271 .exec(&db)
272 .await
273 .expect("first insert");
274
275 let duplicate = dataset_configuration::Entity::insert(dataset_configuration::ActiveModel {
276 id: Set(uuid_hex::to_hex(second_id)),
277 dataset_id: Set(uuid_hex::to_hex(dataset_id)),
278 graph_schema: Set(Some(serde_json::json!({"second": true}))),
279 custom_prompt: Set(Some("Y".to_owned())),
280 created_at: Set(now),
281 updated_at: Set(None),
282 })
283 .exec(&db)
284 .await;
285
286 let error = duplicate.expect_err("expected unique constraint error");
287 assert!(matches!(
288 map_sea_err(error),
289 DatabaseError::UniqueViolation(_)
290 ));
291 }
292
293 #[tokio::test]
294 async fn cascade_delete_on_dataset_removal() {
295 let db = in_memory_db().await;
296 let dataset_id = Uuid::new_v4();
297 seed_dataset(&db, dataset_id).await;
298
299 db.upsert(
300 dataset_id,
301 DatasetConfigurationPatch {
302 graph_schema: Some(serde_json::json!({"type": "object"})),
303 custom_prompt: Some("X".to_owned()),
304 },
305 )
306 .await
307 .expect("upsert");
308
309 dataset::Entity::delete_by_id(uuid_hex::to_hex(dataset_id))
310 .exec(&db)
311 .await
312 .expect("delete dataset");
313
314 let result = db.get_by_dataset_id(dataset_id).await.expect("get");
315 assert!(result.is_none());
316 }
317
318 #[tokio::test]
319 async fn get_returns_none_when_missing() {
320 let db = in_memory_db().await;
321 let result = db.get_by_dataset_id(Uuid::new_v4()).await.expect("get");
322 assert!(result.is_none());
323 }
324}