nt_memory/reasoningbank/
distillation.rs1use super::trajectory::Trajectory;
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct DistilledPattern {
10 pub id: String,
12
13 pub pattern_type: String,
15
16 pub centroid: Vec<f32>,
18
19 pub trajectory_ids: Vec<String>,
21
22 pub strength: f64,
24
25 pub metadata: serde_json::Value,
27}
28
29pub struct MemoryDistiller {
31 enable_compression: bool,
33
34 min_trajectories: usize,
36
37 similarity_threshold: f64,
39}
40
41impl MemoryDistiller {
42 pub fn new(enable_compression: bool) -> Self {
44 Self {
45 enable_compression,
46 min_trajectories: 3,
47 similarity_threshold: 0.8,
48 }
49 }
50
51 pub fn with_params(
53 enable_compression: bool,
54 min_trajectories: usize,
55 similarity_threshold: f64,
56 ) -> Self {
57 Self {
58 enable_compression,
59 min_trajectories,
60 similarity_threshold,
61 }
62 }
63
64 pub async fn distill(&self, trajectories: &[Trajectory]) -> Vec<DistilledPattern> {
66 let mut by_agent: HashMap<String, Vec<&Trajectory>> = HashMap::new();
68
69 for trajectory in trajectories {
70 by_agent
71 .entry(trajectory.agent_id.clone())
72 .or_insert_with(Vec::new)
73 .push(trajectory);
74 }
75
76 let mut patterns = Vec::new();
77
78 for (agent_id, agent_trajectories) in by_agent {
80 if agent_trajectories.len() < self.min_trajectories {
81 continue;
82 }
83
84 let embeddings: Vec<Vec<f32>> = agent_trajectories
86 .iter()
87 .flat_map(|t| &t.observations)
88 .filter_map(|obs| obs.embedding.clone())
89 .collect();
90
91 if embeddings.is_empty() {
92 continue;
93 }
94
95 let centroid = self.calculate_centroid(&embeddings);
97
98 let strength = self.calculate_pattern_strength(&embeddings, ¢roid);
100
101 let pattern = DistilledPattern {
103 id: uuid::Uuid::new_v4().to_string(),
104 pattern_type: "agent_behavior".to_string(),
105 centroid,
106 trajectory_ids: agent_trajectories.iter().map(|t| t.id.clone()).collect(),
107 strength,
108 metadata: serde_json::json!({
109 "agent_id": agent_id,
110 "trajectory_count": agent_trajectories.len(),
111 }),
112 };
113
114 patterns.push(pattern);
115 }
116
117 if self.enable_compression {
119 self.compress_patterns(&mut patterns);
120 }
121
122 patterns
123 }
124
125 fn calculate_centroid(&self, embeddings: &[Vec<f32>]) -> Vec<f32> {
127 if embeddings.is_empty() {
128 return Vec::new();
129 }
130
131 let dimension = embeddings[0].len();
132 let count = embeddings.len() as f32;
133
134 let mut centroid = vec![0.0; dimension];
135
136 for embedding in embeddings {
137 for (i, &value) in embedding.iter().enumerate() {
138 centroid[i] += value / count;
139 }
140 }
141
142 let magnitude: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
144 if magnitude > 0.0 {
145 centroid.iter_mut().for_each(|x| *x /= magnitude);
146 }
147
148 centroid
149 }
150
151 fn calculate_pattern_strength(&self, embeddings: &[Vec<f32>], centroid: &[f32]) -> f64 {
153 if embeddings.is_empty() || centroid.is_empty() {
154 return 0.0;
155 }
156
157 let mut total_similarity = 0.0;
159
160 for embedding in embeddings {
161 let similarity = cosine_similarity(embedding, centroid);
162 total_similarity += similarity;
163 }
164
165 let avg_similarity = total_similarity / embeddings.len() as f64;
166
167 (avg_similarity + 1.0) / 2.0
169 }
170
171 fn compress_patterns(&self, patterns: &mut [DistilledPattern]) {
173 use lz4::EncoderBuilder;
174 use std::io::Write;
175
176 for pattern in patterns.iter_mut() {
177 let bytes = bincode::serialize(&pattern.centroid).unwrap();
179
180 let mut encoder = EncoderBuilder::new()
182 .level(4)
183 .build(Vec::new())
184 .unwrap();
185
186 encoder.write_all(&bytes).unwrap();
187 let (compressed, _) = encoder.finish();
188
189 let ratio = compressed.len() as f64 / bytes.len() as f64;
191
192 if let Some(obj) = pattern.metadata.as_object_mut() {
193 obj.insert("compression_ratio".to_string(), serde_json::json!(ratio));
194 obj.insert("original_size".to_string(), serde_json::json!(bytes.len()));
195 obj.insert("compressed_size".to_string(), serde_json::json!(compressed.len()));
196 }
197 }
198 }
199
200 pub fn merge_similar(&self, patterns: &[DistilledPattern]) -> Vec<DistilledPattern> {
202 let mut merged = Vec::new();
203 let mut used = vec![false; patterns.len()];
204
205 for i in 0..patterns.len() {
206 if used[i] {
207 continue;
208 }
209
210 let mut cluster = vec![i];
211
212 for j in (i + 1)..patterns.len() {
214 if used[j] {
215 continue;
216 }
217
218 let similarity = cosine_similarity(&patterns[i].centroid, &patterns[j].centroid);
219
220 if similarity >= self.similarity_threshold {
221 cluster.push(j);
222 used[j] = true;
223 }
224 }
225
226 let cluster_patterns: Vec<&DistilledPattern> =
228 cluster.iter().map(|&idx| &patterns[idx]).collect();
229
230 let merged_pattern = self.merge_cluster(&cluster_patterns);
231 merged.push(merged_pattern);
232
233 used[i] = true;
234 }
235
236 merged
237 }
238
239 fn merge_cluster(&self, patterns: &[&DistilledPattern]) -> DistilledPattern {
241 let embeddings: Vec<Vec<f32>> = patterns.iter().map(|p| p.centroid.clone()).collect();
243
244 let centroid = self.calculate_centroid(&embeddings);
246
247 let trajectory_ids: Vec<String> = patterns
249 .iter()
250 .flat_map(|p| p.trajectory_ids.clone())
251 .collect();
252
253 let strength = patterns.iter().map(|p| p.strength).sum::<f64>() / patterns.len() as f64;
255
256 DistilledPattern {
257 id: uuid::Uuid::new_v4().to_string(),
258 pattern_type: patterns[0].pattern_type.clone(),
259 centroid,
260 trajectory_ids,
261 strength,
262 metadata: serde_json::json!({
263 "merged_from": patterns.len(),
264 "pattern_ids": patterns.iter().map(|p| p.id.clone()).collect::<Vec<_>>(),
265 }),
266 }
267 }
268}
269
270fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
272 if a.len() != b.len() || a.is_empty() {
273 return 0.0;
274 }
275
276 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
277 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
278 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
279
280 if mag_a > 0.0 && mag_b > 0.0 {
281 (dot / (mag_a * mag_b)) as f64
282 } else {
283 0.0
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[tokio::test]
292 async fn test_distillation() {
293 let distiller = MemoryDistiller::new(false);
294
295 let mut trajectories = Vec::new();
297
298 for i in 0..5 {
299 let mut trajectory = Trajectory::new("agent_1".to_string());
300
301 let embedding = vec![0.5 + i as f32 * 0.01; 128];
302 trajectory.add_observation(serde_json::json!({"i": i}), Some(embedding));
303
304 trajectories.push(trajectory);
305 }
306
307 let patterns = distiller.distill(&trajectories).await;
308
309 assert!(!patterns.is_empty());
310 assert_eq!(patterns[0].trajectory_ids.len(), 5);
311 }
312
313 #[test]
314 fn test_centroid_calculation() {
315 let distiller = MemoryDistiller::new(false);
316
317 let embeddings = vec![
318 vec![1.0, 0.0],
319 vec![0.0, 1.0],
320 vec![1.0, 1.0],
321 ];
322
323 let centroid = distiller.calculate_centroid(&embeddings);
324
325 assert!(centroid.len() == 2);
327 assert!(centroid[0] > 0.0 && centroid[1] > 0.0);
328 }
329
330 #[test]
331 fn test_pattern_merging() {
332 let distiller = MemoryDistiller::new(false);
333
334 let pattern1 = DistilledPattern {
335 id: "p1".to_string(),
336 pattern_type: "test".to_string(),
337 centroid: vec![1.0, 0.0],
338 trajectory_ids: vec!["t1".to_string()],
339 strength: 0.9,
340 metadata: serde_json::json!({}),
341 };
342
343 let pattern2 = DistilledPattern {
344 id: "p2".to_string(),
345 pattern_type: "test".to_string(),
346 centroid: vec![0.9, 0.1], trajectory_ids: vec!["t2".to_string()],
348 strength: 0.85,
349 metadata: serde_json::json!({}),
350 };
351
352 let merged = distiller.merge_similar(&[pattern1, pattern2]);
353
354 assert_eq!(merged.len(), 1);
356 assert_eq!(merged[0].trajectory_ids.len(), 2);
357 }
358}