#[cfg(feature = "gpu")]
use cudarc::driver::CudaDevice;
use crate::{ClusteringError, Result};
use ndarray::ArrayView2;
#[cfg(feature = "gpu")]
pub struct GpuKMeans {
n_clusters: usize,
device: CudaDevice,
}
#[cfg(feature = "gpu")]
impl GpuKMeans {
pub fn new(n_clusters: usize) -> Result<Self> {
let device = CudaDevice::new(0)
.map_err(|e| ClusteringError::GpuError(format!("Failed to initialize CUDA: {}", e)))?;
Ok(Self { n_clusters, device })
}
pub fn fit(&self, data: &ArrayView2<f64>) -> Result<GpuClusteringResult> {
unimplemented!("GpuKMeans::fit")
}
}
#[cfg(feature = "gpu")]
pub struct GpuClusteringResult {
pub labels: Vec<usize>,
pub centroids: Vec<Vec<f64>>,
}
#[cfg(not(feature = "gpu"))]
pub struct GpuKMeans;
#[cfg(not(feature = "gpu"))]
impl GpuKMeans {
pub fn new(_n_clusters: usize) -> Result<Self> {
Err(ClusteringError::GpuError(
"GPU support not enabled. Compile with --features gpu".to_string(),
))
}
}