use crate::OxirsError;
use crossbeam_deque::{Injector, Stealer, Worker};
use scirs2_core::metrics::{Counter, Timer};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::Duration;
pub type Result<T> = std::result::Result<T, OxirsError>;
pub struct ThreadPerCore {
workers: Vec<CoreWorker>,
global_queue: Arc<Injector<Task>>,
running: Arc<AtomicBool>,
config: ThreadPerCoreConfig,
submitted_counter: Counter,
#[allow(dead_code)]
completed_counter: Counter,
#[allow(dead_code)]
stolen_counter: Counter,
#[allow(dead_code)]
execution_timer: Timer,
}
#[derive(Debug, Clone)]
pub struct ThreadPerCoreConfig {
pub num_workers: usize,
pub enable_affinity: bool,
pub queue_capacity: usize,
pub enable_work_stealing: bool,
pub steal_batch_size: usize,
}
impl Default for ThreadPerCoreConfig {
fn default() -> Self {
Self {
num_workers: num_cpus::get(),
enable_affinity: true,
queue_capacity: 1024,
enable_work_stealing: true,
steal_batch_size: 16,
}
}
}
pub struct Task {
func: Box<dyn FnOnce() + Send + 'static>,
id: usize,
}
impl Task {
pub fn new<F>(f: F) -> Self
where
F: FnOnce() + Send + 'static,
{
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
Self {
func: Box::new(f),
id: NEXT_ID.fetch_add(1, Ordering::Relaxed),
}
}
fn execute(self) {
(self.func)();
}
pub fn id(&self) -> usize {
self.id
}
}
struct CoreWorker {
#[allow(dead_code)]
id: usize,
handle: Option<JoinHandle<()>>,
local_queue: Worker<Task>,
#[allow(dead_code)]
stealer: Stealer<Task>,
stats: Arc<WorkerStats>,
}
#[derive(Default)]
struct WorkerStats {
executed: AtomicUsize,
#[allow(dead_code)]
stolen_from: AtomicUsize,
stolen_by: AtomicUsize,
idle_time_us: AtomicUsize,
}
impl ThreadPerCore {
pub fn new() -> Result<Self> {
Self::with_config(ThreadPerCoreConfig::default())
}
pub fn with_config(config: ThreadPerCoreConfig) -> Result<Self> {
tracing::info!(
"Initializing thread-per-core executor with {} workers",
config.num_workers
);
let global_queue = Arc::new(Injector::new());
let running = Arc::new(AtomicBool::new(true));
let mut workers = Vec::with_capacity(config.num_workers);
let mut stealers = Vec::new();
let mut worker_stats = Vec::new();
for worker_id in 0..config.num_workers {
let local_queue = Worker::new_fifo();
let stealer = local_queue.stealer();
stealers.push(stealer.clone());
let stats = Arc::new(WorkerStats::default());
worker_stats.push(stats.clone());
let worker = CoreWorker {
id: worker_id,
handle: None,
local_queue,
stealer,
stats,
};
workers.push(worker);
}
let stealers_arc = Arc::new(stealers);
for (worker_id, worker) in workers.iter_mut().enumerate() {
let local_queue = std::mem::replace(&mut worker.local_queue, Worker::new_fifo());
let global_queue = global_queue.clone();
let running = running.clone();
let stealers = stealers_arc.clone();
let stats = worker_stats[worker_id].clone();
let enable_affinity = config.enable_affinity;
let enable_work_stealing = config.enable_work_stealing;
let handle = thread::Builder::new()
.name(format!("rdf-worker-{}", worker_id))
.spawn(move || {
Self::worker_loop(
worker_id,
local_queue,
global_queue,
stealers,
running,
stats,
enable_affinity,
enable_work_stealing,
)
})
.map_err(|e| {
OxirsError::ConcurrencyError(format!("Failed to spawn worker: {}", e))
})?;
worker.handle = Some(handle);
}
Ok(Self {
workers,
global_queue,
running,
config,
submitted_counter: Counter::new("threadpool.submitted".to_string()),
completed_counter: Counter::new("threadpool.completed".to_string()),
stolen_counter: Counter::new("threadpool.stolen".to_string()),
execution_timer: Timer::new("threadpool.execution".to_string()),
})
}
pub fn submit(&self, task: Task) -> Result<()> {
if !self.running.load(Ordering::Relaxed) {
return Err(OxirsError::ConcurrencyError(
"Thread pool is shutting down".to_string(),
));
}
self.global_queue.push(task);
self.submitted_counter.add(1);
Ok(())
}
pub fn submit_batch(&self, tasks: Vec<Task>) -> Result<()> {
if !self.running.load(Ordering::Relaxed) {
return Err(OxirsError::ConcurrencyError(
"Thread pool is shutting down".to_string(),
));
}
for task in tasks {
self.global_queue.push(task);
}
self.submitted_counter.add(1);
Ok(())
}
pub fn stats(&self) -> ThreadPerCoreStats {
let total_executed: usize = self
.workers
.iter()
.map(|w| w.stats.executed.load(Ordering::Relaxed))
.sum();
let total_stolen: usize = self
.workers
.iter()
.map(|w| w.stats.stolen_by.load(Ordering::Relaxed))
.sum();
let total_idle_us: usize = self
.workers
.iter()
.map(|w| w.stats.idle_time_us.load(Ordering::Relaxed))
.sum();
ThreadPerCoreStats {
num_workers: self.config.num_workers,
submitted: self.submitted_counter.get(),
completed: total_executed as u64,
stolen: total_stolen as u64,
avg_idle_time_us: total_idle_us as f64 / self.config.num_workers as f64,
}
}
#[allow(clippy::too_many_arguments)]
fn worker_loop(
worker_id: usize,
local_queue: Worker<Task>,
global_queue: Arc<Injector<Task>>,
stealers: Arc<Vec<Stealer<Task>>>,
running: Arc<AtomicBool>,
stats: Arc<WorkerStats>,
enable_affinity: bool,
enable_work_stealing: bool,
) {
if enable_affinity {
if let Err(e) = Self::set_cpu_affinity(worker_id) {
tracing::warn!("Failed to set CPU affinity for worker {}: {}", worker_id, e);
} else {
tracing::debug!("Worker {} pinned to core {}", worker_id, worker_id);
}
}
while running.load(Ordering::Relaxed) {
if let Some(task) = local_queue.pop() {
task.execute();
stats.executed.fetch_add(1, Ordering::Relaxed);
continue;
}
match global_queue.steal() {
crossbeam_deque::Steal::Success(task) => {
task.execute();
stats.executed.fetch_add(1, Ordering::Relaxed);
continue;
}
crossbeam_deque::Steal::Empty => {}
crossbeam_deque::Steal::Retry => continue,
}
if enable_work_stealing {
let mut found = false;
for (i, stealer) in stealers.iter().enumerate() {
if i == worker_id {
continue; }
match stealer.steal() {
crossbeam_deque::Steal::Success(task) => {
task.execute();
stats.executed.fetch_add(1, Ordering::Relaxed);
stats.stolen_by.fetch_add(1, Ordering::Relaxed);
found = true;
break;
}
crossbeam_deque::Steal::Empty => {}
crossbeam_deque::Steal::Retry => continue,
}
}
if found {
continue;
}
}
let idle_start = std::time::Instant::now();
thread::sleep(Duration::from_micros(10));
let idle_us = idle_start.elapsed().as_micros() as usize;
stats.idle_time_us.fetch_add(idle_us, Ordering::Relaxed);
}
tracing::info!("Worker {} shutting down", worker_id);
}
#[cfg(target_os = "linux")]
fn set_cpu_affinity(core_id: usize) -> Result<()> {
use std::mem;
unsafe {
let mut cpu_set: libc::cpu_set_t = mem::zeroed();
libc::CPU_SET(core_id, &mut cpu_set);
if libc::sched_setaffinity(0, mem::size_of::<libc::cpu_set_t>(), &cpu_set) != 0 {
return Err(OxirsError::ConcurrencyError(format!(
"Failed to set CPU affinity: {}",
std::io::Error::last_os_error()
)));
}
}
Ok(())
}
#[cfg(not(target_os = "linux"))]
fn set_cpu_affinity(_core_id: usize) -> Result<()> {
Ok(())
}
pub fn shutdown(self) -> Result<()> {
tracing::info!("Shutting down thread-per-core executor");
self.running.store(false, Ordering::Relaxed);
for mut worker in self.workers {
if let Some(handle) = worker.handle.take() {
handle.join().map_err(|_| {
OxirsError::ConcurrencyError("Worker thread panicked".to_string())
})?;
}
}
tracing::info!("Thread-per-core executor shut down successfully");
Ok(())
}
}
impl Default for ThreadPerCore {
fn default() -> Self {
Self::new().expect("Failed to create ThreadPerCore executor")
}
}
#[derive(Debug, Clone)]
pub struct ThreadPerCoreStats {
pub num_workers: usize,
pub submitted: u64,
pub completed: u64,
pub stolen: u64,
pub avg_idle_time_us: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
#[test]
fn test_thread_per_core_creation() -> Result<()> {
let config = ThreadPerCoreConfig {
num_workers: 4,
..Default::default()
};
let executor = ThreadPerCore::with_config(config)?;
executor.shutdown()?;
Ok(())
}
#[test]
fn test_task_submission() -> Result<()> {
let executor = ThreadPerCore::new()?;
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let task = Task::new(move || {
counter_clone.fetch_add(1, Ordering::Relaxed);
});
executor.submit(task)?;
thread::sleep(Duration::from_millis(100));
assert_eq!(counter.load(Ordering::Relaxed), 1);
executor.shutdown()?;
Ok(())
}
#[test]
fn test_batch_submission() -> Result<()> {
let executor = ThreadPerCore::new()?;
let counter = Arc::new(AtomicUsize::new(0));
let tasks: Vec<_> = (0..100)
.map(|_| {
let counter = counter.clone();
Task::new(move || {
counter.fetch_add(1, Ordering::Relaxed);
})
})
.collect();
executor.submit_batch(tasks)?;
thread::sleep(Duration::from_millis(500));
assert_eq!(counter.load(Ordering::Relaxed), 100);
executor.shutdown()?;
Ok(())
}
#[test]
fn test_stats() -> Result<()> {
let executor = ThreadPerCore::new()?;
for _ in 0..10 {
let task = Task::new(|| {
thread::sleep(Duration::from_millis(1));
});
executor.submit(task)?;
}
thread::sleep(Duration::from_millis(100));
let stats = executor.stats();
assert_eq!(stats.submitted, 10);
assert!(stats.completed <= 10);
executor.shutdown()?;
Ok(())
}
}