Skip to main content

mockforge_intelligence/intelligent_behavior/
memory.rs

1//! Vector memory store for long-term semantic memory
2//!
3//! This module provides persistent memory using vector embeddings for
4//! semantic search over past interactions.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use super::config::VectorStoreConfig;
11use super::embedding_client::{cosine_similarity, EmbeddingClient};
12use super::types::InteractionRecord;
13use mockforge_foundation::Result;
14
15/// Vector memory store for persistent, searchable interaction history
16pub struct VectorMemoryStore {
17    /// In-memory storage (session_id -> interactions)
18    storage: Arc<RwLock<HashMap<String, Vec<InteractionRecord>>>>,
19
20    /// Embedding client
21    embedding_client: Option<Arc<EmbeddingClient>>,
22
23    /// Configuration
24    config: VectorStoreConfig,
25}
26
27impl VectorMemoryStore {
28    /// Create a new vector memory store
29    pub fn new(config: VectorStoreConfig) -> Self {
30        let embedding_client = if config.enabled {
31            Some(Arc::new(EmbeddingClient::new(
32                config.embedding_provider.clone(),
33                config.embedding_model.clone(),
34                None,
35                None,
36            )))
37        } else {
38            None
39        };
40
41        Self {
42            storage: Arc::new(RwLock::new(HashMap::new())),
43            embedding_client,
44            config,
45        }
46    }
47
48    /// Store an interaction with semantic embedding
49    ///
50    /// # Arguments
51    /// * `session_id` - Session identifier
52    /// * `interaction` - Interaction to store
53    pub async fn store_interaction(
54        &self,
55        session_id: &str,
56        interaction: &InteractionRecord,
57    ) -> Result<()> {
58        let mut interaction_with_embedding = interaction.clone();
59
60        // Generate embedding if enabled
61        if let Some(ref client) = self.embedding_client {
62            let summary = interaction.summary();
63            match client.generate_embedding(&summary).await {
64                Ok(embedding) => {
65                    interaction_with_embedding.embedding = Some(embedding);
66                }
67                Err(e) => {
68                    tracing::warn!("Failed to generate embedding: {}", e);
69                    // Continue without embedding
70                }
71            }
72        }
73
74        let mut storage = self.storage.write().await;
75        storage
76            .entry(session_id.to_string())
77            .or_insert_with(Vec::new)
78            .push(interaction_with_embedding);
79
80        Ok(())
81    }
82
83    /// Retrieve relevant past interactions using semantic search
84    ///
85    /// # Arguments
86    /// * `session_id` - Session identifier
87    /// * `query` - Search query
88    /// * `limit` - Maximum number of results to return
89    pub async fn retrieve_context(
90        &self,
91        session_id: &str,
92        query: &str,
93        limit: usize,
94    ) -> Result<Vec<InteractionRecord>> {
95        let storage = self.storage.read().await;
96
97        let interactions = storage.get(session_id).cloned().unwrap_or_default();
98
99        // If no embedding client or no interactions, return recent ones
100        if self.embedding_client.is_none() || interactions.is_empty() {
101            return Ok(interactions.into_iter().rev().take(limit).collect());
102        }
103
104        // Generate embedding for query
105        let query_embedding = match &self.embedding_client {
106            Some(client) => match client.generate_embedding(query).await {
107                Ok(emb) => emb,
108                Err(e) => {
109                    tracing::warn!("Failed to generate query embedding: {}", e);
110                    return Ok(interactions.into_iter().rev().take(limit).collect());
111                }
112            },
113            None => return Ok(interactions.into_iter().rev().take(limit).collect()),
114        };
115
116        // Calculate similarity scores
117        let mut scored_interactions: Vec<(InteractionRecord, f32)> = interactions
118            .into_iter()
119            .filter_map(|interaction| {
120                interaction.embedding.as_ref().map(|emb| {
121                    let score = cosine_similarity(&query_embedding, emb);
122                    (interaction.clone(), score)
123                })
124            })
125            .collect();
126
127        // Filter by threshold and sort by score
128        scored_interactions.retain(|(_, score)| *score >= self.config.similarity_threshold);
129        scored_interactions
130            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
131
132        // Return top-k results
133        Ok(scored_interactions
134            .into_iter()
135            .take(limit)
136            .map(|(interaction, _)| interaction)
137            .collect())
138    }
139
140    /// Get all interactions for a session
141    pub async fn get_session_interactions(
142        &self,
143        session_id: &str,
144    ) -> Result<Vec<InteractionRecord>> {
145        let storage = self.storage.read().await;
146
147        Ok(storage.get(session_id).cloned().unwrap_or_default())
148    }
149
150    /// Clear all interactions for a session
151    pub async fn clear_session(&self, session_id: &str) -> Result<()> {
152        let mut storage = self.storage.write().await;
153        storage.remove(session_id);
154        Ok(())
155    }
156
157    /// Clear all stored interactions
158    pub async fn clear_all(&self) -> Result<()> {
159        let mut storage = self.storage.write().await;
160        storage.clear();
161        Ok(())
162    }
163}
164
165impl Default for VectorMemoryStore {
166    fn default() -> Self {
167        Self::new(VectorStoreConfig::default())
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[tokio::test]
176    async fn test_store_and_retrieve() {
177        let store = VectorMemoryStore::new(VectorStoreConfig::default());
178
179        let interaction = InteractionRecord::new(
180            "POST",
181            "/api/users",
182            Some(serde_json::json!({"name": "Alice"})),
183            201,
184            Some(serde_json::json!({"id": "user_1", "name": "Alice"})),
185        );
186
187        store.store_interaction("session_1", &interaction).await.unwrap();
188
189        let retrieved = store.retrieve_context("session_1", "user creation", 10).await.unwrap();
190        assert_eq!(retrieved.len(), 1);
191        assert_eq!(retrieved[0].method, "POST");
192    }
193
194    #[tokio::test]
195    async fn test_clear_session() {
196        let store = VectorMemoryStore::new(VectorStoreConfig::default());
197
198        let interaction = InteractionRecord::new("GET", "/api/users", None, 200, None);
199        store.store_interaction("session_1", &interaction).await.unwrap();
200
201        store.clear_session("session_1").await.unwrap();
202
203        let retrieved = store.get_session_interactions("session_1").await.unwrap();
204        assert!(retrieved.is_empty());
205    }
206}