use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use crossbeam_utils::sync::{Parker, Unparker};
use num_cpus;
use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
type TaskCompletionNotify = Arc<(Mutex<bool>, Condvar)>;
type TaskCompletionMap = Arc<Mutex<HashMap<usize, TaskCompletionNotify>>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum TaskPriority {
Background = 0,
#[default]
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SchedulingPolicy {
Fifo,
Lifo,
#[default]
Priority,
WeightedFair,
}
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub numworkers: usize,
pub policy: SchedulingPolicy,
pub max_queue_size: usize,
pub adaptive: bool,
pub enable_stealing_heuristics: bool,
pub enable_priorities: bool,
pub stealing_threshold: usize,
pub sleep_ms: u64,
pub min_batch_size: usize,
pub max_batch_size: usize,
pub task_timeout_ms: u64,
pub maxretries: usize,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
numworkers: num_cpus::get(),
policy: SchedulingPolicy::Priority,
max_queue_size: 10000,
adaptive: true,
enable_stealing_heuristics: true,
enable_priorities: true,
stealing_threshold: 4,
sleep_ms: 1,
min_batch_size: 1,
max_batch_size: 100,
task_timeout_ms: 0,
maxretries: 3,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SchedulerConfigBuilder {
config: SchedulerConfig,
}
impl SchedulerConfigBuilder {
pub fn new() -> Self {
Self::default()
}
pub const fn workers(mut self, numworkers: usize) -> Self {
self.config.numworkers = numworkers;
self
}
pub const fn policy(mut self, policy: SchedulingPolicy) -> Self {
self.config.policy = policy;
self
}
pub const fn max_queue_size(mut self, size: usize) -> Self {
self.config.max_queue_size = size;
self
}
pub const fn adaptive(mut self, enable: bool) -> Self {
self.config.adaptive = enable;
self
}
pub const fn enable_stealing_heuristics(mut self, enable: bool) -> Self {
self.config.enable_stealing_heuristics = enable;
self
}
pub const fn enable_priorities(mut self, enable: bool) -> Self {
self.config.enable_priorities = enable;
self
}
pub const fn stealing_threshold(mut self, threshold: usize) -> Self {
self.config.stealing_threshold = threshold;
self
}
pub const fn sleep_ms(mut self, ms: u64) -> Self {
self.config.sleep_ms = ms;
self
}
pub const fn min_batch_size(mut self, size: usize) -> Self {
self.config.min_batch_size = size;
self
}
pub const fn max_batch_size(mut self, size: usize) -> Self {
self.config.max_batch_size = size;
self
}
pub const fn task_timeout_ms(mut self, timeout: u64) -> Self {
self.config.task_timeout_ms = timeout;
self
}
pub const fn maxretries(mut self, retries: usize) -> Self {
self.config.maxretries = retries;
self
}
pub fn build(self) -> SchedulerConfig {
self.config
}
}
pub trait Task: Send + 'static {
fn execute(&mut self) -> Result<(), CoreError>;
fn priority(&self) -> TaskPriority {
TaskPriority::Normal
}
fn weight(&self) -> usize {
1
}
fn estimated_cost(&self) -> usize {
1
}
fn clone_task(&self) -> Box<dyn Task>;
fn name(&self) -> &str {
"unnamed"
}
}
#[derive(Clone)]
pub struct TaskHandle {
id: usize,
status: Arc<Mutex<TaskStatus>>,
result_notify: TaskCompletionNotify,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed(usize), Cancelled,
TimedOut,
}
impl TaskHandle {
#[allow(dead_code)]
fn new(id: usize) -> Self {
Self {
id,
status: Arc::new(Mutex::new(TaskStatus::Pending)),
result_notify: Arc::new((Mutex::new(false), Condvar::new())),
}
}
pub fn id(&self) -> usize {
self.id
}
pub fn status(&self) -> TaskStatus {
*self.status.lock().expect("Operation failed")
}
pub fn wait(&self) -> TaskStatus {
let (lock, cvar) = &*self.result_notify;
let completed = lock.lock().expect("Operation failed");
if !*completed {
let _completed = cvar.wait(completed).expect("Operation failed");
}
self.status()
}
pub fn wait_timeout(&self, timeout: Duration) -> Result<TaskStatus, CoreError> {
let (lock, cvar) = &*self.result_notify;
let completed = lock.lock().expect("Operation failed");
if !*completed {
let result = cvar
.wait_timeout(completed, timeout)
.expect("Operation failed");
if result.1.timed_out() {
return Err(CoreError::TimeoutError(
ErrorContext::new(format!("{}", self.id))
.with_location(ErrorLocation::new(file!(), line!())),
));
}
}
Ok(self.status())
}
pub fn cancel(&self) -> bool {
let mut status = self.status.lock().expect("Operation failed");
if *status == TaskStatus::Pending {
*status = TaskStatus::Cancelled;
let (lock, cvar) = &*self.result_notify;
let mut completed = lock.lock().expect("Operation failed");
*completed = true;
cvar.notify_all();
true
} else {
false
}
}
}
struct TaskWrapper {
id: usize,
task: Box<dyn Task>,
priority: TaskPriority,
weight: usize,
#[allow(dead_code)]
cost: usize,
status: Arc<Mutex<TaskStatus>>,
result_notify: TaskCompletionNotify,
#[allow(dead_code)]
submission_time: Instant,
retry_count: usize,
#[allow(dead_code)]
name: String,
}
impl TaskWrapper {
fn new(id: usize, task: Box<dyn Task>) -> Self {
let priority = task.priority();
let weight = task.weight();
let cost = task.estimated_cost();
let name = task.name().to_string();
Self {
id,
task,
priority,
weight,
cost,
status: Arc::new(Mutex::new(TaskStatus::Pending)),
result_notify: Arc::new((Mutex::new(false), Condvar::new())),
submission_time: Instant::now(),
retry_count: 0,
name,
}
}
fn create_handle(&self) -> TaskHandle {
TaskHandle {
id: self.id,
status: self.status.clone(),
result_notify: self.result_notify.clone(),
}
}
fn execute(&mut self) -> Result<(), CoreError> {
{
let mut status = self.status.lock().expect("Operation failed");
*status = TaskStatus::Running;
}
let result = self.task.execute();
{
let mut status = self.status.lock().expect("Operation failed");
*status = match result {
Ok(_) => TaskStatus::Completed,
Err(_) => TaskStatus::Failed(self.retry_count),
};
}
let (lock, cvar) = &*self.result_notify;
let mut completed = lock.lock().expect("Operation failed");
*completed = true;
cvar.notify_all();
result
}
fn increment_retry(&mut self) {
self.retry_count += 1;
}
}
thread_local! {
static WORKER_ID: UnsafeCell<Option<usize>> = const { UnsafeCell::new(None) };
}
#[allow(dead_code)]
fn set_workerid(id: usize) {
WORKER_ID.with(|cell| unsafe {
*cell.get() = Some(id);
});
}
#[allow(dead_code)]
pub fn get_workerid() -> Option<usize> {
WORKER_ID.with(|cell| unsafe { *cell.get() })
}
struct WorkerState {
#[allow(dead_code)]
id: usize,
#[allow(clippy::type_complexity)]
local_queue: UnsafeCell<Worker<TaskWrapper>>,
stealers: Vec<Stealer<TaskWrapper>>,
injector: Arc<Injector<TaskWrapper>>,
#[allow(dead_code)]
active: Arc<AtomicBool>,
parker: UnsafeCell<Parker>,
unparker: Unparker,
tasks_processed: AtomicUsize,
tasks_stolen: AtomicUsize,
failed_steals: AtomicUsize,
last_active: Mutex<Instant>,
local_queue_size: AtomicUsize,
adaptive_batch_size: AtomicUsize,
}
unsafe impl Send for WorkerState {}
unsafe impl Sync for WorkerState {}
impl WorkerState {
fn new(
id: usize,
stealers: Vec<Stealer<TaskWrapper>>,
injector: Arc<Injector<TaskWrapper>>,
) -> Self {
let parker = Parker::new();
let unparker = parker.unparker().clone();
Self {
id,
local_queue: UnsafeCell::new(Worker::new_fifo()),
stealers,
injector,
active: Arc::new(AtomicBool::new(true)),
parker: UnsafeCell::new(parker),
unparker,
tasks_processed: AtomicUsize::new(0),
tasks_stolen: AtomicUsize::new(0),
failed_steals: AtomicUsize::new(0),
last_active: Mutex::new(Instant::now()),
local_queue_size: AtomicUsize::new(0),
adaptive_batch_size: AtomicUsize::new(1),
}
}
#[allow(dead_code)]
fn id(&self) -> usize {
self.id
}
fn local_queue_size(&self) -> usize {
self.local_queue_size.load(Ordering::Relaxed)
}
fn push_local(&self, task: TaskWrapper) {
unsafe {
(*self.local_queue.get()).push(task);
}
self.local_queue_size.fetch_add(1, Ordering::Relaxed);
}
fn pop_local(&self) -> Option<TaskWrapper> {
let result = unsafe { (*self.local_queue.get()).pop() };
if result.is_some() {
self.local_queue_size.fetch_sub(1, Ordering::Relaxed);
}
result
}
fn steal(&self) -> Option<TaskWrapper> {
match self.injector.steal() {
Steal::Success(task) => {
self.tasks_stolen.fetch_add(1, Ordering::Relaxed);
return Some(task);
}
Steal::Empty => {}
Steal::Retry => {}
}
for stealer in &self.stealers {
match stealer.steal() {
Steal::Success(task) => {
self.tasks_stolen.fetch_add(1, Ordering::Relaxed);
return Some(task);
}
Steal::Empty => {}
Steal::Retry => {}
}
}
self.failed_steals.fetch_add(1, Ordering::Relaxed);
None
}
fn update_last_active(&self) {
let mut last_active = self.last_active.lock().expect("Operation failed");
*last_active = Instant::now();
}
fn time_since_last_active(&self) -> Duration {
let last_active = self.last_active.lock().expect("Operation failed");
last_active.elapsed()
}
fn update_adaptive_batch_size(&self, config: &SchedulerConfig) {
if !config.adaptive {
self.adaptive_batch_size
.store(config.min_batch_size, Ordering::Relaxed);
return;
}
let _tasks_processed = self.tasks_processed.load(Ordering::Relaxed);
let tasks_stolen = self.tasks_stolen.load(Ordering::Relaxed);
let failed_steals = self.failed_steals.load(Ordering::Relaxed);
let steal_attempts = tasks_stolen + failed_steals;
let steal_success_rate = if steal_attempts > 0 {
tasks_stolen as f64 / steal_attempts as f64
} else {
0.0
};
let current_batch_size = self.adaptive_batch_size.load(Ordering::Relaxed);
let new_batch_size = if steal_success_rate > 0.8 {
(current_batch_size * 2).min(config.max_batch_size)
} else if steal_success_rate < 0.2 {
(current_batch_size / 2).max(config.min_batch_size)
} else {
current_batch_size
};
self.adaptive_batch_size
.store(new_batch_size, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct SchedulerStats {
pub tasks_submitted: usize,
pub tasks_completed: usize,
pub tasks_failed: usize,
pub tasks_cancelled: usize,
pub tasks_timed_out: usize,
pub task_retries: usize,
pub numworkers: usize,
pub avg_queue_size: f64,
pub avg_task_latency_ms: f64,
pub avg_task_execution_ms: f64,
pub successful_steals: usize,
pub failed_steals: usize,
pub worker_utilization: Vec<f64>,
pub uptime_seconds: f64,
pub tasks_per_second: f64,
}
impl Default for SchedulerStats {
fn default() -> Self {
Self {
tasks_submitted: 0,
tasks_completed: 0,
tasks_failed: 0,
tasks_cancelled: 0,
tasks_timed_out: 0,
task_retries: 0,
numworkers: 0,
avg_queue_size: 0.0,
avg_task_latency_ms: 0.0,
avg_task_execution_ms: 0.0,
successful_steals: 0,
failed_steals: 0,
worker_utilization: Vec::new(),
uptime_seconds: 0.0,
tasks_per_second: 0.0,
}
}
}
pub struct WorkStealingScheduler {
config: SchedulerConfig,
injector: Arc<Injector<TaskWrapper>>,
workers: Vec<JoinHandle<()>>,
worker_states: Vec<Arc<WorkerState>>,
state: Arc<RwLock<SchedulerState>>,
next_taskid: Arc<AtomicUsize>,
task_completion: TaskCompletionMap,
task_submissions: Arc<Mutex<HashMap<usize, Instant>>>,
task_executions: Arc<Mutex<HashMap<usize, Duration>>>,
start_time: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SchedulerState {
Running,
ShuttingDown,
ShutDown,
}
impl WorkStealingScheduler {
pub fn new(config: SchedulerConfig) -> Self {
let injector = Arc::new(Injector::new());
let state = Arc::new(RwLock::new(SchedulerState::Running));
let next_taskid = Arc::new(AtomicUsize::new(1));
let task_completion = Arc::new(Mutex::new(HashMap::new()));
let task_submissions = Arc::new(Mutex::new(HashMap::new()));
let task_executions = Arc::new(Mutex::new(HashMap::new()));
let mut worker_states = Vec::with_capacity(config.numworkers);
let mut workers = Vec::with_capacity(config.numworkers);
let worker_queues: Vec<_> = (0..config.numworkers).map(|_| Worker::new_fifo()).collect();
let stealers: Vec<_> = worker_queues
.iter()
.map(|worker| worker.stealer())
.collect();
for i in 0..config.numworkers {
let worker_state = Arc::new(WorkerState::new(i, stealers.clone(), injector.clone()));
worker_states.push(worker_state.clone());
let state_clone = state.clone();
let config_clone = config.clone();
let task_completion_clone = task_completion.clone();
let task_executions_clone = task_executions.clone();
let worker_thread = thread::spawn(move || {
set_workerid(i);
Self::worker_loop(
worker_state,
state_clone,
config_clone,
task_completion_clone, task_executions_clone,
);
});
workers.push(worker_thread);
}
Self {
config,
injector,
workers,
worker_states,
state,
next_taskid,
task_completion,
task_submissions,
task_executions,
start_time: Instant::now(),
}
}
fn worker_loop(
worker_state: Arc<WorkerState>,
state: Arc<RwLock<SchedulerState>>,
config: SchedulerConfig,
_task_completion: TaskCompletionMap,
task_executions: Arc<Mutex<HashMap<usize, Duration>>>,
) {
while let SchedulerState::Running = *state.read().expect("Operation failed") {
let task = worker_state.pop_local().or_else(|| worker_state.steal());
if let Some(mut task) = task {
worker_state.update_last_active();
let start_time = Instant::now();
let result = task.execute();
let execution_time = start_time.elapsed();
let taskid = task.id;
task_executions
.lock()
.expect("Test: operation failed")
.insert(taskid, execution_time);
worker_state.tasks_processed.fetch_add(1, Ordering::Relaxed);
match result {
Ok(_) => {
}
Err(_) => {
if task.retry_count < config.maxretries {
task.increment_retry();
{
let mut status = task.status.lock().expect("Operation failed");
*status = TaskStatus::Pending;
}
worker_state.push_local(task);
}
}
}
worker_state.update_adaptive_batch_size(&config);
} else {
if config.sleep_ms > 0 {
unsafe {
(*worker_state.parker.get())
.park_timeout(Duration::from_millis(config.sleep_ms));
}
} else {
thread::yield_now();
}
}
}
while let Some(mut task) = worker_state.pop_local() {
let start_time = Instant::now();
let _ = task.execute();
let execution_time = start_time.elapsed();
let taskid = task.id;
task_executions
.lock()
.expect("Test: operation failed")
.insert(taskid, execution_time);
worker_state.tasks_processed.fetch_add(1, Ordering::Relaxed);
}
}
pub fn submit<T: Task>(&self, task: T) -> TaskHandle {
self.submit_boxed(Box::new(task))
}
pub fn submit_boxed(&self, task: Box<dyn Task>) -> TaskHandle {
if *self.state.read().expect("Operation failed") != SchedulerState::Running {
panic!("Cannot submit tasks to a stopped scheduler");
}
let taskid = self.next_taskid.fetch_add(1, Ordering::SeqCst);
let wrapper = TaskWrapper::new(taskid, task);
let handle = wrapper.create_handle();
self.task_completion
.lock()
.expect("Test: operation failed")
.insert(taskid, wrapper.result_notify.clone());
self.task_submissions
.lock()
.expect("Test: operation failed")
.insert(taskid, Instant::now());
match self.config.policy {
SchedulingPolicy::Fifo | SchedulingPolicy::Lifo => {
self.injector.push(wrapper);
}
SchedulingPolicy::Priority => {
if wrapper.priority >= TaskPriority::High {
let queue_idx = taskid % self.worker_states.len();
self.worker_states[queue_idx].push_local(wrapper);
self.worker_states[queue_idx].unparker.unpark();
} else {
self.injector.push(wrapper);
}
}
SchedulingPolicy::WeightedFair => {
if wrapper.weight > 1 {
let min_queue_idx = self
.worker_states
.iter()
.enumerate()
.min_by_key(|(_, state)| state.local_queue_size())
.map(|(idx, _)| idx)
.unwrap_or(0);
self.worker_states[min_queue_idx].push_local(wrapper);
self.worker_states[min_queue_idx].unparker.unpark();
} else {
self.injector.push(wrapper);
}
}
}
self.wake_idle_workers();
handle
}
pub fn submit_batch<T: Task + Clone>(&self, tasks: &[T]) -> Vec<TaskHandle> {
let mut handles = Vec::with_capacity(tasks.len());
for task in tasks {
handles.push(self.submit(task.clone()));
}
handles
}
pub fn submit_fn<F, R>(&self, f: F) -> TaskHandle
where
F: FnOnce() -> Result<R, CoreError> + Send + 'static,
R: Send + 'static,
{
struct FnTask<F, R> {
f: Option<F>,
phantom: std::marker::PhantomData<R>,
}
impl<F, R> Task for FnTask<F, R>
where
F: FnOnce() -> Result<R, CoreError> + Send + 'static,
R: Send + 'static,
{
fn execute(&mut self) -> Result<(), CoreError> {
if let Some(f) = self.f.take() {
f()?;
Ok(())
} else {
Err(CoreError::SchedulerError(
ErrorContext::new("Task function was already called".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
))
}
}
fn clone_task(&self) -> Box<dyn Task> {
panic!("FnTask cannot be cloned")
}
}
self.submit_boxed(Box::new(FnTask {
f: Some(f),
phantom: std::marker::PhantomData,
}))
}
fn wake_idle_workers(&self) {
for worker in &self.worker_states {
if worker.time_since_last_active() > Duration::from_millis(self.config.sleep_ms) {
worker.unparker.unpark();
}
}
}
pub fn wait_all(&self) {
let taskids: Vec<_> = self
.task_completion
.lock()
.expect("Test: operation failed")
.keys()
.copied()
.collect();
for id in taskids {
if let Some(notify) = self
.task_completion
.lock()
.expect("Operation failed")
.get(&id)
{
let (lock, cvar) = &**notify;
let completed = lock.lock().expect("Operation failed");
if !*completed {
let _completed = cvar.wait(completed).expect("Operation failed");
}
}
}
}
pub fn wait_all_timeout(&self, timeout: Duration) -> Result<(), CoreError> {
let deadline = Instant::now() + timeout;
let taskids: Vec<_> = self
.task_completion
.lock()
.expect("Test: operation failed")
.keys()
.copied()
.collect();
for id in taskids {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.as_secs() == 0 && remaining.subsec_nanos() == 0 {
return Err(CoreError::TimeoutError(
ErrorContext::new("Timeout waiting for tasks".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
));
}
if let Some(notify) = self
.task_completion
.lock()
.expect("Operation failed")
.get(&id)
{
let (lock, cvar) = &**notify;
let completed = lock.lock().expect("Operation failed");
if !*completed {
let result = cvar
.wait_timeout(completed, remaining)
.expect("Operation failed");
if result.1.timed_out() && !*result.0 {
return Err(CoreError::TimeoutError(
ErrorContext::new("Timeout waiting for tasks".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
));
}
}
}
}
Ok(())
}
pub fn stats(&self) -> SchedulerStats {
let mut stats = SchedulerStats {
tasks_submitted: self.next_taskid.load(Ordering::Relaxed) - 1,
numworkers: self.config.numworkers,
..SchedulerStats::default()
};
let mut total_latency = Duration::from_secs(0);
let mut total_execution = Duration::from_secs(0);
let mut completed_tasks = 0;
let submissions = self.task_submissions.lock().expect("Operation failed");
let executions = self.task_executions.lock().expect("Operation failed");
for (id, submission_time) in submissions.iter() {
if let Some(execution_time) = executions.get(id) {
let latency = submission_time.elapsed() - *execution_time;
total_latency += latency;
total_execution += *execution_time;
completed_tasks += 1;
}
}
stats.tasks_completed = completed_tasks;
if completed_tasks > 0 {
stats.avg_task_latency_ms = total_latency.as_millis() as f64 / completed_tasks as f64;
stats.avg_task_execution_ms =
total_execution.as_millis() as f64 / completed_tasks as f64;
}
let mut total_queue_size = 0;
let mut total_successful_steals = 0;
let mut total_failed_steals = 0;
let mut worker_utils = Vec::with_capacity(self.worker_states.len());
for worker in &self.worker_states {
total_queue_size += worker.local_queue_size();
total_successful_steals += worker.tasks_stolen.load(Ordering::Relaxed);
total_failed_steals += worker.failed_steals.load(Ordering::Relaxed);
let tasks_processed = worker.tasks_processed.load(Ordering::Relaxed);
let utilization = if stats.tasks_submitted > 0 {
tasks_processed as f64 / stats.tasks_submitted as f64
} else {
0.0
};
worker_utils.push(utilization);
}
stats.avg_queue_size = total_queue_size as f64 / self.worker_states.len() as f64;
stats.successful_steals = total_successful_steals;
stats.failed_steals = total_failed_steals;
stats.worker_utilization = worker_utils;
stats.uptime_seconds = self.start_time.elapsed().as_secs_f64();
if stats.uptime_seconds > 0.0 {
stats.tasks_per_second = stats.tasks_completed as f64 / stats.uptime_seconds;
}
stats
}
pub fn shutdown(&mut self) {
{
let mut state = self.state.write().expect("Operation failed");
*state = SchedulerState::ShuttingDown;
}
for worker in &self.worker_states {
worker.unparker.unpark();
}
while let Some(worker) = self.workers.pop() {
let _ = worker.join();
}
{
let mut state = self.state.write().expect("Operation failed");
*state = SchedulerState::ShutDown;
}
}
pub fn numworkers(&self) -> usize {
self.worker_states.len()
}
pub fn current_workerid(&self) -> Option<usize> {
get_workerid()
}
pub fn pending_tasks(&self) -> usize {
let mut total = 0;
for worker in &self.worker_states {
total += worker.local_queue_size();
}
total
}
}
impl Drop for WorkStealingScheduler {
fn drop(&mut self) {
if *self.state.read().expect("Operation failed") != SchedulerState::ShutDown {
self.shutdown();
}
}
}
pub struct CloneableTask<F, R>
where
F: Fn() -> Result<R, CoreError> + Send + Sync + Clone + 'static,
R: Send + 'static,
{
func: F,
name: String,
priority: TaskPriority,
weight: usize,
}
impl<F, R> CloneableTask<F, R>
where
F: Fn() -> Result<R, CoreError> + Send + Sync + Clone + 'static,
R: Send + 'static,
{
pub fn new(func: F) -> Self {
Self {
func,
name: "unnamed".to_string(),
priority: TaskPriority::Normal,
weight: 1,
}
}
pub fn with_name(mut self, name: &str) -> Self {
self.name = name.to_string();
self
}
pub fn with_priority(mut self, priority: TaskPriority) -> Self {
self.priority = priority;
self
}
pub fn with_weight(mut self, weight: usize) -> Self {
self.weight = weight;
self
}
}
impl<F, R> Task for CloneableTask<F, R>
where
F: Fn() -> Result<R, CoreError> + Send + Sync + Clone + 'static,
R: Send + 'static,
{
fn execute(&mut self) -> Result<(), CoreError> {
(self.func)().map(|_| ())
}
fn priority(&self) -> TaskPriority {
self.priority
}
fn weight(&self) -> usize {
self.weight
}
fn clone_task(&self) -> Box<dyn Task> {
Box::new(Self {
func: self.func.clone(),
name: self.name.clone(),
priority: self.priority,
weight: self.weight,
})
}
fn name(&self) -> &str {
&self.name
}
}
#[allow(dead_code)]
pub fn create_work_stealing_scheduler() -> WorkStealingScheduler {
WorkStealingScheduler::new(SchedulerConfig::default())
}
#[allow(dead_code)]
pub fn create_work_stealing_scheduler_with_workers(numworkers: usize) -> WorkStealingScheduler {
let config = SchedulerConfigBuilder::new().workers(numworkers).build();
WorkStealingScheduler::new(config)
}
pub struct ParallelTask<T, F, R>
where
T: Clone + Send + Sync + 'static,
F: Fn(&T) -> Result<R, CoreError> + Send + Sync + Clone + 'static,
R: Send + 'static,
{
items: Vec<T>,
func: F,
name: String,
priority: TaskPriority,
continue_onerror: bool,
}
impl<T, F, R> ParallelTask<T, F, R>
where
T: Clone + Send + Sync + 'static,
F: Fn(&T) -> Result<R, CoreError> + Send + Sync + Clone + 'static,
R: Send + 'static,
{
pub fn new(items: Vec<T>, func: F) -> Self {
Self {
items,
func,
name: "parallel".to_string(),
priority: TaskPriority::Normal,
continue_onerror: false,
}
}
pub fn with_name(mut self, name: &str) -> Self {
self.name = name.to_string();
self
}
pub fn with_priority(mut self, priority: TaskPriority) -> Self {
self.priority = priority;
self
}
pub fn continue_onerror(mut self, continue_onerror: bool) -> Self {
self.continue_onerror = continue_onerror;
self
}
pub fn execute(self) -> Result<Vec<R>, CoreError>
where
R: Clone,
{
let scheduler = create_work_stealing_scheduler();
let items_len = self.items.len();
let mut handles = Vec::with_capacity(items_len);
let results = Arc::new(Mutex::new(Vec::with_capacity(items_len)));
for (i, item) in self.items.into_iter().enumerate() {
let func = self.func.clone();
let results_clone = results.clone();
let task_name = format!("{}_{}", self.name, i);
let priority = self.priority;
let task = CloneableTask::new(move || {
let result = func(&item)?;
results_clone
.lock()
.expect("Operation failed")
.push((i, result));
Ok(())
})
.with_name(&task_name)
.with_priority(priority);
handles.push(scheduler.submit(task));
}
for handle in &handles {
match handle.wait() {
TaskStatus::Completed => {}
TaskStatus::Failed(_) if self.continue_onerror => {}
status => {
return Err(CoreError::SchedulerError(
ErrorContext::new(format!(
"Task {} failed with status {:?}",
handle.id(),
status
))
.with_location(ErrorLocation::new(file!(), line!())),
));
}
}
}
let mut result_map = Vec::with_capacity(items_len);
{
let results_guard = results.lock().expect("Operation failed");
for (i, result) in results_guard.iter() {
result_map.push((*i, result.clone()));
}
}
result_map.sort_by_key(|(i, _)| *i);
let results = result_map.into_iter().map(|(_, r)| r).collect();
Ok(results)
}
}
pub mod parallel {
use super::*;
use crate::error::CoreResult;
pub fn par_map<T, U, F>(items: &[T], f: F) -> CoreResult<Vec<U>>
where
T: Clone + Send + Sync + 'static,
U: Clone + Send + 'static,
F: Fn(&T) -> Result<U, CoreError> + Send + Sync + Clone + 'static,
{
let owned_items = items.to_vec();
let task = ParallelTask::new(owned_items, f);
task.execute()
}
#[allow(dead_code)]
pub fn par_filter<T, F>(items: &[T], predicate: F) -> CoreResult<Vec<T>>
where
T: Clone + Send + Sync + 'static,
F: Fn(&T) -> Result<bool, CoreError> + Send + Sync + Clone + 'static,
{
let task = ParallelTask::new(items.to_vec(), move |item| {
let include = predicate(item)?;
if include {
Ok(Some(item.clone()))
} else {
Ok(None)
}
});
let results = task.execute()?;
let filtered: Vec<_> = results.into_iter().flatten().collect();
Ok(filtered)
}
#[allow(dead_code)]
pub fn par_for_each<T, F>(items: &[T], f: F) -> CoreResult<()>
where
T: Clone + Send + Sync + 'static,
F: Fn(&T) -> Result<(), CoreError> + Send + Sync + Clone + 'static,
{
let task = ParallelTask::new(items.to_vec(), f);
task.execute()?;
Ok(())
}
#[allow(dead_code)]
pub fn par_reduce<T, F>(items: &[T], init: T, f: F) -> CoreResult<T>
where
T: Clone + Send + Sync + 'static,
F: Fn(T, &T) -> Result<T, CoreError> + Send + Sync + Clone + 'static,
{
if items.is_empty() {
return Ok(init);
}
let items_owned: Vec<T> = items.to_vec();
let num_chunks = std::cmp::min(items_owned.len(), num_cpus::get() * 4);
let chunk_size = std::cmp::max(1, items_owned.len() / num_chunks);
let mut chunks = Vec::with_capacity(num_chunks);
for chunk_start in (0..items_owned.len()).step_by(chunk_size) {
let chunk_end = std::cmp::min(chunk_start + chunk_size, items_owned.len());
chunks.push(items_owned[chunk_start..chunk_end].to_vec());
}
let f_clone = f.clone();
let init_clone = init.clone();
let chunk_results = par_map(&chunks, move |chunk| {
let mut result = init_clone.clone();
for item in chunk {
result = f_clone(result, item)?;
}
Ok(result)
})?;
let mut final_result = init;
for result in chunk_results {
final_result = f(final_result, &result)?;
}
Ok(final_result)
}
}
pub trait WorkStealingArray<A, S, D>
where
A: Clone + Send + Sync + 'static,
S: crate::ndarray::RawData<Elem = A>,
D: crate::ndarray::Dimension,
{
fn work_stealing_map<F, B>(&self, f: F) -> CoreResult<crate::ndarray::Array<B, D>>
where
B: Clone + Send + 'static,
F: Fn(&A) -> Result<B, CoreError> + Send + Sync + Clone + 'static;
}
impl<A, S, D> WorkStealingArray<A, S, D> for crate::ndarray::ArrayBase<S, D>
where
A: Clone + Send + Sync + 'static,
S: crate::ndarray::RawData<Elem = A> + crate::ndarray::Data,
D: crate::ndarray::Dimension + Clone + Send + 'static,
{
fn work_stealing_map<F, B>(&self, f: F) -> CoreResult<crate::ndarray::Array<B, D>>
where
B: Clone + Send + 'static,
F: Fn(&A) -> Result<B, CoreError> + Send + Sync + Clone + 'static,
{
let shape = self.raw_dim();
let flat_view = self
.view()
.into_shape_with_order(crate::ndarray::IxDyn(&[self.len()]))
.expect("Test: operation failed");
let flat = flat_view.to_slice().expect("Operation failed");
let results = parallel::par_map(flat, f)?;
let result_array = crate::ndarray::Array::from_shape_vec(shape, results).map_err(|e| {
CoreError::DimensionError(
ErrorContext::new(format!("{e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
Ok(result_array)
}
}
impl CoreError {
pub fn schedulererror(message: &str) -> Self {
CoreError::SchedulerError(
ErrorContext::new(message.to_string())
.with_location(ErrorLocation::new(file!(), line!())),
)
}
}