use crate::error::{ClusteringError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use super::core::SerializableModel;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ClusteringWorkflow {
pub workflow_id: String,
pub current_step: usize,
pub steps: Vec<TrainingStep>,
pub current_state: AlgorithmState,
pub config: WorkflowConfig,
pub execution_history: Vec<ExecutionRecord>,
pub intermediate_results: HashMap<String, serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum AlgorithmState {
NotStarted,
Running {
iteration: usize,
start_time: u64,
progress: f32,
},
Completed {
iterations: usize,
execution_time: f64,
final_metrics: HashMap<String, f64>,
},
Failed {
error: String,
failure_time: u64,
},
Paused {
pause_time: u64,
paused_at_iteration: usize,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TrainingStep {
pub name: String,
pub algorithm: String,
pub parameters: HashMap<String, serde_json::Value>,
pub dependencies: Vec<String>,
pub completed: bool,
pub execution_time: Option<f64>,
pub results: Option<serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct WorkflowConfig {
pub auto_save_interval: Option<u64>,
pub max_retries: usize,
pub step_timeout: Option<u64>,
pub parallel_execution: bool,
pub checkpoint_dir: Option<PathBuf>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ExecutionRecord {
pub timestamp: u64,
pub step_name: String,
pub action: String,
pub result: ExecutionResult,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ExecutionResult {
Success {
duration: f64,
output: Option<serde_json::Value>,
},
Failure {
error: String,
error_code: Option<String>,
},
Skipped {
reason: String,
},
}
impl ClusteringWorkflow {
pub fn new(workflow_id: String, config: WorkflowConfig) -> Self {
Self {
workflow_id,
current_step: 0,
steps: Vec::new(),
current_state: AlgorithmState::NotStarted,
config,
execution_history: Vec::new(),
intermediate_results: HashMap::new(),
}
}
pub fn add_step(&mut self, step: TrainingStep) {
self.steps.push(step);
}
pub fn execute(&mut self) -> Result<()> {
self.current_state = AlgorithmState::Running {
iteration: 0,
start_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
progress: 0.0,
};
let start_time = std::time::Instant::now();
let steps_len = self.steps.len();
for i in 0..steps_len {
self.current_step = i;
let dependencies = self.steps[i].dependencies.clone();
if !self.check_dependencies(&dependencies)? {
return Err(ClusteringError::InvalidInput(format!(
"Dependencies not satisfied for step: {}",
self.steps[i].name
)));
}
let step_start = std::time::Instant::now();
let step_clone = self.steps[i].clone();
let result = self.execute_step(&step_clone)?;
let step_duration = step_start.elapsed().as_secs_f64();
self.steps[i].completed = true;
self.steps[i].execution_time = Some(step_duration);
self.steps[i].results = Some(result.clone());
let step_name = self.steps[i].name.clone();
self.record_execution(
&step_name,
"execute",
ExecutionResult::Success {
duration: step_duration,
output: Some(result),
},
);
let progress = ((i + 1) as f32 / steps_len as f32) * 100.0;
self.update_progress(progress);
if let Some(interval) = self.config.auto_save_interval {
if step_duration > interval as f64 {
self.save_checkpoint()?;
}
}
}
let total_time = start_time.elapsed().as_secs_f64();
self.current_state = AlgorithmState::Completed {
iterations: self.steps.len(),
execution_time: total_time,
final_metrics: self.collect_final_metrics(),
};
Ok(())
}
fn execute_step(&mut self, step: &TrainingStep) -> Result<serde_json::Value> {
use serde_json::json;
let result = match step.algorithm.as_str() {
"kmeans" => {
json!({
"algorithm": "kmeans",
"centroids": [[0.0, 0.0], [1.0, 1.0]],
"inertia": 0.5,
"iterations": 10
})
}
"dbscan" => {
json!({
"algorithm": "dbscan",
"n_clusters": 2,
"core_samples": [0, 1, 2],
"noise_points": []
})
}
_ => {
return Err(ClusteringError::InvalidInput(format!(
"Unknown algorithm: {}",
step.algorithm
)));
}
};
self.intermediate_results
.insert(step.name.clone(), result.clone());
Ok(result)
}
fn check_dependencies(&self, dependencies: &[String]) -> Result<bool> {
for dep in dependencies {
if !self.steps.iter().any(|s| s.name == *dep && s.completed) {
return Ok(false);
}
}
Ok(true)
}
fn update_progress(&mut self, progress: f32) {
if let AlgorithmState::Running {
iteration,
start_time,
..
} = &mut self.current_state
{
self.current_state = AlgorithmState::Running {
iteration: *iteration + 1,
start_time: *start_time,
progress,
};
}
}
fn record_execution(&mut self, step_name: &str, action: &str, result: ExecutionResult) {
let record = ExecutionRecord {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
step_name: step_name.to_string(),
action: action.to_string(),
result,
metadata: HashMap::new(),
};
self.execution_history.push(record);
}
fn collect_final_metrics(&self) -> HashMap<String, f64> {
let mut metrics = HashMap::new();
let total_steps = self.steps.len() as f64;
let completed_steps = self.steps.iter().filter(|s| s.completed).count() as f64;
let total_time: f64 = self.steps.iter().filter_map(|s| s.execution_time).sum();
metrics.insert("total_steps".to_string(), total_steps);
metrics.insert("completed_steps".to_string(), completed_steps);
metrics.insert("completion_rate".to_string(), completed_steps / total_steps);
metrics.insert("total_execution_time".to_string(), total_time);
metrics
}
pub fn save_checkpoint(&self) -> Result<()> {
if let Some(ref checkpoint_dir) = self.config.checkpoint_dir {
std::fs::create_dir_all(checkpoint_dir)
.map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
let checkpoint_file =
checkpoint_dir.join(format!("{}_checkpoint.json", self.workflow_id));
self.save_to_file(checkpoint_file)?;
}
Ok(())
}
pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::load_from_file(path)
}
pub fn pause(&mut self) {
let current_iteration = match &self.current_state {
AlgorithmState::Running { iteration, .. } => *iteration,
_ => 0,
};
self.current_state = AlgorithmState::Paused {
pause_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
paused_at_iteration: current_iteration,
};
}
pub fn resume(&mut self) -> Result<()> {
if let AlgorithmState::Paused {
paused_at_iteration,
..
} = &self.current_state
{
self.current_state = AlgorithmState::Running {
iteration: *paused_at_iteration,
start_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
progress: (*paused_at_iteration as f32 / self.steps.len() as f32) * 100.0,
};
self.execute_remaining_steps()
} else {
Err(ClusteringError::InvalidInput(
"Workflow is not in paused state".to_string(),
))
}
}
fn execute_remaining_steps(&mut self) -> Result<()> {
let start_index = self.current_step;
let steps_len = self.steps.len();
for i in start_index..steps_len {
if !self.steps[i].completed {
self.current_step = i;
let step_start = std::time::Instant::now();
let step_clone = self.steps[i].clone();
let result = self.execute_step(&step_clone)?;
let step_duration = step_start.elapsed().as_secs_f64();
self.steps[i].completed = true;
self.steps[i].execution_time = Some(step_duration);
self.steps[i].results = Some(result.clone());
let step_name = self.steps[i].name.clone();
self.record_execution(
&step_name,
"resume_execute",
ExecutionResult::Success {
duration: step_duration,
output: Some(result),
},
);
}
}
let final_metrics = self.collect_final_metrics();
self.current_state = AlgorithmState::Completed {
iterations: self.steps.len(),
execution_time: final_metrics
.get("total_execution_time")
.copied()
.unwrap_or(0.0),
final_metrics,
};
Ok(())
}
pub fn get_progress(&self) -> f32 {
match &self.current_state {
AlgorithmState::Running { progress, .. } => *progress,
AlgorithmState::Completed { .. } => 100.0,
AlgorithmState::Failed { .. } => 0.0,
AlgorithmState::Paused {
paused_at_iteration,
..
} => (*paused_at_iteration as f32 / self.steps.len() as f32) * 100.0,
AlgorithmState::NotStarted => 0.0,
}
}
pub fn get_status(&self) -> WorkflowStatus {
WorkflowStatus {
workflow_id: self.workflow_id.clone(),
current_step: self.current_step,
total_steps: self.steps.len(),
state: self.current_state.clone(),
progress: self.get_progress(),
completed_steps: self.steps.iter().filter(|s| s.completed).count(),
}
}
}
impl SerializableModel for ClusteringWorkflow {}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WorkflowStatus {
pub workflow_id: String,
pub current_step: usize,
pub total_steps: usize,
pub state: AlgorithmState,
pub progress: f32,
pub completed_steps: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ClusteringWorkflowManager {
pub workflows: HashMap<String, ClusteringWorkflow>,
pub config: ManagerConfig,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ManagerConfig {
pub max_concurrent_workflows: usize,
pub default_checkpoint_dir: Option<PathBuf>,
pub global_auto_save_interval: Option<u64>,
}
impl Default for ManagerConfig {
fn default() -> Self {
Self {
max_concurrent_workflows: 10,
default_checkpoint_dir: None,
global_auto_save_interval: Some(300), }
}
}
impl ClusteringWorkflowManager {
pub fn new(config: ManagerConfig) -> Self {
Self {
workflows: HashMap::new(),
config,
}
}
pub fn add_workflow(&mut self, workflow: ClusteringWorkflow) -> Result<()> {
if self.workflows.len() >= self.config.max_concurrent_workflows {
return Err(ClusteringError::InvalidInput(
"Maximum number of concurrent workflows reached".to_string(),
));
}
self.workflows
.insert(workflow.workflow_id.clone(), workflow);
Ok(())
}
pub fn get_workflow(&self, workflow_id: &str) -> Option<&ClusteringWorkflow> {
self.workflows.get(workflow_id)
}
pub fn get_workflow_mut(&mut self, workflow_id: &str) -> Option<&mut ClusteringWorkflow> {
self.workflows.get_mut(workflow_id)
}
pub fn execute_workflow(&mut self, workflow_id: &str) -> Result<()> {
if let Some(workflow) = self.workflows.get_mut(workflow_id) {
workflow.execute()
} else {
Err(ClusteringError::InvalidInput(format!(
"Workflow not found: {}",
workflow_id
)))
}
}
pub fn get_all_statuses(&self) -> HashMap<String, WorkflowStatus> {
self.workflows
.iter()
.map(|(id, workflow)| (id.clone(), workflow.get_status()))
.collect()
}
pub fn cleanup_completed(&mut self) {
self.workflows.retain(|_, workflow| {
!matches!(workflow.current_state, AlgorithmState::Completed { .. })
});
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct AutoSaveConfig {
pub enabled: bool,
pub interval_seconds: u64,
pub save_directory: PathBuf,
}
impl Default for AutoSaveConfig {
fn default() -> Self {
Self {
enabled: true,
interval_seconds: 300, save_directory: PathBuf::from("./checkpoints"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub enum WorkflowState {
Created,
Running,
Paused,
Completed,
Failed(String),
Cancelled,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum StepResult {
Success {
output: serde_json::Value,
metrics: HashMap<String, f64>,
},
Failure {
error: String,
details: Option<serde_json::Value>,
},
Skipped {
reason: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WorkflowStep {
pub name: String,
pub step_type: String,
pub parameters: HashMap<String, serde_json::Value>,
pub dependencies: Vec<String>,
pub expected_duration: Option<f64>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workflow_creation() {
let config = WorkflowConfig {
auto_save_interval: Some(60),
max_retries: 3,
step_timeout: Some(300),
parallel_execution: false,
checkpoint_dir: None,
};
let workflow = ClusteringWorkflow::new("test_workflow".to_string(), config);
assert_eq!(workflow.workflow_id, "test_workflow");
assert_eq!(workflow.current_step, 0);
assert!(workflow.steps.is_empty());
}
#[test]
fn test_workflow_step_addition() {
let config = WorkflowConfig {
auto_save_interval: None,
max_retries: 1,
step_timeout: None,
parallel_execution: false,
checkpoint_dir: None,
};
let mut workflow = ClusteringWorkflow::new("test".to_string(), config);
let step = TrainingStep {
name: "kmeans_step".to_string(),
algorithm: "kmeans".to_string(),
parameters: HashMap::new(),
dependencies: Vec::new(),
completed: false,
execution_time: None,
results: None,
};
workflow.add_step(step);
assert_eq!(workflow.steps.len(), 1);
assert_eq!(workflow.steps[0].name, "kmeans_step");
}
#[test]
fn test_workflow_manager() {
let config = ManagerConfig::default();
let mut manager = ClusteringWorkflowManager::new(config);
let workflow_config = WorkflowConfig {
auto_save_interval: None,
max_retries: 1,
step_timeout: None,
parallel_execution: false,
checkpoint_dir: None,
};
let workflow = ClusteringWorkflow::new("test_workflow".to_string(), workflow_config);
manager.add_workflow(workflow).expect("Operation failed");
assert!(manager.get_workflow("test_workflow").is_some());
assert_eq!(manager.workflows.len(), 1);
}
}