1pub mod conceptnet;
8pub mod config;
9pub mod db;
10pub mod embeddings;
11pub mod error;
12pub mod llm;
13pub mod models;
14
15use std::str::FromStr;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use conceptnet::ConceptNetClient;
21pub use config::Config;
22pub use db::{Store, SurrealStore};
23pub use embeddings::{
24 ChunkingStrategy, EmbeddingMiddleware, EmbeddingModelConfig, EmbeddingProvider, LongTextHandler,
25};
26pub use error::{LLMBrainError, Result};
27use llm::OpenAiClient;
28pub use models::MemoryFragment;
29
30#[derive(Clone)]
35pub struct LLMBrain {
36 _config: Arc<Config>,
37 db: Arc<dyn Store>,
38 embedding_client: Arc<dyn EmbeddingProvider>,
39 _conceptnet_client: Option<Arc<ConceptNetClient>>,
40}
41
42impl LLMBrain {
43 pub async fn launch() -> Result<Self> {
52 Config::load()?;
54 let config = Config::get();
55
56 let db_store = SurrealStore::connect(
58 &config.database.path,
59 &config.database.namespace,
60 &config.database.database,
61 )
62 .await?;
63
64 let llm_client = OpenAiClient::new(config.llm.as_ref())?;
66
67 let embedding_client = EmbeddingMiddleware::new(llm_client, true);
69
70 EmbeddingMiddleware::<OpenAiClient>::initialize_tokenizer()?;
72
73 let conceptnet_client = config
75 .conceptnet
76 .as_ref()
77 .map(|cn_config| -> Result<_> { ConceptNetClient::new(Some(cn_config)) })
78 .transpose()?
79 .map(Arc::new);
80
81 Ok(Self {
82 _config: Arc::new(config.clone()),
83 db: Arc::new(db_store),
84 embedding_client: Arc::new(embedding_client),
85 _conceptnet_client: conceptnet_client,
86 })
87 }
88
89 pub async fn add_memory(
98 &self, content: String, metadata: serde_json::Value,
99 ) -> Result<surrealdb::sql::Thing> {
100 let embedding = self.embedding_client.generate_embedding(&content).await?;
102
103 let fragment = MemoryFragment::new(content, embedding, metadata);
105
106 self.db.add_memory(fragment).await
108 }
109
110 pub async fn recall(
118 &self, query_text: &str, top_k: usize,
119 ) -> Result<Vec<(MemoryFragment, f32)>> {
120 let query_embedding = self.embedding_client.generate_embedding(query_text).await?;
122
123 self.db.find_similar(&query_embedding, top_k).await
125 }
126
127 pub async fn process_long_text(
129 &self, text: &str, strategy: Option<ChunkingStrategy>,
130 ) -> Result<Vec<f32>> {
131 let strategy = strategy.unwrap_or_default();
133
134 let embedding_provider =
138 Box::new(CloneableEmbeddingProvider(self.embedding_client.clone()));
139 let model_config = EmbeddingModelConfig::default();
140 let handler = LongTextHandler::new(embedding_provider, model_config, strategy);
141
142 handler.process_embeddings(text).await
143 }
144
145 pub async fn get_memory_by_id_string(&self, id_string: &str) -> Result<Option<MemoryFragment>> {
149 let thing_id = surrealdb::sql::Thing::from_str(id_string)
151 .map_err(|_| LLMBrainError::NotFoundError(format!("Invalid ID format: {id_string}")))?;
152 self.db.get_memory(thing_id.clone()).await
154 }
155}
156
157struct CloneableEmbeddingProvider(Arc<dyn EmbeddingProvider>);
159
160unsafe impl Send for CloneableEmbeddingProvider {}
164unsafe impl Sync for CloneableEmbeddingProvider {}
165
166impl Clone for CloneableEmbeddingProvider {
167 fn clone(&self) -> Self {
168 Self(self.0.clone())
169 }
170}
171
172#[async_trait]
173impl EmbeddingProvider for CloneableEmbeddingProvider {
174 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
175 self.0.generate_embedding(text).await
176 }
177
178 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
179 self.0.generate_embeddings(texts).await
180 }
181
182 fn count_tokens(&self, text: &str) -> Result<usize> {
183 self.0.count_tokens(text)
184 }
185}
186
187#[async_trait]
189impl EmbeddingProvider for Box<CloneableEmbeddingProvider> {
190 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
191 (**self).generate_embedding(text).await
192 }
193
194 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
195 (**self).generate_embeddings(texts).await
196 }
197
198 fn count_tokens(&self, text: &str) -> Result<usize> {
199 (**self).count_tokens(text)
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[tokio::test]
208 #[ignore] async fn test_launch_and_load_config() {
210 if !std::path::Path::new("config").exists() {
212 std::fs::create_dir("config").unwrap();
213 }
214 if !std::path::Path::new("config/default.toml").exists() {
215 std::fs::write(
216 "config/default.toml",
217 r#"
218[database]
219path = "./llm_brain_test.db"
220namespace = "test_ns"
221database = "test_db"
222
223[llm]
224embedding_model = "text-embedding-3-small"
225 "#,
226 )
227 .unwrap();
228 }
229
230 let llm_brain_instance = LLMBrain::launch().await;
231 assert!(llm_brain_instance.is_ok());
232
233 std::fs::remove_file("config/default.toml").unwrap_or_default();
235 std::fs::remove_dir("config").unwrap_or_default();
236 tokio::fs::remove_file("./llm_brain_test.db")
237 .await
238 .unwrap_or_default();
239
240 let db_path = std::path::PathBuf::from("./llm_brain_test.db");
242 if let Some(parent) = db_path.parent() {
243 if parent.is_dir() {
244 tokio::fs::remove_dir_all(parent).await.unwrap_or_default();
245 }
246 }
247 }
248}