use crate::error::{ClusteringError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::path::Path;
#[cfg(feature = "yaml")]
use serde_yaml;
use super::core::{EnhancedModelMetadata, PlatformInfo, SerializableModel};
use super::models::*;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExportFormat {
Json,
JsonGz,
Binary,
Csv,
Yaml,
Xml,
Hdf5,
Custom(String),
}
pub trait AdvancedExport {
fn export_with_metadata(
&self,
format: ExportFormat,
metadata: Option<ModelMetadata>,
) -> Result<Vec<u8>>;
fn export_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()>;
fn export_summary(&self) -> Result<String>;
fn export_compatible(&self, target_library: &str) -> Result<Value>;
fn validate_for_export(&self) -> Result<()>;
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ModelMetadata {
pub model_info: ModelInfo,
pub algorithm_config: AlgorithmConfig,
pub performance_metrics: PerformanceMetrics,
pub data_characteristics: ModelDataCharacteristics,
pub export_settings: ExportSettings,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ModelInfo {
pub name: String,
pub version: String,
pub created_at: String,
pub author: Option<String>,
pub description: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct AlgorithmConfig {
pub algorithm: String,
pub hyperparameters: HashMap<String, Value>,
pub preprocessing: Vec<String>,
pub random_seed: Option<u64>,
pub convergence_criteria: Option<HashMap<String, f64>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PerformanceMetrics {
pub training_time_seconds: f64,
pub peak_memory_mb: f64,
pub cpu_utilization: f64,
pub quality_metrics: HashMap<String, f64>,
pub convergence_info: Option<ConvergenceInfo>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ConvergenceInfo {
pub converged: bool,
pub iterations: usize,
pub final_objective: f64,
pub tolerance_achieved: f64,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ModelDataCharacteristics {
pub n_samples: usize,
pub n_features: usize,
pub feature_names: Option<Vec<String>>,
pub feature_types: Option<Vec<String>>,
pub feature_statistics: Option<HashMap<String, FeatureStats>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct FeatureStats {
pub mean: f64,
pub std: f64,
pub min: f64,
pub max: f64,
pub missing_count: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ExportSettings {
pub include_raw_data: bool,
pub include_training_data: bool,
pub compression_level: Option<u8>,
pub float_precision: Option<usize>,
pub custom_options: HashMap<String, Value>,
}
impl Default for ExportSettings {
fn default() -> Self {
Self {
include_raw_data: true,
include_training_data: false,
compression_level: None,
float_precision: Some(6),
custom_options: HashMap::new(),
}
}
}
impl AdvancedExport for KMeansModel {
fn export_with_metadata(
&self,
format: ExportFormat,
metadata: Option<ModelMetadata>,
) -> Result<Vec<u8>> {
let export_data = KMeansExportData {
model: self.clone(),
metadata,
format_version: "1.0".to_string(),
export_timestamp: chrono::Utc::now().to_rfc3339(),
};
match format {
ExportFormat::Json => {
let json = serde_json::to_string_pretty(&export_data)
.map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
Ok(json.into_bytes())
}
ExportFormat::JsonGz => {
let json = serde_json::to_string(&export_data)
.map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
oxiarc_deflate::gzip_compress(json.as_bytes(), 6)
.map_err(|e| ClusteringError::InvalidInput(e.to_string()))
}
#[cfg(feature = "yaml")]
ExportFormat::Yaml => {
let yaml = serde_yaml::to_string(&export_data)
.map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
Ok(yaml.into_bytes())
}
#[cfg(not(feature = "yaml"))]
ExportFormat::Yaml => Err(ClusteringError::InvalidInput(
"YAML support not enabled. Enable the 'yaml' feature".to_string(),
)),
ExportFormat::Csv => self.export_csv(),
_ => Err(ClusteringError::InvalidInput(format!(
"Unsupported export format: {:?}",
format
))),
}
}
fn export_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let path = path.as_ref();
let format = detect_format_from_extension(path)?;
let data = self.export_with_metadata(format, None)?;
std::fs::write(path, data)
.map_err(|e| ClusteringError::InvalidInput(format!("Failed to write file: {}", e)))
}
fn export_summary(&self) -> Result<String> {
let summary = KMeansSummary {
algorithm: "K-Means".to_string(),
n_clusters: self.n_clusters,
n_features: self.centroids.ncols(),
n_iterations: self.n_iter,
inertia: self.inertia,
has_labels: self.labels.is_some(),
};
serde_json::to_string_pretty(&summary)
.map_err(|e| ClusteringError::InvalidInput(e.to_string()))
}
fn export_compatible(&self, target_library: &str) -> Result<Value> {
match target_library.to_lowercase().as_str() {
"sklearn" | "scikit-learn" => self.to_sklearn_format(),
"tensorflow" | "tf" => self.to_tensorflow_format(),
"pytorch" => self.to_pytorch_format(),
_ => Err(ClusteringError::InvalidInput(format!(
"Unsupported target library: {}",
target_library
))),
}
}
fn validate_for_export(&self) -> Result<()> {
if self.centroids.is_empty() {
return Err(ClusteringError::InvalidInput(
"Cannot export model with empty centroids".to_string(),
));
}
if self.n_clusters == 0 {
return Err(ClusteringError::InvalidInput(
"Cannot export model with zero clusters".to_string(),
));
}
if self.centroids.nrows() != self.n_clusters {
return Err(ClusteringError::InvalidInput(
"Centroids shape inconsistent with n_clusters".to_string(),
));
}
Ok(())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct KMeansExportData {
model: KMeansModel,
metadata: Option<ModelMetadata>,
format_version: String,
export_timestamp: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct KMeansSummary {
algorithm: String,
n_clusters: usize,
n_features: usize,
n_iterations: usize,
inertia: f64,
has_labels: bool,
}
impl KMeansModel {
fn export_csv(&self) -> Result<Vec<u8>> {
let mut csv_content = String::new();
csv_content.push_str("cluster_id");
for i in 0..self.centroids.ncols() {
csv_content.push_str(&format!(",feature_{}", i));
}
csv_content.push('\n');
for (cluster_id, centroid) in self.centroids.rows().into_iter().enumerate() {
csv_content.push_str(&cluster_id.to_string());
for value in centroid {
csv_content.push_str(&format!(",{:.6}", value));
}
csv_content.push('\n');
}
Ok(csv_content.into_bytes())
}
fn to_sklearn_format(&self) -> Result<Value> {
use serde_json::json;
Ok(json!({
"cluster_centers_": self.centroids.as_slice().expect("Operation failed"),
"labels_": self.labels.as_ref().map(|l| l.as_slice().expect("Operation failed")),
"inertia_": self.inertia,
"n_iter_": self.n_iter,
"n_clusters": self.n_clusters,
"_sklearn_version": "1.0.0"
}))
}
fn to_tensorflow_format(&self) -> Result<Value> {
use serde_json::json;
Ok(json!({
"centroids": {
"data": self.centroids.as_slice().expect("Operation failed"),
"shape": [self.centroids.nrows(), self.centroids.ncols()],
"dtype": "float64"
},
"metadata": {
"n_clusters": self.n_clusters,
"inertia": self.inertia,
"iterations": self.n_iter
}
}))
}
fn to_pytorch_format(&self) -> Result<Value> {
use serde_json::json;
Ok(json!({
"state_dict": {
"centroids": self.centroids.as_slice().expect("Operation failed")
},
"hyperparameters": {
"n_clusters": self.n_clusters
},
"metrics": {
"inertia": self.inertia,
"n_iter": self.n_iter
}
}))
}
}
fn detect_format_from_extension<P: AsRef<Path>>(path: P) -> Result<ExportFormat> {
let path = path.as_ref();
let extension = path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("")
.to_lowercase();
match extension.as_str() {
"json" => Ok(ExportFormat::Json),
"gz" | "json.gz" => Ok(ExportFormat::JsonGz),
"yaml" | "yml" => Ok(ExportFormat::Yaml),
"csv" => Ok(ExportFormat::Csv),
"xml" => Ok(ExportFormat::Xml),
"h5" | "hdf5" => Ok(ExportFormat::Hdf5),
_ => Err(ClusteringError::InvalidInput(format!(
"Unknown file extension: {}",
extension
))),
}
}
pub mod utils {
use super::*;
pub fn create_default_metadata(algorithm_name: &str) -> ModelMetadata {
ModelMetadata {
model_info: ModelInfo {
name: format!("{}_model", algorithm_name),
version: "1.0.0".to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
author: None,
description: None,
},
algorithm_config: AlgorithmConfig {
algorithm: algorithm_name.to_string(),
hyperparameters: HashMap::new(),
preprocessing: Vec::new(),
random_seed: None,
convergence_criteria: None,
},
performance_metrics: PerformanceMetrics {
training_time_seconds: 0.0,
peak_memory_mb: 0.0,
cpu_utilization: 0.0,
quality_metrics: HashMap::new(),
convergence_info: None,
},
data_characteristics: ModelDataCharacteristics {
n_samples: 0,
n_features: 0,
feature_names: None,
feature_types: None,
feature_statistics: None,
},
export_settings: ExportSettings::default(),
}
}
pub fn compress_data(data: &[u8]) -> Result<Vec<u8>> {
oxiarc_deflate::gzip_compress(data, 6)
.map_err(|e| ClusteringError::InvalidInput(e.to_string()))
}
pub fn decompress_data(compressed: &[u8]) -> Result<Vec<u8>> {
oxiarc_deflate::gzip_decompress(compressed)
.map_err(|e| ClusteringError::InvalidInput(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_kmeans_export_summary() {
let centroids =
Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).expect("Operation failed");
let model = KMeansModel::new(centroids, 2, 10, 0.5, None);
let summary = model.export_summary().expect("Operation failed");
assert!(summary.contains("K-Means"));
assert!(summary.contains("\"n_clusters\": 2"));
}
#[test]
fn test_format_detection() {
assert_eq!(
detect_format_from_extension("model.json").expect("Operation failed"),
ExportFormat::Json
);
assert_eq!(
detect_format_from_extension("model.yaml").expect("Operation failed"),
ExportFormat::Yaml
);
assert_eq!(
detect_format_from_extension("model.csv").expect("Operation failed"),
ExportFormat::Csv
);
}
#[test]
fn test_sklearn_compatibility() {
let centroids =
Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).expect("Operation failed");
let model = KMeansModel::new(centroids, 2, 10, 0.5, None);
let sklearn_format = model
.export_compatible("sklearn")
.expect("Operation failed");
assert!(sklearn_format.get("cluster_centers_").is_some());
assert!(sklearn_format.get("_sklearn_version").is_some());
}
}