1use crate::context::ContextBuilder;
4use crate::core::{Config, MemoryError, MemoryId, MemoryResult};
5use crate::embedding::EmbeddingGenerator;
6use crate::ingest::{MemoryProcessor, ProcessorConfig, ProcessingResult};
7use crate::models::{Conversation, Fact, ImageMetadata, Memory, TranscriptionMetadata, UserModel, UserRole};
8use crate::retrieval::{MemorySearcher, SearchOptions, SearchResult};
9use crate::storage::{PostgresStorage, RedisCache, StoragePool};
10use crate::user::{Permission, PermissionChecker, UserManager};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15pub struct MemoryEngine {
17 config: Config,
18 storage_pool: Arc<StoragePool>,
19 postgres: Arc<PostgresStorage>,
20 redis: Arc<RedisCache>,
21 processor: Arc<MemoryProcessor>,
22 searcher: Arc<MemorySearcher>,
23 embedding_generator: Arc<EmbeddingGenerator>,
24 user_manager: Arc<UserManager>,
25 context_builder: Arc<ContextBuilder>,
26 initialized: Arc<RwLock<bool>>,
27}
28
29impl MemoryEngine {
30 pub async fn new(config: Config) -> MemoryResult<Self> {
32 config
33 .validate()
34 .map_err(|e| MemoryError::Config(config::ConfigError::Message(e)))?;
35
36 let storage_pool = Arc::new(StoragePool::new(&config).await?);
37 let postgres = Arc::new(PostgresStorage::new(storage_pool.postgres.clone()));
38 let redis = Arc::new(RedisCache::new(storage_pool.redis.clone()));
39
40 let embedding_generator = Arc::new(
42 EmbeddingGenerator::with_model(&config.embedding_model_path, config.embedding_dimension)
43 .unwrap_or_else(|e| {
44 warn!(error = %e, "Failed to load ONNX model, using fallback embedding generator");
45 EmbeddingGenerator::new().with_cache_size(config.cache_size)
46 })
47 );
48
49 let _processor_config = ProcessorConfig {
51 generate_embeddings: true,
52 extract_facts: true,
53 min_fact_confidence: 0.5,
54 ephemeral_fact_ttl_days: config.ephemeral_fact_ttl_days,
55 max_content_length: 100_000,
56 };
57 let processor = Arc::new(MemoryProcessor::with_embedding_generator(embedding_generator.clone()));
58
59 let searcher = Arc::new(
61 MemorySearcher::with_embedding_generator(embedding_generator.clone())
62 .with_threshold(config.similarity_threshold)
63 .with_limit(config.search_top_k)
64 );
65
66 let context_builder = Arc::new(ContextBuilder::new(config.max_context_tokens));
67
68 Ok(Self {
69 config,
70 storage_pool,
71 postgres,
72 redis,
73 processor,
74 searcher,
75 embedding_generator,
76 user_manager: Arc::new(UserManager::new()),
77 context_builder,
78 initialized: Arc::new(RwLock::new(false)),
79 })
80 }
81
82 pub async fn initialize(&self) -> MemoryResult<()> {
84 info!("Initializing memory engine...");
85
86 self.storage_pool.health_check().await?;
88 debug!("PostgreSQL connection verified");
89
90 match self.redis.ping().await {
92 Ok(true) => debug!("Redis connection verified"),
93 Ok(false) => warn!("Redis ping returned unexpected response"),
94 Err(e) => warn!(error = %e, "Redis connection check failed, cache may not work"),
95 }
96
97 let mut init = self.initialized.write().await;
98 *init = true;
99
100 info!("Memory engine initialized successfully");
101 Ok(())
102 }
103
104 pub async fn is_initialized(&self) -> bool {
106 *self.initialized.read().await
107 }
108
109 async fn ensure_initialized(&self) -> MemoryResult<()> {
111 if !self.is_initialized().await {
112 return Err(MemoryError::Generic(
113 "Engine not initialized. Call initialize() first.".to_string(),
114 ));
115 }
116 Ok(())
117 }
118
119 pub async fn create_user(
123 &self,
124 username: String,
125 email: String,
126 role: UserRole,
127 ) -> MemoryResult<UserModel> {
128 crate::utils::Validator::validate_username(&username)?;
130 crate::utils::Validator::validate_email(&email)?;
131
132 let user = self.user_manager.create_user(username, email, role).await?;
134
135 self.postgres.save_user(&user).await?;
137
138 self.log_audit(
140 Some(&user.id),
141 "user_created",
142 "user",
143 Some(&user.id),
144 None,
145 ).await;
146
147 info!(username = %user.username, role = ?user.role, "User created");
148 Ok(user)
149 }
150
151 pub async fn get_user(&self, user_id: &MemoryId) -> MemoryResult<Option<UserModel>> {
153 self.postgres.get_user_by_id(user_id).await
154 }
155
156 pub async fn get_user_by_username(&self, username: &str) -> MemoryResult<Option<UserModel>> {
158 self.postgres.get_user_by_username(username).await
159 }
160
161 pub fn check_permission(&self, user: &UserModel, permission: Permission) -> bool {
163 PermissionChecker::can(user, permission)
164 }
165
166 pub async fn store_memory(&self, user_id: &str, memory: Memory) -> MemoryResult<ProcessingResult> {
170 self.ensure_initialized().await?;
171
172 let user_id: MemoryId = user_id
173 .parse()
174 .map_err(|_| MemoryError::InvalidData("Invalid user ID format".to_string()))?;
175
176 let result = self.processor
178 .process_and_store(&self.postgres, &user_id, memory.clone())
179 .await?;
180
181 if let Err(e) = self.redis.invalidate_user_cache(&user_id.to_string()).await {
183 warn!(error = %e, "Failed to invalidate user cache");
184 }
185
186 self.log_audit(
188 Some(&user_id),
189 "memory_created",
190 memory.memory_type().as_str(),
191 Some(&result.memory_id),
192 None,
193 ).await;
194
195 Ok(result)
196 }
197
198 pub async fn store_conversation(
200 &self,
201 user_id: &MemoryId,
202 conversation: Conversation,
203 ) -> MemoryResult<ProcessingResult> {
204 self.ensure_initialized().await?;
205
206 let memory = Memory::Conversation(conversation);
207 self.processor
208 .process_and_store(&self.postgres, user_id, memory)
209 .await
210 }
211
212 pub async fn store_fact(&self, user_id: &MemoryId, fact: Fact) -> MemoryResult<ProcessingResult> {
214 self.ensure_initialized().await?;
215
216 let memory = Memory::Fact(fact);
217 self.processor
218 .process_and_store(&self.postgres, user_id, memory)
219 .await
220 }
221
222 pub async fn store_image(
224 &self,
225 user_id: &MemoryId,
226 image: ImageMetadata,
227 ) -> MemoryResult<ProcessingResult> {
228 self.ensure_initialized().await?;
229
230 self.processor
231 .process_image(&self.postgres, user_id, image)
232 .await
233 }
234
235 pub async fn store_transcription(
237 &self,
238 user_id: &MemoryId,
239 transcription: TranscriptionMetadata,
240 ) -> MemoryResult<ProcessingResult> {
241 self.ensure_initialized().await?;
242
243 self.processor
244 .process_transcription(&self.postgres, user_id, transcription)
245 .await
246 }
247
248 pub async fn search(
252 &self,
253 user_id: &str,
254 query: &str,
255 limit: usize,
256 ) -> MemoryResult<Vec<Memory>> {
257 self.ensure_initialized().await?;
258
259 let user_id: MemoryId = user_id
260 .parse()
261 .map_err(|_| MemoryError::InvalidData("Invalid user ID format".to_string()))?;
262
263 let options = SearchOptions::default()
264 .with_limit(limit)
265 .with_threshold(self.config.similarity_threshold);
266
267 let results = self.searcher
268 .search_with_storage(&self.postgres, &user_id, query, options)
269 .await?;
270
271 Ok(results.into_iter().map(|r| r.memory).collect())
272 }
273
274 pub async fn search_with_options(
276 &self,
277 user_id: &MemoryId,
278 query: &str,
279 options: SearchOptions,
280 ) -> MemoryResult<Vec<SearchResult>> {
281 self.ensure_initialized().await?;
282
283 self.searcher
284 .search_with_storage(&self.postgres, user_id, query, options)
285 .await
286 }
287
288 pub async fn hybrid_search(
290 &self,
291 user_id: &MemoryId,
292 query: &str,
293 options: SearchOptions,
294 semantic_weight: f32,
295 ) -> MemoryResult<Vec<SearchResult>> {
296 self.ensure_initialized().await?;
297
298 self.searcher
299 .hybrid_search(&self.postgres, user_id, query, options, semantic_weight)
300 .await
301 }
302
303 pub async fn find_similar(
305 &self,
306 user_id: &MemoryId,
307 memory: &Memory,
308 limit: usize,
309 ) -> MemoryResult<Vec<SearchResult>> {
310 self.ensure_initialized().await?;
311
312 self.searcher
313 .find_similar(&self.postgres, user_id, memory, limit)
314 .await
315 }
316
317 pub async fn get_conversation(&self, id: &MemoryId) -> MemoryResult<Option<Conversation>> {
321 self.postgres.get_conversation(id).await
322 }
323
324 pub async fn list_conversations(
326 &self,
327 user_id: &MemoryId,
328 limit: i64,
329 offset: i64,
330 ) -> MemoryResult<Vec<Conversation>> {
331 self.postgres.list_conversations(user_id, limit, offset).await
332 }
333
334 pub async fn delete_conversation(&self, id: &MemoryId) -> MemoryResult<bool> {
336 let result = self.postgres.delete_conversation(id).await?;
337
338 self.log_audit(
339 None,
340 "memory_deleted",
341 "conversation",
342 Some(id),
343 None,
344 ).await;
345
346 Ok(result)
347 }
348
349 pub async fn get_fact(&self, id: &MemoryId) -> MemoryResult<Option<Fact>> {
353 self.postgres.get_fact(id).await
354 }
355
356 pub async fn search_facts(
358 &self,
359 user_id: &MemoryId,
360 query: &str,
361 limit: i64,
362 ) -> MemoryResult<Vec<Fact>> {
363 self.postgres.search_facts_by_content(user_id, query, limit).await
364 }
365
366 pub async fn cleanup_expired_facts(&self) -> MemoryResult<u64> {
368 let count = self.postgres.delete_expired_facts().await?;
369 info!(count = count, "Expired facts cleaned up");
370 Ok(count)
371 }
372
373 pub async fn get_image(&self, id: &MemoryId) -> MemoryResult<Option<ImageMetadata>> {
377 self.postgres.get_image(id).await
378 }
379
380 pub async fn list_images(
382 &self,
383 user_id: &MemoryId,
384 limit: i64,
385 offset: i64,
386 ) -> MemoryResult<Vec<ImageMetadata>> {
387 self.postgres.list_images(user_id, limit, offset).await
388 }
389
390 pub async fn get_transcription(&self, id: &MemoryId) -> MemoryResult<Option<TranscriptionMetadata>> {
394 self.postgres.get_transcription(id).await
395 }
396
397 pub async fn build_context(
401 &self,
402 user_id: &MemoryId,
403 query: &str,
404 limit: usize,
405 ) -> MemoryResult<String> {
406 let options = SearchOptions::default()
407 .with_limit(limit)
408 .with_threshold(self.config.similarity_threshold);
409
410 let results = self.searcher
411 .search_with_storage(&self.postgres, user_id, query, options)
412 .await?;
413
414 let memories: Vec<Memory> = results.into_iter().map(|r| r.memory).collect();
415 self.context_builder.build_from_memories(memories, query)
416 }
417
418 pub fn build_system_prompt(&self, instruction: &str) -> String {
420 ContextBuilder::build_system_prompt(instruction)
421 }
422
423 pub fn storage(&self) -> &Arc<StoragePool> {
427 &self.storage_pool
428 }
429
430 pub fn postgres(&self) -> &Arc<PostgresStorage> {
432 &self.postgres
433 }
434
435 pub fn redis(&self) -> &Arc<RedisCache> {
437 &self.redis
438 }
439
440 pub fn users(&self) -> &Arc<UserManager> {
442 &self.user_manager
443 }
444
445 pub fn embeddings(&self) -> &Arc<EmbeddingGenerator> {
447 &self.embedding_generator
448 }
449
450 pub fn config(&self) -> &Config {
452 &self.config
453 }
454
455 pub async fn generate_embedding(&self, text: &str) -> MemoryResult<Vec<f32>> {
457 self.embedding_generator.generate(text).await
458 }
459
460 pub async fn compute_similarity(&self, text1: &str, text2: &str) -> MemoryResult<f32> {
462 self.searcher.compute_similarity(text1, text2).await
463 }
464
465 pub async fn health_check(&self) -> MemoryResult<HealthStatus> {
467 let postgres_ok = self.storage_pool.health_check().await.is_ok();
468 let redis_ok = self.redis.ping().await.unwrap_or(false);
469 let initialized = self.is_initialized().await;
470
471 Ok(HealthStatus {
472 initialized,
473 postgres_connected: postgres_ok,
474 redis_connected: redis_ok,
475 embedding_model_loaded: self.embedding_generator.dimension() > 0,
476 })
477 }
478
479 async fn log_audit(
481 &self,
482 user_id: Option<&MemoryId>,
483 action: &str,
484 resource_type: &str,
485 resource_id: Option<&MemoryId>,
486 details: Option<&serde_json::Value>,
487 ) {
488 if let Err(e) = self.postgres
489 .log_audit(user_id, action, resource_type, resource_id, details, None)
490 .await
491 {
492 warn!(error = %e, "Failed to log audit event");
493 }
494 }
495}
496
497#[derive(Debug, Clone)]
499pub struct HealthStatus {
500 pub initialized: bool,
502 pub postgres_connected: bool,
504 pub redis_connected: bool,
506 pub embedding_model_loaded: bool,
508}
509
510impl HealthStatus {
511 pub fn is_healthy(&self) -> bool {
513 self.initialized && self.postgres_connected
514 }
515
516 pub fn is_fully_healthy(&self) -> bool {
518 self.initialized && self.postgres_connected && self.redis_connected && self.embedding_model_loaded
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525
526 #[tokio::test]
527 async fn test_memory_engine_creation() {
528 let config = Config::test_config();
529 let result = MemoryEngine::new(config).await;
530 assert!(result.is_err() || result.is_ok());
532 }
533
534 #[test]
535 fn test_health_status() {
536 let status = HealthStatus {
537 initialized: true,
538 postgres_connected: true,
539 redis_connected: true,
540 embedding_model_loaded: true,
541 };
542
543 assert!(status.is_healthy());
544 assert!(status.is_fully_healthy());
545
546 let partial_status = HealthStatus {
547 initialized: true,
548 postgres_connected: true,
549 redis_connected: false,
550 embedding_model_loaded: false,
551 };
552
553 assert!(partial_status.is_healthy());
554 assert!(!partial_status.is_fully_healthy());
555 }
556}