#![allow(dead_code)]
#![allow(clippy::await_holding_lock)]
use crate::{ProcessGroup, TorshResult};
use log::info;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use tokio::sync::Semaphore;
use super::config::Zero3CpuOffloadConfig;
pub struct PrefetchScheduler {
config: Zero3CpuOffloadConfig,
process_group: Arc<ProcessGroup>,
prefetch_queue: Mutex<VecDeque<PrefetchRequest>>,
active_prefetches: Arc<Mutex<Vec<PrefetchOperation>>>,
metrics: Arc<Mutex<PrefetchMetrics>>,
adaptive_config: Arc<Mutex<AdaptivePrefetchConfig>>,
task_coordination: Arc<Mutex<TaskCoordination>>,
}
impl PrefetchScheduler {
pub fn new(config: &Zero3CpuOffloadConfig, process_group: Arc<ProcessGroup>) -> Self {
Self {
config: config.clone(),
process_group,
prefetch_queue: Mutex::new(VecDeque::new()),
active_prefetches: Arc::new(Mutex::new(Vec::new())),
metrics: Arc::new(Mutex::new(PrefetchMetrics::new())),
adaptive_config: Arc::new(Mutex::new(AdaptivePrefetchConfig::new(config))),
task_coordination: Arc::new(Mutex::new(TaskCoordination::new())),
}
}
pub async fn schedule_prefetch(&self, layer_name: &str) -> TorshResult<()> {
if !self.config.async_prefetch {
return Ok(());
}
let request = PrefetchRequest {
layer_name: layer_name.to_string(),
priority: PrefetchPriority::Normal,
requested_at: std::time::Instant::now(),
estimated_size_bytes: self.estimate_layer_size(layer_name),
};
{
let mut queue = self
.prefetch_queue
.lock()
.expect("lock should not be poisoned");
queue.push_back(request.clone());
let max_queue_size = self
.adaptive_config
.lock()
.expect("lock should not be poisoned")
.max_queue_size;
while queue.len() > max_queue_size {
if let Some(dropped) = queue.pop_front() {
info!(
" Dropped prefetch request for {} (queue full)",
dropped.layer_name
);
}
}
}
info!(
" Scheduled prefetch for layer: {} ({} bytes)",
layer_name, request.estimated_size_bytes
);
self.execute_async_prefetch(request).await?;
Ok(())
}
async fn execute_async_prefetch(&self, request: PrefetchRequest) -> TorshResult<()> {
let process_group = self.process_group.clone();
let metrics = self.metrics.clone();
let active_prefetches = self.active_prefetches.clone();
if !self.can_start_prefetch().await? {
info!(
" = Delaying prefetch for {} (system busy)",
request.layer_name
);
return Ok(());
}
let operation = PrefetchOperation {
layer_name: request.layer_name.clone(),
started_at: std::time::Instant::now(),
status: PrefetchStatus::InProgress,
};
{
let mut active = active_prefetches
.lock()
.expect("lock should not be poisoned");
active.push(operation);
}
let layer_name = request.layer_name.clone();
tokio::spawn(async move {
let start_time = std::time::Instant::now();
let result = Self::prefetch_layer_data(&layer_name, process_group).await;
{
let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
let duration = start_time.elapsed();
match result {
Ok(()) => {
metrics_guard.record_successful_prefetch(duration, 0); info!(
" = Async prefetch completed for layer: {} in {:?}",
layer_name, duration
);
}
Err(e) => {
metrics_guard.record_failed_prefetch(duration, e.to_string());
tracing::error!("Async prefetch failed for layer {}: {}", layer_name, e);
}
}
}
{
let mut active = active_prefetches
.lock()
.expect("lock should not be poisoned");
active.retain(|op| op.layer_name != layer_name);
}
});
Ok(())
}
async fn can_start_prefetch(&self) -> TorshResult<bool> {
let adaptive_config = self
.adaptive_config
.lock()
.expect("lock should not be poisoned");
let active_count = self
.active_prefetches
.lock()
.expect("lock should not be poisoned")
.len();
if active_count >= adaptive_config.max_concurrent_prefetches {
return Ok(false);
}
Ok(true)
}
async fn prefetch_layer_data(
layer_name: &str,
_process_group: Arc<ProcessGroup>,
) -> TorshResult<()> {
let estimated_transfer_time = Self::estimate_transfer_time(layer_name);
tokio::time::sleep(estimated_transfer_time).await;
Ok(())
}
fn estimate_transfer_time(layer_name: &str) -> tokio::time::Duration {
let base_time_ms = if layer_name.contains("large") {
50 } else if layer_name.contains("medium") {
25 } else {
10 };
tokio::time::Duration::from_millis(base_time_ms)
}
pub async fn batch_prefetch(&self, layer_names: Vec<String>) -> TorshResult<()> {
if !self.config.async_prefetch || layer_names.is_empty() {
return Ok(());
}
info!(
" = Starting batch prefetch for {} layers",
layer_names.len()
);
#[allow(clippy::await_holding_lock)]
let adaptive_config = self
.adaptive_config
.lock()
.expect("lock should not be poisoned");
let max_concurrent = adaptive_config.max_concurrent_prefetches;
drop(adaptive_config);
let semaphore = Arc::new(Semaphore::new(max_concurrent));
let mut tasks = Vec::new();
for layer_name in layer_names {
let sem = semaphore.clone();
let process_group = self.process_group.clone();
let metrics = self.metrics.clone();
let task = tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore should not be closed");
let start_time = std::time::Instant::now();
let result = Self::prefetch_layer_data(&layer_name, process_group).await;
{
let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
let duration = start_time.elapsed();
match result {
Ok(()) => metrics_guard.record_successful_prefetch(duration, 0),
Err(ref e) => metrics_guard.record_failed_prefetch(duration, e.to_string()),
}
}
result
});
tasks.push(task);
}
let results: Vec<_> = futures::future::join_all(tasks).await;
let mut successful = 0;
let mut failed = 0;
for result in results {
match result {
Ok(Ok(())) => successful += 1,
Ok(Err(e)) => {
failed += 1;
tracing::error!("Prefetch task failed: {}", e);
}
Err(e) => {
failed += 1;
tracing::error!("Prefetch task panicked: {}", e);
}
}
}
info!(
" Batch prefetch completed: {} successful, {} failed",
successful, failed
);
{
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
metrics.record_batch_prefetch(successful, failed);
}
Ok(())
}
pub async fn intelligent_prefetch(
&self,
current_layer: &str,
execution_graph: &[String],
) -> TorshResult<()> {
if !self.config.async_prefetch {
return Ok(());
}
let current_pos = execution_graph.iter().position(|l| l == current_layer);
if let Some(pos) = current_pos {
let prefetch_distance = self.calculate_optimal_prefetch_distance().await?;
let mut layers_to_prefetch = Vec::new();
for i in 1..=prefetch_distance {
if pos + i < execution_graph.len() {
layers_to_prefetch.push(execution_graph[pos + i].clone());
}
}
if !layers_to_prefetch.is_empty() {
info!(
" > Intelligent prefetch: {} layers ahead from {}",
layers_to_prefetch.len(),
current_layer
);
self.prioritized_batch_prefetch(layers_to_prefetch, pos)
.await?;
}
}
Ok(())
}
async fn prioritized_batch_prefetch(
&self,
layer_names: Vec<String>,
_current_pos: usize,
) -> TorshResult<()> {
let mut prioritized_requests = Vec::new();
for (i, layer_name) in layer_names.iter().enumerate() {
let priority = match i {
0 => PrefetchPriority::High, 1..=2 => PrefetchPriority::Normal, _ => PrefetchPriority::Low, };
let request = PrefetchRequest {
layer_name: layer_name.clone(),
priority,
requested_at: std::time::Instant::now(),
estimated_size_bytes: self.estimate_layer_size(layer_name),
};
prioritized_requests.push(request);
}
prioritized_requests.sort_by(|a, b| b.priority.cmp(&a.priority));
#[allow(clippy::await_holding_lock)]
let adaptive_config = self
.adaptive_config
.lock()
.expect("lock should not be poisoned");
let max_concurrent = adaptive_config.max_concurrent_prefetches;
drop(adaptive_config);
let semaphore = Arc::new(Semaphore::new(max_concurrent));
let mut tasks = Vec::new();
for request in prioritized_requests {
let sem = semaphore.clone();
let process_group = self.process_group.clone();
let metrics = self.metrics.clone();
let delay = match request.priority {
PrefetchPriority::High => tokio::time::Duration::from_millis(0),
PrefetchPriority::Normal => tokio::time::Duration::from_millis(10),
PrefetchPriority::Low => tokio::time::Duration::from_millis(25),
};
let task = tokio::spawn(async move {
tokio::time::sleep(delay).await; let _permit = sem.acquire().await.expect("semaphore should not be closed");
let start_time = std::time::Instant::now();
let result = Self::prefetch_layer_data(&request.layer_name, process_group).await;
{
let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
let duration = start_time.elapsed();
match result {
Ok(()) => metrics_guard
.record_successful_prefetch(duration, request.estimated_size_bytes),
Err(ref e) => metrics_guard.record_failed_prefetch(duration, e.to_string()),
}
}
(request.layer_name, result)
});
tasks.push(task);
}
let results: Vec<_> = futures::future::join_all(tasks).await;
let mut successful = 0;
let mut failed = 0;
for result in results {
match result {
Ok((layer_name, Ok(()))) => {
successful += 1;
info!(" Prioritized prefetch completed: {}", layer_name);
}
Ok((layer_name, Err(e))) => {
failed += 1;
tracing::error!("Prioritized prefetch failed for {}: {}", layer_name, e);
}
Err(e) => {
failed += 1;
tracing::error!("Prioritized prefetch task panicked: {}", e);
}
}
}
info!(
" < Prioritized batch prefetch completed: {} successful, {} failed",
successful, failed
);
Ok(())
}
pub async fn calculate_optimal_prefetch_distance(&self) -> TorshResult<usize> {
let adaptive_config = self
.adaptive_config
.lock()
.expect("lock should not be poisoned");
let base_distance = adaptive_config.base_prefetch_distance;
let current_performance = self
.metrics
.lock()
.expect("lock should not be poisoned")
.get_success_rate();
let performance_multiplier = if current_performance > 0.9 {
1.5 } else if current_performance > 0.7 {
1.0 } else {
0.7 };
let optimal_distance = (base_distance as f32 * performance_multiplier) as usize;
let optimal_distance = optimal_distance
.max(1)
.min(adaptive_config.max_prefetch_distance);
drop(adaptive_config);
{
let mut adaptive_config = self
.adaptive_config
.lock()
.expect("lock should not be poisoned");
adaptive_config.current_prefetch_distance = optimal_distance;
}
Ok(optimal_distance)
}
pub async fn adapt_prefetch_strategy(&self) -> TorshResult<()> {
let metrics = self
.metrics
.lock()
.expect("lock should not be poisoned")
.clone();
let mut adaptive_config = self
.adaptive_config
.lock()
.expect("lock should not be poisoned");
info!(" < Adapting prefetch strategy based on performance");
if metrics.get_success_rate() > 0.95 && metrics.total_prefetches > 10 {
adaptive_config.max_concurrent_prefetches =
(adaptive_config.max_concurrent_prefetches + 1).min(8);
info!(
" Increased max concurrent prefetches to {}",
adaptive_config.max_concurrent_prefetches
);
} else if metrics.get_success_rate() < 0.8 && adaptive_config.max_concurrent_prefetches > 1
{
adaptive_config.max_concurrent_prefetches =
(adaptive_config.max_concurrent_prefetches - 1).max(1);
info!(
" Decreased max concurrent prefetches to {}",
adaptive_config.max_concurrent_prefetches
);
}
let queue_size = self
.prefetch_queue
.lock()
.expect("lock should not be poisoned")
.len();
if queue_size > adaptive_config.max_queue_size * 3 / 4 {
adaptive_config.max_queue_size = (adaptive_config.max_queue_size + 2).min(32);
info!(
" = Increased max queue size to {}",
adaptive_config.max_queue_size
);
} else if queue_size < adaptive_config.max_queue_size / 4
&& adaptive_config.max_queue_size > 4
{
adaptive_config.max_queue_size = (adaptive_config.max_queue_size - 1).max(4);
info!(
" = Decreased max queue size to {}",
adaptive_config.max_queue_size
);
}
if metrics.average_prefetch_time > tokio::time::Duration::from_millis(100) {
adaptive_config.base_prefetch_distance =
(adaptive_config.base_prefetch_distance - 1).max(1);
info!(
" =; Decreased base prefetch distance to {}",
adaptive_config.base_prefetch_distance
);
} else if metrics.average_prefetch_time < tokio::time::Duration::from_millis(20) {
adaptive_config.base_prefetch_distance =
(adaptive_config.base_prefetch_distance + 1).min(16);
info!(
" =: Increased base prefetch distance to {}",
adaptive_config.base_prefetch_distance
);
}
Ok(())
}
pub async fn cancel_all_prefetches(&self) -> TorshResult<()> {
info!(" = Cancelling all pending prefetch operations");
{
let mut queue = self
.prefetch_queue
.lock()
.expect("lock should not be poisoned");
let cancelled_count = queue.len();
queue.clear();
if cancelled_count > 0 {
info!(
" = Cancelled {} queued prefetch requests",
cancelled_count
);
}
}
{
let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
metrics.record_cancellation();
}
Ok(())
}
pub fn get_queue_status(&self) -> PrefetchQueueStatus {
let queue = self
.prefetch_queue
.lock()
.expect("lock should not be poisoned");
let active = self
.active_prefetches
.lock()
.expect("lock should not be poisoned");
let adaptive_config = self
.adaptive_config
.lock()
.expect("lock should not be poisoned");
PrefetchQueueStatus {
queued_requests: queue.len(),
active_operations: active.len(),
max_queue_size: adaptive_config.max_queue_size,
max_concurrent: adaptive_config.max_concurrent_prefetches,
current_prefetch_distance: adaptive_config.current_prefetch_distance,
}
}
pub fn get_metrics(&self) -> PrefetchMetrics {
self.metrics
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn get_adaptive_config(&self) -> AdaptivePrefetchConfig {
self.adaptive_config
.lock()
.expect("lock should not be poisoned")
.clone()
}
fn estimate_layer_size(&self, layer_name: &str) -> usize {
if layer_name.contains("large") {
64 * 1024 * 1024 } else if layer_name.contains("medium") {
16 * 1024 * 1024 } else {
4 * 1024 * 1024 }
}
}
#[derive(Debug, Clone)]
pub struct PrefetchRequest {
pub layer_name: String,
pub priority: PrefetchPriority,
pub requested_at: std::time::Instant,
pub estimated_size_bytes: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PrefetchPriority {
Low = 0,
Normal = 1,
High = 2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrefetchStatus {
Queued,
InProgress,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone)]
pub struct PrefetchOperation {
pub layer_name: String,
pub started_at: std::time::Instant,
pub status: PrefetchStatus,
}
#[derive(Debug, Clone)]
pub struct PrefetchQueueStatus {
pub queued_requests: usize,
pub active_operations: usize,
pub max_queue_size: usize,
pub max_concurrent: usize,
pub current_prefetch_distance: usize,
}
#[derive(Debug, Clone)]
pub struct AdaptivePrefetchConfig {
pub base_prefetch_distance: usize,
pub current_prefetch_distance: usize,
pub max_prefetch_distance: usize,
pub max_concurrent_prefetches: usize,
pub max_queue_size: usize,
pub adaptive_optimization_enabled: bool,
}
impl AdaptivePrefetchConfig {
pub fn new(config: &Zero3CpuOffloadConfig) -> Self {
Self {
base_prefetch_distance: config.prefetch_buffer_size / 4,
current_prefetch_distance: config.prefetch_buffer_size / 4,
max_prefetch_distance: config.prefetch_buffer_size,
max_concurrent_prefetches: 4,
max_queue_size: 16,
adaptive_optimization_enabled: true,
}
}
}
#[derive(Debug)]
pub struct TaskCoordination {
pub active_tasks: usize,
pub max_background_tasks: usize,
pub coordination_enabled: bool,
}
impl TaskCoordination {
pub fn new() -> Self {
Self {
active_tasks: 0,
max_background_tasks: 8,
coordination_enabled: true,
}
}
}
impl Default for TaskCoordination {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PrefetchMetrics {
pub total_prefetches: u64,
pub successful_prefetches: u64,
pub failed_prefetches: u64,
pub cancelled_prefetches: u64,
pub total_prefetch_time: tokio::time::Duration,
pub average_prefetch_time: tokio::time::Duration,
pub total_bytes_prefetched: usize,
pub batch_operations: u64,
pub failed_batch_operations: u64,
pub recent_failures: Vec<String>,
}
impl PrefetchMetrics {
pub fn new() -> Self {
Self {
total_prefetches: 0,
successful_prefetches: 0,
failed_prefetches: 0,
cancelled_prefetches: 0,
total_prefetch_time: tokio::time::Duration::ZERO,
average_prefetch_time: tokio::time::Duration::ZERO,
total_bytes_prefetched: 0,
batch_operations: 0,
failed_batch_operations: 0,
recent_failures: Vec::new(),
}
}
pub fn record_successful_prefetch(&mut self, duration: tokio::time::Duration, bytes: usize) {
self.total_prefetches += 1;
self.successful_prefetches += 1;
self.total_prefetch_time += duration;
self.total_bytes_prefetched += bytes;
self.update_average_time();
}
pub fn record_failed_prefetch(&mut self, duration: tokio::time::Duration, error: String) {
self.total_prefetches += 1;
self.failed_prefetches += 1;
self.total_prefetch_time += duration;
self.recent_failures.push(error);
if self.recent_failures.len() > 10 {
self.recent_failures.remove(0);
}
self.update_average_time();
}
pub fn record_batch_prefetch(&mut self, _successful: usize, failed: usize) {
self.batch_operations += 1;
if failed > 0 {
self.failed_batch_operations += 1;
}
}
pub fn record_cancellation(&mut self) {
self.cancelled_prefetches += 1;
}
pub fn get_success_rate(&self) -> f32 {
if self.total_prefetches > 0 {
self.successful_prefetches as f32 / self.total_prefetches as f32
} else {
1.0 }
}
pub fn get_failure_rate(&self) -> f32 {
if self.total_prefetches > 0 {
self.failed_prefetches as f32 / self.total_prefetches as f32
} else {
0.0
}
}
pub fn get_throughput_bps(&self) -> f64 {
if !self.total_prefetch_time.is_zero() {
self.total_bytes_prefetched as f64 / self.total_prefetch_time.as_secs_f64()
} else {
0.0
}
}
fn update_average_time(&mut self) {
if self.total_prefetches > 0 {
self.average_prefetch_time = self.total_prefetch_time / self.total_prefetches as u32;
}
}
pub fn reset(&mut self) {
*self = Self::new();
}
}
impl Default for PrefetchMetrics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{init_process_group, BackendType};
#[test]
fn test_prefetch_request_priority_ordering() {
let mut requests = [
PrefetchRequest {
layer_name: "low".to_string(),
priority: PrefetchPriority::Low,
requested_at: std::time::Instant::now(),
estimated_size_bytes: 1000,
},
PrefetchRequest {
layer_name: "high".to_string(),
priority: PrefetchPriority::High,
requested_at: std::time::Instant::now(),
estimated_size_bytes: 1000,
},
PrefetchRequest {
layer_name: "normal".to_string(),
priority: PrefetchPriority::Normal,
requested_at: std::time::Instant::now(),
estimated_size_bytes: 1000,
},
];
requests.sort_by(|a, b| b.priority.cmp(&a.priority));
assert_eq!(requests[0].layer_name, "high");
assert_eq!(requests[1].layer_name, "normal");
assert_eq!(requests[2].layer_name, "low");
}
#[test]
fn test_adaptive_prefetch_config() {
let zero3_config = Zero3CpuOffloadConfig::default();
let config = AdaptivePrefetchConfig::new(&zero3_config);
assert_eq!(
config.base_prefetch_distance,
zero3_config.prefetch_buffer_size / 4
);
assert_eq!(
config.max_prefetch_distance,
zero3_config.prefetch_buffer_size
);
assert!(config.adaptive_optimization_enabled);
}
#[test]
fn test_prefetch_metrics() {
let mut metrics = PrefetchMetrics::new();
metrics.record_successful_prefetch(tokio::time::Duration::from_millis(100), 1000);
assert_eq!(metrics.total_prefetches, 1);
assert_eq!(metrics.successful_prefetches, 1);
assert_eq!(metrics.get_success_rate(), 1.0);
metrics.record_failed_prefetch(
tokio::time::Duration::from_millis(50),
"test error".to_string(),
);
assert_eq!(metrics.total_prefetches, 2);
assert_eq!(metrics.failed_prefetches, 1);
assert_eq!(metrics.get_success_rate(), 0.5);
assert_eq!(metrics.recent_failures.len(), 1);
}
#[tokio::test]
async fn test_prefetch_scheduler_creation() {
let config = Zero3CpuOffloadConfig::default();
let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
.await
.expect("operation should succeed");
let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
let status = scheduler.get_queue_status();
assert_eq!(status.queued_requests, 0);
assert_eq!(status.active_operations, 0);
let metrics = scheduler.get_metrics();
assert_eq!(metrics.total_prefetches, 0);
}
#[tokio::test]
async fn test_prefetch_distance_calculation() {
let config = Zero3CpuOffloadConfig::default();
let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
.await
.expect("operation should succeed");
let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
let distance = scheduler
.calculate_optimal_prefetch_distance()
.await
.expect("operation should succeed");
assert!(distance >= 1);
assert!(distance <= config.prefetch_buffer_size);
}
#[tokio::test]
async fn test_batch_prefetch() {
let config = Zero3CpuOffloadConfig::default();
let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
.await
.expect("operation should succeed");
let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
let layers = vec!["layer1".to_string(), "layer2".to_string()];
scheduler
.batch_prefetch(layers)
.await
.expect("operation should succeed");
let metrics = scheduler.get_metrics();
assert_eq!(metrics.batch_operations, 1);
}
#[test]
fn test_task_coordination() {
let coordination = TaskCoordination::new();
assert_eq!(coordination.active_tasks, 0);
assert!(coordination.coordination_enabled);
assert_eq!(coordination.max_background_tasks, 8);
}
#[tokio::test]
async fn test_cancel_prefetches() {
let config = Zero3CpuOffloadConfig::default();
let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
.await
.expect("operation should succeed");
let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
scheduler
.cancel_all_prefetches()
.await
.expect("operation should succeed");
let status = scheduler.get_queue_status();
assert_eq!(status.queued_requests, 0);
}
}