1use codemem_core::{CodememError, MemoryNode, MemoryType};
6use rusqlite::Connection;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Mutex;
11
12mod backend;
13pub mod cross_repo;
14pub mod graph;
15mod graph_persistence;
16mod memory;
17mod migrations;
18mod queries;
19pub mod vector;
20
21pub use cross_repo::{ApiEndpointEntry, PackageRegistryEntry, UnresolvedRefEntry};
22pub use graph::GraphEngine;
23pub use vector::HnswIndex;
24
25pub(crate) trait MapStorageErr<T> {
27 fn storage_err(self) -> Result<T, CodememError>;
28}
29
30impl<T> MapStorageErr<T> for Result<T, rusqlite::Error> {
31 fn storage_err(self) -> Result<T, CodememError> {
32 self.map_err(|e| CodememError::Storage(e.to_string()))
33 }
34}
35
36pub struct Storage {
41 conn: Mutex<Connection>,
42 in_transaction: AtomicBool,
46}
47
48impl Storage {
49 pub(crate) fn conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, CodememError> {
51 self.conn
52 .lock()
53 .map_err(|e| CodememError::LockPoisoned(format!("Storage mutex: {e}")))
54 }
55
56 fn apply_pragmas(
61 conn: &Connection,
62 cache_size_mb: Option<u32>,
63 busy_timeout_secs: Option<u64>,
64 ) -> Result<(), CodememError> {
65 conn.pragma_update(None, "journal_mode", "WAL")
67 .storage_err()?;
68 let cache_kb = i64::from(cache_size_mb.unwrap_or(64)) * 1000;
70 conn.pragma_update(None, "cache_size", -cache_kb)
71 .storage_err()?;
72 conn.pragma_update(None, "foreign_keys", "ON")
74 .storage_err()?;
75 conn.pragma_update(None, "synchronous", "NORMAL")
77 .storage_err()?;
78 conn.pragma_update(None, "mmap_size", 268435456i64)
80 .storage_err()?;
81 conn.pragma_update(None, "temp_store", "MEMORY")
83 .storage_err()?;
84 let timeout = busy_timeout_secs.unwrap_or(5);
86 conn.busy_timeout(std::time::Duration::from_secs(timeout))
87 .storage_err()?;
88 Ok(())
89 }
90
91 pub fn open(path: &Path) -> Result<Self, CodememError> {
93 Self::open_with_config(path, None, None)
94 }
95
96 pub fn open_with_config(
98 path: &Path,
99 cache_size_mb: Option<u32>,
100 busy_timeout_secs: Option<u64>,
101 ) -> Result<Self, CodememError> {
102 let conn = Connection::open(path).storage_err()?;
103 Self::apply_pragmas(&conn, cache_size_mb, busy_timeout_secs)?;
104 migrations::run_migrations(&conn)?;
105 Ok(Self {
106 conn: Mutex::new(conn),
107 in_transaction: AtomicBool::new(false),
108 })
109 }
110
111 pub fn open_without_migrations(path: &Path) -> Result<Self, CodememError> {
117 Self::open_without_migrations_with_config(path, None, None)
118 }
119
120 pub fn open_without_migrations_with_config(
122 path: &Path,
123 cache_size_mb: Option<u32>,
124 busy_timeout_secs: Option<u64>,
125 ) -> Result<Self, CodememError> {
126 let conn = Connection::open(path).storage_err()?;
127 Self::apply_pragmas(&conn, cache_size_mb, busy_timeout_secs)?;
128 Ok(Self {
129 conn: Mutex::new(conn),
130 in_transaction: AtomicBool::new(false),
131 })
132 }
133
134 pub fn open_in_memory() -> Result<Self, CodememError> {
136 let conn = Connection::open_in_memory().storage_err()?;
137 Self::apply_pragmas(&conn, None, None)?;
138 migrations::run_migrations(&conn)?;
139 Ok(Self {
140 conn: Mutex::new(conn),
141 in_transaction: AtomicBool::new(false),
142 })
143 }
144
145 pub fn content_hash(content: &str) -> String {
147 codemem_core::content_hash(content)
148 }
149
150 pub(crate) fn has_outer_transaction(&self) -> bool {
152 self.in_transaction.load(Ordering::Acquire)
153 }
154}
155
156pub(crate) struct MemoryRow {
158 pub(crate) id: String,
159 pub(crate) content: String,
160 pub(crate) memory_type: String,
161 pub(crate) importance: f64,
162 pub(crate) confidence: f64,
163 pub(crate) access_count: i64,
164 pub(crate) content_hash: String,
165 pub(crate) tags: String,
166 pub(crate) metadata: String,
167 pub(crate) namespace: Option<String>,
168 pub(crate) session_id: Option<String>,
169 pub(crate) repo: Option<String>,
170 pub(crate) git_ref: Option<String>,
171 pub(crate) expires_at: Option<i64>,
172 pub(crate) created_at: i64,
173 pub(crate) updated_at: i64,
174 pub(crate) last_accessed_at: i64,
175}
176
177impl MemoryRow {
178 pub(crate) fn from_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Self> {
179 Ok(Self {
180 id: row.get(0)?,
181 content: row.get(1)?,
182 memory_type: row.get(2)?,
183 importance: row.get(3)?,
184 confidence: row.get(4)?,
185 access_count: row.get(5)?,
186 content_hash: row.get(6)?,
187 tags: row.get(7)?,
188 metadata: row.get(8)?,
189 namespace: row.get(9)?,
190 session_id: row.get(10)?,
191 repo: row.get(11)?,
192 git_ref: row.get(12)?,
193 expires_at: row.get(13)?,
194 created_at: row.get(14)?,
195 updated_at: row.get(15)?,
196 last_accessed_at: row.get(16)?,
197 })
198 }
199
200 pub(crate) fn into_memory_node(self) -> Result<MemoryNode, CodememError> {
201 let memory_type: MemoryType = self.memory_type.parse()?;
202 let tags: Vec<String> = serde_json::from_str(&self.tags).unwrap_or_else(|e| {
203 tracing::warn!(id = %self.id, error = %e, "Malformed tags JSON for memory");
204 Vec::new()
205 });
206 let metadata: HashMap<String, serde_json::Value> = serde_json::from_str(&self.metadata)
207 .unwrap_or_else(|e| {
208 tracing::warn!(id = %self.id, error = %e, "Malformed metadata JSON for memory");
209 HashMap::new()
210 });
211
212 let created_at = chrono::DateTime::from_timestamp(self.created_at, 0)
213 .unwrap_or_else(|| {
214 tracing::warn!(id = %self.id, ts = self.created_at, "Invalid created_at timestamp");
215 chrono::DateTime::<chrono::Utc>::default()
216 })
217 .with_timezone(&chrono::Utc);
218 let updated_at = chrono::DateTime::from_timestamp(self.updated_at, 0)
219 .unwrap_or_else(|| {
220 tracing::warn!(id = %self.id, ts = self.updated_at, "Invalid updated_at timestamp");
221 chrono::DateTime::<chrono::Utc>::default()
222 })
223 .with_timezone(&chrono::Utc);
224 let last_accessed_at = chrono::DateTime::from_timestamp(self.last_accessed_at, 0)
225 .unwrap_or_else(|| {
226 tracing::warn!(id = %self.id, ts = self.last_accessed_at, "Invalid last_accessed_at timestamp");
227 chrono::DateTime::<chrono::Utc>::default()
228 })
229 .with_timezone(&chrono::Utc);
230
231 let expires_at = self
232 .expires_at
233 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
234 .map(|dt| dt.with_timezone(&chrono::Utc));
235
236 Ok(MemoryNode {
237 id: self.id,
238 content: self.content,
239 memory_type,
240 importance: self.importance,
241 confidence: self.confidence,
242 access_count: u32::try_from(self.access_count).unwrap_or(u32::MAX),
243 content_hash: self.content_hash,
244 tags,
245 metadata,
246 namespace: self.namespace,
247 session_id: self.session_id,
248 repo: self.repo,
249 git_ref: self.git_ref,
250 expires_at,
251 created_at,
252 updated_at,
253 last_accessed_at,
254 })
255 }
256}