use super::{ComputeConfig, ComputeMetrics};
use anyhow::Result;
use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[derive(Clone)]
pub struct ComputePool {
pool: Arc<rayon::ThreadPool>,
metrics: Arc<ComputeMetrics>,
config: ComputeConfig,
}
impl std::fmt::Debug for ComputePool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComputePool")
.field("num_threads", &self.pool.current_num_threads())
.field("metrics", &self.metrics)
.field("config", &self.config)
.finish()
}
}
impl ComputePool {
pub fn new(config: ComputeConfig) -> Result<Self> {
let pool = config.build_pool()?;
let metrics = Arc::new(ComputeMetrics::new());
Ok(Self {
pool: Arc::new(pool),
metrics,
config,
})
}
pub fn with_defaults() -> Result<Self> {
Self::new(ComputeConfig::default())
}
pub fn execute_sync<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R + Send,
R: Send,
{
self.pool.install(f)
}
pub async fn execute<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
self.metrics.record_task_start();
let start = std::time::Instant::now();
let pool = self.pool.clone();
let result = tokio_rayon::spawn(move || pool.install(f)).await;
self.metrics.record_task_completion(start.elapsed());
Ok(result)
}
pub async fn execute_scoped<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce(&rayon::Scope) -> R + Send + 'static,
R: Send + 'static,
{
self.metrics.record_task_start();
let start = std::time::Instant::now();
let pool = self.pool.clone();
let result = tokio_rayon::spawn(move || {
pool.install(|| {
let mut result = None;
rayon::scope(|s| {
result = Some(f(s));
});
result.unwrap()
})
})
.await;
self.metrics.record_task_completion(start.elapsed());
Ok(result)
}
pub async fn execute_scoped_fifo<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce(&rayon::ScopeFifo) -> R + Send + 'static,
R: Send + 'static,
{
self.metrics.record_task_start();
let start = std::time::Instant::now();
let pool = self.pool.clone();
let result = tokio_rayon::spawn(move || {
pool.install(|| {
let mut result = None;
rayon::scope_fifo(|s| {
result = Some(f(s));
});
result.unwrap()
})
})
.await;
self.metrics.record_task_completion(start.elapsed());
Ok(result)
}
pub async fn join<F1, F2, R1, R2>(&self, f1: F1, f2: F2) -> Result<(R1, R2)>
where
F1: FnOnce() -> R1 + Send + 'static,
F2: FnOnce() -> R2 + Send + 'static,
R1: Send + 'static,
R2: Send + 'static,
{
self.execute(move || rayon::join(f1, f2)).await
}
pub fn metrics(&self) -> &ComputeMetrics {
&self.metrics
}
pub fn num_threads(&self) -> usize {
self.pool.current_num_threads()
}
pub async fn install<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let pool = self.pool.clone();
self.metrics.record_task_start();
let start = std::time::Instant::now();
let result = tokio_rayon::spawn(move || pool.install(f)).await;
self.metrics.record_task_completion(start.elapsed());
Ok(result)
}
}
pub struct ComputeHandle<T> {
inner: Pin<Box<dyn Future<Output = T> + Send>>,
}
impl<T> ComputeHandle<T> {
pub(crate) fn new<F>(future: F) -> Self
where
F: Future<Output = T> + Send + 'static,
{
Self {
inner: Box::pin(future),
}
}
}
impl<T> Future for ComputeHandle<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.as_mut().poll(cx)
}
}
#[async_trait]
pub trait ComputePoolExt {
async fn parallel_batch<T, F, R>(
&self,
items: Vec<T>,
batch_size: usize,
f: F,
) -> Result<Vec<R>>
where
T: Send + Sync + 'static,
F: Fn(&[T]) -> Vec<R> + Send + Sync + 'static,
R: Send + 'static;
async fn parallel_map<T, F, R>(&self, items: Vec<T>, f: F) -> Result<Vec<R>>
where
T: Send + Sync + 'static,
F: Fn(T) -> R + Send + Sync + 'static,
R: Send + 'static;
}
#[async_trait]
impl ComputePoolExt for ComputePool {
async fn parallel_batch<T, F, R>(
&self,
items: Vec<T>,
batch_size: usize,
f: F,
) -> Result<Vec<R>>
where
T: Send + Sync + 'static,
F: Fn(&[T]) -> Vec<R> + Send + Sync + 'static,
R: Send + 'static,
{
use rayon::prelude::*;
self.install(move || items.par_chunks(batch_size).flat_map(f).collect())
.await
}
async fn parallel_map<T, F, R>(&self, items: Vec<T>, f: F) -> Result<Vec<R>>
where
T: Send + Sync + 'static,
F: Fn(T) -> R + Send + Sync + 'static,
R: Send + 'static,
{
use rayon::prelude::*;
self.install(move || items.into_par_iter().map(f).collect())
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use parking_lot::Mutex;
#[tokio::test]
async fn test_compute_pool_execute() {
let pool = ComputePool::with_defaults().unwrap();
let result = pool
.execute(|| {
let mut sum = 0u64;
for i in 0..1000 {
sum += i;
}
sum
})
.await
.unwrap();
assert_eq!(result, 499500);
}
#[tokio::test]
async fn test_compute_pool_join() {
let pool = ComputePool::with_defaults().unwrap();
let (a, b) = pool.join(|| 2 + 2, || 3 * 3).await.unwrap();
assert_eq!(a, 4);
assert_eq!(b, 9);
}
#[tokio::test]
async fn test_compute_pool_execute_sync() {
let pool = Arc::new(ComputePool::with_defaults().unwrap());
let pool_clone = pool.clone();
let result = tokio::task::spawn_blocking(move || {
pool_clone.execute_sync(|| {
let mut sum = 0u64;
for i in 0..1000 {
sum += i;
}
sum
})
})
.await
.unwrap();
assert_eq!(result, 499500);
}
#[tokio::test]
async fn test_compute_pool_scoped() {
use std::sync::mpsc;
let pool = ComputePool::with_defaults().unwrap();
let mut result = pool
.execute_scoped(|scope| {
let (tx, rx) = mpsc::channel();
for i in 0..4 {
let tx = tx.clone();
scope.spawn(move |_| {
tx.send((i, i * 2)).unwrap();
});
}
drop(tx);
let mut results = vec![0; 4];
for (i, val) in rx {
results[i] = val;
}
results
})
.await
.unwrap();
result.sort();
assert_eq!(result, vec![0, 2, 4, 6]);
}
}