use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::JoinHandle;
use crossbeam_channel::{Receiver, Sender};
use crate::ets::OwnedTerm;
use crate::native::{ExceptionClass, NativeContinuation, NativeFn, ProcessContext, SuspendRequest};
use crate::scheduler::lock_or_recover;
use crate::term::Term;
pub const DEFAULT_DIRTY_QUEUE_DEPTH: usize = 1024;
pub const DEFAULT_DIRTY_IO_THREADS: usize = 10;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum DirtySchedulerKind {
Cpu,
Io,
}
pub mod oneshot {
use std::sync::mpsc;
pub struct Sender<T>(mpsc::SyncSender<T>);
pub struct Receiver<T>(mpsc::Receiver<T>);
pub struct SendError<T>(pub T);
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct RecvError;
#[must_use]
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let (sender, receiver) = mpsc::sync_channel(1);
(Sender(sender), Receiver(receiver))
}
impl<T> Sender<T> {
pub fn send(self, value: T) -> Result<(), SendError<T>> {
self.0.send(value).map_err(|error| SendError(error.0))
}
}
impl<T> Receiver<T> {
pub fn recv(self) -> Result<T, RecvError> {
self.0.recv().map_err(|_| RecvError)
}
}
}
#[derive(Debug)]
pub struct DirtyResult {
pub result: Result<Term, Term>,
pub owned_result: Option<OwnedTerm>,
pub exception_class: ExceptionClass,
pub exception_stacktrace: Term,
pub suspend: Option<SuspendRequest>,
pub trampoline: Option<OwnedDirtyTrampoline>,
}
#[derive(Debug)]
pub struct OwnedDirtyTrampoline {
pub fun: OwnedTerm,
pub args: Vec<OwnedTerm>,
pub continuation: NativeContinuation,
}
pub struct DirtyJob {
pub pid: u64,
pub function: NativeFn,
pub args: Vec<Term>,
pub context: ProcessContext<'static>,
pub result_sender: oneshot::Sender<DirtyResult>,
}
unsafe impl Send for DirtyJob {}
pub struct DirtyTask {
task: Box<dyn FnOnce() + Send + 'static>,
}
impl DirtyTask {
pub fn new(task: impl FnOnce() + Send + 'static) -> Self {
Self {
task: Box::new(task),
}
}
}
enum DirtyMessage {
RunNative(Box<DirtyJob>),
RunTask(DirtyTask),
Shutdown,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum DirtySubmitError {
ShutDown,
QueueFull,
Disconnected,
}
pub struct DirtyPool {
name: String,
thread_count: usize,
queue_depth: usize,
sender: Sender<DirtyMessage>,
shutdown: AtomicBool,
threads: Mutex<Vec<JoinHandle<()>>>,
worker_names: Vec<String>,
}
impl DirtyPool {
#[must_use]
pub fn new(name: &str, thread_count: usize) -> Self {
Self::with_queue_depth(name, thread_count, DEFAULT_DIRTY_QUEUE_DEPTH)
}
#[must_use]
pub fn default_cpu() -> Self {
Self::new("dirty-cpu", num_cpus::get())
}
#[must_use]
pub fn default_io() -> Self {
Self::new("dirty-io", DEFAULT_DIRTY_IO_THREADS)
}
#[must_use]
pub fn with_queue_depth(name: &str, thread_count: usize, queue_depth: usize) -> Self {
let pool_thread_count = thread_count.max(1);
let pool_queue_depth = queue_depth.max(1);
let (sender, receiver) = crossbeam_channel::bounded(pool_queue_depth);
let mut threads = Vec::with_capacity(pool_thread_count);
let mut worker_names = Vec::with_capacity(pool_thread_count);
for index in 0..pool_thread_count {
let thread_name = format!("{name}-{index}");
let receiver_for_thread = receiver.clone();
match std::thread::Builder::new()
.name(thread_name.clone())
.spawn(move || worker_loop(receiver_for_thread))
{
Ok(handle) => {
worker_names.push(thread_name);
threads.push(handle);
}
Err(_error) => break,
}
}
Self {
name: name.to_owned(),
thread_count: worker_names.len(),
queue_depth: pool_queue_depth,
sender,
shutdown: AtomicBool::new(false),
threads: Mutex::new(threads),
worker_names,
}
}
pub fn submit(&self, job: DirtyJob) -> Result<(), DirtySubmitError> {
if self.shutdown.load(Ordering::Acquire) {
return Err(DirtySubmitError::ShutDown);
}
self.sender
.try_send(DirtyMessage::RunNative(Box::new(job)))
.map_err(|error| match error {
crossbeam_channel::TrySendError::Full(_) => DirtySubmitError::QueueFull,
crossbeam_channel::TrySendError::Disconnected(_) => DirtySubmitError::Disconnected,
})
}
pub fn submit_task(&self, task: DirtyTask) -> Result<(), DirtySubmitError> {
if self.shutdown.load(Ordering::Acquire) {
return Err(DirtySubmitError::ShutDown);
}
self.sender
.try_send(DirtyMessage::RunTask(task))
.map_err(|error| match error {
crossbeam_channel::TrySendError::Full(_) => DirtySubmitError::QueueFull,
crossbeam_channel::TrySendError::Disconnected(_) => DirtySubmitError::Disconnected,
})
}
pub fn shutdown(&self) {
if self.shutdown.swap(true, Ordering::AcqRel) {
return;
}
let mut threads = lock_or_recover(&self.threads);
for _ in 0..threads.len() {
let _ = self.sender.send(DirtyMessage::Shutdown);
}
for handle in threads.drain(..) {
if let Err(payload) = handle.join() {
std::panic::resume_unwind(payload);
}
}
}
#[must_use]
pub fn thread_count(&self) -> usize {
self.thread_count
}
#[must_use]
pub fn queue_depth(&self) -> usize {
self.queue_depth
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn worker_names(&self) -> &[String] {
&self.worker_names
}
#[must_use]
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(Ordering::Acquire)
}
}
impl Drop for DirtyPool {
fn drop(&mut self) {
self.shutdown();
}
}
fn worker_loop(receiver: Receiver<DirtyMessage>) {
while let Ok(message) = receiver.recv() {
match message {
DirtyMessage::RunNative(mut job) => {
let _pid = job.pid;
let result = (job.function)(&job.args, &mut job.context);
let raw_result = match &result {
Ok(value) | Err(value) => *value,
};
let owned_result = job.context.take_detached_result(raw_result).or_else(|| {
if raw_result.is_list() || raw_result.is_boxed() {
crate::ets::copy_term_to_ets(raw_result).ok()
} else {
None
}
});
let result = match owned_result.as_ref() {
Some(owned) => result.map(|_| owned.root()).map_err(|_| owned.root()),
None => result,
};
let exception_class = job.context.take_exception_class();
let exception_stacktrace = job.context.take_exception_stacktrace();
let suspend = job.context.take_suspend().filter(|_| result.is_ok());
let trampoline = match job.context.take_trampoline().filter(|_| result.is_ok()) {
None => None,
Some(request) => match own_dirty_trampoline(request) {
Ok(owned) => Some(owned),
Err(reason) => {
let _ = job.result_sender.send(DirtyResult {
result: Err(Term::atom(crate::atom::Atom::BADARG)),
owned_result: None,
exception_class: ExceptionClass::Error,
exception_stacktrace: Term::NIL,
suspend: None,
trampoline: None,
});
let _trace = reason;
continue;
}
},
};
let _ = job.result_sender.send(DirtyResult {
result,
owned_result,
exception_class,
exception_stacktrace,
suspend,
trampoline,
});
}
DirtyMessage::RunTask(task) => {
(task.task)();
}
DirtyMessage::Shutdown => break,
}
}
}
fn own_dirty_trampoline(
request: crate::native::TrampolineRequest,
) -> Result<OwnedDirtyTrampoline, &'static str> {
let Some(continuation) = request.continuation else {
return Err("dirty trampoline requires a continuation");
};
let mut holds_terms = false;
continuation.for_each_term(&mut |_| holds_terms = true);
if holds_terms {
return Err("dirty trampoline continuation must not hold heap terms");
}
let fun = own_term(request.fun).map_err(|_| "dirty trampoline fun copy failed")?;
let mut args = Vec::with_capacity(request.args.len());
for arg in request.args {
args.push(own_term(arg).map_err(|_| "dirty trampoline arg copy failed")?);
}
Ok(OwnedDirtyTrampoline {
fun,
args,
continuation,
})
}
fn own_term(term: Term) -> Result<OwnedTerm, crate::ets::EtsError> {
if term.is_list() || term.is_boxed() {
crate::ets::copy_term_to_ets(term)
} else {
Ok(OwnedTerm::immediate(term))
}
}
#[cfg(test)]
mod tests {
use super::{DirtyJob, DirtyPool, DirtySchedulerKind, oneshot};
use crate::native::{ExceptionClass, ProcessContext};
use crate::term::Term;
fn forty_two(_args: &[Term], _context: &mut ProcessContext) -> Result<Term, Term> {
Ok(Term::small_int(42))
}
#[test]
fn dirty_pool_starts_named_threads_and_shuts_down_cleanly() {
let pool = DirtyPool::new("dirty-test", 4);
assert_eq!(pool.thread_count(), 4);
assert_eq!(pool.worker_names().len(), 4);
assert_eq!(
pool.worker_names(),
&[
"dirty-test-0".to_owned(),
"dirty-test-1".to_owned(),
"dirty-test-2".to_owned(),
"dirty-test-3".to_owned(),
]
);
pool.shutdown();
assert!(pool.is_shutdown());
pool.shutdown();
}
#[test]
fn dirty_pool_executes_submitted_job_and_returns_result() {
let pool = DirtyPool::with_queue_depth("dirty-test-job", 1, 1);
let (result_sender, result_receiver) = oneshot::channel();
assert_eq!(
pool.submit(DirtyJob {
pid: 7,
function: forty_two,
args: Vec::new(),
context: ProcessContext::new(),
result_sender,
}),
Ok(())
);
let result = result_receiver.recv().expect("dirty result");
assert_eq!(result.result, Ok(Term::small_int(42)));
assert!(result.owned_result.is_none());
assert_eq!(result.exception_class, ExceptionClass::Error);
assert_eq!(result.exception_stacktrace, Term::NIL);
pool.shutdown();
}
#[test]
fn dirty_scheduler_kind_distinguishes_cpu_and_io() {
assert_eq!(DirtySchedulerKind::Cpu, DirtySchedulerKind::Cpu);
assert_ne!(DirtySchedulerKind::Cpu, DirtySchedulerKind::Io);
}
}