use crate::core::{ClusterResult, FormicaXError, OHLCV};
pub trait ClusteringAlgorithm {
type Config;
fn new() -> Self
where
Self: Default;
fn with_config(config: Self::Config) -> Self;
fn fit(&mut self, data: &[OHLCV]) -> Result<ClusterResult, FormicaXError>;
fn predict(&self, data: &[OHLCV]) -> Result<Vec<usize>, FormicaXError>;
fn get_cluster_centers(&self) -> Option<Vec<Vec<f64>>>;
fn validate_config(&self, data: &[OHLCV]) -> Result<(), FormicaXError>;
fn algorithm_name(&self) -> &'static str;
fn supports_incremental(&self) -> bool {
false
}
fn update(&mut self, _data: &[OHLCV]) -> Result<ClusterResult, FormicaXError> {
Err(FormicaXError::Clustering(
crate::core::ClusteringError::AlgorithmError {
message: "Incremental updates not supported for this algorithm".to_string(),
},
))
}
}
pub trait ParallelClusteringAlgorithm: ClusteringAlgorithm {
fn set_parallel(&mut self, parallel: bool);
fn is_parallel(&self) -> bool;
fn set_num_threads(&mut self, num_threads: usize);
fn num_threads(&self) -> usize;
}
pub trait SimdClusteringAlgorithm: ClusteringAlgorithm {
fn set_simd(&mut self, simd: bool);
fn is_simd(&self) -> bool;
fn simd_instruction_set(&self) -> Option<&'static str>;
}
pub trait StreamingClusteringAlgorithm: ClusteringAlgorithm {
fn fit_streaming<I>(&mut self, data_stream: I) -> Result<ClusterResult, FormicaXError>
where
I: Iterator<Item = OHLCV>;
fn streaming_state(&self) -> Option<StreamingState>;
}
#[derive(Debug, Clone)]
pub struct StreamingState {
pub n_processed: usize,
pub n_clusters: usize,
pub is_stable: bool,
pub last_update: std::time::Instant,
}
pub trait EnsembleClusteringAlgorithm: ClusteringAlgorithm {
fn fit_ensemble(
&mut self,
data: &[OHLCV],
n_runs: usize,
) -> Result<Vec<ClusterResult>, FormicaXError>;
fn consensus(
results: &[ClusterResult],
method: ConsensusMethod,
) -> Result<ClusterResult, FormicaXError>;
}
#[derive(Debug, Clone, Copy)]
pub enum ConsensusMethod {
MajorityVote,
WeightedVote,
Hierarchical,
GraphBased,
}
pub trait ValidatedClusteringAlgorithm: ClusteringAlgorithm {
fn calculate_validation_metrics(
&self,
data: &[OHLCV],
result: &ClusterResult,
) -> Result<ValidationMetrics, FormicaXError>;
fn validate_quality(
&self,
result: &ClusterResult,
threshold: f64,
) -> Result<bool, FormicaXError>;
}
#[derive(Debug, Clone)]
pub struct ValidationMetrics {
pub silhouette_score: f64,
pub calinski_harabasz: f64,
pub davies_bouldin: f64,
pub dunn_index: f64,
pub adjusted_rand: Option<f64>,
pub normalized_mutual_info: Option<f64>,
pub homogeneity: Option<f64>,
pub completeness: Option<f64>,
pub v_measure: Option<f64>,
}
impl ValidationMetrics {
pub fn new(silhouette_score: f64) -> Self {
Self {
silhouette_score,
calinski_harabasz: 0.0,
davies_bouldin: 0.0,
dunn_index: 0.0,
adjusted_rand: None,
normalized_mutual_info: None,
homogeneity: None,
completeness: None,
v_measure: None,
}
}
pub fn overall_score(&self) -> f64 {
let silhouette_norm = (self.silhouette_score + 1.0) / 2.0;
let ch_norm = self.calinski_harabasz / (self.calinski_harabasz + 1.0);
let db_norm = 1.0 / (1.0 + self.davies_bouldin);
let dunn_norm = self.dunn_index / (self.dunn_index + 1.0);
(silhouette_norm + ch_norm + db_norm + dunn_norm) / 4.0
}
pub fn is_acceptable(&self, threshold: f64) -> bool {
self.overall_score() >= threshold
}
}
pub trait OptimizableClusteringAlgorithm: ClusteringAlgorithm {
type ParameterSpace;
fn optimize_parameters(
&self,
data: &[OHLCV],
parameter_space: Self::ParameterSpace,
cv_folds: usize,
) -> Result<Self::Config, FormicaXError>;
fn grid_search(
&self,
data: &[OHLCV],
parameter_space: Self::ParameterSpace,
) -> Result<Vec<(Self::Config, f64)>, FormicaXError>;
}
pub trait PersistableClusteringAlgorithm: ClusteringAlgorithm {
fn save_model(&self, path: &str) -> Result<(), FormicaXError>;
fn load_model(path: &str) -> Result<Self, FormicaXError>
where
Self: Sized;
fn to_bytes(&self) -> Result<Vec<u8>, FormicaXError>;
fn from_bytes(bytes: &[u8]) -> Result<Self, FormicaXError>
where
Self: Sized;
}
pub trait RealTimeClusteringAlgorithm: ClusteringAlgorithm {
fn process_realtime(&mut self, data_point: &OHLCV) -> Result<Option<usize>, FormicaXError>;
fn realtime_metrics(&self) -> RealTimeMetrics;
fn set_realtime_params(&mut self, params: RealTimeParams);
}
#[derive(Debug, Clone)]
pub struct RealTimeParams {
pub max_processing_time: std::time::Duration,
pub buffer_size: usize,
pub adaptive: bool,
}
#[derive(Debug, Clone)]
pub struct RealTimeMetrics {
pub avg_processing_time: std::time::Duration,
pub max_processing_time: std::time::Duration,
pub n_processed: usize,
pub buffer_utilization: f64,
}
impl Default for RealTimeParams {
fn default() -> Self {
Self {
max_processing_time: std::time::Duration::from_millis(1),
buffer_size: 1000,
adaptive: true,
}
}
}
impl Default for RealTimeMetrics {
fn default() -> Self {
Self {
avg_processing_time: std::time::Duration::ZERO,
max_processing_time: std::time::Duration::ZERO,
n_processed: 0,
buffer_utilization: 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_validation_metrics_creation() {
let metrics = ValidationMetrics::new(0.75);
assert_eq!(metrics.silhouette_score, 0.75);
assert_eq!(metrics.calinski_harabasz, 0.0);
assert_eq!(metrics.davies_bouldin, 0.0);
}
#[test]
fn test_validation_metrics_overall_score() {
let mut metrics = ValidationMetrics::new(0.5);
metrics.calinski_harabasz = 100.0;
metrics.davies_bouldin = 0.5;
metrics.dunn_index = 2.0;
let score = metrics.overall_score();
assert!((0.0..=1.0).contains(&score));
}
#[test]
fn test_validation_metrics_acceptability() {
let mut metrics = ValidationMetrics::new(0.8);
metrics.calinski_harabasz = 100.0;
metrics.davies_bouldin = 0.5;
metrics.dunn_index = 2.0;
assert!(metrics.is_acceptable(0.7));
assert!(!metrics.is_acceptable(0.9));
}
#[test]
fn test_realtime_params_default() {
let params = RealTimeParams::default();
assert_eq!(
params.max_processing_time,
std::time::Duration::from_millis(1)
);
assert_eq!(params.buffer_size, 1000);
assert!(params.adaptive);
}
#[test]
fn test_realtime_metrics_default() {
let metrics = RealTimeMetrics::default();
assert_eq!(metrics.avg_processing_time, std::time::Duration::ZERO);
assert_eq!(metrics.max_processing_time, std::time::Duration::ZERO);
assert_eq!(metrics.n_processed, 0);
assert_eq!(metrics.buffer_utilization, 0.0);
}
proptest! {
#[test]
fn test_validation_metrics_properties(
silhouette in -1.0..1.0f64,
calinski_harabasz in 0.0..1000.0f64,
davies_bouldin in 0.0..10.0f64,
dunn_index in 0.0..10.0f64
) {
let mut metrics = ValidationMetrics::new(silhouette);
metrics.calinski_harabasz = calinski_harabasz;
metrics.davies_bouldin = davies_bouldin;
metrics.dunn_index = dunn_index;
let score = metrics.overall_score();
assert!((0.0..=1.0).contains(&score));
assert!(metrics.is_acceptable(0.0));
assert!(!metrics.is_acceptable(1.1));
}
}
}