use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::entities::ClusterId;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ClusteringMethod {
HDBSCAN,
KMeans {
k: usize,
},
Spectral {
n_clusters: usize,
},
Agglomerative {
n_clusters: usize,
linkage: LinkageMethod,
},
}
impl Default for ClusteringMethod {
fn default() -> Self {
Self::HDBSCAN
}
}
impl std::fmt::Display for ClusteringMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClusteringMethod::HDBSCAN => write!(f, "HDBSCAN"),
ClusteringMethod::KMeans { k } => write!(f, "K-Means (k={})", k),
ClusteringMethod::Spectral { n_clusters } => {
write!(f, "Spectral (n={})", n_clusters)
}
ClusteringMethod::Agglomerative { n_clusters, linkage } => {
write!(f, "Agglomerative (n={}, {:?})", n_clusters, linkage)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkageMethod {
Ward,
Complete,
Average,
Single,
}
impl Default for LinkageMethod {
fn default() -> Self {
Self::Ward
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistanceMetric {
Euclidean,
Cosine,
Manhattan,
Poincare,
}
impl Default for DistanceMetric {
fn default() -> Self {
Self::Cosine
}
}
impl std::fmt::Display for DistanceMetric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DistanceMetric::Euclidean => write!(f, "Euclidean"),
DistanceMetric::Cosine => write!(f, "Cosine"),
DistanceMetric::Manhattan => write!(f, "Manhattan"),
DistanceMetric::Poincare => write!(f, "Poincare"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringParameters {
pub min_cluster_size: usize,
pub min_samples: usize,
pub epsilon: Option<f32>,
pub metric: DistanceMetric,
pub max_clusters: Option<usize>,
pub allow_single_cluster: bool,
}
impl Default for ClusteringParameters {
fn default() -> Self {
Self {
min_cluster_size: 5,
min_samples: 3,
epsilon: None,
metric: DistanceMetric::Cosine,
max_clusters: None,
allow_single_cluster: false,
}
}
}
impl ClusteringParameters {
#[must_use]
pub fn hdbscan(min_cluster_size: usize, min_samples: usize) -> Self {
Self {
min_cluster_size,
min_samples,
..Default::default()
}
}
#[must_use]
pub fn kmeans() -> Self {
Self {
min_cluster_size: 1,
min_samples: 1,
allow_single_cluster: true,
..Default::default()
}
}
#[must_use]
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
#[must_use]
pub fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = Some(epsilon);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringConfig {
pub method: ClusteringMethod,
pub parameters: ClusteringParameters,
pub compute_prototypes: bool,
pub prototypes_per_cluster: usize,
pub compute_silhouette: bool,
pub random_seed: Option<u64>,
}
impl Default for ClusteringConfig {
fn default() -> Self {
Self {
method: ClusteringMethod::HDBSCAN,
parameters: ClusteringParameters::default(),
compute_prototypes: true,
prototypes_per_cluster: 3,
compute_silhouette: true,
random_seed: None,
}
}
}
impl ClusteringConfig {
#[must_use]
pub fn hdbscan(min_cluster_size: usize, min_samples: usize) -> Self {
Self {
method: ClusteringMethod::HDBSCAN,
parameters: ClusteringParameters::hdbscan(min_cluster_size, min_samples),
..Default::default()
}
}
#[must_use]
pub fn kmeans(k: usize) -> Self {
Self {
method: ClusteringMethod::KMeans { k },
parameters: ClusteringParameters::kmeans(),
..Default::default()
}
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.random_seed = Some(seed);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MotifConfig {
pub min_length: usize,
pub max_length: usize,
pub min_occurrences: usize,
pub min_confidence: f32,
pub allow_overlap: bool,
pub max_gap: usize,
}
impl Default for MotifConfig {
fn default() -> Self {
Self {
min_length: 2,
max_length: 10,
min_occurrences: 3,
min_confidence: 0.5,
allow_overlap: false,
max_gap: 0,
}
}
}
impl MotifConfig {
#[must_use]
pub fn strict() -> Self {
Self {
min_length: 3,
max_length: 8,
min_occurrences: 5,
min_confidence: 0.7,
allow_overlap: false,
max_gap: 0,
}
}
#[must_use]
pub fn relaxed() -> Self {
Self {
min_length: 2,
max_length: 15,
min_occurrences: 2,
min_confidence: 0.3,
allow_overlap: true,
max_gap: 2,
}
}
#[must_use]
pub fn with_length_range(mut self, min: usize, max: usize) -> Self {
self.min_length = min;
self.max_length = max;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SequenceMetrics {
pub entropy: f32,
pub normalized_entropy: f32,
pub stereotypy: f32,
pub unique_clusters: usize,
pub unique_transitions: usize,
pub total_transitions: usize,
pub dominant_transition: Option<(ClusterId, ClusterId, f32)>,
pub repetition_rate: f32,
}
impl Default for SequenceMetrics {
fn default() -> Self {
Self {
entropy: 0.0,
normalized_entropy: 0.0,
stereotypy: 1.0,
unique_clusters: 0,
unique_transitions: 0,
total_transitions: 0,
dominant_transition: None,
repetition_rate: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransitionMatrix {
pub cluster_ids: Vec<ClusterId>,
pub probabilities: Vec<Vec<f32>>,
pub observations: Vec<Vec<u32>>,
#[serde(skip)]
index_map: HashMap<ClusterId, usize>,
}
impl TransitionMatrix {
#[must_use]
pub fn new(cluster_ids: Vec<ClusterId>) -> Self {
let n = cluster_ids.len();
let index_map: HashMap<ClusterId, usize> = cluster_ids
.iter()
.enumerate()
.map(|(i, id)| (*id, i))
.collect();
Self {
cluster_ids,
probabilities: vec![vec![0.0; n]; n],
observations: vec![vec![0; n]; n],
index_map,
}
}
#[must_use]
pub fn size(&self) -> usize {
self.cluster_ids.len()
}
#[must_use]
pub fn index_of(&self, cluster_id: &ClusterId) -> Option<usize> {
self.index_map.get(cluster_id).copied()
}
pub fn record_transition(&mut self, from: &ClusterId, to: &ClusterId) {
if let (Some(i), Some(j)) = (self.index_of(from), self.index_of(to)) {
self.observations[i][j] += 1;
}
}
pub fn compute_probabilities(&mut self) {
for i in 0..self.size() {
let row_sum: u32 = self.observations[i].iter().sum();
if row_sum > 0 {
for j in 0..self.size() {
self.probabilities[i][j] = self.observations[i][j] as f32 / row_sum as f32;
}
}
}
}
#[must_use]
pub fn probability(&self, from: &ClusterId, to: &ClusterId) -> Option<f32> {
match (self.index_of(from), self.index_of(to)) {
(Some(i), Some(j)) => Some(self.probabilities[i][j]),
_ => None,
}
}
#[must_use]
pub fn observation_count(&self, from: &ClusterId, to: &ClusterId) -> Option<u32> {
match (self.index_of(from), self.index_of(to)) {
(Some(i), Some(j)) => Some(self.observations[i][j]),
_ => None,
}
}
#[must_use]
pub fn non_zero_transitions(&self) -> Vec<(ClusterId, ClusterId, f32)> {
let mut transitions = Vec::new();
for (i, from) in self.cluster_ids.iter().enumerate() {
for (j, to) in self.cluster_ids.iter().enumerate() {
let prob = self.probabilities[i][j];
if prob > 0.0 {
transitions.push((*from, *to, prob));
}
}
}
transitions
}
#[must_use]
pub fn stationary_distribution(&self) -> Option<Vec<f32>> {
let n = self.size();
if n == 0 {
return None;
}
let mut dist = vec![1.0 / n as f32; n];
let max_iterations = 1000;
let tolerance = 1e-8;
for _ in 0..max_iterations {
let mut new_dist = vec![0.0; n];
for j in 0..n {
for i in 0..n {
new_dist[j] += dist[i] * self.probabilities[i][j];
}
}
let diff: f32 = dist
.iter()
.zip(new_dist.iter())
.map(|(a, b)| (a - b).abs())
.sum();
dist = new_dist;
if diff < tolerance {
return Some(dist);
}
}
Some(dist)
}
pub fn rebuild_index_map(&mut self) {
self.index_map = self
.cluster_ids
.iter()
.enumerate()
.map(|(i, id)| (*id, i))
.collect();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringResult {
pub clusters: Vec<super::entities::Cluster>,
pub noise: Vec<super::entities::EmbeddingId>,
pub silhouette_score: Option<f32>,
pub v_measure: Option<f32>,
pub prototypes: Vec<super::entities::Prototype>,
pub parameters: ClusteringParameters,
pub method: ClusteringMethod,
}
impl ClusteringResult {
#[must_use]
pub fn cluster_count(&self) -> usize {
self.clusters.len()
}
#[must_use]
pub fn noise_rate(&self) -> f32 {
let total = self
.clusters
.iter()
.map(|c| c.member_count())
.sum::<usize>()
+ self.noise.len();
if total == 0 {
0.0
} else {
self.noise.len() as f32 / total as f32
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clustering_config_creation() {
let config = ClusteringConfig::hdbscan(10, 5);
assert!(matches!(config.method, ClusteringMethod::HDBSCAN));
assert_eq!(config.parameters.min_cluster_size, 10);
assert_eq!(config.parameters.min_samples, 5);
}
#[test]
fn test_transition_matrix() {
let c1 = ClusterId::new();
let c2 = ClusterId::new();
let c3 = ClusterId::new();
let mut matrix = TransitionMatrix::new(vec![c1, c2, c3]);
matrix.record_transition(&c1, &c2);
matrix.record_transition(&c1, &c2);
matrix.record_transition(&c1, &c3);
matrix.record_transition(&c2, &c1);
matrix.compute_probabilities();
assert!((matrix.probability(&c1, &c2).unwrap() - 2.0 / 3.0).abs() < 0.001);
assert!((matrix.probability(&c1, &c3).unwrap() - 1.0 / 3.0).abs() < 0.001);
assert!((matrix.probability(&c2, &c1).unwrap() - 1.0).abs() < 0.001);
}
#[test]
fn test_motif_config() {
let config = MotifConfig::strict();
assert_eq!(config.min_length, 3);
assert_eq!(config.min_occurrences, 5);
assert!(!config.allow_overlap);
let relaxed = MotifConfig::relaxed();
assert!(relaxed.allow_overlap);
assert_eq!(relaxed.max_gap, 2);
}
#[test]
fn test_distance_metric_display() {
assert_eq!(format!("{}", DistanceMetric::Cosine), "Cosine");
assert_eq!(format!("{}", DistanceMetric::Euclidean), "Euclidean");
}
}