nt_memory/reasoningbank/
distillation.rs

1//! Memory distillation - Compress and extract patterns from trajectories
2
3use super::trajectory::Trajectory;
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6
7/// Distilled pattern from multiple trajectories
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct DistilledPattern {
10    /// Pattern ID
11    pub id: String,
12
13    /// Pattern type
14    pub pattern_type: String,
15
16    /// Centroid embedding
17    pub centroid: Vec<f32>,
18
19    /// Supporting trajectories
20    pub trajectory_ids: Vec<String>,
21
22    /// Pattern strength (0.0 - 1.0)
23    pub strength: f64,
24
25    /// Metadata
26    pub metadata: serde_json::Value,
27}
28
29/// Memory distiller
30pub struct MemoryDistiller {
31    /// Enable compression
32    enable_compression: bool,
33
34    /// Minimum trajectories for pattern
35    min_trajectories: usize,
36
37    /// Similarity threshold for clustering
38    similarity_threshold: f64,
39}
40
41impl MemoryDistiller {
42    /// Create new distiller
43    pub fn new(enable_compression: bool) -> Self {
44        Self {
45            enable_compression,
46            min_trajectories: 3,
47            similarity_threshold: 0.8,
48        }
49    }
50
51    /// Configure distillation parameters
52    pub fn with_params(
53        enable_compression: bool,
54        min_trajectories: usize,
55        similarity_threshold: f64,
56    ) -> Self {
57        Self {
58            enable_compression,
59            min_trajectories,
60            similarity_threshold,
61        }
62    }
63
64    /// Distill patterns from trajectories
65    pub async fn distill(&self, trajectories: &[Trajectory]) -> Vec<DistilledPattern> {
66        // Group trajectories by agent
67        let mut by_agent: HashMap<String, Vec<&Trajectory>> = HashMap::new();
68
69        for trajectory in trajectories {
70            by_agent
71                .entry(trajectory.agent_id.clone())
72                .or_insert_with(Vec::new)
73                .push(trajectory);
74        }
75
76        let mut patterns = Vec::new();
77
78        // Distill patterns per agent
79        for (agent_id, agent_trajectories) in by_agent {
80            if agent_trajectories.len() < self.min_trajectories {
81                continue;
82            }
83
84            // Extract embeddings from observations
85            let embeddings: Vec<Vec<f32>> = agent_trajectories
86                .iter()
87                .flat_map(|t| &t.observations)
88                .filter_map(|obs| obs.embedding.clone())
89                .collect();
90
91            if embeddings.is_empty() {
92                continue;
93            }
94
95            // Calculate centroid
96            let centroid = self.calculate_centroid(&embeddings);
97
98            // Calculate pattern strength (based on clustering tightness)
99            let strength = self.calculate_pattern_strength(&embeddings, &centroid);
100
101            // Create pattern
102            let pattern = DistilledPattern {
103                id: uuid::Uuid::new_v4().to_string(),
104                pattern_type: "agent_behavior".to_string(),
105                centroid,
106                trajectory_ids: agent_trajectories.iter().map(|t| t.id.clone()).collect(),
107                strength,
108                metadata: serde_json::json!({
109                    "agent_id": agent_id,
110                    "trajectory_count": agent_trajectories.len(),
111                }),
112            };
113
114            patterns.push(pattern);
115        }
116
117        // Compress if enabled
118        if self.enable_compression {
119            self.compress_patterns(&mut patterns);
120        }
121
122        patterns
123    }
124
125    /// Calculate centroid of embeddings
126    fn calculate_centroid(&self, embeddings: &[Vec<f32>]) -> Vec<f32> {
127        if embeddings.is_empty() {
128            return Vec::new();
129        }
130
131        let dimension = embeddings[0].len();
132        let count = embeddings.len() as f32;
133
134        let mut centroid = vec![0.0; dimension];
135
136        for embedding in embeddings {
137            for (i, &value) in embedding.iter().enumerate() {
138                centroid[i] += value / count;
139            }
140        }
141
142        // Normalize
143        let magnitude: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
144        if magnitude > 0.0 {
145            centroid.iter_mut().for_each(|x| *x /= magnitude);
146        }
147
148        centroid
149    }
150
151    /// Calculate pattern strength based on clustering tightness
152    fn calculate_pattern_strength(&self, embeddings: &[Vec<f32>], centroid: &[f32]) -> f64 {
153        if embeddings.is_empty() || centroid.is_empty() {
154            return 0.0;
155        }
156
157        // Calculate average cosine similarity to centroid
158        let mut total_similarity = 0.0;
159
160        for embedding in embeddings {
161            let similarity = cosine_similarity(embedding, centroid);
162            total_similarity += similarity;
163        }
164
165        let avg_similarity = total_similarity / embeddings.len() as f64;
166
167        // Normalize to 0-1 range (cosine similarity is -1 to 1)
168        (avg_similarity + 1.0) / 2.0
169    }
170
171    /// Compress patterns using LZ4
172    fn compress_patterns(&self, patterns: &mut [DistilledPattern]) {
173        use lz4::EncoderBuilder;
174        use std::io::Write;
175
176        for pattern in patterns.iter_mut() {
177            // Serialize embedding
178            let bytes = bincode::serialize(&pattern.centroid).unwrap();
179
180            // Compress
181            let mut encoder = EncoderBuilder::new()
182                .level(4)
183                .build(Vec::new())
184                .unwrap();
185
186            encoder.write_all(&bytes).unwrap();
187            let (compressed, _) = encoder.finish();
188
189            // Store compression ratio in metadata
190            let ratio = compressed.len() as f64 / bytes.len() as f64;
191
192            if let Some(obj) = pattern.metadata.as_object_mut() {
193                obj.insert("compression_ratio".to_string(), serde_json::json!(ratio));
194                obj.insert("original_size".to_string(), serde_json::json!(bytes.len()));
195                obj.insert("compressed_size".to_string(), serde_json::json!(compressed.len()));
196            }
197        }
198    }
199
200    /// Merge similar patterns
201    pub fn merge_similar(&self, patterns: &[DistilledPattern]) -> Vec<DistilledPattern> {
202        let mut merged = Vec::new();
203        let mut used = vec![false; patterns.len()];
204
205        for i in 0..patterns.len() {
206            if used[i] {
207                continue;
208            }
209
210            let mut cluster = vec![i];
211
212            // Find similar patterns
213            for j in (i + 1)..patterns.len() {
214                if used[j] {
215                    continue;
216                }
217
218                let similarity = cosine_similarity(&patterns[i].centroid, &patterns[j].centroid);
219
220                if similarity >= self.similarity_threshold {
221                    cluster.push(j);
222                    used[j] = true;
223                }
224            }
225
226            // Merge cluster into single pattern
227            let cluster_patterns: Vec<&DistilledPattern> =
228                cluster.iter().map(|&idx| &patterns[idx]).collect();
229
230            let merged_pattern = self.merge_cluster(&cluster_patterns);
231            merged.push(merged_pattern);
232
233            used[i] = true;
234        }
235
236        merged
237    }
238
239    /// Merge cluster of patterns
240    fn merge_cluster(&self, patterns: &[&DistilledPattern]) -> DistilledPattern {
241        // Collect all embeddings
242        let embeddings: Vec<Vec<f32>> = patterns.iter().map(|p| p.centroid.clone()).collect();
243
244        // Calculate new centroid
245        let centroid = self.calculate_centroid(&embeddings);
246
247        // Collect all trajectory IDs
248        let trajectory_ids: Vec<String> = patterns
249            .iter()
250            .flat_map(|p| p.trajectory_ids.clone())
251            .collect();
252
253        // Calculate merged strength
254        let strength = patterns.iter().map(|p| p.strength).sum::<f64>() / patterns.len() as f64;
255
256        DistilledPattern {
257            id: uuid::Uuid::new_v4().to_string(),
258            pattern_type: patterns[0].pattern_type.clone(),
259            centroid,
260            trajectory_ids,
261            strength,
262            metadata: serde_json::json!({
263                "merged_from": patterns.len(),
264                "pattern_ids": patterns.iter().map(|p| p.id.clone()).collect::<Vec<_>>(),
265            }),
266        }
267    }
268}
269
270/// Helper: Cosine similarity
271fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
272    if a.len() != b.len() || a.is_empty() {
273        return 0.0;
274    }
275
276    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
277    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
278    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
279
280    if mag_a > 0.0 && mag_b > 0.0 {
281        (dot / (mag_a * mag_b)) as f64
282    } else {
283        0.0
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[tokio::test]
292    async fn test_distillation() {
293        let distiller = MemoryDistiller::new(false);
294
295        // Create trajectories with embeddings
296        let mut trajectories = Vec::new();
297
298        for i in 0..5 {
299            let mut trajectory = Trajectory::new("agent_1".to_string());
300
301            let embedding = vec![0.5 + i as f32 * 0.01; 128];
302            trajectory.add_observation(serde_json::json!({"i": i}), Some(embedding));
303
304            trajectories.push(trajectory);
305        }
306
307        let patterns = distiller.distill(&trajectories).await;
308
309        assert!(!patterns.is_empty());
310        assert_eq!(patterns[0].trajectory_ids.len(), 5);
311    }
312
313    #[test]
314    fn test_centroid_calculation() {
315        let distiller = MemoryDistiller::new(false);
316
317        let embeddings = vec![
318            vec![1.0, 0.0],
319            vec![0.0, 1.0],
320            vec![1.0, 1.0],
321        ];
322
323        let centroid = distiller.calculate_centroid(&embeddings);
324
325        // Should be normalized average
326        assert!(centroid.len() == 2);
327        assert!(centroid[0] > 0.0 && centroid[1] > 0.0);
328    }
329
330    #[test]
331    fn test_pattern_merging() {
332        let distiller = MemoryDistiller::new(false);
333
334        let pattern1 = DistilledPattern {
335            id: "p1".to_string(),
336            pattern_type: "test".to_string(),
337            centroid: vec![1.0, 0.0],
338            trajectory_ids: vec!["t1".to_string()],
339            strength: 0.9,
340            metadata: serde_json::json!({}),
341        };
342
343        let pattern2 = DistilledPattern {
344            id: "p2".to_string(),
345            pattern_type: "test".to_string(),
346            centroid: vec![0.9, 0.1], // Very similar
347            trajectory_ids: vec!["t2".to_string()],
348            strength: 0.85,
349            metadata: serde_json::json!({}),
350        };
351
352        let merged = distiller.merge_similar(&[pattern1, pattern2]);
353
354        // Should merge into single pattern
355        assert_eq!(merged.len(), 1);
356        assert_eq!(merged[0].trajectory_ids.len(), 2);
357    }
358}