llm_brain/
db.rs

1use std::path::Path;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::Value as JsonValue;
7use surrealdb::Surreal;
8use surrealdb::engine::local::{Db, SurrealKv};
9use surrealdb::sql::Thing;
10
11use crate::error::{LLMBrainError, Result};
12use crate::models::MemoryFragment;
13
14/// Trait defining the interface for the memory storage layer.
15#[async_trait]
16pub trait Store: Send + Sync {
17    /// Adds a new memory fragment to the store.
18    /// Returns the ID (Thing) of the newly created fragment.
19    async fn add_memory(&self, fragment: MemoryFragment) -> Result<Thing>;
20
21    /// Retrieves a memory fragment by its unique ID.
22    async fn get_memory(&self, id: Thing) -> Result<Option<MemoryFragment>>;
23
24    /// Finds the top_k most similar memory fragments to the given embedding.
25    /// Returns a list of (MemoryFragment, similarity_score) tuples.
26    async fn find_similar(
27        &self, embedding: &[f32], top_k: usize,
28    ) -> Result<Vec<(MemoryFragment, f32)>>;
29
30    /// Updates the last_accessed_at timestamp for a given memory fragment.
31    async fn record_access(&self, id: Thing) -> Result<()>;
32}
33
34/// Concrete implementation of the Store trait using `SurrealDB`.
35#[derive(Clone)]
36pub struct SurrealStore {
37    db: Arc<Surreal<Db>>,
38}
39
40impl SurrealStore {
41    /// Establishes a connection to the embedded `SurrealDB` (`SurrealKV`)
42    /// database at the specified path and selects the namespace and
43    /// database.
44    pub async fn connect(
45        db_path: impl AsRef<Path>, namespace: &str, database: &str,
46    ) -> Result<Self> {
47        // Ensure the parent directory exists
48        if let Some(parent) = db_path.as_ref().parent() {
49            tokio::fs::create_dir_all(parent).await.map_err(|e| {
50                LLMBrainError::InitializationError(format!(
51                    "Failed to create database directory '{}': {}",
52                    parent.display(),
53                    e
54                ))
55            })?;
56        }
57
58        // Configure and open the SurrealKV engine
59        let db = Surreal::new::<SurrealKv>(db_path.as_ref()).await?;
60
61        // Select the namespace and database
62        db.use_ns(namespace).use_db(database).await?;
63
64        // Perform any necessary initialization
65        Self::initialize_database(&db).await?;
66
67        Ok(Self { db: Arc::new(db) })
68    }
69
70    /// Performs initial database setup like defining tables and vector indexes.
71    async fn initialize_database(db: &Surreal<Db>) -> Result<()> {
72        // Define the 'memory' table schema
73        let _ = db
74            .query(
75                r#"
76            DEFINE TABLE memory SCHEMAFULL;
77            DEFINE FIELD content ON TABLE memory TYPE string;
78            DEFINE FIELD embedding ON TABLE memory TYPE array<float>;
79            DEFINE FIELD metadata ON TABLE memory TYPE object;
80            DEFINE FIELD created_at ON TABLE memory TYPE datetime DEFAULT time::now();
81            DEFINE FIELD last_accessed_at ON TABLE memory TYPE option<datetime>;
82            "#,
83            )
84            .await?;
85
86        // Define a vector index on the embedding field for efficient similarity search
87        let _ = db
88            .query(
89                r#"
90            DEFINE INDEX memory_embedding_index ON TABLE memory FIELDS embedding
91                MTREE DIMENSION 1536 DISTANCE COSINE;
92            "#,
93            )
94            .await?;
95
96        // Define an index on created_at for potential time-based queries
97        let _ = db
98            .query(
99                r#"
100            DEFINE INDEX memory_created_at_index ON TABLE memory FIELDS created_at;
101            "#,
102            )
103            .await?;
104
105        Ok(())
106    }
107}
108
109#[async_trait]
110impl Store for SurrealStore {
111    async fn add_memory(&self, fragment: MemoryFragment) -> Result<Thing> {
112        #[derive(Deserialize, Debug)]
113        struct CreatedId {
114            id: Thing,
115        }
116
117        // Rely on database default for created_at and initial last_accessed_at.
118        // Bind CLONES of content, embedding, and metadata.
119        let mut response = self
120            .db
121            .query(
122                "CREATE memory SET content = $content, embedding = $embedding, metadata = $metadata RETURN id"
123            )
124            .bind(("content", fragment.content.clone())) // Clone
125            .bind(("embedding", fragment.embedding.clone())) // Clone
126            .bind(("metadata", fragment.metadata.clone())) // Clone
127            // Do not bind created_at or last_accessed_at
128            .await?;
129
130        let created: Option<CreatedId> = response.take(0)?;
131        match created {
132            Some(created) => Ok(created.id),
133            None => {
134                Err(LLMBrainError::DatabaseError(Box::new(
135                    surrealdb::Error::Api(surrealdb::error::Api::Query(
136                        "No record returned after creation".into(),
137                    )),
138                )))
139            }
140        }
141    }
142
143    async fn get_memory(&self, id: Thing) -> Result<Option<MemoryFragment>> {
144        // Bind the owned Thing directly
145        let mut response = self
146            .db
147            .query("SELECT * FROM $id")
148            .bind(("id", id)) // Bind owned Thing
149            .await?;
150
151        let memory: Option<MemoryFragment> = response.take(0)?;
152        Ok(memory)
153    }
154
155    async fn find_similar(
156        &self, embedding: &[f32], top_k: usize,
157    ) -> Result<Vec<(MemoryFragment, f32)>> {
158        // Explicitly select fields and calculate similarity
159        let query = "SELECT
160                        content,
161                        embedding,
162                        metadata,
163                        created_at,
164                        last_accessed_at,
165                        vector::similarity::cosine(embedding, $embedding) as similarity
166                    FROM memory
167                    ORDER BY similarity DESC
168                    LIMIT $limit";
169
170        // Define QueryResult matching the explicit SELECT list
171        #[derive(Deserialize, Debug)]
172        struct QueryResult {
173            content: String,
174            embedding: Vec<f32>,
175            metadata: JsonValue,
176            created_at: JsonValue,       // Expect JsonValue from DB datetime
177            last_accessed_at: JsonValue, // Expect JsonValue from DB option<datetime>
178            similarity: f32,
179        }
180
181        let embedding_vec = embedding.to_vec();
182        let mut response = self
183            .db
184            .query(query)
185            .bind(("embedding", embedding_vec))
186            .bind(("limit", top_k))
187            .await?;
188
189        let results: Vec<QueryResult> = response.take(0)?;
190
191        // Manually reconstruct MemoryFragment from QueryResult fields
192        let scored_fragments = results
193            .into_iter()
194            .map(|res| {
195                let fragment = MemoryFragment {
196                    content: res.content,
197                    embedding: res.embedding,
198                    metadata: res.metadata,
199                    created_at: res.created_at,
200                    last_accessed_at: res.last_accessed_at,
201                };
202                (fragment, res.similarity.clamp(0.0, 1.0))
203            })
204            .collect();
205
206        Ok(scored_fragments)
207    }
208
209    async fn record_access(&self, id: Thing) -> Result<()> {
210        // Bind the owned Thing directly
211        let _ = self
212            .db
213            .query("UPDATE $id SET last_accessed_at = time::now()")
214            .bind(("id", id)) // Bind owned Thing
215            .await?;
216        Ok(())
217    }
218}