use crate::error::{ClusteringError, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use super::models::*;
pub fn create_sklearn_param_grid(
algorithm: &str,
param_ranges: HashMap<String, Vec<Value>>,
) -> Result<HashMap<String, Vec<Value>>> {
match algorithm {
"kmeans" => {
let mut grid = HashMap::new();
if let Some(n_clusters) = param_ranges.get("n_clusters") {
grid.insert("n_clusters".to_string(), n_clusters.clone());
}
if let Some(init) = param_ranges.get("init") {
grid.insert("init".to_string(), init.clone());
}
Ok(grid)
}
"dbscan" => {
let mut grid = HashMap::new();
if let Some(eps) = param_ranges.get("eps") {
grid.insert("eps".to_string(), eps.clone());
}
if let Some(min_samples) = param_ranges.get("min_samples") {
grid.insert("min_samples".to_string(), min_samples.clone());
}
Ok(grid)
}
_ => Err(ClusteringError::InvalidInput(format!(
"Unsupported algorithm for sklearn parameter grid: {}",
algorithm
))),
}
}
pub fn from_joblib_format(data: Vec<u8>) -> Result<Value> {
serde_json::from_slice(&data)
.map_err(|e| ClusteringError::InvalidInput(format!("Failed to parse joblib format: {}", e)))
}
pub fn from_numpy_format(data: Vec<u8>) -> Result<scirs2_core::ndarray::Array2<f64>> {
let json_data: Value = serde_json::from_slice(&data).map_err(|e| {
ClusteringError::InvalidInput(format!("Failed to parse numpy format: {}", e))
})?;
if let Value::Array(array) = json_data {
let mut flat_data = Vec::new();
let mut ncols = 0;
if let Some(Value::Array(first_row)) = array.first() {
ncols = first_row.len();
}
let nrows = array.len();
for row in array {
if let Value::Array(row_values) = row {
for val in row_values {
if let Value::Number(num) = val {
flat_data.push(num.as_f64().unwrap_or(0.0));
}
}
}
}
scirs2_core::ndarray::Array2::from_shape_vec((nrows, ncols), flat_data).map_err(|e| {
ClusteringError::InvalidInput(format!("Failed to create array from numpy data: {}", e))
})
} else {
Err(ClusteringError::InvalidInput(
"Invalid numpy format".to_string(),
))
}
}
pub fn from_sklearn_format(data: Value) -> Result<Value> {
Ok(data)
}
pub fn generate_sklearn_model_summary(model_type: &str, model_data: &Value) -> Result<String> {
match model_type {
"KMeans" => {
let summary = serde_json::json!({
"model_type": "KMeans",
"n_clusters": model_data.get("n_clusters").unwrap_or(&Value::Null),
"inertia": model_data.get("inertia_").unwrap_or(&Value::Null),
"n_iter": model_data.get("n_iter_").unwrap_or(&Value::Null)
});
Ok(serde_json::to_string_pretty(&summary)?)
}
"DBSCAN" => {
let summary = serde_json::json!({
"model_type": "DBSCAN",
"eps": model_data.get("eps").unwrap_or(&Value::Null),
"min_samples": model_data.get("min_samples").unwrap_or(&Value::Null)
});
Ok(serde_json::to_string_pretty(&summary)?)
}
_ => Err(ClusteringError::InvalidInput(format!(
"Unsupported sklearn model type: {}",
model_type
))),
}
}
pub fn to_arrow_schema<T: ClusteringModel>(model: &T) -> Result<Value> {
let schema = serde_json::json!({
"type": "struct",
"fields": [
{
"name": "cluster_id",
"type": {
"name": "int",
"bitWidth": 32
},
"nullable": false
},
{
"name": "features",
"type": {
"name": "list",
"valueType": {
"name": "floatingpoint",
"precision": "DOUBLE"
}
},
"nullable": false
}
]
});
Ok(schema)
}
pub fn to_huggingface_card<T: ClusteringModel>(model: &T) -> Result<String> {
let summary = model.summary()?;
let card = format!(
r#"
---
tags:
- clustering
- unsupervised-learning
- scirs2-cluster
library_name: scirs2-cluster
model_summary: {}
---
# Clustering Model
This is a clustering model trained using scirs2-cluster.
## Model Details
{}
## Usage
```rust
use scirs2_cluster::serialization::SerializableModel;
// Load the model
let model = Model::load_from_file("model.json")?;
// Use for prediction
let predictions = model.predict(data.view())?;
```
"#,
serde_json::to_string_pretty(&summary)?,
serde_json::to_string_pretty(&summary)?
);
Ok(card)
}
pub fn to_joblib_format<T: ClusteringModel>(model: &T) -> Result<Vec<u8>> {
let summary = model.summary()?;
Ok(serde_json::to_vec(&summary)?)
}
pub fn to_mlflow_format<T: ClusteringModel>(model: &T) -> Result<Value> {
let summary = model.summary()?;
Ok(serde_json::json!({
"artifact_path": "model",
"flavors": {
"scirs2_cluster": {
"model_type": "clustering",
"scirs2_version": env!("CARGO_PKG_VERSION"),
"data": summary
}
},
"model_uuid": uuid::Uuid::new_v4().to_string(),
"run_id": "unknown",
"utc_time_created": chrono::Utc::now().to_rfc3339()
}))
}
pub fn to_numpy_format(data: &scirs2_core::ndarray::Array2<f64>) -> Result<Vec<u8>> {
let shape = data.shape();
let numpy_data = serde_json::json!({
"shape": shape,
"data": data.as_slice().unwrap_or(&[])
});
Ok(serde_json::to_vec(&numpy_data)?)
}
pub fn to_onnx_metadata<T: ClusteringModel>(model: &T) -> Result<Value> {
let summary = model.summary()?;
Ok(serde_json::json!({
"ir_version": 7,
"producer_name": "scirs2-cluster",
"producer_version": env!("CARGO_PKG_VERSION"),
"model_version": 1,
"doc_string": "Clustering model exported from scirs2-cluster",
"metadata_props": {
"model_summary": summary
}
}))
}
pub fn to_pandas_clustering_report<T: ClusteringModel>(model: &T) -> Result<Value> {
let summary = model.summary()?;
Ok(serde_json::json!({
"model_type": "clustering",
"n_clusters": model.n_clusters(),
"summary": summary,
"pandas_version": "1.0.0",
"created_at": chrono::Utc::now().to_rfc3339()
}))
}
pub fn to_pandas_format<T: ClusteringModel>(model: &T) -> Result<Value> {
to_pandas_clustering_report(model)
}
pub fn to_pickle_like_format<T: ClusteringModel>(model: &T) -> Result<Vec<u8>> {
let summary = model.summary()?;
Ok(serde_json::to_vec(&summary)?)
}
pub fn to_pytorch_checkpoint<T: ClusteringModel>(model: &T) -> Result<Value> {
let summary = model.summary()?;
Ok(serde_json::json!({
"model_state_dict": summary,
"optimizer_state_dict": {},
"epoch": 1,
"loss": 0.0,
"pytorch_version": "1.10.0",
"scirs2_cluster_version": env!("CARGO_PKG_VERSION")
}))
}
pub fn to_r_format<T: ClusteringModel>(model: &T) -> Result<Value> {
let summary = model.summary()?;
Ok(serde_json::json!({
"class": "clustering_model",
"data": summary,
"r_version": "4.0.0",
"created_by": "scirs2-cluster"
}))
}
pub fn to_scipy_dendrogram_format(
linkage_matrix: &scirs2_core::ndarray::Array2<f64>,
) -> Result<Value> {
Ok(serde_json::json!({
"linkage": linkage_matrix.as_slice().unwrap_or(&[]),
"format": "scipy_dendrogram",
"shape": linkage_matrix.shape()
}))
}
pub fn to_scipy_linkage_format(
linkage_matrix: &scirs2_core::ndarray::Array2<f64>,
) -> Result<Value> {
Ok(serde_json::json!({
"linkage_matrix": linkage_matrix.as_slice().unwrap_or(&[]),
"shape": linkage_matrix.shape(),
"method": "ward",
"metric": "euclidean"
}))
}
pub fn to_sklearn_clustering_result<T: ClusteringModel>(model: &T) -> Result<Value> {
let summary = model.summary()?;
Ok(serde_json::json!({
"labels_": [],
"n_clusters_": model.n_clusters(),
"model_summary": summary,
"_sklearn_version": "1.0.0"
}))
}
pub fn to_sklearn_format<T: ClusteringModel>(model: &T) -> Result<Value> {
to_sklearn_clustering_result(model)
}
pub fn export_to_scipy_json(hierarchy: &HierarchicalModel) -> Result<Value> {
let linkage = &hierarchy.linkage;
let nrows = linkage.nrows();
let ncols = linkage.ncols();
let mut rows: Vec<Value> = Vec::with_capacity(nrows);
for i in 0..nrows {
let row: Vec<Value> = (0..ncols)
.map(|j| serde_json::json!(linkage[[i, j]]))
.collect();
rows.push(Value::Array(row));
}
Ok(serde_json::json!({
"linkage": rows,
"n_observations": hierarchy.n_observations,
"method": hierarchy.method,
"labels": hierarchy.labels,
}))
}
pub fn export_to_sklearn_json(kmeans: &KMeansModel) -> Result<Value> {
let centers = &kmeans.centroids;
let nrows = centers.nrows();
let ncols = centers.ncols();
let mut center_rows: Vec<Value> = Vec::with_capacity(nrows);
for i in 0..nrows {
let row: Vec<Value> = (0..ncols)
.map(|j| serde_json::json!(centers[[i, j]]))
.collect();
center_rows.push(Value::Array(row));
}
let labels_val: Value = match &kmeans.labels {
Some(labels) => Value::Array(labels.iter().map(|&l| serde_json::json!(l)).collect()),
None => Value::Null,
};
Ok(serde_json::json!({
"cluster_centers_": center_rows,
"labels_": labels_val,
"n_iter_": kmeans.n_iter,
"inertia_": kmeans.inertia,
"n_clusters_": kmeans.n_clusters,
}))
}
pub fn import_scipy_hierarchy(json: &Value) -> Result<HierarchicalModel> {
let linkage_arr = json["linkage"]
.as_array()
.ok_or_else(|| ClusteringError::InvalidInput("Missing 'linkage' array".into()))?;
let nrows = linkage_arr.len();
if nrows == 0 {
return Err(ClusteringError::InvalidInput("Empty linkage array".into()));
}
let ncols = linkage_arr[0].as_array().map(|r| r.len()).unwrap_or(0);
let mut flat: Vec<f64> = Vec::with_capacity(nrows * ncols);
for row in linkage_arr {
let row_arr = row
.as_array()
.ok_or_else(|| ClusteringError::InvalidInput("Linkage row is not an array".into()))?;
for v in row_arr {
let val = v.as_f64().ok_or_else(|| {
ClusteringError::InvalidInput("Non-numeric value in linkage".into())
})?;
flat.push(val);
}
}
let linkage = scirs2_core::ndarray::Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| ClusteringError::InvalidInput(format!("Reshape failed: {e}")))?;
let n_observations = json["n_observations"]
.as_u64()
.ok_or_else(|| ClusteringError::InvalidInput("Missing 'n_observations'".into()))?
as usize;
let method = json["method"].as_str().unwrap_or("ward").to_string();
let labels: Option<Vec<String>> = json["labels"].as_array().map(|arr| {
arr.iter()
.map(|v| v.as_str().unwrap_or("").to_string())
.collect()
});
Ok(HierarchicalModel::new(
linkage,
n_observations,
method,
labels,
))
}
pub fn import_sklearn_kmeans(json: &Value) -> Result<KMeansModel> {
let centers_arr = json["cluster_centers_"]
.as_array()
.ok_or_else(|| ClusteringError::InvalidInput("Missing 'cluster_centers_'".into()))?;
let nrows = centers_arr.len();
if nrows == 0 {
return Err(ClusteringError::InvalidInput(
"Empty cluster_centers_ array".into(),
));
}
let ncols = centers_arr[0].as_array().map(|r| r.len()).unwrap_or(0);
let mut flat: Vec<f64> = Vec::with_capacity(nrows * ncols);
for row in centers_arr {
let row_arr = row.as_array().ok_or_else(|| {
ClusteringError::InvalidInput("cluster_centers_ row is not an array".into())
})?;
for v in row_arr {
let val = v.as_f64().ok_or_else(|| {
ClusteringError::InvalidInput("Non-numeric in cluster_centers_".into())
})?;
flat.push(val);
}
}
let centroids = scirs2_core::ndarray::Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| ClusteringError::InvalidInput(format!("Reshape failed: {e}")))?;
let n_clusters = json["n_clusters_"].as_u64().unwrap_or(nrows as u64) as usize;
let n_iter = json["n_iter_"].as_u64().unwrap_or(0) as usize;
let inertia = json["inertia_"].as_f64().unwrap_or(0.0);
let labels: Option<scirs2_core::ndarray::Array1<usize>> =
json["labels_"].as_array().map(|arr| {
scirs2_core::ndarray::Array1::from_vec(
arr.iter()
.map(|v| v.as_u64().unwrap_or(0) as usize)
.collect(),
)
});
Ok(KMeansModel::new(
centroids, n_clusters, n_iter, inertia, labels,
))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_create_sklearn_param_grid() {
let mut params = HashMap::new();
params.insert(
"n_clusters".to_string(),
vec![serde_json::json!(2), serde_json::json!(3)],
);
let grid = create_sklearn_param_grid("kmeans", params).expect("Operation failed");
assert!(grid.contains_key("n_clusters"));
}
#[test]
fn test_to_numpy_format() {
let data =
Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed");
let result = to_numpy_format(&data);
assert!(result.is_ok());
}
#[test]
fn test_to_scipy_linkage_format() {
let linkage =
Array2::from_shape_vec((1, 3), vec![0.0, 1.0, 0.5]).expect("Operation failed");
let result = to_scipy_linkage_format(&linkage);
assert!(result.is_ok());
}
#[test]
fn test_scipy_hierarchy_roundtrip() {
let linkage = Array2::from_shape_vec((2, 4), vec![0.0, 1.0, 0.5, 2.0, 2.0, 2.0, 1.0, 3.0])
.expect("shape error");
let model = HierarchicalModel::new(
linkage.clone(),
3,
"ward".to_string(),
Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
);
let json = export_to_scipy_json(&model).expect("export failed");
let restored = import_scipy_hierarchy(&json).expect("import failed");
assert_eq!(restored.n_observations, 3);
assert_eq!(restored.method, "ward");
assert_eq!(restored.linkage.nrows(), 2);
assert!((restored.linkage[[0, 2]] - 0.5).abs() < 1e-12);
assert!(restored.labels.is_some());
}
#[test]
fn test_sklearn_kmeans_roundtrip() {
let centroids = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("shape error");
let labels = scirs2_core::ndarray::Array1::from_vec(vec![0usize, 1, 0]);
let model = KMeansModel::new(centroids, 2, 10, 0.5, Some(labels));
let json = export_to_sklearn_json(&model).expect("export failed");
let restored = import_sklearn_kmeans(&json).expect("import failed");
assert_eq!(restored.n_clusters, 2);
assert_eq!(restored.n_iter, 10);
assert!((restored.inertia - 0.5).abs() < 1e-12);
assert_eq!(restored.centroids.nrows(), 2);
assert_eq!(restored.centroids.ncols(), 3);
assert!((restored.centroids[[0, 0]] - 1.0).abs() < 1e-12);
assert!(restored.labels.is_some());
}
}