#![allow(dead_code)]
use crate::error_recovery::{CircuitBreakerConfig, FailureDetector, RetryConfig};
use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::fs;
use torsh_nn::Parameter;
use torsh_tensor::Tensor;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub checkpoint_dir: PathBuf,
pub checkpoint_frequency: usize,
pub max_checkpoints: usize,
pub async_save: bool,
pub compression_level: u8,
pub verify_after_save: bool,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
checkpoint_dir: PathBuf::from("./checkpoints"),
checkpoint_frequency: 1000,
max_checkpoints: 5,
async_save: true,
compression_level: 3,
verify_after_save: true,
}
}
}
#[derive(Debug, Clone)]
pub struct ElasticConfig {
pub min_workers: usize,
pub max_workers: usize,
pub scaling_timeout: Duration,
pub scaling_check_interval: Duration,
pub enable_elastic_scheduling: bool,
pub rendezvous_backend: String,
pub rendezvous_endpoint: String,
}
impl Default for ElasticConfig {
fn default() -> Self {
Self {
min_workers: 1,
max_workers: 64,
scaling_timeout: Duration::from_secs(300), scaling_check_interval: Duration::from_secs(30),
enable_elastic_scheduling: true,
rendezvous_backend: "etcd".to_string(),
rendezvous_endpoint: "localhost:2379".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingCheckpoint {
pub step: usize,
pub epoch: usize,
pub model_state: HashMap<String, Vec<f32>>,
pub optimizer_state: HashMap<String, Vec<f32>>,
pub scheduler_state: HashMap<String, f32>,
pub rng_states: HashMap<String, Vec<u8>>,
pub loss: f32,
pub metrics: HashMap<String, f32>,
pub config: HashMap<String, String>,
pub timestamp: u64,
pub version: String,
pub distributed_meta: DistributedMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedMetadata {
pub world_size: usize,
pub rank: usize,
pub process_group_info: HashMap<String, String>,
pub dp_size: usize,
pub tp_size: usize,
pub pp_size: usize,
pub fsdp_sharding: HashMap<String, Vec<usize>>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ScalingEvent {
WorkerFailure { failed_ranks: Vec<usize> },
WorkerJoin { new_ranks: Vec<usize> },
ManualScale { target_workers: usize },
AutoScale {
target_workers: usize,
reason: String,
},
}
#[derive(Debug, Clone)]
pub enum ScalingState {
Stable,
Scaling {
event: ScalingEvent,
start_time: SystemTime,
expected_workers: usize,
},
Synchronizing {
current_workers: usize,
target_workers: usize,
},
}
#[derive(Debug)]
pub struct CheckpointManager {
config: CheckpointConfig,
failure_detector: FailureDetector,
latest_checkpoint: Arc<RwLock<Option<TrainingCheckpoint>>>,
checkpoint_history: Arc<RwLock<Vec<PathBuf>>>,
}
impl CheckpointManager {
pub fn new(
config: CheckpointConfig,
health_check_interval: Duration,
health_timeout: Duration,
) -> TorshResult<Self> {
std::fs::create_dir_all(&config.checkpoint_dir).map_err(|e| {
TorshDistributedError::backend_error(
"checkpoint",
format!("Failed to create checkpoint directory: {}", e),
)
})?;
let retry_config = RetryConfig::default();
let circuit_breaker_config = CircuitBreakerConfig::default();
let failure_detector = FailureDetector::new(
health_check_interval,
health_timeout,
retry_config,
Some(circuit_breaker_config),
);
Ok(Self {
config,
failure_detector,
latest_checkpoint: Arc::new(RwLock::new(None)),
checkpoint_history: Arc::new(RwLock::new(Vec::new())),
})
}
pub async fn save_checkpoint(
&self,
checkpoint: TrainingCheckpoint,
rank: usize,
) -> TorshResult<PathBuf> {
let checkpoint_path = self.config.checkpoint_dir.join(format!(
"checkpoint_step_{}_rank_{}.json",
checkpoint.step, rank
));
info!(
"Saving checkpoint at step {} to {:?}",
checkpoint.step, checkpoint_path
);
let checkpoint_data = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
TorshDistributedError::backend_error(
"checkpoint",
format!("Failed to serialize checkpoint: {}", e),
)
})?;
self.failure_detector
.execute_with_recovery(
|| async {
fs::write(&checkpoint_path, &checkpoint_data)
.await
.map_err(|e| {
TorshDistributedError::backend_error(
"checkpoint",
format!("Failed to write checkpoint: {}", e),
)
})
},
None,
)
.await?;
if self.config.verify_after_save {
self.verify_checkpoint(&checkpoint_path).await?;
}
let checkpoint_step = checkpoint.step;
{
let mut latest = self
.latest_checkpoint
.write()
.expect("lock should not be poisoned");
*latest = Some(checkpoint);
}
{
let mut history = self
.checkpoint_history
.write()
.expect("lock should not be poisoned");
history.push(checkpoint_path.clone());
if history.len() > self.config.max_checkpoints {
let old_checkpoint = history.remove(0);
if let Err(e) = std::fs::remove_file(&old_checkpoint) {
warn!(
"Failed to remove old checkpoint {:?}: {}",
old_checkpoint, e
);
}
}
}
info!("Successfully saved checkpoint at step {}", checkpoint_step);
Ok(checkpoint_path)
}
pub async fn load_latest_checkpoint(&self) -> TorshResult<Option<TrainingCheckpoint>> {
let checkpoint_files = self.find_checkpoint_files().await?;
if checkpoint_files.is_empty() {
info!("No checkpoints found");
return Ok(None);
}
let latest_file = checkpoint_files
.iter()
.max_by_key(|path| self.extract_step_from_filename(path))
.expect("checkpoint_files should not be empty");
info!("Loading latest checkpoint from {:?}", latest_file);
self.load_checkpoint(latest_file).await
}
pub async fn load_checkpoint(
&self,
checkpoint_path: &PathBuf,
) -> TorshResult<Option<TrainingCheckpoint>> {
self.failure_detector
.execute_with_recovery(
|| async {
let checkpoint_data =
fs::read_to_string(checkpoint_path).await.map_err(|e| {
TorshDistributedError::backend_error(
"checkpoint",
format!("Failed to read checkpoint: {}", e),
)
})?;
let checkpoint: TrainingCheckpoint = serde_json::from_str(&checkpoint_data)
.map_err(|e| {
TorshDistributedError::backend_error(
"checkpoint",
format!("Failed to deserialize checkpoint: {}", e),
)
})?;
info!(
"Successfully loaded checkpoint from step {}",
checkpoint.step
);
Ok(Some(checkpoint))
},
None,
)
.await
}
async fn verify_checkpoint(&self, checkpoint_path: &PathBuf) -> TorshResult<()> {
debug!("Verifying checkpoint {:?}", checkpoint_path);
let checkpoint = self.load_checkpoint(checkpoint_path).await?;
if checkpoint.is_none() {
return Err(TorshDistributedError::backend_error(
"checkpoint",
"Checkpoint verification failed: could not load",
));
}
debug!("Checkpoint verification successful");
Ok(())
}
async fn find_checkpoint_files(&self) -> TorshResult<Vec<PathBuf>> {
let mut checkpoint_files = Vec::new();
let mut dir_entries = fs::read_dir(&self.config.checkpoint_dir)
.await
.map_err(|e| {
TorshDistributedError::backend_error(
"checkpoint",
format!("Failed to read checkpoint directory: {}", e),
)
})?;
while let Some(entry) = dir_entries.next_entry().await.map_err(|e| {
TorshDistributedError::backend_error(
"checkpoint",
format!("Failed to read directory entry: {}", e),
)
})? {
let path = entry.path();
if let Some(filename) = path.file_name() {
if filename.to_string_lossy().starts_with("checkpoint_")
&& filename.to_string_lossy().ends_with(".json")
{
checkpoint_files.push(path);
}
}
}
Ok(checkpoint_files)
}
fn extract_step_from_filename(&self, path: &Path) -> usize {
if let Some(filename) = path.file_stem() {
let filename_str = filename.to_string_lossy();
if let Some(step_start) = filename_str.find("step_") {
let step_part = &filename_str[step_start + 5..];
if let Some(rank_pos) = step_part.find("_rank_") {
let step_str = &step_part[..rank_pos];
return step_str.parse().unwrap_or(0);
}
}
}
0
}
pub fn get_latest_checkpoint_info(&self) -> Option<TrainingCheckpoint> {
self.latest_checkpoint
.read()
.expect("lock should not be poisoned")
.clone()
}
pub async fn cleanup_all_checkpoints(&self) -> TorshResult<()> {
let checkpoint_files = self.find_checkpoint_files().await?;
for file in checkpoint_files {
if let Err(e) = fs::remove_file(&file).await {
warn!("Failed to remove checkpoint file {:?}: {}", file, e);
}
}
{
let mut history = self
.checkpoint_history
.write()
.expect("lock should not be poisoned");
history.clear();
}
{
let mut latest = self
.latest_checkpoint
.write()
.expect("lock should not be poisoned");
*latest = None;
}
info!("Cleaned up all checkpoints");
Ok(())
}
}
#[derive(Debug)]
pub struct ElasticTrainingManager {
config: ElasticConfig,
scaling_state: Arc<RwLock<ScalingState>>,
checkpoint_manager: CheckpointManager,
current_world_size: Arc<RwLock<usize>>,
worker_registry: Arc<RwLock<HashMap<usize, SystemTime>>>,
scaling_events: Arc<Mutex<Vec<ScalingEvent>>>,
}
impl ElasticTrainingManager {
pub fn new(
config: ElasticConfig,
checkpoint_config: CheckpointConfig,
initial_world_size: usize,
) -> TorshResult<Self> {
let checkpoint_manager = CheckpointManager::new(
checkpoint_config,
Duration::from_secs(30), Duration::from_secs(120), )?;
Ok(Self {
config,
scaling_state: Arc::new(RwLock::new(ScalingState::Stable)),
checkpoint_manager,
current_world_size: Arc::new(RwLock::new(initial_world_size)),
worker_registry: Arc::new(RwLock::new(HashMap::new())),
scaling_events: Arc::new(Mutex::new(Vec::new())),
})
}
pub async fn check_scaling_needs(&self) -> TorshResult<Option<ScalingEvent>> {
let current_state = self
.scaling_state
.read()
.expect("lock should not be poisoned")
.clone();
match current_state {
ScalingState::Stable => {
let current_workers = *self
.current_world_size
.read()
.expect("lock should not be poisoned");
let failed_workers = self.detect_failed_workers().await?;
if !failed_workers.is_empty() {
let event = ScalingEvent::WorkerFailure {
failed_ranks: failed_workers,
};
info!("Detected worker failures, initiating scaling: {:?}", event);
self.initiate_scaling(event.clone()).await?;
return Ok(Some(event));
}
let new_workers = self.detect_new_workers().await?;
if !new_workers.is_empty() {
let event = ScalingEvent::WorkerJoin {
new_ranks: new_workers,
};
info!("Detected new workers, initiating scaling: {:?}", event);
self.initiate_scaling(event.clone()).await?;
return Ok(Some(event));
}
if self.config.enable_elastic_scheduling {
if let Some(target) = self.calculate_optimal_workers(current_workers).await? {
if target != current_workers {
let event = ScalingEvent::AutoScale {
target_workers: target,
reason: "Load-based scaling".to_string(),
};
info!("Initiating auto-scaling: {:?}", event);
self.initiate_scaling(event.clone()).await?;
return Ok(Some(event));
}
}
}
}
ScalingState::Scaling { .. } => {
if self.is_scaling_complete().await? {
self.complete_scaling().await?;
}
}
ScalingState::Synchronizing { .. } => {
if self.is_synchronization_complete().await? {
self.complete_synchronization().await?;
}
}
}
Ok(None)
}
async fn initiate_scaling(&self, event: ScalingEvent) -> TorshResult<()> {
info!("Initiating scaling for event: {:?}", event);
if let Some(checkpoint) = self.checkpoint_manager.get_latest_checkpoint_info() {
self.checkpoint_manager
.save_checkpoint(checkpoint, 0)
.await?;
}
let expected_workers = match &event {
ScalingEvent::WorkerFailure { failed_ranks } => {
*self
.current_world_size
.read()
.expect("lock should not be poisoned")
- failed_ranks.len()
}
ScalingEvent::WorkerJoin { new_ranks } => {
*self
.current_world_size
.read()
.expect("lock should not be poisoned")
+ new_ranks.len()
}
ScalingEvent::ManualScale { target_workers }
| ScalingEvent::AutoScale { target_workers, .. } => *target_workers,
};
let expected_workers = expected_workers
.max(self.config.min_workers)
.min(self.config.max_workers);
{
let mut state = self
.scaling_state
.write()
.expect("lock should not be poisoned");
*state = ScalingState::Scaling {
event: event.clone(),
start_time: SystemTime::now(),
expected_workers,
};
}
{
let mut events = self
.scaling_events
.lock()
.expect("lock should not be poisoned");
events.push(event);
if events.len() > 100 {
events.drain(0..50);
}
}
Ok(())
}
async fn is_scaling_complete(&self) -> TorshResult<bool> {
if let ScalingState::Scaling { start_time, .. } = *self
.scaling_state
.read()
.expect("lock should not be poisoned")
{
Ok(start_time.elapsed().unwrap_or(Duration::ZERO) >= self.config.scaling_timeout)
} else {
Ok(false)
}
}
async fn complete_scaling(&self) -> TorshResult<()> {
info!("Completing scaling process");
let expected_workers = if let ScalingState::Scaling {
expected_workers, ..
} = *self
.scaling_state
.read()
.expect("lock should not be poisoned")
{
expected_workers
} else {
return Ok(());
};
{
let mut state = self
.scaling_state
.write()
.expect("lock should not be poisoned");
*state = ScalingState::Synchronizing {
current_workers: *self
.current_world_size
.read()
.expect("lock should not be poisoned"),
target_workers: expected_workers,
};
}
info!("Transitioning to synchronization phase");
Ok(())
}
async fn is_synchronization_complete(&self) -> TorshResult<bool> {
Ok(true)
}
async fn complete_synchronization(&self) -> TorshResult<()> {
info!("Completing synchronization process");
let target_workers = if let ScalingState::Synchronizing { target_workers, .. } = *self
.scaling_state
.read()
.expect("lock should not be poisoned")
{
target_workers
} else {
return Ok(());
};
{
let mut world_size = self
.current_world_size
.write()
.expect("lock should not be poisoned");
*world_size = target_workers;
}
{
let mut state = self
.scaling_state
.write()
.expect("lock should not be poisoned");
*state = ScalingState::Stable;
}
info!(
"Elastic scaling completed, new world size: {}",
target_workers
);
Ok(())
}
async fn detect_failed_workers(&self) -> TorshResult<Vec<usize>> {
Ok(Vec::new())
}
async fn detect_new_workers(&self) -> TorshResult<Vec<usize>> {
Ok(Vec::new())
}
async fn calculate_optimal_workers(
&self,
_current_workers: usize,
) -> TorshResult<Option<usize>> {
Ok(None)
}
pub fn get_scaling_state(&self) -> ScalingState {
self.scaling_state
.read()
.expect("lock should not be poisoned")
.clone()
}
pub fn get_world_size(&self) -> usize {
*self
.current_world_size
.read()
.expect("lock should not be poisoned")
}
pub async fn scale_to(&self, target_workers: usize) -> TorshResult<()> {
let event = ScalingEvent::ManualScale { target_workers };
self.initiate_scaling(event).await
}
pub fn get_scaling_history(&self) -> Vec<ScalingEvent> {
self.scaling_events
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn can_proceed_training(&self) -> bool {
matches!(
*self
.scaling_state
.read()
.expect("lock should not be poisoned"),
ScalingState::Stable
)
}
pub fn checkpoint_manager(&self) -> &CheckpointManager {
&self.checkpoint_manager
}
}
pub mod checkpoint_utils {
use super::*;
#[allow(dead_code)]
pub struct CheckpointParams {
pub step: usize,
pub epoch: usize,
pub model_params: HashMap<String, Parameter>,
pub optimizer_state: HashMap<String, Tensor>,
pub loss: f32,
pub metrics: HashMap<String, f32>,
pub world_size: usize,
pub rank: usize,
}
pub fn create_checkpoint(params: CheckpointParams) -> TorshResult<TrainingCheckpoint> {
let CheckpointParams {
step,
epoch,
model_params,
optimizer_state,
loss,
metrics,
world_size,
rank,
} = params;
let mut model_state = HashMap::new();
for (name, param) in model_params {
let tensor = param.tensor();
let tensor_guard = tensor.read();
let data = tensor_guard.flatten()?.to_vec()?;
model_state.insert(name, data);
}
let mut opt_state = HashMap::new();
for (name, tensor) in optimizer_state {
let data = tensor.flatten()?.to_vec()?;
opt_state.insert(name, data);
}
let distributed_meta = DistributedMetadata {
world_size,
rank,
process_group_info: HashMap::new(),
dp_size: world_size, tp_size: 1,
pp_size: 1,
fsdp_sharding: HashMap::new(),
};
Ok(TrainingCheckpoint {
step,
epoch,
model_state,
optimizer_state: opt_state,
scheduler_state: HashMap::new(),
rng_states: HashMap::new(),
loss,
metrics,
config: HashMap::new(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
version: "1.0.0".to_string(),
distributed_meta,
})
}
pub fn restore_model_from_checkpoint(
checkpoint: &TrainingCheckpoint,
) -> TorshResult<HashMap<String, Tensor>> {
let mut model_params = HashMap::new();
for (name, data) in &checkpoint.model_state {
let shape = vec![data.len()]; let tensor = Tensor::from_vec(data.clone(), &shape)?;
model_params.insert(name.clone(), tensor);
}
Ok(model_params)
}
pub fn restore_optimizer_from_checkpoint(
checkpoint: &TrainingCheckpoint,
) -> TorshResult<HashMap<String, Tensor>> {
let mut optimizer_state = HashMap::new();
for (name, data) in &checkpoint.optimizer_state {
let shape = vec![data.len()]; let tensor = Tensor::from_vec(data.clone(), &shape)?;
optimizer_state.insert(name.clone(), tensor);
}
Ok(optimizer_state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_checkpoint_manager() -> TorshResult<()> {
let temp_dir = TempDir::new().unwrap();
let config = CheckpointConfig {
checkpoint_dir: temp_dir.path().to_path_buf(),
checkpoint_frequency: 100,
max_checkpoints: 3,
..Default::default()
};
let manager = CheckpointManager::new(
config,
Duration::from_millis(100),
Duration::from_millis(200),
)?;
let checkpoint = TrainingCheckpoint {
step: 1000,
epoch: 10,
model_state: {
let mut state = HashMap::new();
state.insert("weight".to_string(), vec![1.0, 2.0, 3.0]);
state
},
optimizer_state: HashMap::new(),
scheduler_state: HashMap::new(),
rng_states: HashMap::new(),
loss: 0.5,
metrics: HashMap::new(),
config: HashMap::new(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
version: "1.0.0".to_string(),
distributed_meta: DistributedMetadata {
world_size: 4,
rank: 0,
process_group_info: HashMap::new(),
dp_size: 4,
tp_size: 1,
pp_size: 1,
fsdp_sharding: HashMap::new(),
},
};
let checkpoint_path = manager.save_checkpoint(checkpoint.clone(), 0).await?;
assert!(checkpoint_path.exists());
let loaded = manager.load_latest_checkpoint().await?;
assert!(loaded.is_some());
let loaded_checkpoint = loaded.unwrap();
assert_eq!(loaded_checkpoint.step, checkpoint.step);
assert_eq!(loaded_checkpoint.loss, checkpoint.loss);
Ok(())
}
#[tokio::test]
async fn test_elastic_training_manager() -> TorshResult<()> {
let temp_dir = TempDir::new().unwrap();
let elastic_config = ElasticConfig {
min_workers: 2,
max_workers: 8,
scaling_timeout: Duration::from_millis(100),
..Default::default()
};
let checkpoint_config = CheckpointConfig {
checkpoint_dir: temp_dir.path().to_path_buf(),
..Default::default()
};
let manager = ElasticTrainingManager::new(
elastic_config,
checkpoint_config,
4, )?;
assert_eq!(manager.get_world_size(), 4);
assert!(manager.can_proceed_training());
manager.scale_to(6).await?;
match manager.get_scaling_state() {
ScalingState::Scaling {
expected_workers, ..
} => {
assert_eq!(expected_workers, 6);
}
_ => panic!("Expected scaling state"),
}
Ok(())
}
#[test]
fn test_checkpoint_config() {
let config = CheckpointConfig::default();
assert_eq!(config.checkpoint_frequency, 1000);
assert_eq!(config.max_checkpoints, 5);
assert!(config.async_save);
}
#[test]
fn test_elastic_config() {
let config = ElasticConfig::default();
assert_eq!(config.min_workers, 1);
assert_eq!(config.max_workers, 64);
assert!(config.enable_elastic_scheduling);
}
#[test]
fn test_scaling_events() {
let event1 = ScalingEvent::WorkerFailure {
failed_ranks: vec![1, 2],
};
let event2 = ScalingEvent::WorkerJoin {
new_ranks: vec![5, 6],
};
let event3 = ScalingEvent::ManualScale { target_workers: 8 };
assert_ne!(event1, event2);
assert_ne!(event2, event3);
}
}