use crate::{ProcessGroup, TorshDistributedError, TorshResult};
use log::{debug, info, warn};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
use super::config::Zero3CpuOffloadConfig;
pub struct PrefetchScheduler {
config: Zero3CpuOffloadConfig,
process_group: Arc<ProcessGroup>,
prefetch_queue: Mutex<VecDeque<PrefetchRequest>>,
execution_history: Mutex<HashMap<String, ExecutionHistory>>,
performance_metrics: Mutex<PrefetchPerformanceMetrics>,
active_prefetch_tasks: Mutex<HashMap<String, tokio::task::JoinHandle<TorshResult<()>>>>,
}
impl PrefetchScheduler {
pub fn new(config: &Zero3CpuOffloadConfig, process_group: Arc<ProcessGroup>) -> Self {
info!(
"⏰ Prefetch Scheduler initialized: async={}, buffer_size={}, overlap={}",
config.async_prefetch, config.prefetch_buffer_size, config.overlap_computation
);
Self {
config: config.clone(),
process_group,
prefetch_queue: Mutex::new(VecDeque::new()),
execution_history: Mutex::new(HashMap::new()),
performance_metrics: Mutex::new(PrefetchPerformanceMetrics::new()),
active_prefetch_tasks: Mutex::new(HashMap::new()),
}
}
pub async fn schedule_prefetch(&self, layer_name: &str, priority: PrefetchPriority) -> TorshResult<()> {
if !self.config.async_prefetch {
return Ok(());
}
let request = PrefetchRequest {
layer_name: layer_name.to_string(),
priority,
requested_at: Instant::now(),
estimated_completion_time: None,
};
{
let mut queue = self.prefetch_queue.lock().expect("lock should not be poisoned");
let insert_pos = queue.iter().position(|req| req.priority < priority)
.unwrap_or(queue.len());
queue.insert(insert_pos, request);
}
info!(" ⏰ Scheduled prefetch for layer: {} (priority: {:?})", layer_name, priority);
self.process_prefetch_queue().await?;
Ok(())
}
async fn process_prefetch_queue(&self) -> TorshResult<()> {
let mut requests_to_process = Vec::new();
{
let mut queue = self.prefetch_queue.lock().expect("lock should not be poisoned");
let max_concurrent = self.config.prefetch_buffer_size.min(4);
let active_count = self.active_prefetch_tasks.lock().expect("lock should not be poisoned").len();
let can_process = max_concurrent.saturating_sub(active_count);
for _ in 0..can_process {
if let Some(request) = queue.pop_front() {
requests_to_process.push(request);
} else {
break;
}
}
}
for request in requests_to_process {
self.execute_async_prefetch(request).await?;
}
Ok(())
}
async fn execute_async_prefetch(&self, request: PrefetchRequest) -> TorshResult<()> {
let layer_name = request.layer_name.clone();
let process_group = self.process_group.clone();
let start_time = Instant::now();
{
let mut metrics = self.performance_metrics.lock().expect("lock should not be poisoned");
metrics.prefetch_started(&layer_name, start_time);
}
let task = tokio::spawn(async move {
let result = Self::prefetch_layer_data(&layer_name, process_group).await;
if let Err(ref e) = result {
tracing::error!("Async prefetch failed for layer {}: {}", layer_name, e);
}
result
});
{
let mut active_tasks = self.active_prefetch_tasks.lock().expect("lock should not be poisoned");
active_tasks.insert(layer_name.clone(), task);
}
self.cleanup_completed_tasks().await;
Ok(())
}
async fn cleanup_completed_tasks(&self) {
let mut completed_tasks = Vec::new();
{
let mut active_tasks = self.active_prefetch_tasks.lock().expect("lock should not be poisoned");
let mut to_remove = Vec::new();
for (layer_name, task) in active_tasks.iter() {
if task.is_finished() {
to_remove.push(layer_name.clone());
}
}
for layer_name in to_remove {
if let Some(task) = active_tasks.remove(&layer_name) {
completed_tasks.push((layer_name, task));
}
}
}
for (layer_name, task) in completed_tasks {
let end_time = Instant::now();
match task.await {
Ok(Ok(())) => {
let mut metrics = self.performance_metrics.lock().expect("lock should not be poisoned");
metrics.prefetch_completed(&layer_name, end_time, true);
}
Ok(Err(_)) | Err(_) => {
let mut metrics = self.performance_metrics.lock().expect("lock should not be poisoned");
metrics.prefetch_completed(&layer_name, end_time, false);
}
}
}
}
async fn prefetch_layer_data(
layer_name: &str,
_process_group: Arc<ProcessGroup>,
) -> TorshResult<()> {
let transfer_time = Duration::from_millis(5 + (layer_name.len() % 20) as u64);
tokio::time::sleep(transfer_time).await;
info!(" 📤 Async prefetch completed for layer: {}", layer_name);
Ok(())
}
pub async fn batch_prefetch(&self, layer_names: Vec<String>, priority: PrefetchPriority) -> TorshResult<BatchPrefetchResult> {
if !self.config.async_prefetch || layer_names.is_empty() {
return Ok(BatchPrefetchResult::default());
}
let start_time = Instant::now();
info!(
" Starting batch prefetch for {} layers",
layer_names.len()
);
let max_concurrent = self.config.prefetch_buffer_size.min(4);
let semaphore = Arc::new(Semaphore::new(max_concurrent));
let mut tasks = Vec::new();
for layer_name in layer_names.iter() {
let sem = semaphore.clone();
let process_group = self.process_group.clone();
let layer_name_clone = layer_name.clone();
let task = tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore should not be closed");
Self::prefetch_layer_data(&layer_name_clone, process_group).await
});
tasks.push((layer_name.clone(), task));
}
let results = futures::future::join_all(tasks.into_iter().map(|(_, task)| task)).await;
let mut successful = 0;
let mut failed = 0;
let mut failed_layers = Vec::new();
for (i, result) in results.into_iter().enumerate() {
match result {
Ok(Ok(())) => successful += 1,
Ok(Err(e)) => {
failed += 1;
failed_layers.push(layer_names[i].clone());
tracing::error!("Prefetch task failed for {}: {}", layer_names[i], e);
}
Err(e) => {
failed += 1;
failed_layers.push(layer_names[i].clone());
tracing::error!("Prefetch task panicked for {}: {}", layer_names[i], e);
}
}
}
let duration = start_time.elapsed();
info!(
" Batch prefetch completed: {} successful, {} failed in {:?}",
successful, failed, duration
);
Ok(BatchPrefetchResult {
total_layers: layer_names.len(),
successful,
failed,
failed_layers,
duration,
})
}
pub async fn intelligent_prefetch(
&self,
current_layer: &str,
execution_graph: &[String],
) -> TorshResult<IntelligentPrefetchResult> {
if !self.config.async_prefetch {
return Ok(IntelligentPrefetchResult::default());
}
let start_time = Instant::now();
self.update_execution_history(current_layer).await;
let current_pos = execution_graph.iter().position(|l| l == current_layer);
if let Some(pos) = current_pos {
let prefetch_distance = self.calculate_optimal_prefetch_distance().await?;
let mut layers_to_prefetch = Vec::new();
for i in 1..=prefetch_distance {
if pos + i < execution_graph.len() {
let layer_name = &execution_graph[pos + i];
let priority = self.calculate_layer_priority(layer_name, i).await;
layers_to_prefetch.push((layer_name.clone(), priority));
}
}
if !layers_to_prefetch.is_empty() {
info!(
" 🧠 Intelligent prefetch: {} layers ahead from {}",
layers_to_prefetch.len(),
current_layer
);
layers_to_prefetch.sort_by(|a, b| b.1.cmp(&a.1));
let layer_names: Vec<String> = layers_to_prefetch.iter().map(|(name, _)| name.clone()).collect();
let highest_priority = layers_to_prefetch.first().map(|(_, p)| *p).unwrap_or(PrefetchPriority::Medium);
let batch_result = self.batch_prefetch(layer_names.clone(), highest_priority).await?;
return Ok(IntelligentPrefetchResult {
layers_scheduled: layer_names,
prefetch_distance,
duration: start_time.elapsed(),
batch_result: Some(batch_result),
});
}
}
Ok(IntelligentPrefetchResult {
layers_scheduled: Vec::new(),
prefetch_distance: 0,
duration: start_time.elapsed(),
batch_result: None,
})
}
async fn update_execution_history(&self, layer_name: &str) {
let mut history = self.execution_history.lock().expect("lock should not be poisoned");
let entry = history.entry(layer_name.to_string()).or_insert_with(ExecutionHistory::new);
entry.record_execution(Instant::now());
}
async fn calculate_layer_priority(&self, layer_name: &str, distance: usize) -> PrefetchPriority {
let history = self.execution_history.lock().expect("lock should not be poisoned");
if let Some(layer_history) = history.get(layer_name) {
let frequency_score = layer_history.access_frequency();
let distance_penalty = 1.0 / (distance as f32 + 1.0);
let combined_score = frequency_score * distance_penalty;
if combined_score > 0.7 {
PrefetchPriority::High
} else if combined_score > 0.3 {
PrefetchPriority::Medium
} else {
PrefetchPriority::Low
}
} else {
match distance {
1 => PrefetchPriority::High,
2..=3 => PrefetchPriority::Medium,
_ => PrefetchPriority::Low,
}
}
}
async fn calculate_optimal_prefetch_distance(&self) -> TorshResult<usize> {
let base_distance = self.config.prefetch_buffer_size / 4;
let metrics = self.performance_metrics.lock().expect("lock should not be poisoned");
let success_rate = metrics.overall_success_rate();
let adjusted_distance = if success_rate > 0.9 {
base_distance + 2 } else if success_rate < 0.7 {
base_distance.saturating_sub(1) } else {
base_distance
};
let optimal_distance = adjusted_distance.max(1).min(8);
Ok(optimal_distance)
}
pub async fn cancel_prefetch(&self, layer_name: &str) -> TorshResult<bool> {
{
let mut queue = self.prefetch_queue.lock().expect("lock should not be poisoned");
if let Some(pos) = queue.iter().position(|req| req.layer_name == layer_name) {
queue.remove(pos);
info!(" ❌ Cancelled queued prefetch for layer: {}", layer_name);
return Ok(true);
}
}
{
let mut active_tasks = self.active_prefetch_tasks.lock().expect("lock should not be poisoned");
if let Some(task) = active_tasks.remove(layer_name) {
task.abort();
info!(" ❌ Cancelled active prefetch for layer: {}", layer_name);
return Ok(true);
}
}
Ok(false)
}
pub fn get_statistics(&self) -> PrefetchSchedulerStats {
let queue_length = self.prefetch_queue.lock().expect("lock should not be poisoned").len();
let active_tasks = self.active_prefetch_tasks.lock().expect("lock should not be poisoned").len();
let metrics = self.performance_metrics.lock().expect("lock should not be poisoned").clone();
let history_entries = self.execution_history.lock().expect("lock should not be poisoned").len();
PrefetchSchedulerStats {
queue_length,
active_tasks,
performance_metrics: metrics,
tracked_layers: history_entries,
async_prefetch_enabled: self.config.async_prefetch,
prefetch_buffer_size: self.config.prefetch_buffer_size,
overlap_computation: self.config.overlap_computation,
}
}
pub async fn clear_all(&self) -> TorshResult<()> {
{
let mut queue = self.prefetch_queue.lock().expect("lock should not be poisoned");
queue.clear();
}
let active_tasks: Vec<_> = {
let mut tasks = self.active_prefetch_tasks.lock().expect("lock should not be poisoned");
tasks.drain().collect()
};
for (layer_name, task) in active_tasks {
task.abort();
info!(" ❌ Cancelled prefetch task for: {}", layer_name);
}
info!(" 🧹 Cleared all prefetch queues and tasks");
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PrefetchPriority {
Low = 1,
Medium = 2,
High = 3,
Critical = 4,
}
#[derive(Debug, Clone)]
struct PrefetchRequest {
layer_name: String,
priority: PrefetchPriority,
requested_at: Instant,
estimated_completion_time: Option<Instant>,
}
#[derive(Debug, Clone)]
struct ExecutionHistory {
access_times: VecDeque<Instant>,
total_accesses: usize,
last_access: Option<Instant>,
}
impl ExecutionHistory {
fn new() -> Self {
Self {
access_times: VecDeque::new(),
total_accesses: 0,
last_access: None,
}
}
fn record_execution(&mut self, time: Instant) {
self.access_times.push_back(time);
self.total_accesses += 1;
self.last_access = Some(time);
if self.access_times.len() > 50 {
self.access_times.pop_front();
}
}
fn access_frequency(&self) -> f32 {
if self.access_times.len() < 2 {
return 0.0;
}
let time_span = self.access_times.back().expect("access_times should have at least 2 elements")
.duration_since(*self.access_times.front().expect("access_times should have at least 2 elements"));
if time_span.is_zero() {
return 1.0;
}
(self.access_times.len() as f32) / time_span.as_secs_f32()
}
}
#[derive(Debug, Clone)]
pub struct PrefetchPerformanceMetrics {
total_requests: usize,
successful_requests: usize,
failed_requests: usize,
total_duration: Duration,
layer_metrics: HashMap<String, LayerPrefetchMetrics>,
}
impl PrefetchPerformanceMetrics {
fn new() -> Self {
Self {
total_requests: 0,
successful_requests: 0,
failed_requests: 0,
total_duration: Duration::ZERO,
layer_metrics: HashMap::new(),
}
}
fn prefetch_started(&mut self, layer_name: &str, start_time: Instant) {
self.total_requests += 1;
let entry = self.layer_metrics.entry(layer_name.to_string())
.or_insert_with(LayerPrefetchMetrics::new);
entry.record_start(start_time);
}
fn prefetch_completed(&mut self, layer_name: &str, end_time: Instant, success: bool) {
if success {
self.successful_requests += 1;
} else {
self.failed_requests += 1;
}
if let Some(entry) = self.layer_metrics.get_mut(layer_name) {
entry.record_completion(end_time, success);
if let Some(duration) = entry.last_duration {
self.total_duration += duration;
}
}
}
pub fn overall_success_rate(&self) -> f32 {
if self.total_requests == 0 {
0.0
} else {
self.successful_requests as f32 / self.total_requests as f32
}
}
pub fn average_duration(&self) -> Duration {
if self.successful_requests == 0 {
Duration::ZERO
} else {
self.total_duration / self.successful_requests as u32
}
}
}
#[derive(Debug, Clone)]
struct LayerPrefetchMetrics {
requests: usize,
successes: usize,
failures: usize,
total_duration: Duration,
last_start: Option<Instant>,
last_duration: Option<Duration>,
}
impl LayerPrefetchMetrics {
fn new() -> Self {
Self {
requests: 0,
successes: 0,
failures: 0,
total_duration: Duration::ZERO,
last_start: None,
last_duration: None,
}
}
fn record_start(&mut self, start_time: Instant) {
self.requests += 1;
self.last_start = Some(start_time);
}
fn record_completion(&mut self, end_time: Instant, success: bool) {
if success {
self.successes += 1;
} else {
self.failures += 1;
}
if let Some(start_time) = self.last_start {
let duration = end_time.duration_since(start_time);
self.total_duration += duration;
self.last_duration = Some(duration);
}
}
}
#[derive(Debug, Clone)]
pub struct BatchPrefetchResult {
pub total_layers: usize,
pub successful: usize,
pub failed: usize,
pub failed_layers: Vec<String>,
pub duration: Duration,
}
impl Default for BatchPrefetchResult {
fn default() -> Self {
Self {
total_layers: 0,
successful: 0,
failed: 0,
failed_layers: Vec::new(),
duration: Duration::ZERO,
}
}
}
#[derive(Debug, Clone)]
pub struct IntelligentPrefetchResult {
pub layers_scheduled: Vec<String>,
pub prefetch_distance: usize,
pub duration: Duration,
pub batch_result: Option<BatchPrefetchResult>,
}
impl Default for IntelligentPrefetchResult {
fn default() -> Self {
Self {
layers_scheduled: Vec::new(),
prefetch_distance: 0,
duration: Duration::ZERO,
batch_result: None,
}
}
}
#[derive(Debug, Clone)]
pub struct PrefetchSchedulerStats {
pub queue_length: usize,
pub active_tasks: usize,
pub performance_metrics: PrefetchPerformanceMetrics,
pub tracked_layers: usize,
pub async_prefetch_enabled: bool,
pub prefetch_buffer_size: usize,
pub overlap_computation: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{init_process_group, BackendType};
#[tokio::test]
async fn test_prefetch_scheduler_creation() {
let config = Zero3CpuOffloadConfig::default();
let process_group = Arc::new(init_process_group(BackendType::Gloo, 0, 1).await.expect("operation should succeed"));
let scheduler = PrefetchScheduler::new(&config, process_group);
let stats = scheduler.get_statistics();
assert_eq!(stats.queue_length, 0);
assert_eq!(stats.active_tasks, 0);
assert!(stats.async_prefetch_enabled);
}
#[tokio::test]
async fn test_schedule_prefetch() {
let config = Zero3CpuOffloadConfig::default();
let process_group = Arc::new(init_process_group(BackendType::Gloo, 0, 1).await.expect("operation should succeed"));
let scheduler = PrefetchScheduler::new(&config, process_group);
scheduler.schedule_prefetch("layer1", PrefetchPriority::High).await.expect("operation should succeed");
tokio::time::sleep(Duration::from_millis(50)).await;
let stats = scheduler.get_statistics();
assert!(stats.performance_metrics.total_requests > 0);
}
#[tokio::test]
async fn test_batch_prefetch() {
let config = Zero3CpuOffloadConfig::default();
let process_group = Arc::new(init_process_group(BackendType::Gloo, 0, 1).await.expect("operation should succeed"));
let scheduler = PrefetchScheduler::new(&config, process_group);
let layers = vec!["layer1".to_string(), "layer2".to_string(), "layer3".to_string()];
let result = scheduler.batch_prefetch(layers.clone(), PrefetchPriority::Medium).await.expect("operation should succeed");
assert_eq!(result.total_layers, 3);
assert_eq!(result.successful + result.failed, 3);
}
#[tokio::test]
async fn test_intelligent_prefetch() {
let config = Zero3CpuOffloadConfig::default();
let process_group = Arc::new(init_process_group(BackendType::Gloo, 0, 1).await.expect("operation should succeed"));
let scheduler = PrefetchScheduler::new(&config, process_group);
let execution_graph = vec![
"layer1".to_string(),
"layer2".to_string(),
"layer3".to_string(),
"layer4".to_string(),
"layer5".to_string(),
];
let result = scheduler.intelligent_prefetch("layer1", &execution_graph).await.expect("operation should succeed");
assert!(result.prefetch_distance > 0);
assert!(!result.layers_scheduled.is_empty() || result.prefetch_distance == 0);
}
#[tokio::test]
async fn test_cancel_prefetch() {
let mut config = Zero3CpuOffloadConfig::default();
config.async_prefetch = true;
let process_group = Arc::new(init_process_group(BackendType::Gloo, 0, 1).await.expect("operation should succeed"));
let scheduler = PrefetchScheduler::new(&config, process_group);
scheduler.schedule_prefetch("layer1", PrefetchPriority::Low).await.expect("operation should succeed");
let cancelled = scheduler.cancel_prefetch("layer1").await.expect("operation should succeed");
assert!(cancelled);
let not_cancelled = scheduler.cancel_prefetch("nonexistent").await.expect("operation should succeed");
assert!(!not_cancelled);
}
#[test]
fn test_execution_history() {
let mut history = ExecutionHistory::new();
let now = Instant::now();
history.record_execution(now);
assert_eq!(history.total_accesses, 1);
assert!(history.last_access.is_some());
let freq = history.access_frequency();
assert!(freq >= 0.0);
}
#[test]
fn test_prefetch_priority_ordering() {
assert!(PrefetchPriority::Critical > PrefetchPriority::High);
assert!(PrefetchPriority::High > PrefetchPriority::Medium);
assert!(PrefetchPriority::Medium > PrefetchPriority::Low);
}
#[test]
fn test_performance_metrics() {
let mut metrics = PrefetchPerformanceMetrics::new();
let start_time = Instant::now();
metrics.prefetch_started("layer1", start_time);
assert_eq!(metrics.total_requests, 1);
let end_time = start_time + Duration::from_millis(10);
metrics.prefetch_completed("layer1", end_time, true);
assert_eq!(metrics.successful_requests, 1);
assert_eq!(metrics.overall_success_rate(), 1.0);
assert!(metrics.average_duration() > Duration::ZERO);
}
}