brainwires_cognition/prompting/
clustering.rs1use super::techniques::{ComplexityLevel, PromptingTechnique};
7use crate::prompting::seal::SealProcessingResult;
8use anyhow::{Context as _, Result, anyhow};
9#[cfg(feature = "prompting")]
10use linfa::prelude::*;
11#[cfg(feature = "prompting")]
12use linfa_clustering::KMeans;
13#[cfg(feature = "prompting")]
14use ndarray::Array2;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TaskCluster {
20 pub id: String,
22 pub description: String,
24 pub embedding: Vec<f32>,
26 pub techniques: Vec<PromptingTechnique>,
28 pub example_tasks: Vec<String>,
30
31 pub seal_query_cores: Vec<String>,
33 pub avg_seal_quality: f32,
35 pub recommended_complexity: ComplexityLevel,
37}
38
39impl TaskCluster {
40 pub fn new(
42 id: String,
43 description: String,
44 embedding: Vec<f32>,
45 techniques: Vec<PromptingTechnique>,
46 example_tasks: Vec<String>,
47 ) -> Self {
48 Self {
49 id,
50 description,
51 embedding,
52 techniques,
53 example_tasks,
54 seal_query_cores: Vec::new(),
55 avg_seal_quality: 0.5,
56 recommended_complexity: ComplexityLevel::Moderate,
57 }
58 }
59
60 pub fn update_seal_metrics(&mut self, query_cores: Vec<String>, avg_quality: f32) {
62 self.seal_query_cores = query_cores;
63 self.avg_seal_quality = avg_quality;
64 self.recommended_complexity = if avg_quality < 0.5 {
65 ComplexityLevel::Simple
66 } else if avg_quality < 0.8 {
67 ComplexityLevel::Moderate
68 } else {
69 ComplexityLevel::Advanced
70 };
71 }
72}
73
74pub struct TaskClusterManager {
76 clusters: Vec<TaskCluster>,
77 _embedding_dim: usize,
78}
79
80impl TaskClusterManager {
81 pub fn new() -> Self {
83 Self {
84 clusters: Vec::new(),
85 _embedding_dim: 768, }
87 }
88
89 pub fn with_embedding_dim(embedding_dim: usize) -> Self {
91 Self {
92 clusters: Vec::new(),
93 _embedding_dim: embedding_dim,
94 }
95 }
96
97 pub fn get_clusters(&self) -> &[TaskCluster] {
99 &self.clusters
100 }
101
102 pub fn add_cluster(&mut self, cluster: TaskCluster) {
104 self.clusters.push(cluster);
105 }
106
107 pub fn set_clusters(&mut self, clusters: Vec<TaskCluster>) {
109 self.clusters = clusters;
110 }
111
112 pub fn find_matching_cluster(
126 &self,
127 task_embedding: &[f32],
128 seal_result: Option<&SealProcessingResult>,
129 ) -> Result<(&TaskCluster, f32)> {
130 if self.clusters.is_empty() {
131 return Err(anyhow!("No clusters available"));
132 }
133
134 let mut best_match = None;
135 let mut best_similarity = f32::NEG_INFINITY;
136
137 for cluster in &self.clusters {
138 let similarity = cosine_similarity(task_embedding, &cluster.embedding);
139
140 let boosted_similarity = if let Some(seal) = seal_result {
142 if seal.quality_score > 0.7 {
143 similarity * 1.1 } else {
145 similarity
146 }
147 } else {
148 similarity
149 };
150
151 if boosted_similarity > best_similarity {
152 best_similarity = boosted_similarity;
153 best_match = Some(cluster);
154 }
155 }
156
157 let cluster = best_match.ok_or_else(|| anyhow!("No matching cluster found"))?;
158 Ok((cluster, best_similarity))
159 }
160
161 #[cfg(feature = "prompting")]
163 pub fn build_clusters_from_embeddings(
164 &mut self,
165 task_embeddings: Array2<f32>,
166 task_descriptions: Vec<String>,
167 min_clusters: usize,
168 max_clusters: usize,
169 ) -> Result<Vec<usize>> {
170 if task_embeddings.nrows() != task_descriptions.len() {
171 return Err(anyhow!(
172 "Embeddings and descriptions length mismatch: {} vs {}",
173 task_embeddings.nrows(),
174 task_descriptions.len()
175 ));
176 }
177
178 if task_embeddings.nrows() < min_clusters {
179 return Err(anyhow!(
180 "Not enough tasks ({}) for minimum clusters ({})",
181 task_embeddings.nrows(),
182 min_clusters
183 ));
184 }
185
186 let optimal_k = self.find_optimal_k(&task_embeddings, min_clusters, max_clusters)?;
188
189 let assignments = self.perform_kmeans(&task_embeddings, optimal_k)?;
191
192 self.build_cluster_objects(
194 &task_embeddings,
195 &task_descriptions,
196 &assignments,
197 optimal_k,
198 )?;
199
200 Ok(assignments)
201 }
202
203 #[cfg(feature = "prompting")]
205 fn find_optimal_k(
206 &self,
207 embeddings: &Array2<f32>,
208 min_k: usize,
209 max_k: usize,
210 ) -> Result<usize> {
211 let mut best_k = min_k;
212 let mut best_score = f32::NEG_INFINITY;
213
214 let effective_max_k = max_k.min(embeddings.nrows() / 2);
215
216 for k in min_k..=effective_max_k {
217 let assignments = self.perform_kmeans(embeddings, k)?;
218 let score = self.compute_silhouette_score(embeddings, &assignments, k);
219
220 if score > best_score {
221 best_score = score;
222 best_k = k;
223 }
224 }
225
226 Ok(best_k)
227 }
228
229 #[cfg(feature = "prompting")]
231 fn perform_kmeans(&self, embeddings: &Array2<f32>, k: usize) -> Result<Vec<usize>> {
232 let dataset = DatasetBase::from(embeddings.clone());
233
234 let model = KMeans::params(k)
235 .max_n_iterations(100)
236 .tolerance(1e-4)
237 .fit(&dataset)
238 .context("K-means fitting failed")?;
239
240 let assignments: Vec<usize> = model.predict(embeddings).iter().copied().collect();
241
242 Ok(assignments)
243 }
244
245 #[cfg(feature = "prompting")]
247 fn compute_silhouette_score(
248 &self,
249 embeddings: &Array2<f32>,
250 assignments: &[usize],
251 k: usize,
252 ) -> f32 {
253 let n = embeddings.nrows();
254 if n == 0 {
255 return 0.0;
256 }
257
258 let mut silhouette_sum = 0.0;
259 let mut count = 0;
260
261 for i in 0..n {
262 let cluster_i = assignments[i];
263
264 let mut a_i = 0.0;
265 let mut same_cluster_count = 0;
266 for (j, &assignment_j) in assignments.iter().enumerate().take(n) {
267 if i != j && assignment_j == cluster_i {
268 a_i += euclidean_distance(
269 &embeddings.row(i).to_vec(),
270 &embeddings.row(j).to_vec(),
271 );
272 same_cluster_count += 1;
273 }
274 }
275 if same_cluster_count > 0 {
276 a_i /= same_cluster_count as f32;
277 }
278
279 let mut b_i = f32::INFINITY;
280 for other_cluster in 0..k {
281 if other_cluster == cluster_i {
282 continue;
283 }
284
285 let mut dist_sum = 0.0;
286 let mut other_count = 0;
287 for (j, &assignment_j) in assignments.iter().enumerate().take(n) {
288 if assignment_j == other_cluster {
289 dist_sum += euclidean_distance(
290 &embeddings.row(i).to_vec(),
291 &embeddings.row(j).to_vec(),
292 );
293 other_count += 1;
294 }
295 }
296 if other_count > 0 {
297 let avg_dist = dist_sum / other_count as f32;
298 b_i = b_i.min(avg_dist);
299 }
300 }
301
302 if b_i.is_finite() && a_i > 0.0 {
303 let s_i = (b_i - a_i) / a_i.max(b_i);
304 silhouette_sum += s_i;
305 count += 1;
306 }
307 }
308
309 if count > 0 {
310 silhouette_sum / count as f32
311 } else {
312 0.0
313 }
314 }
315
316 #[cfg(feature = "prompting")]
318 fn build_cluster_objects(
319 &mut self,
320 embeddings: &Array2<f32>,
321 descriptions: &[String],
322 assignments: &[usize],
323 k: usize,
324 ) -> Result<()> {
325 let mut clusters = Vec::new();
326
327 for cluster_id in 0..k {
328 let mut cluster_tasks = Vec::new();
329 let mut cluster_embeddings = Vec::new();
330
331 for (i, &assignment) in assignments.iter().enumerate() {
332 if assignment == cluster_id {
333 cluster_tasks.push(descriptions[i].clone());
334 cluster_embeddings.push(embeddings.row(i).to_vec());
335 }
336 }
337
338 if cluster_tasks.is_empty() {
339 continue;
340 }
341
342 let centroid = compute_centroid(&cluster_embeddings);
343
344 let cluster = TaskCluster::new(
345 format!("cluster_{}", cluster_id),
346 format!("Cluster {}", cluster_id),
347 centroid,
348 Vec::new(),
349 cluster_tasks.iter().take(5).cloned().collect(),
350 );
351
352 clusters.push(cluster);
353 }
354
355 self.clusters = clusters;
356 Ok(())
357 }
358
359 pub fn cluster_count(&self) -> usize {
361 self.clusters.len()
362 }
363
364 pub fn get_cluster_by_id(&self, id: &str) -> Option<&TaskCluster> {
366 self.clusters.iter().find(|c| c.id == id)
367 }
368
369 pub fn get_cluster_by_id_mut(&mut self, id: &str) -> Option<&mut TaskCluster> {
371 self.clusters.iter_mut().find(|c| c.id == id)
372 }
373}
374
375impl Default for TaskClusterManager {
376 fn default() -> Self {
377 Self::new()
378 }
379}
380
381pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
383 if a.len() != b.len() {
384 return 0.0;
385 }
386
387 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
388 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
389 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
390
391 if norm_a == 0.0 || norm_b == 0.0 {
392 return 0.0;
393 }
394
395 dot / (norm_a * norm_b)
396}
397
398fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
400 if a.len() != b.len() {
401 return f32::INFINITY;
402 }
403
404 a.iter()
405 .zip(b)
406 .map(|(x, y)| (x - y).powi(2))
407 .sum::<f32>()
408 .sqrt()
409}
410
411fn compute_centroid(embeddings: &[Vec<f32>]) -> Vec<f32> {
413 if embeddings.is_empty() {
414 return Vec::new();
415 }
416
417 let dim = embeddings[0].len();
418 let mut centroid = vec![0.0; dim];
419
420 for embedding in embeddings {
421 for (i, &val) in embedding.iter().enumerate() {
422 centroid[i] += val;
423 }
424 }
425
426 let n = embeddings.len() as f32;
427 for val in &mut centroid {
428 *val /= n;
429 }
430
431 centroid
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_cosine_similarity() {
440 let a = vec![1.0, 0.0, 0.0];
441 let b = vec![1.0, 0.0, 0.0];
442 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
443
444 let c = vec![1.0, 0.0, 0.0];
445 let d = vec![0.0, 1.0, 0.0];
446 assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
447 }
448
449 #[test]
450 fn test_euclidean_distance() {
451 let a = vec![0.0, 0.0];
452 let b = vec![3.0, 4.0];
453 assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
454 }
455
456 #[test]
457 fn test_compute_centroid() {
458 let embeddings = vec![
459 vec![1.0, 2.0, 3.0],
460 vec![4.0, 5.0, 6.0],
461 vec![7.0, 8.0, 9.0],
462 ];
463 let centroid = compute_centroid(&embeddings);
464 assert_eq!(centroid, vec![4.0, 5.0, 6.0]);
465 }
466
467 #[test]
468 fn test_cluster_manager_basic() {
469 let mut manager = TaskClusterManager::new();
470 assert_eq!(manager.cluster_count(), 0);
471
472 let cluster = TaskCluster::new(
473 "test_cluster".to_string(),
474 "Test cluster".to_string(),
475 vec![0.1, 0.2, 0.3],
476 Vec::new(),
477 vec!["task1".to_string()],
478 );
479
480 manager.add_cluster(cluster);
481 assert_eq!(manager.cluster_count(), 1);
482 }
483}