use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use crate::error::{KernelError, Result};
use crate::executor::{CpuExecutor, ExecutorRegistry, GpuExecutor, RemoteExecutor};
use crate::fault::{FailureDetector, RetryPolicy, RetryState};
use crate::scheduler::{Scheduler, TaskId};
use crate::task::{ExecutionResult, Task};
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub cpu_workers: usize,
pub enable_gpu: bool,
pub remote_workers: Vec<String>,
pub queue_capacity: usize,
pub retry_policy: RetryPolicy,
pub default_timeout: Option<std::time::Duration>,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
cpu_workers: 0, enable_gpu: false,
remote_workers: Vec::new(),
queue_capacity: 1024,
retry_policy: RetryPolicy::default(),
default_timeout: None,
}
}
}
#[derive(Debug, Default)]
pub struct PoolBuilder {
config: PoolConfig,
}
impl PoolBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn cpu_workers(mut self, count: usize) -> Self {
self.config.cpu_workers = count;
self
}
#[must_use]
pub const fn enable_gpu(mut self, enable: bool) -> Self {
self.config.enable_gpu = enable;
self
}
#[must_use]
pub fn remote_workers(mut self, addresses: Vec<impl Into<String>>) -> Self {
self.config.remote_workers = addresses.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub const fn queue_capacity(mut self, capacity: usize) -> Self {
self.config.queue_capacity = capacity;
self
}
#[must_use]
pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
self.config.retry_policy = policy;
self
}
#[must_use]
pub fn default_timeout(mut self, timeout: std::time::Duration) -> Self {
self.config.default_timeout = Some(timeout);
self
}
pub fn build(self) -> Result<Pool> {
Pool::from_config(self.config)
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub tasks_submitted: u64,
pub tasks_completed: u64,
pub tasks_failed: u64,
pub tasks_pending: u64,
pub total_retries: u64,
}
impl PoolStats {
#[must_use]
pub fn success_rate(&self) -> f64 {
let total = self.tasks_completed + self.tasks_failed;
if total == 0 {
return 1.0;
}
self.tasks_completed as f64 / total as f64
}
}
pub struct Pool {
config: PoolConfig,
scheduler: Arc<Scheduler<Task>>,
executors: Arc<ExecutorRegistry>,
failure_detector: Arc<FailureDetector>,
tasks_submitted: AtomicU64,
tasks_completed: AtomicU64,
tasks_failed: AtomicU64,
total_retries: AtomicU64,
running: AtomicBool,
}
impl Pool {
#[must_use]
pub fn builder() -> PoolBuilder {
PoolBuilder::new()
}
fn from_config(config: PoolConfig) -> Result<Self> {
let cpu_workers = if config.cpu_workers == 0 {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
} else {
config.cpu_workers
};
let scheduler = Arc::new(Scheduler::with_capacity(
cpu_workers,
config.queue_capacity,
));
let mut executors = ExecutorRegistry::new();
executors.register(Arc::new(CpuExecutor::new(cpu_workers)));
if config.enable_gpu {
executors.register(Arc::new(GpuExecutor::new()));
}
if !config.remote_workers.is_empty() {
let mut remote = RemoteExecutor::new();
for addr in &config.remote_workers {
remote.add_worker(addr.clone());
}
executors.register(Arc::new(remote));
}
let failure_detector = Arc::new(FailureDetector::new());
Ok(Self {
config,
scheduler,
executors: Arc::new(executors),
failure_detector,
tasks_submitted: AtomicU64::new(0),
tasks_completed: AtomicU64::new(0),
tasks_failed: AtomicU64::new(0),
total_retries: AtomicU64::new(0),
running: AtomicBool::new(true),
})
}
pub fn submit(&self, task: Task) -> Result<ExecutionResult> {
if !self.running.load(Ordering::Acquire) {
return Err(KernelError::DeviceNotReady);
}
self.tasks_submitted.fetch_add(1, Ordering::Relaxed);
let mut retry_state = RetryState::new();
let mut current_task = task;
loop {
let result = self.executors.execute(¤t_task);
match result {
Ok(exec_result) => {
if exec_result.is_success() {
self.tasks_completed.fetch_add(1, Ordering::Relaxed);
return Ok(exec_result);
}
if retry_state.should_retry(&self.config.retry_policy) {
retry_state.record_failure(
exec_result.error.clone().unwrap_or_default(),
);
self.total_retries.fetch_add(1, Ordering::Relaxed);
current_task.increment_retry();
let delay = retry_state.next_delay(&self.config.retry_policy);
if delay > std::time::Duration::ZERO {
std::thread::sleep(delay);
}
continue;
}
self.tasks_failed.fetch_add(1, Ordering::Relaxed);
return Ok(exec_result);
}
Err(e) => {
if e.is_retriable() && retry_state.should_retry(&self.config.retry_policy) {
retry_state.record_failure(format!("{}", e));
self.total_retries.fetch_add(1, Ordering::Relaxed);
current_task.increment_retry();
let delay = retry_state.next_delay(&self.config.retry_policy);
if delay > std::time::Duration::ZERO {
std::thread::sleep(delay);
}
continue;
}
self.tasks_failed.fetch_add(1, Ordering::Relaxed);
return Err(e);
}
}
}
}
pub fn submit_async(&self, task: Task) -> Result<TaskId> {
if !self.running.load(Ordering::Acquire) {
return Err(KernelError::DeviceNotReady);
}
self.tasks_submitted.fetch_add(1, Ordering::Relaxed);
self.scheduler
.submit(task)
.ok_or(KernelError::UblkQueueFull)
}
#[must_use]
pub fn stats(&self) -> PoolStats {
PoolStats {
tasks_submitted: self.tasks_submitted.load(Ordering::Relaxed),
tasks_completed: self.tasks_completed.load(Ordering::Relaxed),
tasks_failed: self.tasks_failed.load(Ordering::Relaxed),
tasks_pending: self.scheduler.pending_tasks() as u64,
total_retries: self.total_retries.load(Ordering::Relaxed),
}
}
#[must_use]
pub fn num_workers(&self) -> usize {
self.executors.total_workers()
}
#[must_use]
pub fn pending_tasks(&self) -> usize {
self.scheduler.pending_tasks()
}
#[must_use]
pub fn is_running(&self) -> bool {
self.running.load(Ordering::Acquire)
}
pub fn shutdown(&self) {
self.running.store(false, Ordering::Release);
self.scheduler.stop();
self.executors.shutdown();
self.failure_detector.stop();
}
#[must_use]
pub fn worker_loads(&self) -> Vec<usize> {
self.scheduler.worker_loads()
}
}
impl std::fmt::Debug for Pool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Pool")
.field("num_workers", &self.num_workers())
.field("pending_tasks", &self.pending_tasks())
.field("running", &self.is_running())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task::TaskPriority;
#[test]
fn test_pool_config_default() {
let config = PoolConfig::default();
assert_eq!(config.cpu_workers, 0);
assert!(!config.enable_gpu);
assert!(config.remote_workers.is_empty());
}
#[test]
fn test_pool_builder_cpu_workers() {
let pool = Pool::builder().cpu_workers(4).build().unwrap();
assert_eq!(pool.num_workers(), 4);
}
#[test]
fn test_pool_builder_auto_workers() {
let pool = Pool::builder().cpu_workers(0).build().unwrap();
assert!(pool.num_workers() >= 1);
}
#[test]
fn test_pool_builder_queue_capacity() {
let pool = Pool::builder()
.cpu_workers(2)
.queue_capacity(100)
.build()
.unwrap();
assert!(pool.is_running());
}
#[test]
fn test_pool_builder_retry_policy() {
let policy = RetryPolicy::no_retry();
let pool = Pool::builder()
.cpu_workers(2)
.retry_policy(policy)
.build()
.unwrap();
assert!(pool.is_running());
}
#[test]
fn test_pool_builder() {
let pool = Pool::builder().cpu_workers(4).build();
assert!(pool.is_ok());
}
#[test]
fn test_pool_submit_echo() {
let pool = Pool::builder().cpu_workers(2).build().unwrap();
let task = Task::binary("echo")
.args(vec!["Hello, Pool!"])
.build();
let result = pool.submit(task).unwrap();
assert!(result.is_success());
assert!(result.stdout_string().contains("Hello"));
}
#[test]
fn test_pool_submit_multiple() {
let pool = Pool::builder().cpu_workers(4).build().unwrap();
for i in 0..10 {
let task = Task::binary("echo").args(vec![format!("Task {}", i)]).build();
let result = pool.submit(task).unwrap();
assert!(result.is_success());
}
let stats = pool.stats();
assert_eq!(stats.tasks_submitted, 10);
assert_eq!(stats.tasks_completed, 10);
}
#[test]
fn test_pool_stats() {
let pool = Pool::builder().cpu_workers(2).build().unwrap();
let task = Task::binary("true").build();
pool.submit(task).unwrap();
let task = Task::binary("false").build();
let _ = pool.submit(task);
let stats = pool.stats();
assert_eq!(stats.tasks_submitted, 2);
assert_eq!(stats.tasks_completed, 1);
assert_eq!(stats.tasks_failed, 1);
}
#[test]
fn test_pool_stats_success_rate() {
let mut stats = PoolStats::default();
assert_eq!(stats.success_rate(), 1.0);
stats.tasks_completed = 8;
stats.tasks_failed = 2;
assert!((stats.success_rate() - 0.8).abs() < 0.001);
}
#[test]
fn test_pool_shutdown() {
let pool = Pool::builder().cpu_workers(2).build().unwrap();
assert!(pool.is_running());
pool.shutdown();
assert!(!pool.is_running());
let task = Task::binary("echo").args(vec!["test"]).build();
let result = pool.submit(task);
assert!(result.is_err());
}
#[test]
fn test_pool_worker_loads() {
let pool = Pool::builder().cpu_workers(4).build().unwrap();
let loads = pool.worker_loads();
assert_eq!(loads.len(), 4);
}
#[test]
fn test_pool_pending_tasks() {
let pool = Pool::builder().cpu_workers(2).build().unwrap();
assert_eq!(pool.pending_tasks(), 0);
}
#[test]
fn test_pool_retry() {
let pool = Pool::builder()
.cpu_workers(2)
.retry_policy(RetryPolicy::new().with_max_retries(2))
.build()
.unwrap();
let task = Task::binary("false").build();
let result = pool.submit(task);
assert!(result.is_ok());
let stats = pool.stats();
assert!(stats.total_retries > 0 || stats.tasks_failed == 1);
}
#[test]
fn test_pool_submit_async() {
let pool = Pool::builder().cpu_workers(2).build().unwrap();
let task = Task::binary("echo").args(vec!["async"]).build();
let task_id = pool.submit_async(task);
assert!(task_id.is_ok());
}
#[test]
fn test_pool_with_priority() {
let pool = Pool::builder().cpu_workers(2).build().unwrap();
let task = Task::binary("echo")
.args(vec!["high priority"])
.priority(TaskPriority::High)
.build();
let result = pool.submit(task).unwrap();
assert!(result.is_success());
}
#[test]
fn test_pool_debug() {
let pool = Pool::builder().cpu_workers(2).build().unwrap();
let debug = format!("{:?}", pool);
assert!(debug.contains("Pool"));
assert!(debug.contains("num_workers"));
}
}