use crate::episode::Episode;
use crate::pattern::Pattern;
use crate::patterns::clustering::{ClusterCentroid, ClusteringConfig, EpisodeCluster};
use crate::types::TaskContext;
use std::collections::HashMap;
pub struct PatternClusterer {
config: ClusteringConfig,
}
impl PatternClusterer {
#[must_use]
pub fn new() -> Self {
Self {
config: ClusteringConfig::default(),
}
}
#[must_use]
pub fn with_config(config: ClusteringConfig) -> Self {
Self { config }
}
#[must_use]
pub fn deduplicate_patterns(&self, patterns: Vec<Pattern>) -> Vec<Pattern> {
if patterns.is_empty() {
return Vec::new();
}
let mut deduplicated: Vec<Pattern> = Vec::new();
for pattern in patterns {
let mut merged = false;
for existing in &mut deduplicated {
let similarity = pattern.similarity_score(existing);
if similarity >= self.config.deduplication_threshold {
existing.merge_with(&pattern);
merged = true;
break;
}
}
if !merged {
deduplicated.push(pattern);
}
}
deduplicated
.into_iter()
.filter(|p| p.confidence() >= self.config.min_confidence)
.collect()
}
#[must_use]
pub fn group_by_similarity_key(&self, patterns: Vec<Pattern>) -> HashMap<String, Vec<Pattern>> {
let mut groups: HashMap<String, Vec<Pattern>> = HashMap::new();
for pattern in patterns {
let key = pattern.similarity_key();
groups.entry(key).or_default().push(pattern);
}
groups
}
#[must_use]
pub fn cluster_episodes(&self, episodes: Vec<Episode>) -> Vec<EpisodeCluster> {
if episodes.is_empty() {
return Vec::new();
}
let k = if self.config.num_clusters > 0 {
self.config.num_clusters
} else {
((episodes.len() as f32 / 2.0).sqrt().ceil() as usize).max(1)
};
let mut clusters = self.initialize_clusters(&episodes, k);
for _iteration in 0..self.config.max_iterations {
let mut changed = false;
let mut new_assignments: Vec<Vec<Episode>> = vec![Vec::new(); k];
for episode in &episodes {
let nearest_cluster = self.find_nearest_cluster(episode, &clusters);
new_assignments[nearest_cluster].push(episode.clone());
}
for (i, cluster) in clusters.iter_mut().enumerate() {
if !new_assignments[i].is_empty() {
let new_centroid = self.calculate_centroid(&new_assignments[i]);
#[allow(clippy::excessive_nesting)]
if !self.centroids_equal(&cluster.centroid, &new_centroid) {
cluster.centroid = new_centroid;
cluster.episodes = new_assignments[i].clone();
changed = true;
}
}
}
if !changed {
break;
}
}
clusters
.into_iter()
.filter(|c| !c.episodes.is_empty())
.collect()
}
#[must_use]
pub fn find_similar_patterns(
&self,
target: &Pattern,
candidates: &[Pattern],
limit: usize,
) -> Vec<(Pattern, f32)> {
let mut similarities: Vec<(Pattern, f32)> = candidates
.iter()
.map(|p| (p.clone(), target.similarity_score(p)))
.filter(|(_, score)| *score > 0.0)
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.into_iter().take(limit).collect()
}
fn initialize_clusters(&self, episodes: &[Episode], k: usize) -> Vec<EpisodeCluster> {
let mut clusters = Vec::new();
let step = episodes.len() / k.max(1);
for i in 0..k {
let idx = (i * step).min(episodes.len() - 1);
let episode = &episodes[idx];
clusters.push(EpisodeCluster {
centroid: ClusterCentroid::from_episode(episode),
episodes: vec![episode.clone()],
});
}
clusters
}
fn find_nearest_cluster(&self, episode: &Episode, clusters: &[EpisodeCluster]) -> usize {
let mut min_distance = f32::MAX;
let mut nearest = 0;
for (i, cluster) in clusters.iter().enumerate() {
let distance = self.episode_distance(episode, &cluster.centroid);
if distance < min_distance {
min_distance = distance;
nearest = i;
}
}
nearest
}
fn episode_distance(&self, episode: &Episode, centroid: &ClusterCentroid) -> f32 {
let context_dist = 1.0 - self.context_distance(&episode.context, ¢roid.context);
let steps_dist =
(episode.steps.len() as f32 - centroid.avg_steps).abs() / centroid.avg_steps.max(1.0);
let outcome_dist = if episode.outcome.is_some() == centroid.has_outcome {
0.0
} else {
1.0
};
context_dist * 0.5 + steps_dist * 0.3 + outcome_dist * 0.2
}
fn context_distance(&self, ctx1: &TaskContext, ctx2: &TaskContext) -> f32 {
let mut distance = 0.0;
if ctx1.domain != ctx2.domain {
distance += 0.4;
}
if ctx1.language != ctx2.language {
distance += 0.3;
}
let common_tags: Vec<_> = ctx1.tags.iter().filter(|t| ctx2.tags.contains(t)).collect();
let total_unique = ctx1
.tags
.iter()
.chain(ctx2.tags.iter())
.collect::<std::collections::HashSet<_>>()
.len();
let tag_similarity = if total_unique > 0 {
common_tags.len() as f32 / total_unique as f32
} else {
1.0
};
distance += (1.0 - tag_similarity) * 0.3;
distance
}
fn calculate_centroid(&self, episodes: &[Episode]) -> ClusterCentroid {
if episodes.is_empty() {
return ClusterCentroid::default();
}
let representative_context = episodes[0].context.clone();
let avg_steps =
episodes.iter().map(|e| e.steps.len()).sum::<usize>() as f32 / episodes.len() as f32;
let outcome_count = episodes.iter().filter(|e| e.outcome.is_some()).count();
let has_outcome = outcome_count > episodes.len() / 2;
ClusterCentroid {
context: representative_context,
avg_steps,
has_outcome,
}
}
fn centroids_equal(&self, c1: &ClusterCentroid, c2: &ClusterCentroid) -> bool {
c1.context.domain == c2.context.domain
&& (c1.avg_steps - c2.avg_steps).abs() < 0.5
&& c1.has_outcome == c2.has_outcome
}
}
impl Default for PatternClusterer {
fn default() -> Self {
Self::new()
}
}