Skip to main content

brainwires_cognition/prompting/
clustering.rs

1//! Task Clustering System with SEAL Integration
2//!
3//! This module implements k-means clustering of tasks by semantic similarity,
4//! enhanced with SEAL's query core extraction for better classification.
5
6use super::techniques::{ComplexityLevel, PromptingTechnique};
7use crate::prompting::seal::SealProcessingResult;
8use anyhow::{Context as _, Result, anyhow};
9#[cfg(feature = "prompting")]
10use linfa::prelude::*;
11#[cfg(feature = "prompting")]
12use linfa_clustering::KMeans;
13#[cfg(feature = "prompting")]
14use ndarray::Array2;
15use serde::{Deserialize, Serialize};
16
17/// A task cluster identified by semantic similarity (SEAL-enhanced)
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TaskCluster {
20    /// Unique cluster identifier.
21    pub id: String,
22    /// LLM-generated semantic description of this cluster.
23    pub description: String,
24    /// Embedding vector of the cluster description.
25    pub embedding: Vec<f32>,
26    /// Prompting techniques mapped to this cluster (typically 3-4).
27    pub techniques: Vec<PromptingTechnique>,
28    /// Example task descriptions belonging to this cluster.
29    pub example_tasks: Vec<String>,
30
31    /// Example query cores from SEAL for tasks in this cluster.
32    pub seal_query_cores: Vec<String>,
33    /// Average SEAL quality score for tasks in this cluster.
34    pub avg_seal_quality: f32,
35    /// Recommended complexity level based on average SEAL quality.
36    pub recommended_complexity: ComplexityLevel,
37}
38
39impl TaskCluster {
40    /// Create a new task cluster
41    pub fn new(
42        id: String,
43        description: String,
44        embedding: Vec<f32>,
45        techniques: Vec<PromptingTechnique>,
46        example_tasks: Vec<String>,
47    ) -> Self {
48        Self {
49            id,
50            description,
51            embedding,
52            techniques,
53            example_tasks,
54            seal_query_cores: Vec::new(),
55            avg_seal_quality: 0.5,
56            recommended_complexity: ComplexityLevel::Moderate,
57        }
58    }
59
60    /// Update SEAL-related metrics
61    pub fn update_seal_metrics(&mut self, query_cores: Vec<String>, avg_quality: f32) {
62        self.seal_query_cores = query_cores;
63        self.avg_seal_quality = avg_quality;
64        self.recommended_complexity = if avg_quality < 0.5 {
65            ComplexityLevel::Simple
66        } else if avg_quality < 0.8 {
67            ComplexityLevel::Moderate
68        } else {
69            ComplexityLevel::Advanced
70        };
71    }
72}
73
74/// Manages task clustering
75pub struct TaskClusterManager {
76    clusters: Vec<TaskCluster>,
77    _embedding_dim: usize,
78}
79
80impl TaskClusterManager {
81    /// Create a new task cluster manager
82    pub fn new() -> Self {
83        Self {
84            clusters: Vec::new(),
85            _embedding_dim: 768, // Default for most embedding models
86        }
87    }
88
89    /// Create with specific embedding dimension
90    pub fn with_embedding_dim(embedding_dim: usize) -> Self {
91        Self {
92            clusters: Vec::new(),
93            _embedding_dim: embedding_dim,
94        }
95    }
96
97    /// Get all clusters
98    pub fn get_clusters(&self) -> &[TaskCluster] {
99        &self.clusters
100    }
101
102    /// Add a cluster
103    pub fn add_cluster(&mut self, cluster: TaskCluster) {
104        self.clusters.push(cluster);
105    }
106
107    /// Set clusters (replaces existing)
108    pub fn set_clusters(&mut self, clusters: Vec<TaskCluster>) {
109        self.clusters = clusters;
110    }
111
112    /// Find task cluster most similar to a task description (SEAL-enhanced)
113    ///
114    /// This is the core classification function that:
115    /// 1. Uses SEAL's resolved query if available (not original query)
116    /// 2. Prefers SEAL's query core for better semantic matching
117    /// 3. Boosts similarity if SEAL quality is high
118    ///
119    /// # Arguments
120    /// * `task_embedding` - Pre-computed embedding of the task
121    /// * `seal_result` - Optional SEAL processing result for enhancement
122    ///
123    /// # Returns
124    /// * Tuple of (cluster reference, similarity score)
125    pub fn find_matching_cluster(
126        &self,
127        task_embedding: &[f32],
128        seal_result: Option<&SealProcessingResult>,
129    ) -> Result<(&TaskCluster, f32)> {
130        if self.clusters.is_empty() {
131            return Err(anyhow!("No clusters available"));
132        }
133
134        let mut best_match = None;
135        let mut best_similarity = f32::NEG_INFINITY;
136
137        for cluster in &self.clusters {
138            let similarity = cosine_similarity(task_embedding, &cluster.embedding);
139
140            // Boost similarity if SEAL quality is high
141            let boosted_similarity = if let Some(seal) = seal_result {
142                if seal.quality_score > 0.7 {
143                    similarity * 1.1 // 10% boost for high-quality SEAL resolutions
144                } else {
145                    similarity
146                }
147            } else {
148                similarity
149            };
150
151            if boosted_similarity > best_similarity {
152                best_similarity = boosted_similarity;
153                best_match = Some(cluster);
154            }
155        }
156
157        let cluster = best_match.ok_or_else(|| anyhow!("No matching cluster found"))?;
158        Ok((cluster, best_similarity))
159    }
160
161    /// Build clusters from a set of task embeddings using k-means (requires prompting feature - linfa)
162    #[cfg(feature = "prompting")]
163    pub fn build_clusters_from_embeddings(
164        &mut self,
165        task_embeddings: Array2<f32>,
166        task_descriptions: Vec<String>,
167        min_clusters: usize,
168        max_clusters: usize,
169    ) -> Result<Vec<usize>> {
170        if task_embeddings.nrows() != task_descriptions.len() {
171            return Err(anyhow!(
172                "Embeddings and descriptions length mismatch: {} vs {}",
173                task_embeddings.nrows(),
174                task_descriptions.len()
175            ));
176        }
177
178        if task_embeddings.nrows() < min_clusters {
179            return Err(anyhow!(
180                "Not enough tasks ({}) for minimum clusters ({})",
181                task_embeddings.nrows(),
182                min_clusters
183            ));
184        }
185
186        // Find optimal K using silhouette scores
187        let optimal_k = self.find_optimal_k(&task_embeddings, min_clusters, max_clusters)?;
188
189        // Perform k-means clustering
190        let assignments = self.perform_kmeans(&task_embeddings, optimal_k)?;
191
192        // Build cluster objects
193        self.build_cluster_objects(
194            &task_embeddings,
195            &task_descriptions,
196            &assignments,
197            optimal_k,
198        )?;
199
200        Ok(assignments)
201    }
202
203    /// Find optimal number of clusters using silhouette scores
204    #[cfg(feature = "prompting")]
205    fn find_optimal_k(
206        &self,
207        embeddings: &Array2<f32>,
208        min_k: usize,
209        max_k: usize,
210    ) -> Result<usize> {
211        let mut best_k = min_k;
212        let mut best_score = f32::NEG_INFINITY;
213
214        let effective_max_k = max_k.min(embeddings.nrows() / 2);
215
216        for k in min_k..=effective_max_k {
217            let assignments = self.perform_kmeans(embeddings, k)?;
218            let score = self.compute_silhouette_score(embeddings, &assignments, k);
219
220            if score > best_score {
221                best_score = score;
222                best_k = k;
223            }
224        }
225
226        Ok(best_k)
227    }
228
229    /// Perform k-means clustering
230    #[cfg(feature = "prompting")]
231    fn perform_kmeans(&self, embeddings: &Array2<f32>, k: usize) -> Result<Vec<usize>> {
232        let dataset = DatasetBase::from(embeddings.clone());
233
234        let model = KMeans::params(k)
235            .max_n_iterations(100)
236            .tolerance(1e-4)
237            .fit(&dataset)
238            .context("K-means fitting failed")?;
239
240        let assignments: Vec<usize> = model.predict(embeddings).iter().copied().collect();
241
242        Ok(assignments)
243    }
244
245    /// Compute silhouette score for clustering quality
246    #[cfg(feature = "prompting")]
247    fn compute_silhouette_score(
248        &self,
249        embeddings: &Array2<f32>,
250        assignments: &[usize],
251        k: usize,
252    ) -> f32 {
253        let n = embeddings.nrows();
254        if n == 0 {
255            return 0.0;
256        }
257
258        let mut silhouette_sum = 0.0;
259        let mut count = 0;
260
261        for i in 0..n {
262            let cluster_i = assignments[i];
263
264            let mut a_i = 0.0;
265            let mut same_cluster_count = 0;
266            for (j, &assignment_j) in assignments.iter().enumerate().take(n) {
267                if i != j && assignment_j == cluster_i {
268                    a_i += euclidean_distance(
269                        &embeddings.row(i).to_vec(),
270                        &embeddings.row(j).to_vec(),
271                    );
272                    same_cluster_count += 1;
273                }
274            }
275            if same_cluster_count > 0 {
276                a_i /= same_cluster_count as f32;
277            }
278
279            let mut b_i = f32::INFINITY;
280            for other_cluster in 0..k {
281                if other_cluster == cluster_i {
282                    continue;
283                }
284
285                let mut dist_sum = 0.0;
286                let mut other_count = 0;
287                for (j, &assignment_j) in assignments.iter().enumerate().take(n) {
288                    if assignment_j == other_cluster {
289                        dist_sum += euclidean_distance(
290                            &embeddings.row(i).to_vec(),
291                            &embeddings.row(j).to_vec(),
292                        );
293                        other_count += 1;
294                    }
295                }
296                if other_count > 0 {
297                    let avg_dist = dist_sum / other_count as f32;
298                    b_i = b_i.min(avg_dist);
299                }
300            }
301
302            if b_i.is_finite() && a_i > 0.0 {
303                let s_i = (b_i - a_i) / a_i.max(b_i);
304                silhouette_sum += s_i;
305                count += 1;
306            }
307        }
308
309        if count > 0 {
310            silhouette_sum / count as f32
311        } else {
312            0.0
313        }
314    }
315
316    /// Build cluster objects from assignments
317    #[cfg(feature = "prompting")]
318    fn build_cluster_objects(
319        &mut self,
320        embeddings: &Array2<f32>,
321        descriptions: &[String],
322        assignments: &[usize],
323        k: usize,
324    ) -> Result<()> {
325        let mut clusters = Vec::new();
326
327        for cluster_id in 0..k {
328            let mut cluster_tasks = Vec::new();
329            let mut cluster_embeddings = Vec::new();
330
331            for (i, &assignment) in assignments.iter().enumerate() {
332                if assignment == cluster_id {
333                    cluster_tasks.push(descriptions[i].clone());
334                    cluster_embeddings.push(embeddings.row(i).to_vec());
335                }
336            }
337
338            if cluster_tasks.is_empty() {
339                continue;
340            }
341
342            let centroid = compute_centroid(&cluster_embeddings);
343
344            let cluster = TaskCluster::new(
345                format!("cluster_{}", cluster_id),
346                format!("Cluster {}", cluster_id),
347                centroid,
348                Vec::new(),
349                cluster_tasks.iter().take(5).cloned().collect(),
350            );
351
352            clusters.push(cluster);
353        }
354
355        self.clusters = clusters;
356        Ok(())
357    }
358
359    /// Get cluster count
360    pub fn cluster_count(&self) -> usize {
361        self.clusters.len()
362    }
363
364    /// Get cluster by ID
365    pub fn get_cluster_by_id(&self, id: &str) -> Option<&TaskCluster> {
366        self.clusters.iter().find(|c| c.id == id)
367    }
368
369    /// Get mutable cluster by ID
370    pub fn get_cluster_by_id_mut(&mut self, id: &str) -> Option<&mut TaskCluster> {
371        self.clusters.iter_mut().find(|c| c.id == id)
372    }
373}
374
375impl Default for TaskClusterManager {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381/// Compute cosine similarity between two vectors
382pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
383    if a.len() != b.len() {
384        return 0.0;
385    }
386
387    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
388    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
389    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
390
391    if norm_a == 0.0 || norm_b == 0.0 {
392        return 0.0;
393    }
394
395    dot / (norm_a * norm_b)
396}
397
398/// Compute Euclidean distance between two vectors
399fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
400    if a.len() != b.len() {
401        return f32::INFINITY;
402    }
403
404    a.iter()
405        .zip(b)
406        .map(|(x, y)| (x - y).powi(2))
407        .sum::<f32>()
408        .sqrt()
409}
410
411/// Compute centroid of a set of embeddings
412fn compute_centroid(embeddings: &[Vec<f32>]) -> Vec<f32> {
413    if embeddings.is_empty() {
414        return Vec::new();
415    }
416
417    let dim = embeddings[0].len();
418    let mut centroid = vec![0.0; dim];
419
420    for embedding in embeddings {
421        for (i, &val) in embedding.iter().enumerate() {
422            centroid[i] += val;
423        }
424    }
425
426    let n = embeddings.len() as f32;
427    for val in &mut centroid {
428        *val /= n;
429    }
430
431    centroid
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_cosine_similarity() {
440        let a = vec![1.0, 0.0, 0.0];
441        let b = vec![1.0, 0.0, 0.0];
442        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
443
444        let c = vec![1.0, 0.0, 0.0];
445        let d = vec![0.0, 1.0, 0.0];
446        assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
447    }
448
449    #[test]
450    fn test_euclidean_distance() {
451        let a = vec![0.0, 0.0];
452        let b = vec![3.0, 4.0];
453        assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
454    }
455
456    #[test]
457    fn test_compute_centroid() {
458        let embeddings = vec![
459            vec![1.0, 2.0, 3.0],
460            vec![4.0, 5.0, 6.0],
461            vec![7.0, 8.0, 9.0],
462        ];
463        let centroid = compute_centroid(&embeddings);
464        assert_eq!(centroid, vec![4.0, 5.0, 6.0]);
465    }
466
467    #[test]
468    fn test_cluster_manager_basic() {
469        let mut manager = TaskClusterManager::new();
470        assert_eq!(manager.cluster_count(), 0);
471
472        let cluster = TaskCluster::new(
473            "test_cluster".to_string(),
474            "Test cluster".to_string(),
475            vec![0.1, 0.2, 0.3],
476            Vec::new(),
477            vec!["task1".to_string()],
478        );
479
480        manager.add_cluster(cluster);
481        assert_eq!(manager.cluster_count(), 1);
482    }
483}