use crate::error::{ClusterError, ClusterResult};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
use std::sync::Arc;
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct MemoryEfficientConfig {
pub chunk_size: usize,
pub parallel: bool,
pub memory_limit_mb: Option<usize>,
}
impl Default for MemoryEfficientConfig {
fn default() -> Self {
Self {
chunk_size: 1000,
parallel: true,
memory_limit_mb: None,
}
}
}
pub struct ChunkedDataProcessor {
chunk_size: usize,
parallel: bool,
}
impl ChunkedDataProcessor {
pub fn new(chunk_size: usize) -> Self {
Self {
chunk_size,
parallel: true,
}
}
pub fn parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn process<F>(&self, data: &Tensor, mut f: F) -> ClusterResult<()>
where
F: FnMut(ArrayView2<f32>) -> ClusterResult<()>,
{
let shape = data.shape();
let n_samples = shape.dims()[0];
let n_features = shape.dims()[1];
let data_vec = data.to_vec()?;
let data_array = Array2::from_shape_vec((n_samples, n_features), data_vec)
.map_err(|e| ClusterError::InvalidInput(format!("Shape error: {}", e)))?;
for start_idx in (0..n_samples).step_by(self.chunk_size) {
let end_idx = (start_idx + self.chunk_size).min(n_samples);
let chunk = data_array.slice(s![start_idx..end_idx, ..]);
f(chunk)?;
}
Ok(())
}
pub fn process_parallel<F, R>(&self, data: &Tensor, f: F) -> ClusterResult<Vec<R>>
where
F: Fn(ArrayView2<f32>) -> ClusterResult<R> + Send + Sync,
R: Send,
{
let shape = data.shape();
let n_samples = shape.dims()[0];
let n_features = shape.dims()[1];
let data_vec = data.to_vec()?;
let data_array = Array2::from_shape_vec((n_samples, n_features), data_vec)
.map_err(|e| ClusterError::InvalidInput(format!("Shape error: {}", e)))?;
let data_arc = Arc::new(data_array);
let chunks: Vec<(usize, usize)> = (0..n_samples)
.step_by(self.chunk_size)
.map(|start| {
let end = (start + self.chunk_size).min(n_samples);
(start, end)
})
.collect();
if !self.parallel || chunks.len() <= 1 {
let results: Result<Vec<R>, ClusterError> = chunks
.iter()
.map(|(start, end)| {
let chunk = data_arc.slice(s![*start..*end, ..]);
f(chunk)
})
.collect();
return results;
}
let results: Result<Vec<R>, ClusterError> = chunks
.into_par_iter()
.map(|(start, end)| {
let chunk = data_arc.slice(s![start..end, ..]);
f(chunk)
})
.collect();
results
}
pub fn optimal_chunk_size(
n_samples: usize,
n_features: usize,
available_memory_mb: usize,
) -> usize {
let bytes_per_sample = n_features * std::mem::size_of::<f32>();
let available_bytes = available_memory_mb * 1024 * 1024;
let safe_bytes = (available_bytes as f64 * 0.8) as usize;
let chunk_size = safe_bytes / bytes_per_sample;
chunk_size.max(10).min(n_samples)
}
}
pub struct IncrementalCentroidUpdater {
centroids: Array2<f64>,
counts: Array1<usize>,
n_samples: usize,
}
impl IncrementalCentroidUpdater {
pub fn new(n_clusters: usize, n_features: usize) -> Self {
Self {
centroids: Array2::zeros((n_clusters, n_features)),
counts: Array1::zeros(n_clusters),
n_samples: 0,
}
}
pub fn initialize(&mut self, initial_centroids: ArrayView2<f64>) -> ClusterResult<()> {
let (n_clusters, n_features) = initial_centroids.dim();
if (n_clusters, n_features) != self.centroids.dim() {
return Err(ClusterError::InvalidInput(format!(
"Expected {} clusters and {} features, got {} and {}",
self.centroids.nrows(),
self.centroids.ncols(),
n_clusters,
n_features
)));
}
self.centroids.assign(&initial_centroids);
self.counts.fill(1); self.n_samples = n_clusters;
Ok(())
}
pub fn update_batch(
&mut self,
samples: ArrayView2<f64>,
labels: &[usize],
) -> ClusterResult<()> {
if samples.nrows() != labels.len() {
return Err(ClusterError::InvalidInput(format!(
"Sample count {} doesn't match label count {}",
samples.nrows(),
labels.len()
)));
}
for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
if label >= self.centroids.nrows() {
return Err(ClusterError::InvalidInput(format!(
"Label {} exceeds number of clusters {}",
label,
self.centroids.nrows()
)));
}
let count = self.counts[label];
let mut centroid = self.centroids.row_mut(label);
for (i, &value) in sample.iter().enumerate() {
centroid[i] += (value - centroid[i]) / (count + 1) as f64;
}
self.counts[label] += 1;
}
self.n_samples += samples.nrows();
Ok(())
}
pub fn centroids(&self) -> ArrayView2<'_, f64> {
self.centroids.view()
}
pub fn counts(&self) -> &Array1<usize> {
&self.counts
}
pub fn n_samples(&self) -> usize {
self.n_samples
}
}
pub fn estimate_memory_usage(n_samples: usize, n_features: usize, n_clusters: usize) -> f64 {
let data_size = n_samples * n_features * std::mem::size_of::<f32>();
let centroids_size = n_clusters * n_features * std::mem::size_of::<f64>();
let labels_size = n_samples * std::mem::size_of::<usize>();
let distances_size = n_samples * n_clusters * std::mem::size_of::<f32>();
let total_bytes = data_size + centroids_size + labels_size + distances_size;
total_bytes as f64 / (1024.0 * 1024.0)
}
pub fn suggest_clustering_strategy(
n_samples: usize,
n_features: usize,
available_memory_mb: usize,
) -> String {
let estimated_mb = estimate_memory_usage(n_samples, n_features, 10);
if estimated_mb < available_memory_mb as f64 * 0.5 {
format!(
"Standard clustering (estimated {:.2} MB, available {} MB)",
estimated_mb, available_memory_mb
)
} else if estimated_mb < available_memory_mb as f64 * 0.8 {
format!(
"Use parallel processing with caution (estimated {:.2} MB, available {} MB)",
estimated_mb, available_memory_mb
)
} else {
let chunk_size =
ChunkedDataProcessor::optimal_chunk_size(n_samples, n_features, available_memory_mb);
format!(
"Use chunked processing with chunk_size={} (estimated {:.2} MB exceeds available {} MB)",
chunk_size, estimated_mb, available_memory_mb
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_chunked_processor_basic() -> Result<(), Box<dyn std::error::Error>> {
let data = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[10, 10])?;
let processor = ChunkedDataProcessor::new(3);
let mut chunk_count = 0;
processor.process(&data, |chunk| {
chunk_count += 1;
assert!(chunk.nrows() <= 3);
Ok(())
})?;
assert_eq!(chunk_count, 4);
Ok(())
}
#[test]
fn test_chunked_processor_parallel() -> Result<(), Box<dyn std::error::Error>> {
let data = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[10, 10])?;
let processor = ChunkedDataProcessor::new(3).parallel(true);
let results = processor.process_parallel(&data, |chunk| Ok(chunk.nrows()))?;
assert_eq!(results.len(), 4);
assert_eq!(results.iter().sum::<usize>(), 10);
Ok(())
}
#[test]
fn test_optimal_chunk_size() {
let chunk_size = ChunkedDataProcessor::optimal_chunk_size(1000, 100, 100);
assert!(chunk_size > 0);
assert!(chunk_size <= 1000);
}
#[test]
fn test_incremental_centroid_updater() -> Result<(), Box<dyn std::error::Error>> {
let mut updater = IncrementalCentroidUpdater::new(2, 3);
let initial = Array2::from_shape_vec((2, 3), vec![0.0, 0.0, 0.0, 5.0, 5.0, 5.0])?;
updater.initialize(initial.view())?;
let samples = Array2::from_shape_vec((2, 3), vec![1.0, 1.0, 1.0, 6.0, 6.0, 6.0])?;
let labels = vec![0, 1];
updater.update_batch(samples.view(), &labels)?;
let centroids = updater.centroids();
assert_relative_eq!(centroids[[0, 0]], 0.5, epsilon = 1e-6);
assert_relative_eq!(centroids[[1, 0]], 5.5, epsilon = 1e-6);
assert_eq!(updater.n_samples(), 4);
Ok(())
}
#[test]
fn test_memory_estimation() {
let memory_mb = estimate_memory_usage(1000, 100, 10);
assert!(memory_mb > 0.4);
assert!(memory_mb < 0.5);
}
#[test]
fn test_suggest_clustering_strategy() {
let strategy = suggest_clustering_strategy(100, 10, 100);
assert!(strategy.contains("Standard"));
let strategy = suggest_clustering_strategy(1_000_000, 100, 10);
assert!(strategy.contains("chunked"));
}
}