Skip to main content

brainwires_prompting/
clustering.rs

1//! Task Clustering System with SEAL Integration
2//!
3//! This module implements k-means clustering of tasks by semantic similarity,
4//! enhanced with SEAL's query core extraction for better classification.
5
6use super::techniques::{ComplexityLevel, PromptingTechnique};
7use crate::seal::SealProcessingResult;
8#[cfg(feature = "prompting")]
9use anyhow::Context as _;
10use anyhow::{Result, anyhow};
11#[cfg(feature = "prompting")]
12use linfa::prelude::*;
13#[cfg(feature = "prompting")]
14use linfa_clustering::KMeans;
15#[cfg(feature = "prompting")]
16use ndarray::Array2;
17use serde::{Deserialize, Serialize};
18
19/// A task cluster identified by semantic similarity (SEAL-enhanced)
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TaskCluster {
22    /// Unique cluster identifier.
23    pub id: String,
24    /// LLM-generated semantic description of this cluster.
25    pub description: String,
26    /// Embedding vector of the cluster description.
27    pub embedding: Vec<f32>,
28    /// Prompting techniques mapped to this cluster (typically 3-4).
29    pub techniques: Vec<PromptingTechnique>,
30    /// Example task descriptions belonging to this cluster.
31    pub example_tasks: Vec<String>,
32
33    /// Example query cores from SEAL for tasks in this cluster.
34    pub seal_query_cores: Vec<String>,
35    /// Average SEAL quality score for tasks in this cluster.
36    pub avg_seal_quality: f32,
37    /// Recommended complexity level based on average SEAL quality.
38    pub recommended_complexity: ComplexityLevel,
39}
40
41impl TaskCluster {
42    /// Create a new task cluster
43    pub fn new(
44        id: String,
45        description: String,
46        embedding: Vec<f32>,
47        techniques: Vec<PromptingTechnique>,
48        example_tasks: Vec<String>,
49    ) -> Self {
50        Self {
51            id,
52            description,
53            embedding,
54            techniques,
55            example_tasks,
56            seal_query_cores: Vec::new(),
57            avg_seal_quality: 0.5,
58            recommended_complexity: ComplexityLevel::Moderate,
59        }
60    }
61
62    /// Update SEAL-related metrics
63    pub fn update_seal_metrics(&mut self, query_cores: Vec<String>, avg_quality: f32) {
64        self.seal_query_cores = query_cores;
65        self.avg_seal_quality = avg_quality;
66        self.recommended_complexity = if avg_quality < 0.5 {
67            ComplexityLevel::Simple
68        } else if avg_quality < 0.8 {
69            ComplexityLevel::Moderate
70        } else {
71            ComplexityLevel::Advanced
72        };
73    }
74}
75
76/// Manages task clustering
77pub struct TaskClusterManager {
78    clusters: Vec<TaskCluster>,
79    _embedding_dim: usize,
80}
81
82impl TaskClusterManager {
83    /// Create a new task cluster manager
84    pub fn new() -> Self {
85        Self {
86            clusters: Vec::new(),
87            _embedding_dim: 768, // Default for most embedding models
88        }
89    }
90
91    /// Create with specific embedding dimension
92    pub fn with_embedding_dim(embedding_dim: usize) -> Self {
93        Self {
94            clusters: Vec::new(),
95            _embedding_dim: embedding_dim,
96        }
97    }
98
99    /// Get all clusters
100    pub fn get_clusters(&self) -> &[TaskCluster] {
101        &self.clusters
102    }
103
104    /// Add a cluster
105    pub fn add_cluster(&mut self, cluster: TaskCluster) {
106        self.clusters.push(cluster);
107    }
108
109    /// Set clusters (replaces existing)
110    pub fn set_clusters(&mut self, clusters: Vec<TaskCluster>) {
111        self.clusters = clusters;
112    }
113
114    /// Find task cluster most similar to a task description (SEAL-enhanced)
115    ///
116    /// This is the core classification function that:
117    /// 1. Uses SEAL's resolved query if available (not original query)
118    /// 2. Prefers SEAL's query core for better semantic matching
119    /// 3. Boosts similarity if SEAL quality is high
120    ///
121    /// # Arguments
122    /// * `task_embedding` - Pre-computed embedding of the task
123    /// * `seal_result` - Optional SEAL processing result for enhancement
124    ///
125    /// # Returns
126    /// * Tuple of (cluster reference, similarity score)
127    pub fn find_matching_cluster(
128        &self,
129        task_embedding: &[f32],
130        seal_result: Option<&SealProcessingResult>,
131    ) -> Result<(&TaskCluster, f32)> {
132        if self.clusters.is_empty() {
133            return Err(anyhow!("No clusters available"));
134        }
135
136        let mut best_match = None;
137        let mut best_similarity = f32::NEG_INFINITY;
138
139        for cluster in &self.clusters {
140            let similarity = cosine_similarity(task_embedding, &cluster.embedding);
141
142            // Boost similarity if SEAL quality is high
143            let boosted_similarity = if let Some(seal) = seal_result {
144                if seal.quality_score > 0.7 {
145                    similarity * 1.1 // 10% boost for high-quality SEAL resolutions
146                } else {
147                    similarity
148                }
149            } else {
150                similarity
151            };
152
153            if boosted_similarity > best_similarity {
154                best_similarity = boosted_similarity;
155                best_match = Some(cluster);
156            }
157        }
158
159        let cluster = best_match.ok_or_else(|| anyhow!("No matching cluster found"))?;
160        Ok((cluster, best_similarity))
161    }
162
163    /// Build clusters from a set of task embeddings using k-means (requires prompting feature - linfa)
164    #[cfg(feature = "prompting")]
165    pub fn build_clusters_from_embeddings(
166        &mut self,
167        task_embeddings: Array2<f32>,
168        task_descriptions: Vec<String>,
169        min_clusters: usize,
170        max_clusters: usize,
171    ) -> Result<Vec<usize>> {
172        if task_embeddings.nrows() != task_descriptions.len() {
173            return Err(anyhow!(
174                "Embeddings and descriptions length mismatch: {} vs {}",
175                task_embeddings.nrows(),
176                task_descriptions.len()
177            ));
178        }
179
180        if task_embeddings.nrows() < min_clusters {
181            return Err(anyhow!(
182                "Not enough tasks ({}) for minimum clusters ({})",
183                task_embeddings.nrows(),
184                min_clusters
185            ));
186        }
187
188        // Find optimal K using silhouette scores
189        let optimal_k = self.find_optimal_k(&task_embeddings, min_clusters, max_clusters)?;
190
191        // Perform k-means clustering
192        let assignments = self.perform_kmeans(&task_embeddings, optimal_k)?;
193
194        // Build cluster objects
195        self.build_cluster_objects(
196            &task_embeddings,
197            &task_descriptions,
198            &assignments,
199            optimal_k,
200        )?;
201
202        Ok(assignments)
203    }
204
205    /// Find optimal number of clusters using silhouette scores
206    #[cfg(feature = "prompting")]
207    fn find_optimal_k(
208        &self,
209        embeddings: &Array2<f32>,
210        min_k: usize,
211        max_k: usize,
212    ) -> Result<usize> {
213        let mut best_k = min_k;
214        let mut best_score = f32::NEG_INFINITY;
215
216        let effective_max_k = max_k.min(embeddings.nrows() / 2);
217
218        for k in min_k..=effective_max_k {
219            let assignments = self.perform_kmeans(embeddings, k)?;
220            let score = self.compute_silhouette_score(embeddings, &assignments, k);
221
222            if score > best_score {
223                best_score = score;
224                best_k = k;
225            }
226        }
227
228        Ok(best_k)
229    }
230
231    /// Perform k-means clustering
232    #[cfg(feature = "prompting")]
233    fn perform_kmeans(&self, embeddings: &Array2<f32>, k: usize) -> Result<Vec<usize>> {
234        let dataset = DatasetBase::from(embeddings.clone());
235
236        let model = KMeans::params(k)
237            .max_n_iterations(100)
238            .tolerance(1e-4)
239            .fit(&dataset)
240            .context("K-means fitting failed")?;
241
242        let assignments: Vec<usize> = model.predict(embeddings).iter().copied().collect();
243
244        Ok(assignments)
245    }
246
247    /// Compute silhouette score for clustering quality
248    #[cfg(feature = "prompting")]
249    fn compute_silhouette_score(
250        &self,
251        embeddings: &Array2<f32>,
252        assignments: &[usize],
253        k: usize,
254    ) -> f32 {
255        let n = embeddings.nrows();
256        if n == 0 {
257            return 0.0;
258        }
259
260        let mut silhouette_sum = 0.0;
261        let mut count = 0;
262
263        for i in 0..n {
264            let cluster_i = assignments[i];
265
266            let mut a_i = 0.0;
267            let mut same_cluster_count = 0;
268            for (j, &assignment_j) in assignments.iter().enumerate().take(n) {
269                if i != j && assignment_j == cluster_i {
270                    a_i += euclidean_distance(
271                        &embeddings.row(i).to_vec(),
272                        &embeddings.row(j).to_vec(),
273                    );
274                    same_cluster_count += 1;
275                }
276            }
277            if same_cluster_count > 0 {
278                a_i /= same_cluster_count as f32;
279            }
280
281            let mut b_i = f32::INFINITY;
282            for other_cluster in 0..k {
283                if other_cluster == cluster_i {
284                    continue;
285                }
286
287                let mut dist_sum = 0.0;
288                let mut other_count = 0;
289                for (j, &assignment_j) in assignments.iter().enumerate().take(n) {
290                    if assignment_j == other_cluster {
291                        dist_sum += euclidean_distance(
292                            &embeddings.row(i).to_vec(),
293                            &embeddings.row(j).to_vec(),
294                        );
295                        other_count += 1;
296                    }
297                }
298                if other_count > 0 {
299                    let avg_dist = dist_sum / other_count as f32;
300                    b_i = b_i.min(avg_dist);
301                }
302            }
303
304            if b_i.is_finite() && a_i > 0.0 {
305                let s_i = (b_i - a_i) / a_i.max(b_i);
306                silhouette_sum += s_i;
307                count += 1;
308            }
309        }
310
311        if count > 0 {
312            silhouette_sum / count as f32
313        } else {
314            0.0
315        }
316    }
317
318    /// Build cluster objects from assignments
319    #[cfg(feature = "prompting")]
320    fn build_cluster_objects(
321        &mut self,
322        embeddings: &Array2<f32>,
323        descriptions: &[String],
324        assignments: &[usize],
325        k: usize,
326    ) -> Result<()> {
327        let mut clusters = Vec::new();
328
329        for cluster_id in 0..k {
330            let mut cluster_tasks = Vec::new();
331            let mut cluster_embeddings = Vec::new();
332
333            for (i, &assignment) in assignments.iter().enumerate() {
334                if assignment == cluster_id {
335                    cluster_tasks.push(descriptions[i].clone());
336                    cluster_embeddings.push(embeddings.row(i).to_vec());
337                }
338            }
339
340            if cluster_tasks.is_empty() {
341                continue;
342            }
343
344            let centroid = compute_centroid(&cluster_embeddings);
345
346            let cluster = TaskCluster::new(
347                format!("cluster_{}", cluster_id),
348                format!("Cluster {}", cluster_id),
349                centroid,
350                Vec::new(),
351                cluster_tasks.iter().take(5).cloned().collect(),
352            );
353
354            clusters.push(cluster);
355        }
356
357        self.clusters = clusters;
358        Ok(())
359    }
360
361    /// Get cluster count
362    pub fn cluster_count(&self) -> usize {
363        self.clusters.len()
364    }
365
366    /// Get cluster by ID
367    pub fn get_cluster_by_id(&self, id: &str) -> Option<&TaskCluster> {
368        self.clusters.iter().find(|c| c.id == id)
369    }
370
371    /// Get mutable cluster by ID
372    pub fn get_cluster_by_id_mut(&mut self, id: &str) -> Option<&mut TaskCluster> {
373        self.clusters.iter_mut().find(|c| c.id == id)
374    }
375}
376
377impl Default for TaskClusterManager {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383/// Compute cosine similarity between two vectors
384pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
385    if a.len() != b.len() {
386        return 0.0;
387    }
388
389    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
390    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
391    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
392
393    if norm_a == 0.0 || norm_b == 0.0 {
394        return 0.0;
395    }
396
397    dot / (norm_a * norm_b)
398}
399
400/// Compute Euclidean distance between two vectors
401#[allow(dead_code)]
402fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
403    if a.len() != b.len() {
404        return f32::INFINITY;
405    }
406
407    a.iter()
408        .zip(b)
409        .map(|(x, y)| (x - y).powi(2))
410        .sum::<f32>()
411        .sqrt()
412}
413
414/// Compute centroid of a set of embeddings
415#[allow(dead_code)]
416fn compute_centroid(embeddings: &[Vec<f32>]) -> Vec<f32> {
417    if embeddings.is_empty() {
418        return Vec::new();
419    }
420
421    let dim = embeddings[0].len();
422    let mut centroid = vec![0.0; dim];
423
424    for embedding in embeddings {
425        for (i, &val) in embedding.iter().enumerate() {
426            centroid[i] += val;
427        }
428    }
429
430    let n = embeddings.len() as f32;
431    for val in &mut centroid {
432        *val /= n;
433    }
434
435    centroid
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_cosine_similarity() {
444        let a = vec![1.0, 0.0, 0.0];
445        let b = vec![1.0, 0.0, 0.0];
446        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
447
448        let c = vec![1.0, 0.0, 0.0];
449        let d = vec![0.0, 1.0, 0.0];
450        assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
451    }
452
453    #[test]
454    fn test_euclidean_distance() {
455        let a = vec![0.0, 0.0];
456        let b = vec![3.0, 4.0];
457        assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
458    }
459
460    #[test]
461    fn test_compute_centroid() {
462        let embeddings = vec![
463            vec![1.0, 2.0, 3.0],
464            vec![4.0, 5.0, 6.0],
465            vec![7.0, 8.0, 9.0],
466        ];
467        let centroid = compute_centroid(&embeddings);
468        assert_eq!(centroid, vec![4.0, 5.0, 6.0]);
469    }
470
471    #[test]
472    fn test_cluster_manager_basic() {
473        let mut manager = TaskClusterManager::new();
474        assert_eq!(manager.cluster_count(), 0);
475
476        let cluster = TaskCluster::new(
477            "test_cluster".to_string(),
478            "Test cluster".to_string(),
479            vec![0.1, 0.2, 0.3],
480            Vec::new(),
481            vec!["task1".to_string()],
482        );
483
484        manager.add_cluster(cluster);
485        assert_eq!(manager.cluster_count(), 1);
486    }
487}