use super::{SklearnClusterer, SklearnEstimator};
use crate::clustering::core::QuantumClusterer;
use crate::error::{MLError, Result};
use crate::simulator_backends::{SimulatorBackend, StatevectorBackend};
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::HashMap;
use std::sync::Arc;
pub struct QuantumKMeans {
clusterer: Option<QuantumClusterer>,
n_clusters: usize,
max_iter: usize,
tol: f64,
random_state: Option<u64>,
backend: Arc<dyn SimulatorBackend>,
fitted: bool,
cluster_centers_: Option<Array2<f64>>,
labels_: Option<Array1<i32>>,
}
impl QuantumKMeans {
pub fn new(n_clusters: usize) -> Self {
Self {
clusterer: None,
n_clusters,
max_iter: 300,
tol: 1e-4,
random_state: None,
backend: Arc::new(StatevectorBackend::new(10)),
fitted: false,
cluster_centers_: None,
labels_: None,
}
}
pub fn set_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn set_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn set_random_state(mut self, random_state: u64) -> Self {
self.random_state = Some(random_state);
self
}
}
impl SklearnEstimator for QuantumKMeans {
#[allow(non_snake_case)]
fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
let config = crate::clustering::config::QuantumClusteringConfig {
algorithm: crate::clustering::config::ClusteringAlgorithm::QuantumKMeans,
n_clusters: self.n_clusters,
max_iterations: self.max_iter,
tolerance: self.tol,
num_qubits: 4,
random_state: self.random_state,
};
let mut clusterer = QuantumClusterer::new(config);
let result = clusterer.fit_predict(X)?;
let result_i32 = result.mapv(|x| x as i32);
self.labels_ = Some(result_i32);
let n_features = X.ncols();
let n_clusters = self.n_clusters;
let mut centers = Array2::<f64>::zeros((n_clusters, n_features));
let mut counts = vec![0usize; n_clusters];
for (i, &label) in result.iter().enumerate() {
let k = label.min(n_clusters - 1);
counts[k] += 1;
for j in 0..n_features {
centers[[k, j]] += X[[i, j]];
}
}
for k in 0..n_clusters {
let count = counts[k];
if count > 0 {
for j in 0..n_features {
centers[[k, j]] /= count as f64;
}
}
}
self.cluster_centers_ = Some(centers);
self.clusterer = Some(clusterer);
self.fitted = true;
Ok(())
}
fn get_params(&self) -> HashMap<String, String> {
let mut params = HashMap::new();
params.insert("n_clusters".to_string(), self.n_clusters.to_string());
params.insert("max_iter".to_string(), self.max_iter.to_string());
params.insert("tol".to_string(), self.tol.to_string());
if let Some(rs) = self.random_state {
params.insert("random_state".to_string(), rs.to_string());
}
params
}
fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
for (key, value) in params {
match key.as_str() {
"n_clusters" => {
self.n_clusters = value.parse().map_err(|_| {
MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
})?;
}
"max_iter" => {
self.max_iter = value.parse().map_err(|_| {
MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
})?;
}
"tol" => {
self.tol = value.parse().map_err(|_| {
MLError::InvalidConfiguration(format!("Invalid tol: {}", value))
})?;
}
"random_state" => {
self.random_state = Some(value.parse().map_err(|_| {
MLError::InvalidConfiguration(format!("Invalid random_state: {}", value))
})?);
}
_ => {
}
}
}
Ok(())
}
fn is_fitted(&self) -> bool {
self.fitted
}
}
impl SklearnClusterer for QuantumKMeans {
#[allow(non_snake_case)]
fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
if !self.fitted {
return Err(MLError::ModelNotTrained("Model not trained".to_string()));
}
let clusterer = self
.clusterer
.as_ref()
.ok_or_else(|| MLError::ModelNotTrained("Clusterer not initialized".to_string()))?;
let result = clusterer.predict(X)?;
Ok(result.mapv(|x| x as i32))
}
fn cluster_centers(&self) -> Option<&Array2<f64>> {
self.cluster_centers_.as_ref()
}
}
pub struct DBSCAN {
eps: f64,
min_samples: usize,
labels: Option<Array1<i32>>,
core_sample_indices: Vec<usize>,
}
impl DBSCAN {
pub fn new(eps: f64, min_samples: usize) -> Self {
Self {
eps,
min_samples,
labels: None,
core_sample_indices: Vec::new(),
}
}
pub fn eps(mut self, eps: f64) -> Self {
self.eps = eps;
self
}
pub fn min_samples(mut self, min_samples: usize) -> Self {
self.min_samples = min_samples;
self
}
pub fn labels(&self) -> Option<&Array1<i32>> {
self.labels.as_ref()
}
pub fn core_sample_indices(&self) -> &[usize] {
&self.core_sample_indices
}
#[allow(non_snake_case)]
fn compute_distances(&self, X: &Array2<f64>) -> Array2<f64> {
let n = X.nrows();
let mut distances = Array2::zeros((n, n));
for i in 0..n {
for j in i + 1..n {
let mut dist = 0.0;
for k in 0..X.ncols() {
let diff = X[[i, k]] - X[[j, k]];
dist += diff * diff;
}
let dist = dist.sqrt();
distances[[i, j]] = dist;
distances[[j, i]] = dist;
}
}
distances
}
pub fn n_clusters(&self) -> Option<usize> {
self.labels.as_ref().map(|labels| {
let max_label = labels.iter().max().copied().unwrap_or(-1);
if max_label >= 0 {
(max_label + 1) as usize
} else {
0
}
})
}
#[allow(non_snake_case)]
fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
let n = X.nrows();
let distances = self.compute_distances(X);
let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
for i in 0..n {
for j in 0..n {
if i != j && distances[[i, j]] <= self.eps {
neighbors[i].push(j);
}
}
}
self.core_sample_indices.clear();
for (i, n_neighbors) in neighbors.iter().enumerate() {
if n_neighbors.len() >= self.min_samples {
self.core_sample_indices.push(i);
}
}
let mut labels = Array1::from_elem(n, -1_i32); let mut visited = vec![false; n];
let mut cluster_id = 0_i32;
for &core_idx in &self.core_sample_indices {
if visited[core_idx] {
continue;
}
let mut stack = vec![core_idx];
while let Some(idx) = stack.pop() {
if visited[idx] {
continue;
}
visited[idx] = true;
labels[idx] = cluster_id;
if neighbors[idx].len() >= self.min_samples {
for &neighbor in &neighbors[idx] {
if !visited[neighbor] {
stack.push(neighbor);
}
}
}
}
cluster_id += 1;
}
self.labels = Some(labels);
Ok(())
}
}
impl SklearnEstimator for DBSCAN {
#[allow(non_snake_case)]
fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
self.fit_internal(X)
}
fn get_params(&self) -> HashMap<String, String> {
let mut params = HashMap::new();
params.insert("eps".to_string(), self.eps.to_string());
params.insert("min_samples".to_string(), self.min_samples.to_string());
params
}
fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
for (key, value) in params {
match key.as_str() {
"eps" => {
self.eps = value.parse().map_err(|_| {
MLError::InvalidConfiguration(format!("Invalid eps: {}", value))
})?;
}
"min_samples" => {
self.min_samples = value.parse().map_err(|_| {
MLError::InvalidConfiguration(format!("Invalid min_samples: {}", value))
})?;
}
_ => {}
}
}
Ok(())
}
fn is_fitted(&self) -> bool {
self.labels.is_some()
}
}
impl SklearnClusterer for DBSCAN {
#[allow(non_snake_case)]
fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
self.labels
.clone()
.ok_or_else(|| MLError::ModelNotTrained("DBSCAN not fitted".to_string()))
}
}
pub struct AgglomerativeClustering {
n_clusters: usize,
linkage: String,
labels: Option<Array1<i32>>,
}
impl AgglomerativeClustering {
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
linkage: "ward".to_string(),
labels: None,
}
}
pub fn linkage(mut self, linkage: &str) -> Self {
self.linkage = linkage.to_string();
self
}
pub fn get_n_clusters(&self) -> Option<usize> {
if self.labels.is_some() {
Some(self.n_clusters)
} else {
None
}
}
#[allow(non_snake_case)]
fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
let n = X.nrows();
let mut distances = Array2::from_elem((n, n), f64::INFINITY);
for i in 0..n {
for j in i + 1..n {
let mut dist = 0.0;
for k in 0..X.ncols() {
let diff = X[[i, k]] - X[[j, k]];
dist += diff * diff;
}
distances[[i, j]] = dist.sqrt();
distances[[j, i]] = distances[[i, j]];
}
distances[[i, i]] = 0.0;
}
let mut cluster_assignment: Vec<usize> = (0..n).collect();
let mut active_clusters: Vec<bool> = vec![true; n];
let mut cluster_sizes: Vec<usize> = vec![1; n];
let mut num_clusters = n;
while num_clusters > self.n_clusters {
let mut min_dist = f64::INFINITY;
let mut merge_i = 0;
let mut merge_j = 0;
for i in 0..n {
if !active_clusters[i] {
continue;
}
for j in i + 1..n {
if !active_clusters[j] {
continue;
}
if distances[[i, j]] < min_dist {
min_dist = distances[[i, j]];
merge_i = i;
merge_j = j;
}
}
}
for k in 0..n {
if cluster_assignment[k] == merge_j {
cluster_assignment[k] = merge_i;
}
}
active_clusters[merge_j] = false;
cluster_sizes[merge_i] += cluster_sizes[merge_j];
for k in 0..n {
if k != merge_i && active_clusters[k] {
let new_dist = match self.linkage.as_str() {
"single" => distances[[merge_i, k]].min(distances[[merge_j, k]]),
"complete" => distances[[merge_i, k]].max(distances[[merge_j, k]]),
"average" | _ => {
let s_i = cluster_sizes[merge_i] as f64;
let s_j = cluster_sizes[merge_j] as f64;
(distances[[merge_i, k]] * (s_i - cluster_sizes[merge_j] as f64)
+ distances[[merge_j, k]] * s_j)
/ s_i
}
};
distances[[merge_i, k]] = new_dist;
distances[[k, merge_i]] = new_dist;
}
}
num_clusters -= 1;
}
let unique_clusters: Vec<usize> = cluster_assignment
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let label_map: std::collections::HashMap<usize, i32> = unique_clusters
.iter()
.enumerate()
.map(|(i, &c)| (c, i as i32))
.collect();
let labels = cluster_assignment
.iter()
.map(|&c| *label_map.get(&c).unwrap_or(&0))
.collect();
self.labels = Some(Array1::from_vec(labels));
Ok(())
}
}
impl SklearnEstimator for AgglomerativeClustering {
#[allow(non_snake_case)]
fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
self.fit_internal(X)
}
fn get_params(&self) -> HashMap<String, String> {
let mut params = HashMap::new();
params.insert("n_clusters".to_string(), self.n_clusters.to_string());
params.insert("linkage".to_string(), self.linkage.clone());
params
}
fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
for (key, value) in params {
match key.as_str() {
"n_clusters" => {
self.n_clusters = value.parse().map_err(|_| {
MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
})?;
}
"linkage" => {
self.linkage = value;
}
_ => {}
}
}
Ok(())
}
fn is_fitted(&self) -> bool {
self.labels.is_some()
}
}
impl SklearnClusterer for AgglomerativeClustering {
#[allow(non_snake_case)]
fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
self.labels
.clone()
.ok_or_else(|| MLError::ModelNotTrained("Not fitted".to_string()))
}
}