1use std::time::Duration;
3use serde::{Deserialize, Serialize};
4use tokio::sync::OnceCell;
5use tracing::{error, info, instrument};
6
7use candor_core::error::CoreError;
8
9static SCHEMA_INIT: OnceCell<()> = OnceCell::const_new();
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct MemoryBlock {
14 pub project_id: String,
15 pub textual_content: String,
16 pub semantic_embedding: Vec<f32>,
17 pub timestamp: surrealdb::sql::Datetime,
18}
19
20pub struct MemorySystem {
22 db: surrealdb::Surreal<surrealdb::engine::local::Db>,
23 embedding_dim: usize,
24}
25
26impl MemorySystem {
27 pub async fn new(embedding_dim: usize) -> Result<Self, CoreError> {
30 info!("Creating SurrealDB connection (schema init deferred)");
31
32 let db = surrealdb::Surreal::new::<surrealdb::engine::local::Mem>(())
33 .await
34 .map_err(|e| CoreError::Internal(format!("SurrealDB connect failed: {e}")))?;
35
36 db.use_ns("candor_namespace")
37 .use_db("candor_database")
38 .await
39 .map_err(|e| CoreError::Internal(format!("SurrealDB ns/db error: {e}")))?;
40
41 info!("SurrealDB memory engine ready (lazy schema)");
42 Ok(Self {
43 db,
44 embedding_dim,
45 })
46 }
47
48 async fn ensure_schema(&self) -> Result<(), CoreError> {
50 SCHEMA_INIT
51 .get_or_try_init(|| async {
52 info!("Running lazy SurrealDB schema init");
53 let schema_queries = super::schema::schema_queries(self.embedding_dim);
54 let mut qr = self.db.query(&schema_queries).await
55 .map_err(|e| CoreError::Internal(format!("Schema query failed: {e}")))?;
56
57 if !qr.take_errors().is_empty() {
58 error!("Schema definition errors");
59 return Err(CoreError::Internal(
60 "Database schema init failure — check embedding dimension".into(),
61 ));
62 }
63
64 info!("SurrealDB schema initialized");
65 Ok(())
66 })
67 .await
68 .map(|_| ())
69 }
70
71 #[instrument(skip(self, embedding))]
73 pub async fn store_memory(
74 &self,
75 project_id: String,
76 content: String,
77 embedding: Vec<f32>,
78 ) -> Result<(), CoreError> {
79 self.ensure_schema().await?;
80
81 let entry = MemoryBlock {
82 project_id,
83 textual_content: content,
84 semantic_embedding: embedding,
85 timestamp: surrealdb::sql::Datetime::default(),
86 };
87
88 tokio::time::timeout(Duration::from_secs(5), async {
89 let _created: Option<MemoryBlock> = self.db
90 .create("memory_block")
91 .content(entry)
92 .await
93 .map_err(|e| CoreError::Internal(format!("Store failed: {e}")))?;
94 Ok::<_, CoreError>(())
95 })
96 .await
97 .map_err(|_| CoreError::Internal("Store memory timed out after 5s".into()))??;
98
99 info!("Memory block persisted");
100 Ok(())
101 }
102
103 #[instrument(skip(self, query_embedding))]
104 pub async fn retrieve_context(
105 &self,
106 project_id: &str,
107 query_embedding: Vec<f32>,
108 top_k: u32,
109 ) -> Result<Vec<String>, CoreError> {
110 self.ensure_schema().await?;
111
112 let sql = "
113 SELECT textual_content, vector::similarity::cosine(semantic_embedding, $query_vector) AS sim
114 FROM memory_block
115 WHERE project_id = $pid
116 ORDER BY sim DESC
117 LIMIT $limit;
118 ";
119
120 let contents: Vec<String> = tokio::time::timeout(Duration::from_secs(5), async {
121 let mut result = self.db.query(sql)
122 .bind(("query_vector", query_embedding))
123 .bind(("pid", project_id.to_string()))
124 .bind(("limit", top_k))
125 .await
126 .map_err(|e| CoreError::Internal(format!("Retrieve failed: {e}")))?;
127
128 let contents: Vec<String> = result
129 .take::<Vec<serde_json::Value>>(0)
130 .map_err(|e| CoreError::Internal(format!("Deserialize failed: {e}")))?
131 .into_iter()
132 .filter_map(|val| val.get("textual_content")?.as_str().map(|s| s.to_string()))
133 .collect();
134 Ok::<_, CoreError>(contents)
135 })
136 .await
137 .map_err(|_| CoreError::Internal("Retrieve context timed out after 5s".into()))??;
138
139 info!(count = contents.len(), "Context retrieved");
140 Ok(contents)
141 }
142
143 pub async fn store_execution_log(
144 &self,
145 session_id: &str,
146 phase: &str,
147 action: &str,
148 result: &str,
149 ) -> Result<(), CoreError> {
150 self.ensure_schema().await?;
151
152 #[derive(Debug, Serialize, Deserialize)]
153 struct LogEntry {
154 session_id: String,
155 phase: String,
156 action: String,
157 result: String,
158 timestamp: surrealdb::sql::Datetime,
159 }
160
161 let entry = LogEntry {
162 session_id: session_id.to_string(),
163 phase: phase.to_string(),
164 action: action.to_string(),
165 result: result.to_string(),
166 timestamp: surrealdb::sql::Datetime::default(),
167 };
168
169 let _created: Option<LogEntry> = tokio::time::timeout(Duration::from_secs(5), async {
170 let created: Option<LogEntry> = self.db
171 .create("execution_log")
172 .content(entry)
173 .await
174 .map_err(|e| CoreError::Internal(format!("Store log failed: {e}")))?;
175 Ok::<_, CoreError>(created)
176 })
177 .await
178 .map_err(|_| CoreError::Internal("Store execution log timed out after 5s".into()))??;
179
180 Ok(())
181 }
182
183 pub async fn delete_project_memories(&self, project_id: &str) -> Result<(), CoreError> {
184 self.ensure_schema().await?;
185
186 tokio::time::timeout(Duration::from_secs(5), async {
187 self.db
188 .query("DELETE FROM memory_block WHERE project_id = $pid")
189 .bind(("pid", project_id.to_string()))
190 .await
191 .map_err(|e| CoreError::Internal(format!("Delete failed: {e}")))
192 })
193 .await
194 .map_err(|_| CoreError::Internal("Delete project memories timed out after 5s".into()))??;
195
196 Ok(())
197 }
198
199 pub async fn get_all_execution_logs(&self) -> Result<Vec<ExecutionLogEntry>, CoreError> {
200 self.ensure_schema().await?;
201
202 #[derive(Debug, Clone, Serialize, Deserialize)]
203 struct RawLog {
204 session_id: String,
205 phase: String,
206 action: String,
207 result: String,
208 timestamp: surrealdb::sql::Datetime,
209 }
210
211 let rows: Vec<RawLog> = tokio::time::timeout(Duration::from_secs(5), async {
212 let rows: Vec<RawLog> = self
213 .db
214 .query("SELECT session_id, phase, action, result, timestamp FROM execution_log ORDER BY timestamp ASC")
215 .await
216 .map_err(|e| CoreError::Internal(format!("Query execution logs failed: {e}")))?
217 .take(0)
218 .map_err(|e| CoreError::Internal(format!("Deserialize execution logs failed: {e}")))?;
219 Ok::<_, CoreError>(rows)
220 })
221 .await
222 .map_err(|_| CoreError::Internal("Get all execution logs timed out after 5s".into()))??;
223
224 Ok(rows.into_iter().map(|r| ExecutionLogEntry {
225 session_id: r.session_id,
226 phase: r.phase,
227 action: r.action,
228 result: r.result,
229 timestamp: r.timestamp,
230 }).collect())
231 }
232
233 pub async fn delete_all_execution_logs(&self) -> Result<(), CoreError> {
235 self.ensure_schema().await?;
236
237 tokio::time::timeout(Duration::from_secs(5), async {
238 self.db
239 .query("DELETE FROM execution_log")
240 .await
241 .map_err(|e| CoreError::Internal(format!("Delete execution logs failed: {e}")))
242 })
243 .await
244 .map_err(|_| CoreError::Internal("Delete all execution logs timed out after 5s".into()))??;
245
246 Ok(())
247 }
248
249 pub async fn get_execution_logs_by_session(
251 &self,
252 session_id: &str,
253 ) -> Result<Vec<ExecutionLogEntry>, CoreError> {
254 self.ensure_schema().await?;
255
256 #[derive(Debug, Clone, Serialize, Deserialize)]
257 struct RawLog {
258 session_id: String,
259 phase: String,
260 action: String,
261 result: String,
262 timestamp: surrealdb::sql::Datetime,
263 }
264
265 let rows: Vec<RawLog> = tokio::time::timeout(Duration::from_secs(5), async {
266 let rows: Vec<RawLog> = self
267 .db
268 .query("SELECT session_id, phase, action, result, timestamp FROM execution_log WHERE session_id = $sid ORDER BY timestamp ASC")
269 .bind(("sid", session_id.to_string()))
270 .await
271 .map_err(|e| CoreError::Internal(format!("Query session logs failed: {e}")))?
272 .take(0)
273 .map_err(|e| CoreError::Internal(format!("Deserialize session logs failed: {e}")))?;
274 Ok::<_, CoreError>(rows)
275 })
276 .await
277 .map_err(|_| CoreError::Internal("Get execution logs by session timed out after 5s".into()))??;
278
279 Ok(rows.into_iter().map(|r| ExecutionLogEntry {
280 session_id: r.session_id,
281 phase: r.phase,
282 action: r.action,
283 result: r.result,
284 timestamp: r.timestamp,
285 }).collect())
286 }
287
288 pub fn embedding_dim(&self) -> usize {
289 self.embedding_dim
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct ExecutionLogEntry {
296 pub session_id: String,
297 pub phase: String,
298 pub action: String,
299 pub result: String,
300 pub timestamp: surrealdb::sql::Datetime,
301}