1use super::techniques::{ComplexityLevel, PromptingTechnique};
7use crate::seal::SealProcessingResult;
8#[cfg(feature = "prompting")]
9use anyhow::Context as _;
10use anyhow::{Result, anyhow};
11#[cfg(feature = "prompting")]
12use linfa::prelude::*;
13#[cfg(feature = "prompting")]
14use linfa_clustering::KMeans;
15#[cfg(feature = "prompting")]
16use ndarray::Array2;
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TaskCluster {
22 pub id: String,
24 pub description: String,
26 pub embedding: Vec<f32>,
28 pub techniques: Vec<PromptingTechnique>,
30 pub example_tasks: Vec<String>,
32
33 pub seal_query_cores: Vec<String>,
35 pub avg_seal_quality: f32,
37 pub recommended_complexity: ComplexityLevel,
39}
40
41impl TaskCluster {
42 pub fn new(
44 id: String,
45 description: String,
46 embedding: Vec<f32>,
47 techniques: Vec<PromptingTechnique>,
48 example_tasks: Vec<String>,
49 ) -> Self {
50 Self {
51 id,
52 description,
53 embedding,
54 techniques,
55 example_tasks,
56 seal_query_cores: Vec::new(),
57 avg_seal_quality: 0.5,
58 recommended_complexity: ComplexityLevel::Moderate,
59 }
60 }
61
62 pub fn update_seal_metrics(&mut self, query_cores: Vec<String>, avg_quality: f32) {
64 self.seal_query_cores = query_cores;
65 self.avg_seal_quality = avg_quality;
66 self.recommended_complexity = if avg_quality < 0.5 {
67 ComplexityLevel::Simple
68 } else if avg_quality < 0.8 {
69 ComplexityLevel::Moderate
70 } else {
71 ComplexityLevel::Advanced
72 };
73 }
74}
75
76pub struct TaskClusterManager {
78 clusters: Vec<TaskCluster>,
79 _embedding_dim: usize,
80}
81
82impl TaskClusterManager {
83 pub fn new() -> Self {
85 Self {
86 clusters: Vec::new(),
87 _embedding_dim: 768, }
89 }
90
91 pub fn with_embedding_dim(embedding_dim: usize) -> Self {
93 Self {
94 clusters: Vec::new(),
95 _embedding_dim: embedding_dim,
96 }
97 }
98
99 pub fn get_clusters(&self) -> &[TaskCluster] {
101 &self.clusters
102 }
103
104 pub fn add_cluster(&mut self, cluster: TaskCluster) {
106 self.clusters.push(cluster);
107 }
108
109 pub fn set_clusters(&mut self, clusters: Vec<TaskCluster>) {
111 self.clusters = clusters;
112 }
113
114 pub fn find_matching_cluster(
128 &self,
129 task_embedding: &[f32],
130 seal_result: Option<&SealProcessingResult>,
131 ) -> Result<(&TaskCluster, f32)> {
132 if self.clusters.is_empty() {
133 return Err(anyhow!("No clusters available"));
134 }
135
136 let mut best_match = None;
137 let mut best_similarity = f32::NEG_INFINITY;
138
139 for cluster in &self.clusters {
140 let similarity = cosine_similarity(task_embedding, &cluster.embedding);
141
142 let boosted_similarity = if let Some(seal) = seal_result {
144 if seal.quality_score > 0.7 {
145 similarity * 1.1 } else {
147 similarity
148 }
149 } else {
150 similarity
151 };
152
153 if boosted_similarity > best_similarity {
154 best_similarity = boosted_similarity;
155 best_match = Some(cluster);
156 }
157 }
158
159 let cluster = best_match.ok_or_else(|| anyhow!("No matching cluster found"))?;
160 Ok((cluster, best_similarity))
161 }
162
163 #[cfg(feature = "prompting")]
165 pub fn build_clusters_from_embeddings(
166 &mut self,
167 task_embeddings: Array2<f32>,
168 task_descriptions: Vec<String>,
169 min_clusters: usize,
170 max_clusters: usize,
171 ) -> Result<Vec<usize>> {
172 if task_embeddings.nrows() != task_descriptions.len() {
173 return Err(anyhow!(
174 "Embeddings and descriptions length mismatch: {} vs {}",
175 task_embeddings.nrows(),
176 task_descriptions.len()
177 ));
178 }
179
180 if task_embeddings.nrows() < min_clusters {
181 return Err(anyhow!(
182 "Not enough tasks ({}) for minimum clusters ({})",
183 task_embeddings.nrows(),
184 min_clusters
185 ));
186 }
187
188 let optimal_k = self.find_optimal_k(&task_embeddings, min_clusters, max_clusters)?;
190
191 let assignments = self.perform_kmeans(&task_embeddings, optimal_k)?;
193
194 self.build_cluster_objects(
196 &task_embeddings,
197 &task_descriptions,
198 &assignments,
199 optimal_k,
200 )?;
201
202 Ok(assignments)
203 }
204
205 #[cfg(feature = "prompting")]
207 fn find_optimal_k(
208 &self,
209 embeddings: &Array2<f32>,
210 min_k: usize,
211 max_k: usize,
212 ) -> Result<usize> {
213 let mut best_k = min_k;
214 let mut best_score = f32::NEG_INFINITY;
215
216 let effective_max_k = max_k.min(embeddings.nrows() / 2);
217
218 for k in min_k..=effective_max_k {
219 let assignments = self.perform_kmeans(embeddings, k)?;
220 let score = self.compute_silhouette_score(embeddings, &assignments, k);
221
222 if score > best_score {
223 best_score = score;
224 best_k = k;
225 }
226 }
227
228 Ok(best_k)
229 }
230
231 #[cfg(feature = "prompting")]
233 fn perform_kmeans(&self, embeddings: &Array2<f32>, k: usize) -> Result<Vec<usize>> {
234 let dataset = DatasetBase::from(embeddings.clone());
235
236 let model = KMeans::params(k)
237 .max_n_iterations(100)
238 .tolerance(1e-4)
239 .fit(&dataset)
240 .context("K-means fitting failed")?;
241
242 let assignments: Vec<usize> = model.predict(embeddings).iter().copied().collect();
243
244 Ok(assignments)
245 }
246
247 #[cfg(feature = "prompting")]
249 fn compute_silhouette_score(
250 &self,
251 embeddings: &Array2<f32>,
252 assignments: &[usize],
253 k: usize,
254 ) -> f32 {
255 let n = embeddings.nrows();
256 if n == 0 {
257 return 0.0;
258 }
259
260 let mut silhouette_sum = 0.0;
261 let mut count = 0;
262
263 for i in 0..n {
264 let cluster_i = assignments[i];
265
266 let mut a_i = 0.0;
267 let mut same_cluster_count = 0;
268 for (j, &assignment_j) in assignments.iter().enumerate().take(n) {
269 if i != j && assignment_j == cluster_i {
270 a_i += euclidean_distance(
271 &embeddings.row(i).to_vec(),
272 &embeddings.row(j).to_vec(),
273 );
274 same_cluster_count += 1;
275 }
276 }
277 if same_cluster_count > 0 {
278 a_i /= same_cluster_count as f32;
279 }
280
281 let mut b_i = f32::INFINITY;
282 for other_cluster in 0..k {
283 if other_cluster == cluster_i {
284 continue;
285 }
286
287 let mut dist_sum = 0.0;
288 let mut other_count = 0;
289 for (j, &assignment_j) in assignments.iter().enumerate().take(n) {
290 if assignment_j == other_cluster {
291 dist_sum += euclidean_distance(
292 &embeddings.row(i).to_vec(),
293 &embeddings.row(j).to_vec(),
294 );
295 other_count += 1;
296 }
297 }
298 if other_count > 0 {
299 let avg_dist = dist_sum / other_count as f32;
300 b_i = b_i.min(avg_dist);
301 }
302 }
303
304 if b_i.is_finite() && a_i > 0.0 {
305 let s_i = (b_i - a_i) / a_i.max(b_i);
306 silhouette_sum += s_i;
307 count += 1;
308 }
309 }
310
311 if count > 0 {
312 silhouette_sum / count as f32
313 } else {
314 0.0
315 }
316 }
317
318 #[cfg(feature = "prompting")]
320 fn build_cluster_objects(
321 &mut self,
322 embeddings: &Array2<f32>,
323 descriptions: &[String],
324 assignments: &[usize],
325 k: usize,
326 ) -> Result<()> {
327 let mut clusters = Vec::new();
328
329 for cluster_id in 0..k {
330 let mut cluster_tasks = Vec::new();
331 let mut cluster_embeddings = Vec::new();
332
333 for (i, &assignment) in assignments.iter().enumerate() {
334 if assignment == cluster_id {
335 cluster_tasks.push(descriptions[i].clone());
336 cluster_embeddings.push(embeddings.row(i).to_vec());
337 }
338 }
339
340 if cluster_tasks.is_empty() {
341 continue;
342 }
343
344 let centroid = compute_centroid(&cluster_embeddings);
345
346 let cluster = TaskCluster::new(
347 format!("cluster_{}", cluster_id),
348 format!("Cluster {}", cluster_id),
349 centroid,
350 Vec::new(),
351 cluster_tasks.iter().take(5).cloned().collect(),
352 );
353
354 clusters.push(cluster);
355 }
356
357 self.clusters = clusters;
358 Ok(())
359 }
360
361 pub fn cluster_count(&self) -> usize {
363 self.clusters.len()
364 }
365
366 pub fn get_cluster_by_id(&self, id: &str) -> Option<&TaskCluster> {
368 self.clusters.iter().find(|c| c.id == id)
369 }
370
371 pub fn get_cluster_by_id_mut(&mut self, id: &str) -> Option<&mut TaskCluster> {
373 self.clusters.iter_mut().find(|c| c.id == id)
374 }
375}
376
377impl Default for TaskClusterManager {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
385 if a.len() != b.len() {
386 return 0.0;
387 }
388
389 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
390 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
391 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
392
393 if norm_a == 0.0 || norm_b == 0.0 {
394 return 0.0;
395 }
396
397 dot / (norm_a * norm_b)
398}
399
400#[allow(dead_code)]
402fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
403 if a.len() != b.len() {
404 return f32::INFINITY;
405 }
406
407 a.iter()
408 .zip(b)
409 .map(|(x, y)| (x - y).powi(2))
410 .sum::<f32>()
411 .sqrt()
412}
413
414#[allow(dead_code)]
416fn compute_centroid(embeddings: &[Vec<f32>]) -> Vec<f32> {
417 if embeddings.is_empty() {
418 return Vec::new();
419 }
420
421 let dim = embeddings[0].len();
422 let mut centroid = vec![0.0; dim];
423
424 for embedding in embeddings {
425 for (i, &val) in embedding.iter().enumerate() {
426 centroid[i] += val;
427 }
428 }
429
430 let n = embeddings.len() as f32;
431 for val in &mut centroid {
432 *val /= n;
433 }
434
435 centroid
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_cosine_similarity() {
444 let a = vec![1.0, 0.0, 0.0];
445 let b = vec![1.0, 0.0, 0.0];
446 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
447
448 let c = vec![1.0, 0.0, 0.0];
449 let d = vec![0.0, 1.0, 0.0];
450 assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
451 }
452
453 #[test]
454 fn test_euclidean_distance() {
455 let a = vec![0.0, 0.0];
456 let b = vec![3.0, 4.0];
457 assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
458 }
459
460 #[test]
461 fn test_compute_centroid() {
462 let embeddings = vec![
463 vec![1.0, 2.0, 3.0],
464 vec![4.0, 5.0, 6.0],
465 vec![7.0, 8.0, 9.0],
466 ];
467 let centroid = compute_centroid(&embeddings);
468 assert_eq!(centroid, vec![4.0, 5.0, 6.0]);
469 }
470
471 #[test]
472 fn test_cluster_manager_basic() {
473 let mut manager = TaskClusterManager::new();
474 assert_eq!(manager.cluster_count(), 0);
475
476 let cluster = TaskCluster::new(
477 "test_cluster".to_string(),
478 "Test cluster".to_string(),
479 vec![0.1, 0.2, 0.3],
480 Vec::new(),
481 vec!["task1".to_string()],
482 );
483
484 manager.add_cluster(cluster);
485 assert_eq!(manager.cluster_count(), 1);
486 }
487}