1use std::path::Path;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::Value as JsonValue;
7use surrealdb::Surreal;
8use surrealdb::engine::local::{Db, SurrealKv};
9use surrealdb::sql::Thing;
10
11use crate::error::{LLMBrainError, Result};
12use crate::models::MemoryFragment;
13
14#[async_trait]
16pub trait Store: Send + Sync {
17 async fn add_memory(&self, fragment: MemoryFragment) -> Result<Thing>;
20
21 async fn get_memory(&self, id: Thing) -> Result<Option<MemoryFragment>>;
23
24 async fn find_similar(
27 &self, embedding: &[f32], top_k: usize,
28 ) -> Result<Vec<(MemoryFragment, f32)>>;
29
30 async fn record_access(&self, id: Thing) -> Result<()>;
32}
33
34#[derive(Clone)]
36pub struct SurrealStore {
37 db: Arc<Surreal<Db>>,
38}
39
40impl SurrealStore {
41 pub async fn connect(
45 db_path: impl AsRef<Path>, namespace: &str, database: &str,
46 ) -> Result<Self> {
47 if let Some(parent) = db_path.as_ref().parent() {
49 tokio::fs::create_dir_all(parent).await.map_err(|e| {
50 LLMBrainError::InitializationError(format!(
51 "Failed to create database directory '{}': {}",
52 parent.display(),
53 e
54 ))
55 })?;
56 }
57
58 let db = Surreal::new::<SurrealKv>(db_path.as_ref()).await?;
60
61 db.use_ns(namespace).use_db(database).await?;
63
64 Self::initialize_database(&db).await?;
66
67 Ok(Self { db: Arc::new(db) })
68 }
69
70 async fn initialize_database(db: &Surreal<Db>) -> Result<()> {
72 let _ = db
74 .query(
75 r#"
76 DEFINE TABLE memory SCHEMAFULL;
77 DEFINE FIELD content ON TABLE memory TYPE string;
78 DEFINE FIELD embedding ON TABLE memory TYPE array<float>;
79 DEFINE FIELD metadata ON TABLE memory TYPE object;
80 DEFINE FIELD created_at ON TABLE memory TYPE datetime DEFAULT time::now();
81 DEFINE FIELD last_accessed_at ON TABLE memory TYPE option<datetime>;
82 "#,
83 )
84 .await?;
85
86 let _ = db
88 .query(
89 r#"
90 DEFINE INDEX memory_embedding_index ON TABLE memory FIELDS embedding
91 MTREE DIMENSION 1536 DISTANCE COSINE;
92 "#,
93 )
94 .await?;
95
96 let _ = db
98 .query(
99 r#"
100 DEFINE INDEX memory_created_at_index ON TABLE memory FIELDS created_at;
101 "#,
102 )
103 .await?;
104
105 Ok(())
106 }
107}
108
109#[async_trait]
110impl Store for SurrealStore {
111 async fn add_memory(&self, fragment: MemoryFragment) -> Result<Thing> {
112 #[derive(Deserialize, Debug)]
113 struct CreatedId {
114 id: Thing,
115 }
116
117 let mut response = self
120 .db
121 .query(
122 "CREATE memory SET content = $content, embedding = $embedding, metadata = $metadata RETURN id"
123 )
124 .bind(("content", fragment.content.clone())) .bind(("embedding", fragment.embedding.clone())) .bind(("metadata", fragment.metadata.clone())) .await?;
129
130 let created: Option<CreatedId> = response.take(0)?;
131 match created {
132 Some(created) => Ok(created.id),
133 None => {
134 Err(LLMBrainError::DatabaseError(Box::new(
135 surrealdb::Error::Api(surrealdb::error::Api::Query(
136 "No record returned after creation".into(),
137 )),
138 )))
139 }
140 }
141 }
142
143 async fn get_memory(&self, id: Thing) -> Result<Option<MemoryFragment>> {
144 let mut response = self
146 .db
147 .query("SELECT * FROM $id")
148 .bind(("id", id)) .await?;
150
151 let memory: Option<MemoryFragment> = response.take(0)?;
152 Ok(memory)
153 }
154
155 async fn find_similar(
156 &self, embedding: &[f32], top_k: usize,
157 ) -> Result<Vec<(MemoryFragment, f32)>> {
158 let query = "SELECT
160 content,
161 embedding,
162 metadata,
163 created_at,
164 last_accessed_at,
165 vector::similarity::cosine(embedding, $embedding) as similarity
166 FROM memory
167 ORDER BY similarity DESC
168 LIMIT $limit";
169
170 #[derive(Deserialize, Debug)]
172 struct QueryResult {
173 content: String,
174 embedding: Vec<f32>,
175 metadata: JsonValue,
176 created_at: JsonValue, last_accessed_at: JsonValue, similarity: f32,
179 }
180
181 let embedding_vec = embedding.to_vec();
182 let mut response = self
183 .db
184 .query(query)
185 .bind(("embedding", embedding_vec))
186 .bind(("limit", top_k))
187 .await?;
188
189 let results: Vec<QueryResult> = response.take(0)?;
190
191 let scored_fragments = results
193 .into_iter()
194 .map(|res| {
195 let fragment = MemoryFragment {
196 content: res.content,
197 embedding: res.embedding,
198 metadata: res.metadata,
199 created_at: res.created_at,
200 last_accessed_at: res.last_accessed_at,
201 };
202 (fragment, res.similarity.clamp(0.0, 1.0))
203 })
204 .collect();
205
206 Ok(scored_fragments)
207 }
208
209 async fn record_access(&self, id: Thing) -> Result<()> {
210 let _ = self
212 .db
213 .query("UPDATE $id SET last_accessed_at = time::now()")
214 .bind(("id", id)) .await?;
216 Ok(())
217 }
218}