Skip to main content

cognee_database/ops/
notebooks.rs

1//! SeaORM implementation of `NotebookDb` on `DatabaseConnection`.
2
3use async_trait::async_trait;
4use chrono::Utc;
5use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
6use sea_orm::{DatabaseConnection, QueryOrder, Set, prelude::*};
7use tracing::{Span, instrument};
8use uuid::Uuid;
9
10use crate::conversions::map_sea_err;
11use crate::database_system_label;
12use crate::entities::notebook;
13use crate::traits::{Notebook, NotebookDb, NotebookUpdatePatch};
14use crate::types::DatabaseError;
15use crate::uuid_hex;
16
17// ─── Model → domain ─────────────────────────────────────────────────────────
18
19fn model_to_notebook(m: notebook::Model) -> Result<Notebook, DatabaseError> {
20    Ok(Notebook {
21        id: uuid_hex::from_hex(&m.id)
22            .map_err(|e| DatabaseError::QueryError(format!("Invalid notebook id hex: {e}")))?,
23        owner_id: uuid_hex::from_hex(&m.owner_id)
24            .map_err(|e| DatabaseError::QueryError(format!("Invalid owner_id hex: {e}")))?,
25        name: m.name,
26        cells: m.cells,
27        deletable: m.deletable,
28        created_at: m.created_at,
29    })
30}
31
32// ─── NotebookDb impl ─────────────────────────────────────────────────────────
33
34#[async_trait]
35impl NotebookDb for DatabaseConnection {
36    #[instrument(
37        name = "cognee.db.relational.notebooks.list_by_owner",
38        level = "info",
39        skip_all,
40        fields(
41            cognee.db.system = tracing::field::Empty,
42            cognee.db.row_count = tracing::field::Empty,
43        ),
44        err,
45    )]
46    async fn list_by_owner(&self, owner_id: Uuid) -> Result<Vec<Notebook>, DatabaseError> {
47        Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
48        let models: Vec<notebook::Model> = notebook::Entity::find()
49            .filter(notebook::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
50            .order_by_asc(notebook::Column::CreatedAt)
51            .all(self)
52            .await
53            .map_err(map_sea_err)?;
54
55        let rows: Vec<Notebook> = models
56            .into_iter()
57            .map(model_to_notebook)
58            .collect::<Result<_, _>>()?;
59        Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
60        Ok(rows)
61    }
62
63    #[instrument(
64        name = "cognee.db.relational.notebooks.create",
65        level = "info",
66        skip_all,
67        fields(cognee.db.system = tracing::field::Empty),
68        err,
69    )]
70    async fn create(
71        &self,
72        owner_id: Uuid,
73        name: String,
74        cells: serde_json::Value,
75        deletable: bool,
76    ) -> Result<Notebook, DatabaseError> {
77        Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
78        self.create_seeded(Uuid::new_v4(), owner_id, name, cells, deletable)
79            .await
80    }
81
82    #[instrument(
83        name = "cognee.db.relational.notebooks.create_seeded",
84        level = "info",
85        skip_all,
86        fields(cognee.db.system = tracing::field::Empty),
87        err,
88    )]
89    async fn create_seeded(
90        &self,
91        id: Uuid,
92        owner_id: Uuid,
93        name: String,
94        cells: serde_json::Value,
95        deletable: bool,
96    ) -> Result<Notebook, DatabaseError> {
97        Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
98        let now = Utc::now();
99
100        let active = notebook::ActiveModel {
101            id: Set(uuid_hex::to_hex(id)),
102            owner_id: Set(uuid_hex::to_hex(owner_id)),
103            name: Set(name),
104            cells: Set(cells),
105            deletable: Set(deletable),
106            created_at: Set(now),
107        };
108
109        active
110            .insert(self)
111            .await
112            .map_err(map_sea_err)
113            .and_then(model_to_notebook)
114    }
115
116    #[instrument(
117        name = "cognee.db.relational.notebooks.get_by_id_and_owner",
118        level = "info",
119        skip_all,
120        fields(
121            cognee.db.system = tracing::field::Empty,
122            cognee.db.row_count = tracing::field::Empty,
123        ),
124        err,
125    )]
126    async fn get_by_id_and_owner(
127        &self,
128        id: Uuid,
129        owner_id: Uuid,
130    ) -> Result<Option<Notebook>, DatabaseError> {
131        Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
132        let model = notebook::Entity::find()
133            .filter(notebook::Column::Id.eq(uuid_hex::to_hex(id)))
134            .filter(notebook::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
135            .one(self)
136            .await
137            .map_err(map_sea_err)?;
138
139        let result = model.map(model_to_notebook).transpose()?;
140        Span::current().record(
141            COGNEE_DB_ROW_COUNT,
142            if result.is_some() { 1i64 } else { 0i64 },
143        );
144        Ok(result)
145    }
146
147    #[instrument(
148        name = "cognee.db.relational.notebooks.update",
149        level = "info",
150        skip_all,
151        fields(
152            cognee.db.system = tracing::field::Empty,
153            cognee.db.row_count = tracing::field::Empty,
154        ),
155        err,
156    )]
157    async fn update(
158        &self,
159        id: Uuid,
160        owner_id: Uuid,
161        patch: NotebookUpdatePatch,
162    ) -> Result<Option<Notebook>, DatabaseError> {
163        Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
164        let model = notebook::Entity::find()
165            .filter(notebook::Column::Id.eq(uuid_hex::to_hex(id)))
166            .filter(notebook::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
167            .one(self)
168            .await
169            .map_err(map_sea_err)?;
170
171        let Some(model) = model else {
172            Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
173            return Ok(None);
174        };
175
176        let mut active: notebook::ActiveModel = model.into();
177
178        if let Some(new_name) = patch.name {
179            active.name = Set(new_name);
180        }
181        if let Some(new_cells) = patch.cells {
182            active.cells = Set(new_cells);
183        }
184
185        let updated = active.update(self).await.map_err(map_sea_err)?;
186        let result = model_to_notebook(updated).map(Some)?;
187        Span::current().record(
188            COGNEE_DB_ROW_COUNT,
189            if result.is_some() { 1i64 } else { 0i64 },
190        );
191        Ok(result)
192    }
193
194    #[instrument(
195        name = "cognee.db.relational.notebooks.delete",
196        level = "info",
197        skip_all,
198        fields(cognee.db.system = tracing::field::Empty),
199        err,
200    )]
201    async fn delete(&self, id: Uuid, owner_id: Uuid) -> Result<bool, DatabaseError> {
202        Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
203        let result = notebook::Entity::delete_many()
204            .filter(notebook::Column::Id.eq(uuid_hex::to_hex(id)))
205            .filter(notebook::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
206            .exec(self)
207            .await
208            .map_err(map_sea_err)?;
209
210        Ok(result.rows_affected > 0)
211    }
212}
213
214// ─── Tests ───────────────────────────────────────────────────────────────────
215
216#[cfg(test)]
217#[allow(
218    clippy::unwrap_used,
219    clippy::expect_used,
220    reason = "test code — panics are acceptable failures"
221)]
222mod tests {
223    use super::*;
224    use crate::{connect, initialize};
225    use serde_json::json;
226
227    async fn in_memory_db() -> DatabaseConnection {
228        let db = connect("sqlite::memory:").await.expect("in-memory SQLite");
229        initialize(&db).await.expect("migrations");
230        db
231    }
232
233    #[tokio::test]
234    async fn sqlite_inmem_round_trip() {
235        let db = in_memory_db().await;
236        let owner_id = Uuid::new_v4();
237
238        // Create
239        let nb = db
240            .create(owner_id, "My Notebook".into(), json!([]), true)
241            .await
242            .expect("create notebook");
243        assert_eq!(nb.owner_id, owner_id);
244        assert_eq!(nb.name, "My Notebook");
245        assert!(nb.deletable);
246
247        // List
248        let list = db.list_by_owner(owner_id).await.expect("list");
249        assert_eq!(list.len(), 1);
250
251        // Get by id
252        let fetched = db
253            .get_by_id_and_owner(nb.id, owner_id)
254            .await
255            .expect("get")
256            .expect("Some");
257        assert_eq!(fetched.id, nb.id);
258
259        // Update name
260        let patch = NotebookUpdatePatch {
261            name: Some("Renamed".into()),
262            cells: None,
263        };
264        let updated = db
265            .update(nb.id, owner_id, patch)
266            .await
267            .expect("update")
268            .expect("Some");
269        assert_eq!(updated.name, "Renamed");
270
271        // Delete
272        let deleted = db.delete(nb.id, owner_id).await.expect("delete");
273        assert!(deleted);
274
275        let list2 = db.list_by_owner(owner_id).await.expect("list2");
276        assert!(list2.is_empty());
277    }
278
279    #[tokio::test]
280    async fn owner_isolation() {
281        let db = in_memory_db().await;
282        let owner_a = Uuid::new_v4();
283        let owner_b = Uuid::new_v4();
284
285        let nb = db
286            .create(owner_a, "A's notebook".into(), json!([]), true)
287            .await
288            .expect("create");
289
290        // B cannot see A's notebook
291        let result = db.get_by_id_and_owner(nb.id, owner_b).await.expect("get");
292        assert!(result.is_none());
293
294        let deleted = db.delete(nb.id, owner_b).await.expect("delete by B");
295        assert!(!deleted);
296    }
297}