mockforge_intelligence/intelligent_behavior/
memory.rs1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use super::config::VectorStoreConfig;
11use super::embedding_client::{cosine_similarity, EmbeddingClient};
12use super::types::InteractionRecord;
13use mockforge_foundation::Result;
14
15pub struct VectorMemoryStore {
17 storage: Arc<RwLock<HashMap<String, Vec<InteractionRecord>>>>,
19
20 embedding_client: Option<Arc<EmbeddingClient>>,
22
23 config: VectorStoreConfig,
25}
26
27impl VectorMemoryStore {
28 pub fn new(config: VectorStoreConfig) -> Self {
30 let embedding_client = if config.enabled {
31 Some(Arc::new(EmbeddingClient::new(
32 config.embedding_provider.clone(),
33 config.embedding_model.clone(),
34 None,
35 None,
36 )))
37 } else {
38 None
39 };
40
41 Self {
42 storage: Arc::new(RwLock::new(HashMap::new())),
43 embedding_client,
44 config,
45 }
46 }
47
48 pub async fn store_interaction(
54 &self,
55 session_id: &str,
56 interaction: &InteractionRecord,
57 ) -> Result<()> {
58 let mut interaction_with_embedding = interaction.clone();
59
60 if let Some(ref client) = self.embedding_client {
62 let summary = interaction.summary();
63 match client.generate_embedding(&summary).await {
64 Ok(embedding) => {
65 interaction_with_embedding.embedding = Some(embedding);
66 }
67 Err(e) => {
68 tracing::warn!("Failed to generate embedding: {}", e);
69 }
71 }
72 }
73
74 let mut storage = self.storage.write().await;
75 storage
76 .entry(session_id.to_string())
77 .or_insert_with(Vec::new)
78 .push(interaction_with_embedding);
79
80 Ok(())
81 }
82
83 pub async fn retrieve_context(
90 &self,
91 session_id: &str,
92 query: &str,
93 limit: usize,
94 ) -> Result<Vec<InteractionRecord>> {
95 let storage = self.storage.read().await;
96
97 let interactions = storage.get(session_id).cloned().unwrap_or_default();
98
99 if self.embedding_client.is_none() || interactions.is_empty() {
101 return Ok(interactions.into_iter().rev().take(limit).collect());
102 }
103
104 let query_embedding = match &self.embedding_client {
106 Some(client) => match client.generate_embedding(query).await {
107 Ok(emb) => emb,
108 Err(e) => {
109 tracing::warn!("Failed to generate query embedding: {}", e);
110 return Ok(interactions.into_iter().rev().take(limit).collect());
111 }
112 },
113 None => return Ok(interactions.into_iter().rev().take(limit).collect()),
114 };
115
116 let mut scored_interactions: Vec<(InteractionRecord, f32)> = interactions
118 .into_iter()
119 .filter_map(|interaction| {
120 interaction.embedding.as_ref().map(|emb| {
121 let score = cosine_similarity(&query_embedding, emb);
122 (interaction.clone(), score)
123 })
124 })
125 .collect();
126
127 scored_interactions.retain(|(_, score)| *score >= self.config.similarity_threshold);
129 scored_interactions
130 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
131
132 Ok(scored_interactions
134 .into_iter()
135 .take(limit)
136 .map(|(interaction, _)| interaction)
137 .collect())
138 }
139
140 pub async fn get_session_interactions(
142 &self,
143 session_id: &str,
144 ) -> Result<Vec<InteractionRecord>> {
145 let storage = self.storage.read().await;
146
147 Ok(storage.get(session_id).cloned().unwrap_or_default())
148 }
149
150 pub async fn clear_session(&self, session_id: &str) -> Result<()> {
152 let mut storage = self.storage.write().await;
153 storage.remove(session_id);
154 Ok(())
155 }
156
157 pub async fn clear_all(&self) -> Result<()> {
159 let mut storage = self.storage.write().await;
160 storage.clear();
161 Ok(())
162 }
163}
164
165impl Default for VectorMemoryStore {
166 fn default() -> Self {
167 Self::new(VectorStoreConfig::default())
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[tokio::test]
176 async fn test_store_and_retrieve() {
177 let store = VectorMemoryStore::new(VectorStoreConfig::default());
178
179 let interaction = InteractionRecord::new(
180 "POST",
181 "/api/users",
182 Some(serde_json::json!({"name": "Alice"})),
183 201,
184 Some(serde_json::json!({"id": "user_1", "name": "Alice"})),
185 );
186
187 store.store_interaction("session_1", &interaction).await.unwrap();
188
189 let retrieved = store.retrieve_context("session_1", "user creation", 10).await.unwrap();
190 assert_eq!(retrieved.len(), 1);
191 assert_eq!(retrieved[0].method, "POST");
192 }
193
194 #[tokio::test]
195 async fn test_clear_session() {
196 let store = VectorMemoryStore::new(VectorStoreConfig::default());
197
198 let interaction = InteractionRecord::new("GET", "/api/users", None, 200, None);
199 store.store_interaction("session_1", &interaction).await.unwrap();
200
201 store.clear_session("session_1").await.unwrap();
202
203 let retrieved = store.get_session_interactions("session_1").await.unwrap();
204 assert!(retrieved.is_empty());
205 }
206}