Skip to main content

cognee_database/ops/
dataset_configurations.rs

1use async_trait::async_trait;
2use chrono::Utc;
3use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
4use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, Set, prelude::*};
5use tracing::{Span, instrument};
6use uuid::Uuid;
7
8use crate::conversions::map_sea_err;
9use crate::database_system_label;
10use crate::entities::dataset_configuration;
11use crate::traits::{DatasetConfigDb, DatasetConfiguration, DatasetConfigurationPatch};
12use crate::types::DatabaseError;
13use crate::uuid_hex;
14
15fn model_to_dataset_configuration(
16    m: dataset_configuration::Model,
17) -> Result<DatasetConfiguration, DatabaseError> {
18    Ok(DatasetConfiguration {
19        id: uuid_hex::from_hex(&m.id).map_err(|e| {
20            DatabaseError::QueryError(format!("Invalid dataset configuration id hex: {e}"))
21        })?,
22        dataset_id: uuid_hex::from_hex(&m.dataset_id)
23            .map_err(|e| DatabaseError::QueryError(format!("Invalid dataset_id hex: {e}")))?,
24        graph_schema: m.graph_schema,
25        custom_prompt: m.custom_prompt,
26        created_at: m.created_at,
27        updated_at: m.updated_at,
28    })
29}
30
31#[async_trait]
32impl DatasetConfigDb for DatabaseConnection {
33    #[instrument(
34        name = "cognee.db.relational.dataset_configurations.get_by_dataset_id",
35        level = "info",
36        skip_all,
37        fields(
38            cognee.db.system = tracing::field::Empty,
39            cognee.db.row_count = tracing::field::Empty,
40        ),
41        err,
42    )]
43    async fn get_by_dataset_id(
44        &self,
45        dataset_id: Uuid,
46    ) -> Result<Option<DatasetConfiguration>, DatabaseError> {
47        Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
48        let model = dataset_configuration::Entity::find()
49            .filter(dataset_configuration::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
50            .one(self)
51            .await
52            .map_err(map_sea_err)?;
53
54        let result = model.map(model_to_dataset_configuration).transpose()?;
55        Span::current().record(
56            COGNEE_DB_ROW_COUNT,
57            if result.is_some() { 1i64 } else { 0i64 },
58        );
59        Ok(result)
60    }
61
62    #[instrument(
63        name = "cognee.db.relational.dataset_configurations.upsert",
64        level = "info",
65        skip_all,
66        fields(cognee.db.system = tracing::field::Empty),
67        err,
68    )]
69    async fn upsert(
70        &self,
71        dataset_id: Uuid,
72        patch: DatasetConfigurationPatch,
73    ) -> Result<DatasetConfiguration, DatabaseError> {
74        Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
75        let now = Utc::now();
76        let existing = dataset_configuration::Entity::find()
77            .filter(dataset_configuration::Column::DatasetId.eq(uuid_hex::to_hex(dataset_id)))
78            .one(self)
79            .await
80            .map_err(map_sea_err)?;
81        let has_existing = existing.is_some();
82
83        let active = if let Some(model) = existing {
84            let mut active: dataset_configuration::ActiveModel = model.into();
85            if let Some(graph_schema) = patch.graph_schema {
86                active.graph_schema = Set(Some(graph_schema));
87            }
88            if let Some(custom_prompt) = patch.custom_prompt {
89                active.custom_prompt = Set(Some(custom_prompt));
90            }
91            active.updated_at = Set(Some(now));
92            active
93        } else {
94            dataset_configuration::ActiveModel {
95                id: Set(uuid_hex::to_hex(Uuid::new_v4())),
96                dataset_id: Set(uuid_hex::to_hex(dataset_id)),
97                graph_schema: Set(patch.graph_schema),
98                custom_prompt: Set(patch.custom_prompt),
99                created_at: Set(now),
100                updated_at: Set(None),
101            }
102        };
103
104        let inserted = if has_existing {
105            active.update(self).await.map_err(map_sea_err)?
106        } else {
107            active.insert(self).await.map_err(map_sea_err)?
108        };
109        let result = model_to_dataset_configuration(inserted)?;
110        Span::current().record(COGNEE_DB_ROW_COUNT, 1i64);
111        Ok(result)
112    }
113}
114
115#[cfg(test)]
116#[allow(
117    clippy::unwrap_used,
118    clippy::expect_used,
119    reason = "test code — panics are acceptable failures"
120)]
121mod tests {
122    use super::*;
123    use crate::entities::{dataset, dataset_configuration};
124    use crate::{connect, initialize};
125    use sea_orm::{EntityTrait, Set};
126    use tempfile::TempDir;
127
128    async fn in_memory_db() -> DatabaseConnection {
129        let temp_dir = TempDir::new().expect("temp dir");
130        let db_path = temp_dir.path().join("dataset_configurations.db");
131        std::fs::File::create(&db_path).expect("create sqlite db file");
132        let db_url = format!("sqlite://{}?mode=rwc", db_path.display());
133        let db = connect(&db_url).await.expect("in-memory SQLite");
134        initialize(&db).await.expect("migrations");
135        std::mem::forget(temp_dir);
136        db
137    }
138
139    async fn seed_dataset(db: &DatabaseConnection, dataset_id: Uuid) {
140        let owner_id = Uuid::new_v4();
141        let now = Utc::now();
142        dataset::Entity::insert(dataset::ActiveModel {
143            id: Set(uuid_hex::to_hex(dataset_id)),
144            name: Set("dataset".to_owned()),
145            owner_id: Set(uuid_hex::to_hex(owner_id)),
146            tenant_id: Set(None),
147            created_at: Set(now),
148            updated_at: Set(None),
149        })
150        .exec(db)
151        .await
152        .expect("insert dataset");
153    }
154
155    #[tokio::test]
156    async fn upsert_inserts_new_row() {
157        let db = in_memory_db().await;
158        let dataset_id = Uuid::new_v4();
159        seed_dataset(&db, dataset_id).await;
160
161        let patch = DatasetConfigurationPatch {
162            graph_schema: Some(serde_json::json!({"type": "object"})),
163            custom_prompt: Some("X".to_owned()),
164        };
165        let saved = db.upsert(dataset_id, patch).await.expect("upsert");
166        assert_eq!(saved.dataset_id, dataset_id);
167        assert_eq!(
168            saved.graph_schema,
169            Some(serde_json::json!({"type": "object"}))
170        );
171        assert_eq!(saved.custom_prompt.as_deref(), Some("X"));
172        assert!(saved.updated_at.is_none());
173
174        let fetched = db
175            .get_by_dataset_id(dataset_id)
176            .await
177            .expect("get")
178            .expect("row");
179        assert_eq!(fetched.graph_schema, saved.graph_schema);
180        assert_eq!(fetched.custom_prompt, saved.custom_prompt);
181    }
182
183    #[tokio::test]
184    async fn upsert_updates_existing_row() {
185        let db = in_memory_db().await;
186        let dataset_id = Uuid::new_v4();
187        seed_dataset(&db, dataset_id).await;
188
189        let first = db
190            .upsert(
191                dataset_id,
192                DatasetConfigurationPatch {
193                    graph_schema: Some(serde_json::json!({"type": "object"})),
194                    custom_prompt: Some("X".to_owned()),
195                },
196            )
197            .await
198            .expect("first upsert");
199        let second = db
200            .upsert(
201                dataset_id,
202                DatasetConfigurationPatch {
203                    graph_schema: Some(serde_json::json!({"new": "shape"})),
204                    custom_prompt: None,
205                },
206            )
207            .await
208            .expect("second upsert");
209
210        assert_eq!(
211            second.graph_schema,
212            Some(serde_json::json!({"new": "shape"}))
213        );
214        assert_eq!(second.custom_prompt.as_deref(), Some("X"));
215        assert!(second.updated_at.is_some());
216        assert!(second.updated_at.expect("updated_at").gt(&first.created_at));
217    }
218
219    #[tokio::test]
220    async fn upsert_preserves_existing_field_when_patch_omits_it() {
221        let db = in_memory_db().await;
222        let dataset_id = Uuid::new_v4();
223        seed_dataset(&db, dataset_id).await;
224
225        db.upsert(
226            dataset_id,
227            DatasetConfigurationPatch {
228                graph_schema: Some(serde_json::json!({"type": "object"})),
229                custom_prompt: Some("X".to_owned()),
230            },
231        )
232        .await
233        .expect("seed upsert");
234
235        let updated = db
236            .upsert(
237                dataset_id,
238                DatasetConfigurationPatch {
239                    graph_schema: None,
240                    custom_prompt: Some("Y".to_owned()),
241                },
242            )
243            .await
244            .expect("second upsert");
245
246        assert_eq!(
247            updated.graph_schema,
248            Some(serde_json::json!({"type": "object"}))
249        );
250        assert_eq!(updated.custom_prompt.as_deref(), Some("Y"));
251    }
252
253    #[tokio::test]
254    async fn unique_constraint_enforced() {
255        let db = in_memory_db().await;
256        let dataset_id = Uuid::new_v4();
257        seed_dataset(&db, dataset_id).await;
258
259        let first_id = Uuid::new_v4();
260        let second_id = Uuid::new_v4();
261        let now = Utc::now();
262
263        dataset_configuration::Entity::insert(dataset_configuration::ActiveModel {
264            id: Set(uuid_hex::to_hex(first_id)),
265            dataset_id: Set(uuid_hex::to_hex(dataset_id)),
266            graph_schema: Set(Some(serde_json::json!({"first": true}))),
267            custom_prompt: Set(Some("X".to_owned())),
268            created_at: Set(now),
269            updated_at: Set(None),
270        })
271        .exec(&db)
272        .await
273        .expect("first insert");
274
275        let duplicate = dataset_configuration::Entity::insert(dataset_configuration::ActiveModel {
276            id: Set(uuid_hex::to_hex(second_id)),
277            dataset_id: Set(uuid_hex::to_hex(dataset_id)),
278            graph_schema: Set(Some(serde_json::json!({"second": true}))),
279            custom_prompt: Set(Some("Y".to_owned())),
280            created_at: Set(now),
281            updated_at: Set(None),
282        })
283        .exec(&db)
284        .await;
285
286        let error = duplicate.expect_err("expected unique constraint error");
287        assert!(matches!(
288            map_sea_err(error),
289            DatabaseError::UniqueViolation(_)
290        ));
291    }
292
293    #[tokio::test]
294    async fn cascade_delete_on_dataset_removal() {
295        let db = in_memory_db().await;
296        let dataset_id = Uuid::new_v4();
297        seed_dataset(&db, dataset_id).await;
298
299        db.upsert(
300            dataset_id,
301            DatasetConfigurationPatch {
302                graph_schema: Some(serde_json::json!({"type": "object"})),
303                custom_prompt: Some("X".to_owned()),
304            },
305        )
306        .await
307        .expect("upsert");
308
309        dataset::Entity::delete_by_id(uuid_hex::to_hex(dataset_id))
310            .exec(&db)
311            .await
312            .expect("delete dataset");
313
314        let result = db.get_by_dataset_id(dataset_id).await.expect("get");
315        assert!(result.is_none());
316    }
317
318    #[tokio::test]
319    async fn get_returns_none_when_missing() {
320        let db = in_memory_db().await;
321        let result = db.get_by_dataset_id(Uuid::new_v4()).await.expect("get");
322        assert!(result.is_none());
323    }
324}