1use async_trait::async_trait;
4use chrono::Utc;
5use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
6use sea_orm::{DatabaseConnection, QueryOrder, Set, prelude::*};
7use tracing::{Span, instrument};
8use uuid::Uuid;
9
10use crate::conversions::map_sea_err;
11use crate::database_system_label;
12use crate::entities::notebook;
13use crate::traits::{Notebook, NotebookDb, NotebookUpdatePatch};
14use crate::types::DatabaseError;
15use crate::uuid_hex;
16
17fn model_to_notebook(m: notebook::Model) -> Result<Notebook, DatabaseError> {
20 Ok(Notebook {
21 id: uuid_hex::from_hex(&m.id)
22 .map_err(|e| DatabaseError::QueryError(format!("Invalid notebook id hex: {e}")))?,
23 owner_id: uuid_hex::from_hex(&m.owner_id)
24 .map_err(|e| DatabaseError::QueryError(format!("Invalid owner_id hex: {e}")))?,
25 name: m.name,
26 cells: m.cells,
27 deletable: m.deletable,
28 created_at: m.created_at,
29 })
30}
31
32#[async_trait]
35impl NotebookDb for DatabaseConnection {
36 #[instrument(
37 name = "cognee.db.relational.notebooks.list_by_owner",
38 level = "info",
39 skip_all,
40 fields(
41 cognee.db.system = tracing::field::Empty,
42 cognee.db.row_count = tracing::field::Empty,
43 ),
44 err,
45 )]
46 async fn list_by_owner(&self, owner_id: Uuid) -> Result<Vec<Notebook>, DatabaseError> {
47 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
48 let models: Vec<notebook::Model> = notebook::Entity::find()
49 .filter(notebook::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
50 .order_by_asc(notebook::Column::CreatedAt)
51 .all(self)
52 .await
53 .map_err(map_sea_err)?;
54
55 let rows: Vec<Notebook> = models
56 .into_iter()
57 .map(model_to_notebook)
58 .collect::<Result<_, _>>()?;
59 Span::current().record(COGNEE_DB_ROW_COUNT, rows.len() as i64);
60 Ok(rows)
61 }
62
63 #[instrument(
64 name = "cognee.db.relational.notebooks.create",
65 level = "info",
66 skip_all,
67 fields(cognee.db.system = tracing::field::Empty),
68 err,
69 )]
70 async fn create(
71 &self,
72 owner_id: Uuid,
73 name: String,
74 cells: serde_json::Value,
75 deletable: bool,
76 ) -> Result<Notebook, DatabaseError> {
77 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
78 self.create_seeded(Uuid::new_v4(), owner_id, name, cells, deletable)
79 .await
80 }
81
82 #[instrument(
83 name = "cognee.db.relational.notebooks.create_seeded",
84 level = "info",
85 skip_all,
86 fields(cognee.db.system = tracing::field::Empty),
87 err,
88 )]
89 async fn create_seeded(
90 &self,
91 id: Uuid,
92 owner_id: Uuid,
93 name: String,
94 cells: serde_json::Value,
95 deletable: bool,
96 ) -> Result<Notebook, DatabaseError> {
97 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
98 let now = Utc::now();
99
100 let active = notebook::ActiveModel {
101 id: Set(uuid_hex::to_hex(id)),
102 owner_id: Set(uuid_hex::to_hex(owner_id)),
103 name: Set(name),
104 cells: Set(cells),
105 deletable: Set(deletable),
106 created_at: Set(now),
107 };
108
109 active
110 .insert(self)
111 .await
112 .map_err(map_sea_err)
113 .and_then(model_to_notebook)
114 }
115
116 #[instrument(
117 name = "cognee.db.relational.notebooks.get_by_id_and_owner",
118 level = "info",
119 skip_all,
120 fields(
121 cognee.db.system = tracing::field::Empty,
122 cognee.db.row_count = tracing::field::Empty,
123 ),
124 err,
125 )]
126 async fn get_by_id_and_owner(
127 &self,
128 id: Uuid,
129 owner_id: Uuid,
130 ) -> Result<Option<Notebook>, DatabaseError> {
131 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
132 let model = notebook::Entity::find()
133 .filter(notebook::Column::Id.eq(uuid_hex::to_hex(id)))
134 .filter(notebook::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
135 .one(self)
136 .await
137 .map_err(map_sea_err)?;
138
139 let result = model.map(model_to_notebook).transpose()?;
140 Span::current().record(
141 COGNEE_DB_ROW_COUNT,
142 if result.is_some() { 1i64 } else { 0i64 },
143 );
144 Ok(result)
145 }
146
147 #[instrument(
148 name = "cognee.db.relational.notebooks.update",
149 level = "info",
150 skip_all,
151 fields(
152 cognee.db.system = tracing::field::Empty,
153 cognee.db.row_count = tracing::field::Empty,
154 ),
155 err,
156 )]
157 async fn update(
158 &self,
159 id: Uuid,
160 owner_id: Uuid,
161 patch: NotebookUpdatePatch,
162 ) -> Result<Option<Notebook>, DatabaseError> {
163 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
164 let model = notebook::Entity::find()
165 .filter(notebook::Column::Id.eq(uuid_hex::to_hex(id)))
166 .filter(notebook::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
167 .one(self)
168 .await
169 .map_err(map_sea_err)?;
170
171 let Some(model) = model else {
172 Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
173 return Ok(None);
174 };
175
176 let mut active: notebook::ActiveModel = model.into();
177
178 if let Some(new_name) = patch.name {
179 active.name = Set(new_name);
180 }
181 if let Some(new_cells) = patch.cells {
182 active.cells = Set(new_cells);
183 }
184
185 let updated = active.update(self).await.map_err(map_sea_err)?;
186 let result = model_to_notebook(updated).map(Some)?;
187 Span::current().record(
188 COGNEE_DB_ROW_COUNT,
189 if result.is_some() { 1i64 } else { 0i64 },
190 );
191 Ok(result)
192 }
193
194 #[instrument(
195 name = "cognee.db.relational.notebooks.delete",
196 level = "info",
197 skip_all,
198 fields(cognee.db.system = tracing::field::Empty),
199 err,
200 )]
201 async fn delete(&self, id: Uuid, owner_id: Uuid) -> Result<bool, DatabaseError> {
202 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(self));
203 let result = notebook::Entity::delete_many()
204 .filter(notebook::Column::Id.eq(uuid_hex::to_hex(id)))
205 .filter(notebook::Column::OwnerId.eq(uuid_hex::to_hex(owner_id)))
206 .exec(self)
207 .await
208 .map_err(map_sea_err)?;
209
210 Ok(result.rows_affected > 0)
211 }
212}
213
214#[cfg(test)]
217#[allow(
218 clippy::unwrap_used,
219 clippy::expect_used,
220 reason = "test code — panics are acceptable failures"
221)]
222mod tests {
223 use super::*;
224 use crate::{connect, initialize};
225 use serde_json::json;
226
227 async fn in_memory_db() -> DatabaseConnection {
228 let db = connect("sqlite::memory:").await.expect("in-memory SQLite");
229 initialize(&db).await.expect("migrations");
230 db
231 }
232
233 #[tokio::test]
234 async fn sqlite_inmem_round_trip() {
235 let db = in_memory_db().await;
236 let owner_id = Uuid::new_v4();
237
238 let nb = db
240 .create(owner_id, "My Notebook".into(), json!([]), true)
241 .await
242 .expect("create notebook");
243 assert_eq!(nb.owner_id, owner_id);
244 assert_eq!(nb.name, "My Notebook");
245 assert!(nb.deletable);
246
247 let list = db.list_by_owner(owner_id).await.expect("list");
249 assert_eq!(list.len(), 1);
250
251 let fetched = db
253 .get_by_id_and_owner(nb.id, owner_id)
254 .await
255 .expect("get")
256 .expect("Some");
257 assert_eq!(fetched.id, nb.id);
258
259 let patch = NotebookUpdatePatch {
261 name: Some("Renamed".into()),
262 cells: None,
263 };
264 let updated = db
265 .update(nb.id, owner_id, patch)
266 .await
267 .expect("update")
268 .expect("Some");
269 assert_eq!(updated.name, "Renamed");
270
271 let deleted = db.delete(nb.id, owner_id).await.expect("delete");
273 assert!(deleted);
274
275 let list2 = db.list_by_owner(owner_id).await.expect("list2");
276 assert!(list2.is_empty());
277 }
278
279 #[tokio::test]
280 async fn owner_isolation() {
281 let db = in_memory_db().await;
282 let owner_a = Uuid::new_v4();
283 let owner_b = Uuid::new_v4();
284
285 let nb = db
286 .create(owner_a, "A's notebook".into(), json!([]), true)
287 .await
288 .expect("create");
289
290 let result = db.get_by_id_and_owner(nb.id, owner_b).await.expect("get");
292 assert!(result.is_none());
293
294 let deleted = db.delete(nb.id, owner_b).await.expect("delete by B");
295 assert!(!deleted);
296 }
297}