use std::future::Future;
use std::pin::Pin;
use crate::error::CoreError;
pub trait CpuPool: Send + Sync {
fn spawn_raw(
&self,
task: Box<dyn FnOnce() + Send + 'static>,
) -> Pin<Box<dyn Future<Output = Result<(), CoreError>> + Send + 'static>>;
}
pub trait CpuPoolExt: CpuPool {
fn spawn<F, R>(
&self,
f: F,
) -> Pin<Box<dyn Future<Output = Result<R, CoreError>> + Send + 'static>>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = tokio::sync::oneshot::channel::<R>();
let task: Box<dyn FnOnce() + Send + 'static> = Box::new(move || {
let result = f();
let _ = tx.send(result);
});
let raw_fut = self.spawn_raw(task);
Box::pin(async move {
raw_fut.await?;
rx.await.map_err(|_| CoreError::TaskAborted {
reason: "task result channel dropped (task panicked or pool shut down)".into(),
})
})
}
}
impl<T: CpuPool + ?Sized> CpuPoolExt for T {}
pub struct RayonThreadPool {
pool: rayon::ThreadPool,
}
impl RayonThreadPool {
pub fn new(num_threads: usize) -> Result<Self, CoreError> {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.map_err(|e| CoreError::ThreadPoolBuild(e.to_string()))?;
Ok(Self { pool })
}
pub fn with_default_threads() -> Result<Self, CoreError> {
let pool = rayon::ThreadPoolBuilder::new()
.build()
.map_err(|e| CoreError::ThreadPoolBuild(e.to_string()))?;
Ok(Self { pool })
}
pub fn rayon_pool(&self) -> &rayon::ThreadPool {
&self.pool
}
}
impl CpuPool for RayonThreadPool {
fn spawn_raw(
&self,
task: Box<dyn FnOnce() + Send + 'static>,
) -> Pin<Box<dyn Future<Output = Result<(), CoreError>> + Send + 'static>> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
self.pool.spawn(move || {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task));
match result {
Ok(()) => {
let _ = tx.send(Ok(()));
}
Err(panic_payload) => {
let msg = panic_payload
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| panic_payload.downcast_ref::<&str>().copied())
.unwrap_or("unknown panic")
.to_string();
let _ = tx.send(Err(msg));
}
}
});
Box::pin(async move {
match rx.await {
Ok(Ok(())) => Ok(()),
Ok(Err(panic_msg)) => Err(CoreError::TaskAborted {
reason: format!("task panicked: {panic_msg}"),
}),
Err(_) => Err(CoreError::TaskAborted {
reason: "pool shut down before task completed".into(),
}),
}
})
}
}