1use codemem_core::{CodememError, MemoryNode, MemoryType};
6use rusqlite::Connection;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Mutex;
11
12mod backend;
13pub mod graph;
14mod graph_persistence;
15mod memory;
16mod migrations;
17mod queries;
18pub mod vector;
19
20pub use graph::GraphEngine;
21pub use vector::HnswIndex;
22
23pub(crate) trait MapStorageErr<T> {
25 fn storage_err(self) -> Result<T, CodememError>;
26}
27
28impl<T> MapStorageErr<T> for Result<T, rusqlite::Error> {
29 fn storage_err(self) -> Result<T, CodememError> {
30 self.map_err(|e| CodememError::Storage(e.to_string()))
31 }
32}
33
34pub struct Storage {
39 conn: Mutex<Connection>,
40 in_transaction: AtomicBool,
44}
45
46impl Storage {
47 pub(crate) fn conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, CodememError> {
49 self.conn
50 .lock()
51 .map_err(|e| CodememError::LockPoisoned(format!("Storage mutex: {e}")))
52 }
53
54 fn apply_pragmas(
59 conn: &Connection,
60 cache_size_mb: Option<u32>,
61 busy_timeout_secs: Option<u64>,
62 ) -> Result<(), CodememError> {
63 conn.pragma_update(None, "journal_mode", "WAL")
65 .storage_err()?;
66 let cache_kb = i64::from(cache_size_mb.unwrap_or(64)) * 1000;
68 conn.pragma_update(None, "cache_size", -cache_kb)
69 .storage_err()?;
70 conn.pragma_update(None, "foreign_keys", "ON")
72 .storage_err()?;
73 conn.pragma_update(None, "synchronous", "NORMAL")
75 .storage_err()?;
76 conn.pragma_update(None, "mmap_size", 268435456i64)
78 .storage_err()?;
79 conn.pragma_update(None, "temp_store", "MEMORY")
81 .storage_err()?;
82 let timeout = busy_timeout_secs.unwrap_or(5);
84 conn.busy_timeout(std::time::Duration::from_secs(timeout))
85 .storage_err()?;
86 Ok(())
87 }
88
89 pub fn open(path: &Path) -> Result<Self, CodememError> {
91 Self::open_with_config(path, None, None)
92 }
93
94 pub fn open_with_config(
96 path: &Path,
97 cache_size_mb: Option<u32>,
98 busy_timeout_secs: Option<u64>,
99 ) -> Result<Self, CodememError> {
100 let conn = Connection::open(path).storage_err()?;
101 Self::apply_pragmas(&conn, cache_size_mb, busy_timeout_secs)?;
102 migrations::run_migrations(&conn)?;
103 Ok(Self {
104 conn: Mutex::new(conn),
105 in_transaction: AtomicBool::new(false),
106 })
107 }
108
109 pub fn open_without_migrations(path: &Path) -> Result<Self, CodememError> {
115 Self::open_without_migrations_with_config(path, None, None)
116 }
117
118 pub fn open_without_migrations_with_config(
120 path: &Path,
121 cache_size_mb: Option<u32>,
122 busy_timeout_secs: Option<u64>,
123 ) -> Result<Self, CodememError> {
124 let conn = Connection::open(path).storage_err()?;
125 Self::apply_pragmas(&conn, cache_size_mb, busy_timeout_secs)?;
126 Ok(Self {
127 conn: Mutex::new(conn),
128 in_transaction: AtomicBool::new(false),
129 })
130 }
131
132 pub fn open_in_memory() -> Result<Self, CodememError> {
134 let conn = Connection::open_in_memory().storage_err()?;
135 Self::apply_pragmas(&conn, None, None)?;
136 migrations::run_migrations(&conn)?;
137 Ok(Self {
138 conn: Mutex::new(conn),
139 in_transaction: AtomicBool::new(false),
140 })
141 }
142
143 pub fn content_hash(content: &str) -> String {
145 codemem_core::content_hash(content)
146 }
147
148 pub(crate) fn has_outer_transaction(&self) -> bool {
150 self.in_transaction.load(Ordering::Acquire)
151 }
152}
153
154pub(crate) struct MemoryRow {
156 pub(crate) id: String,
157 pub(crate) content: String,
158 pub(crate) memory_type: String,
159 pub(crate) importance: f64,
160 pub(crate) confidence: f64,
161 pub(crate) access_count: i64,
162 pub(crate) content_hash: String,
163 pub(crate) tags: String,
164 pub(crate) metadata: String,
165 pub(crate) namespace: Option<String>,
166 pub(crate) session_id: Option<String>,
167 pub(crate) created_at: i64,
168 pub(crate) updated_at: i64,
169 pub(crate) last_accessed_at: i64,
170}
171
172impl MemoryRow {
173 pub(crate) fn from_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Self> {
174 Ok(Self {
175 id: row.get(0)?,
176 content: row.get(1)?,
177 memory_type: row.get(2)?,
178 importance: row.get(3)?,
179 confidence: row.get(4)?,
180 access_count: row.get(5)?,
181 content_hash: row.get(6)?,
182 tags: row.get(7)?,
183 metadata: row.get(8)?,
184 namespace: row.get(9)?,
185 session_id: row.get(10)?,
186 created_at: row.get(11)?,
187 updated_at: row.get(12)?,
188 last_accessed_at: row.get(13)?,
189 })
190 }
191
192 pub(crate) fn into_memory_node(self) -> Result<MemoryNode, CodememError> {
193 let memory_type: MemoryType = self.memory_type.parse()?;
194 let tags: Vec<String> = serde_json::from_str(&self.tags).unwrap_or_else(|e| {
195 tracing::warn!(id = %self.id, error = %e, "Malformed tags JSON for memory");
196 Vec::new()
197 });
198 let metadata: HashMap<String, serde_json::Value> = serde_json::from_str(&self.metadata)
199 .unwrap_or_else(|e| {
200 tracing::warn!(id = %self.id, error = %e, "Malformed metadata JSON for memory");
201 HashMap::new()
202 });
203
204 let created_at = chrono::DateTime::from_timestamp(self.created_at, 0)
205 .unwrap_or_else(|| {
206 tracing::warn!(id = %self.id, ts = self.created_at, "Invalid created_at timestamp");
207 chrono::DateTime::<chrono::Utc>::default()
208 })
209 .with_timezone(&chrono::Utc);
210 let updated_at = chrono::DateTime::from_timestamp(self.updated_at, 0)
211 .unwrap_or_else(|| {
212 tracing::warn!(id = %self.id, ts = self.updated_at, "Invalid updated_at timestamp");
213 chrono::DateTime::<chrono::Utc>::default()
214 })
215 .with_timezone(&chrono::Utc);
216 let last_accessed_at = chrono::DateTime::from_timestamp(self.last_accessed_at, 0)
217 .unwrap_or_else(|| {
218 tracing::warn!(id = %self.id, ts = self.last_accessed_at, "Invalid last_accessed_at timestamp");
219 chrono::DateTime::<chrono::Utc>::default()
220 })
221 .with_timezone(&chrono::Utc);
222
223 Ok(MemoryNode {
224 id: self.id,
225 content: self.content,
226 memory_type,
227 importance: self.importance,
228 confidence: self.confidence,
229 access_count: u32::try_from(self.access_count).unwrap_or(u32::MAX),
230 content_hash: self.content_hash,
231 tags,
232 metadata,
233 namespace: self.namespace,
234 session_id: self.session_id,
235 created_at,
236 updated_at,
237 last_accessed_at,
238 })
239 }
240}