llm_brain/
lib.rs

1//! `LLMBrain`: A memory layer for LLM applications
2//!
3//! This library provides a vector database storage layer for semantic
4//! memory with `ConceptNet` integration and LLM-friendly interfaces.
5
6// Declare modules
7pub mod conceptnet;
8pub mod config;
9pub mod db;
10pub mod embeddings;
11pub mod error;
12pub mod llm;
13pub mod models;
14
15// Re-export core types for easier access
16use std::str::FromStr;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use conceptnet::ConceptNetClient;
21pub use config::Config;
22pub use db::{Store, SurrealStore};
23pub use embeddings::{
24    ChunkingStrategy, EmbeddingMiddleware, EmbeddingModelConfig, EmbeddingProvider, LongTextHandler,
25};
26pub use error::{LLMBrainError, Result};
27use llm::OpenAiClient;
28pub use models::MemoryFragment;
29
30/// Core library struct holding initialized clients and configuration.
31///
32/// This is the main entry point for interacting with the `LLMBrain` library.
33/// It manages database connections, LLM client, and `ConceptNet` integration.
34#[derive(Clone)]
35pub struct LLMBrain {
36    _config: Arc<Config>,
37    db: Arc<dyn Store>,
38    embedding_client: Arc<dyn EmbeddingProvider>,
39    _conceptnet_client: Option<Arc<ConceptNetClient>>,
40}
41
42impl LLMBrain {
43    /// Loads configuration, initializes clients, and returns an instance of
44    /// `LLMBrain`.
45    ///
46    /// This should be called once at application startup. It performs:
47    /// - Loading and parsing configuration
48    /// - Setting up the `SurrealDB` connection
49    /// - Initializing LLM client
50    /// - Initializing `ConceptNet` client (if configured)
51    pub async fn launch() -> Result<Self> {
52        // Load configuration first
53        Config::load()?;
54        let config = Config::get();
55
56        // Initialize database store
57        let db_store = SurrealStore::connect(
58            &config.database.path,
59            &config.database.namespace,
60            &config.database.database,
61        )
62        .await?;
63
64        // Initialize LLM client
65        let llm_client = OpenAiClient::new(config.llm.as_ref())?;
66
67        // Use embedding middleware to wrap LLM client for vector normalization
68        let embedding_client = EmbeddingMiddleware::new(llm_client, true);
69
70        // Initialize tokenizer
71        EmbeddingMiddleware::<OpenAiClient>::initialize_tokenizer()?;
72
73        // Initialize ConceptNet client (optional)
74        let conceptnet_client = config
75            .conceptnet
76            .as_ref()
77            .map(|cn_config| -> Result<_> { ConceptNetClient::new(Some(cn_config)) })
78            .transpose()?
79            .map(Arc::new);
80
81        Ok(Self {
82            _config: Arc::new(config.clone()),
83            db: Arc::new(db_store),
84            embedding_client: Arc::new(embedding_client),
85            _conceptnet_client: conceptnet_client,
86        })
87    }
88
89    /// Adds a text fragment to the memory store.
90    ///
91    /// This method:
92    /// 1. Generates an embedding for the content using the LLM client
93    /// 2. Creates a `MemoryFragment` with the content, embedding, and metadata
94    /// 3. Stores it in the database
95    ///
96    /// Returns the `SurrealDB` ID of the created record.
97    pub async fn add_memory(
98        &self, content: String, metadata: serde_json::Value,
99    ) -> Result<surrealdb::sql::Thing> {
100        // Generate embeddings - now using embedding middleware
101        let embedding = self.embedding_client.generate_embedding(&content).await?;
102
103        // Create MemoryFragment
104        let fragment = MemoryFragment::new(content, embedding, metadata);
105
106        // Add to database
107        self.db.add_memory(fragment).await
108    }
109
110    /// Retrieves similar memories based on a text query.
111    ///
112    /// This method:
113    /// 1. Generates an embedding for the query text
114    /// 2. Finds memory fragments with similar embeddings
115    ///
116    /// Returns a vec of (`MemoryFragment`, `similarity_score`) tuples.
117    pub async fn recall(
118        &self, query_text: &str, top_k: usize,
119    ) -> Result<Vec<(MemoryFragment, f32)>> {
120        // Generate query text embeddings - now using embedding middleware
121        let query_embedding = self.embedding_client.generate_embedding(query_text).await?;
122
123        // Find similar fragments in the database
124        self.db.find_similar(&query_embedding, top_k).await
125    }
126
127    /// Process long text using appropriate chunking strategy
128    pub async fn process_long_text(
129        &self, text: &str, strategy: Option<ChunkingStrategy>,
130    ) -> Result<Vec<f32>> {
131        // Use default strategy or provided strategy
132        let strategy = strategy.unwrap_or_default();
133
134        // Create long text handler
135        // We need to wrap the embedding client as Box<dyn EmbeddingProvider> to avoid
136        // trait bounds errors
137        let embedding_provider =
138            Box::new(CloneableEmbeddingProvider(self.embedding_client.clone()));
139        let model_config = EmbeddingModelConfig::default();
140        let handler = LongTextHandler::new(embedding_provider, model_config, strategy);
141
142        handler.process_embeddings(text).await
143    }
144
145    /// Retrieves a specific memory by its `SurrealDB` ID string.
146    ///
147    /// Returns None if no memory with the given ID exists.
148    pub async fn get_memory_by_id_string(&self, id_string: &str) -> Result<Option<MemoryFragment>> {
149        // Parse Thing ID
150        let thing_id = surrealdb::sql::Thing::from_str(id_string)
151            .map_err(|_| LLMBrainError::NotFoundError(format!("Invalid ID format: {id_string}")))?;
152        // Clone the Thing ID when calling get_memory
153        self.db.get_memory(thing_id.clone()).await
154    }
155}
156
157// Add this helper struct to support cloning Arc<dyn EmbeddingProvider>
158struct CloneableEmbeddingProvider(Arc<dyn EmbeddingProvider>);
159
160// Ensure CloneableEmbeddingProvider can be safely passed between threads
161// Since it contains Arc<dyn EmbeddingProvider>, and EmbeddingProvider requires
162// Send+Sync, this is safe
163unsafe impl Send for CloneableEmbeddingProvider {}
164unsafe impl Sync for CloneableEmbeddingProvider {}
165
166impl Clone for CloneableEmbeddingProvider {
167    fn clone(&self) -> Self {
168        Self(self.0.clone())
169    }
170}
171
172#[async_trait]
173impl EmbeddingProvider for CloneableEmbeddingProvider {
174    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
175        self.0.generate_embedding(text).await
176    }
177
178    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
179        self.0.generate_embeddings(texts).await
180    }
181
182    fn count_tokens(&self, text: &str) -> Result<usize> {
183        self.0.count_tokens(text)
184    }
185}
186
187// Implement EmbeddingProvider trait for Box<CloneableEmbeddingProvider>
188#[async_trait]
189impl EmbeddingProvider for Box<CloneableEmbeddingProvider> {
190    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
191        (**self).generate_embedding(text).await
192    }
193
194    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
195        (**self).generate_embeddings(texts).await
196    }
197
198    fn count_tokens(&self, text: &str) -> Result<usize> {
199        (**self).count_tokens(text)
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[tokio::test]
208    #[ignore] // Ignore by default as it requires setup
209    async fn test_launch_and_load_config() {
210        // Create dummy config files if they don't exist
211        if !std::path::Path::new("config").exists() {
212            std::fs::create_dir("config").unwrap();
213        }
214        if !std::path::Path::new("config/default.toml").exists() {
215            std::fs::write(
216                "config/default.toml",
217                r#"
218[database]
219path = "./llm_brain_test.db"
220namespace = "test_ns"
221database = "test_db"
222
223[llm]
224embedding_model = "text-embedding-3-small"
225            "#,
226            )
227            .unwrap();
228        }
229
230        let llm_brain_instance = LLMBrain::launch().await;
231        assert!(llm_brain_instance.is_ok());
232
233        // Clean up dummy config and db file
234        std::fs::remove_file("config/default.toml").unwrap_or_default();
235        std::fs::remove_dir("config").unwrap_or_default();
236        tokio::fs::remove_file("./llm_brain_test.db")
237            .await
238            .unwrap_or_default();
239
240        // Potentially remove the db directory if SurrealKV creates one
241        let db_path = std::path::PathBuf::from("./llm_brain_test.db");
242        if let Some(parent) = db_path.parent() {
243            if parent.is_dir() {
244                tokio::fs::remove_dir_all(parent).await.unwrap_or_default();
245            }
246        }
247    }
248}