use crate::error::{ClusterError, ClusterResult};
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub memory_bytes: usize,
pub peak_memory_bytes: usize,
pub allocations: usize,
}
impl MemoryStats {
pub fn new() -> Self {
Self {
memory_bytes: 0,
peak_memory_bytes: 0,
allocations: 0,
}
}
pub fn update(&mut self, bytes: usize) {
self.memory_bytes = bytes;
self.peak_memory_bytes = self.peak_memory_bytes.max(bytes);
self.allocations += 1;
}
pub fn memory_mb(&self) -> f64 {
self.memory_bytes as f64 / (1024.0 * 1024.0)
}
pub fn peak_memory_mb(&self) -> f64 {
self.peak_memory_bytes as f64 / (1024.0 * 1024.0)
}
}
impl Default for MemoryStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct ProfilingContext {
start_time: Instant,
memory_stats: MemoryStats,
timings: HashMap<String, Duration>,
counts: HashMap<String, usize>,
}
impl ProfilingContext {
pub fn new() -> Self {
Self {
start_time: Instant::now(),
memory_stats: MemoryStats::new(),
timings: HashMap::new(),
counts: HashMap::new(),
}
}
pub fn start_operation(&mut self, name: &str) -> OperationTimer {
OperationTimer::new(name.to_string())
}
pub fn record_timing(&mut self, name: String, duration: Duration) {
*self.timings.entry(name).or_insert(Duration::ZERO) += duration;
}
pub fn increment_count(&mut self, name: &str) {
*self.counts.entry(name.to_string()).or_insert(0) += 1;
}
pub fn record_memory(&mut self, bytes: usize) {
self.memory_stats.update(bytes);
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
pub fn memory_stats(&self) -> &MemoryStats {
&self.memory_stats
}
pub fn get_timing(&self, name: &str) -> Option<Duration> {
self.timings.get(name).copied()
}
pub fn get_count(&self, name: &str) -> Option<usize> {
self.counts.get(name).copied()
}
pub fn all_timings(&self) -> &HashMap<String, Duration> {
&self.timings
}
pub fn all_counts(&self) -> &HashMap<String, usize> {
&self.counts
}
pub fn report(&self) -> ProfilingReport {
ProfilingReport {
total_time: self.elapsed(),
memory_stats: self.memory_stats.clone(),
timings: self.timings.clone(),
counts: self.counts.clone(),
}
}
}
impl Default for ProfilingContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct OperationTimer {
name: String,
start: Instant,
}
impl OperationTimer {
pub fn new(name: String) -> Self {
Self {
name,
start: Instant::now(),
}
}
pub fn stop(self) -> (String, Duration) {
(self.name, self.start.elapsed())
}
}
#[derive(Debug, Clone)]
pub struct ProfilingReport {
pub total_time: Duration,
pub memory_stats: MemoryStats,
pub timings: HashMap<String, Duration>,
pub counts: HashMap<String, usize>,
}
impl ProfilingReport {
pub fn print(&self) {
println!("\n=== Clustering Profiling Report ===");
println!("Total Time: {:.3}s", self.total_time.as_secs_f64());
println!("Peak Memory: {:.2} MB", self.memory_stats.peak_memory_mb());
println!("Current Memory: {:.2} MB", self.memory_stats.memory_mb());
println!("\nOperation Timings:");
for (name, duration) in &self.timings {
println!(" {}: {:.3}s", name, duration.as_secs_f64());
}
println!("\nOperation Counts:");
for (name, count) in &self.counts {
println!(" {}: {}", name, count);
}
println!("===================================\n");
}
pub fn to_json(&self) -> String {
let mut json = String::from("{\n");
json.push_str(&format!(
" \"total_time_s\": {},\n",
self.total_time.as_secs_f64()
));
json.push_str(&format!(
" \"peak_memory_mb\": {},\n",
self.memory_stats.peak_memory_mb()
));
json.push_str(&format!(
" \"current_memory_mb\": {},\n",
self.memory_stats.memory_mb()
));
json.push_str(" \"timings\": {\n");
for (i, (name, duration)) in self.timings.iter().enumerate() {
json.push_str(&format!(
" \"{}\": {}{}",
name,
duration.as_secs_f64(),
if i < self.timings.len() - 1 {
",\n"
} else {
"\n"
}
));
}
json.push_str(" },\n");
json.push_str(" \"counts\": {\n");
for (i, (name, count)) in self.counts.iter().enumerate() {
json.push_str(&format!(
" \"{}\": {}{}",
name,
count,
if i < self.counts.len() - 1 {
",\n"
} else {
"\n"
}
));
}
json.push_str(" }\n");
json.push_str("}\n");
json
}
}
pub fn estimate_memory_usage(n_samples: usize, n_features: usize, n_clusters: usize) -> usize {
let data_size = n_samples * n_features * std::mem::size_of::<f32>();
let centroids_size = n_clusters * n_features * std::mem::size_of::<f32>();
let labels_size = n_samples * std::mem::size_of::<i32>();
let distance_matrix_size = n_samples * n_clusters * std::mem::size_of::<f32>();
let total = data_size + centroids_size + labels_size + distance_matrix_size;
(total as f64 * 1.2) as usize
}
pub fn check_memory_feasibility(
n_samples: usize,
n_features: usize,
n_clusters: usize,
) -> ClusterResult<()> {
let estimated = estimate_memory_usage(n_samples, n_features, n_clusters);
const WARNING_THRESHOLD: usize = 8 * 1024 * 1024 * 1024;
if estimated > WARNING_THRESHOLD {
return Err(ClusterError::ConfigError(format!(
"Estimated memory usage ({:.2} GB) exceeds recommended threshold. \
Consider using mini-batch or incremental algorithms.",
estimated as f64 / (1024.0 * 1024.0 * 1024.0)
)));
}
Ok(())
}
pub fn suggest_algorithm(n_samples: usize, n_features: usize, n_clusters: usize) -> &'static str {
if n_samples > 100_000 {
return "MiniBatch K-Means (for memory efficiency)";
}
if n_features > 100 {
return "K-Means or Mini-Batch K-Means (efficient for high dimensions)";
}
if n_clusters > 50 {
return "Elkan's K-Means (optimized for large k)";
}
if n_samples < 1000 {
return "Any algorithm (dataset is small)";
}
"K-Means Lloyd (standard, good all-around performance)"
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_stats() {
let mut stats = MemoryStats::new();
assert_eq!(stats.memory_bytes, 0);
assert_eq!(stats.peak_memory_bytes, 0);
stats.update(1024 * 1024); assert_eq!(stats.memory_mb(), 1.0);
stats.update(2 * 1024 * 1024); assert_eq!(stats.peak_memory_mb(), 2.0);
stats.update(512 * 1024); assert_eq!(stats.peak_memory_mb(), 2.0); }
#[test]
fn test_profiling_context() {
let mut ctx = ProfilingContext::new();
ctx.increment_count("e_step");
ctx.increment_count("e_step");
ctx.increment_count("m_step");
assert_eq!(ctx.get_count("e_step"), Some(2));
assert_eq!(ctx.get_count("m_step"), Some(1));
ctx.record_memory(1024 * 1024);
assert_eq!(ctx.memory_stats().memory_mb(), 1.0);
}
#[test]
fn test_operation_timer() {
let timer = OperationTimer::new("test_op".to_string());
std::thread::sleep(std::time::Duration::from_millis(10));
let (name, duration) = timer.stop();
assert_eq!(name, "test_op");
assert!(duration.as_millis() >= 10);
}
#[test]
fn test_memory_estimation() {
let memory = estimate_memory_usage(1000, 10, 5);
assert!(memory > 70_000 && memory < 100_000);
}
#[test]
fn test_algorithm_suggestion() {
assert_eq!(
suggest_algorithm(100, 10, 5),
"Any algorithm (dataset is small)"
);
assert_eq!(
suggest_algorithm(200_000, 10, 5),
"MiniBatch K-Means (for memory efficiency)"
);
assert_eq!(
suggest_algorithm(5_000, 200, 5),
"K-Means or Mini-Batch K-Means (efficient for high dimensions)"
);
assert_eq!(
suggest_algorithm(5_000, 10, 100),
"Elkan's K-Means (optimized for large k)"
);
}
#[test]
fn test_profiling_report() {
let mut ctx = ProfilingContext::new();
ctx.increment_count("test");
ctx.record_memory(1024 * 1024);
let report = ctx.report();
assert!(report.total_time.as_secs_f64() >= 0.0);
assert_eq!(report.counts.get("test"), Some(&1));
let json = report.to_json();
assert!(json.contains("total_time_s"));
assert!(json.contains("peak_memory_mb"));
}
}