Skip to main content

agentic_memory/index/
cluster_map.rs

1//! Cluster map — pre-computed k-means clustering of feature vectors.
2
3/// Pre-computed clustering of feature vectors for fast semantic grouping.
4pub struct ClusterMap {
5    /// Cluster centroids.
6    centroids: Vec<Vec<f32>>,
7    /// cluster_index -> sorted Vec of node IDs in that cluster.
8    assignments: Vec<Vec<u64>>,
9    /// Feature vector dimension.
10    dimension: usize,
11}
12
13impl ClusterMap {
14    /// Create a new, empty cluster map.
15    pub fn new(dimension: usize) -> Self {
16        Self {
17            centroids: Vec::new(),
18            assignments: Vec::new(),
19            dimension,
20        }
21    }
22
23    /// Run k-means clustering on all feature vectors in the graph.
24    /// k = min(sqrt(node_count), 256). Skips if node_count < 4.
25    pub fn build(&mut self, nodes: &[(u64, &[f32])], max_iterations: usize) {
26        // Filter out zero vectors
27        let non_zero: Vec<(u64, &[f32])> = nodes
28            .iter()
29            .filter(|(_, v)| v.iter().any(|&x| x != 0.0))
30            .copied()
31            .collect();
32
33        if non_zero.len() < 4 {
34            self.centroids.clear();
35            self.assignments.clear();
36            return;
37        }
38
39        let k = ((non_zero.len() as f64).sqrt().ceil() as usize).min(256);
40
41        // Initialize centroids: pick k evenly-spaced nodes
42        let step = non_zero.len() / k;
43        self.centroids = (0..k)
44            .map(|i| {
45                let idx = (i * step).min(non_zero.len() - 1);
46                non_zero[idx].1.to_vec()
47            })
48            .collect();
49
50        self.assignments = vec![Vec::new(); k];
51
52        for _ in 0..max_iterations {
53            // Clear assignments
54            for a in &mut self.assignments {
55                a.clear();
56            }
57
58            // Assign each node to nearest centroid
59            for &(id, vec) in &non_zero {
60                let nearest = self.find_nearest_centroid(vec);
61                self.assignments[nearest].push(id);
62            }
63
64            // Update centroids
65            let mut changed = false;
66            for (ci, cluster_ids) in self.assignments.iter().enumerate() {
67                if cluster_ids.is_empty() {
68                    continue;
69                }
70                let mut new_centroid = vec![0.0f32; self.dimension];
71                let count = cluster_ids.len() as f32;
72                for &node_id in cluster_ids {
73                    if let Some((_, vec)) = non_zero.iter().find(|(id, _)| *id == node_id) {
74                        for (j, &val) in vec.iter().enumerate() {
75                            new_centroid[j] += val;
76                        }
77                    }
78                }
79                for val in &mut new_centroid {
80                    *val /= count;
81                }
82                if new_centroid != self.centroids[ci] {
83                    changed = true;
84                    self.centroids[ci] = new_centroid;
85                }
86            }
87
88            if !changed {
89                break;
90            }
91        }
92
93        // Sort assignments
94        for a in &mut self.assignments {
95            a.sort_unstable();
96        }
97    }
98
99    fn find_nearest_centroid(&self, vec: &[f32]) -> usize {
100        let mut best = 0;
101        let mut best_sim = f32::NEG_INFINITY;
102        for (i, centroid) in self.centroids.iter().enumerate() {
103            let sim = cosine_similarity(vec, centroid);
104            if sim > best_sim {
105                best_sim = sim;
106                best = i;
107            }
108        }
109        best
110    }
111
112    /// Find the nearest cluster for a query vector.
113    pub fn nearest_cluster(&self, query: &[f32]) -> Option<usize> {
114        if self.centroids.is_empty() {
115            return None;
116        }
117        Some(self.find_nearest_centroid(query))
118    }
119
120    /// Get all node IDs in a specific cluster.
121    pub fn get_cluster(&self, cluster_index: usize) -> &[u64] {
122        self.assignments
123            .get(cluster_index)
124            .map(|v| v.as_slice())
125            .unwrap_or(&[])
126    }
127
128    /// Get the centroid for a cluster.
129    pub fn centroid(&self, cluster_index: usize) -> Option<&[f32]> {
130        self.centroids.get(cluster_index).map(|v| v.as_slice())
131    }
132
133    /// Number of clusters.
134    pub fn cluster_count(&self) -> usize {
135        self.centroids.len()
136    }
137
138    /// Assign a new node to the nearest cluster without rebuilding.
139    pub fn assign_node(&mut self, node_id: u64, feature_vec: &[f32]) {
140        if self.centroids.is_empty() {
141            return;
142        }
143        if feature_vec.iter().all(|&x| x == 0.0) {
144            return;
145        }
146        let nearest = self.find_nearest_centroid(feature_vec);
147        let list = &mut self.assignments[nearest];
148        let pos = list.binary_search(&node_id).unwrap_or_else(|p| p);
149        list.insert(pos, node_id);
150    }
151
152    /// Clear the cluster map.
153    pub fn clear(&mut self) {
154        self.centroids.clear();
155        self.assignments.clear();
156    }
157
158    /// Whether the cluster map is empty.
159    pub fn is_empty(&self) -> bool {
160        self.centroids.is_empty()
161    }
162
163    /// Get the dimension.
164    pub fn dimension(&self) -> usize {
165        self.dimension
166    }
167
168    /// Get centroids (for serialization).
169    pub fn centroids(&self) -> &[Vec<f32>] {
170        &self.centroids
171    }
172
173    /// Get assignments (for serialization).
174    pub fn assignments(&self) -> &[Vec<u64>] {
175        &self.assignments
176    }
177}
178
179/// Compute cosine similarity between two vectors.
180pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
181    let mut dot = 0.0f32;
182    let mut norm_a = 0.0f32;
183    let mut norm_b = 0.0f32;
184    for i in 0..a.len().min(b.len()) {
185        dot += a[i] * b[i];
186        norm_a += a[i] * a[i];
187        norm_b += b[i] * b[i];
188    }
189    let denom = norm_a.sqrt() * norm_b.sqrt();
190    if denom == 0.0 {
191        0.0
192    } else {
193        (dot / denom).clamp(-1.0, 1.0)
194    }
195}
196
197impl Default for ClusterMap {
198    fn default() -> Self {
199        Self::new(128)
200    }
201}