use std::fmt::Debug;
use thread_aware::ThreadAware;
use thread_aware::affinity::Affinity;
use thread_aware::closure::ThreadAwareAsyncFnOnce;
use crate::Spawner;
use crate::custom::{BoxedBlockingTask, BoxedFuture, SpawnCustom};
struct Layered<FL, BL, S> {
future_layer: FL,
blocking_layer: BL,
inner: S,
}
impl<FL: Clone, BL: Clone, S: Clone> Clone for Layered<FL, BL, S> {
fn clone(&self) -> Self {
Self {
future_layer: self.future_layer.clone(),
blocking_layer: self.blocking_layer.clone(),
inner: self.inner.clone(),
}
}
}
impl<FL: Send, BL: Send, S: ThreadAware> ThreadAware for Layered<FL, BL, S> {
fn relocate(&mut self, source: Option<Affinity>, destination: Affinity) {
self.inner.relocate(source, destination);
}
}
impl<FL, BL, S> SpawnCustom for Layered<FL, BL, S>
where
FL: Fn(BoxedFuture) -> BoxedFuture + Clone + Send + Sync + 'static,
BL: Fn(BoxedBlockingTask) -> BoxedBlockingTask + Clone + Send + Sync + 'static,
S: SpawnCustom + Clone,
{
fn spawn(&self, task: BoxedFuture) {
self.inner.spawn((self.future_layer)(task));
}
fn spawn_anywhere(&self, task: Box<dyn ThreadAwareAsyncFnOnce<()>>) {
let layered = Box::new(LayeredTask {
task,
layer: self.future_layer.clone(),
});
self.inner.spawn_anywhere(layered);
}
fn spawn_blocking(&self, task: BoxedBlockingTask) {
self.inner.spawn_blocking((self.blocking_layer)(task));
}
}
struct LayeredTask<F> {
task: Box<dyn ThreadAwareAsyncFnOnce<()>>,
layer: F,
}
impl<F: Send> ThreadAware for LayeredTask<F> {
fn relocate(&mut self, source: Option<Affinity>, destination: Affinity) {
self.task.relocate(source, destination);
}
}
impl<F> ThreadAwareAsyncFnOnce<()> for LayeredTask<F>
where
F: Fn(BoxedFuture) -> BoxedFuture + Send + 'static,
{
fn call_once(self: Box<Self>) -> thread_aware::closure::BoxFuture<'static, ()> {
let future = self.task.call_once();
(self.layer)(future)
}
}
#[cfg(feature = "tokio")]
#[derive(Clone)]
struct TokioSpawner(Option<::tokio::runtime::Handle>);
#[cfg(feature = "tokio")]
impl ThreadAware for TokioSpawner {
fn relocate(&mut self, _source: Option<Affinity>, _destination: Affinity) {}
}
#[cfg(feature = "tokio")]
impl SpawnCustom for TokioSpawner {
fn spawn(&self, task: BoxedFuture) {
match &self.0 {
Some(h) => {
h.spawn(task);
}
None => {
::tokio::spawn(task);
}
}
}
fn spawn_anywhere(&self, task: Box<dyn ThreadAwareAsyncFnOnce<()>>) {
self.spawn(task.call_once());
}
fn spawn_blocking(&self, task: BoxedBlockingTask) {
match &self.0 {
Some(h) => {
h.spawn_blocking(task);
}
None => {
::tokio::task::spawn_blocking(task);
}
}
}
}
pub struct CustomSpawnerBuilder<S> {
spawner: S,
name: &'static str,
}
impl CustomSpawnerBuilder<()> {
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[must_use]
pub fn tokio() -> CustomSpawnerBuilder<impl SpawnCustom + Clone> {
CustomSpawnerBuilder {
spawner: TokioSpawner(None),
name: "tokio",
}
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[must_use]
pub fn tokio_with_handle(handle: ::tokio::runtime::Handle) -> CustomSpawnerBuilder<impl SpawnCustom + Clone> {
CustomSpawnerBuilder {
spawner: TokioSpawner(Some(handle)),
name: "tokio",
}
}
pub fn new<S: SpawnCustom + Clone>(base: S) -> CustomSpawnerBuilder<S> {
CustomSpawnerBuilder {
spawner: base,
name: "custom",
}
}
}
impl<S: SpawnCustom + Clone> CustomSpawnerBuilder<S> {
#[must_use]
pub fn name(mut self, name: &'static str) -> Self {
self.name = name;
self
}
pub fn layer<FL, BL>(self, future_layer: FL, blocking_layer: BL) -> CustomSpawnerBuilder<impl SpawnCustom + Clone>
where
FL: Fn(BoxedFuture) -> BoxedFuture + Clone + Send + Sync + 'static,
BL: Fn(BoxedBlockingTask) -> BoxedBlockingTask + Clone + Send + Sync + 'static,
{
CustomSpawnerBuilder {
spawner: Layered {
future_layer,
blocking_layer,
inner: self.spawner,
},
name: self.name,
}
}
pub fn build(self) -> Spawner {
Spawner::new_custom(self.name, self.spawner)
}
}
#[expect(clippy::missing_fields_in_debug, reason = "spawner is opaque and not useful in debug output")]
impl<S> Debug for CustomSpawnerBuilder<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("CustomSpawnerBuilder");
s.field("name", &self.name);
s.finish()
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use super::*;
#[derive(Clone)]
struct TrackingSpawner {
relocated: &'static AtomicBool,
}
impl ThreadAware for TrackingSpawner {
fn relocate(&mut self, _source: Option<Affinity>, _destination: Affinity) {
self.relocated.store(true, Ordering::SeqCst);
}
}
impl SpawnCustom for TrackingSpawner {
fn spawn(&self, _task: BoxedFuture) {}
fn spawn_anywhere(&self, mut task: Box<dyn ThreadAwareAsyncFnOnce<()>>) {
let affinities = thread_aware::affinity::pinned_affinities(&[2]);
task.relocate(Some(affinities[0]), affinities[1]);
}
fn spawn_blocking(&self, _task: BoxedBlockingTask) {}
}
struct NoopTask;
impl ThreadAware for NoopTask {
fn relocate(&mut self, _source: Option<Affinity>, _destination: Affinity) {}
}
impl ThreadAwareAsyncFnOnce<()> for NoopTask {
fn call_once(self: Box<Self>) -> thread_aware::closure::BoxFuture<'static, ()> {
Box::pin(async {})
}
}
#[test]
fn layered_relocate_forwards_to_inner() {
static RELOCATED: AtomicBool = AtomicBool::new(false);
static BLOCKING_LAYER_RAN: AtomicBool = AtomicBool::new(false);
let affinities = thread_aware::affinity::pinned_affinities(&[2]);
let mut layered = Layered {
future_layer: |task: BoxedFuture| -> BoxedFuture { task },
blocking_layer: |task: BoxedBlockingTask| -> BoxedBlockingTask {
BLOCKING_LAYER_RAN.store(true, Ordering::SeqCst);
task
},
inner: TrackingSpawner { relocated: &RELOCATED },
};
layered.relocate(Some(affinities[0]), affinities[1]);
assert!(RELOCATED.load(Ordering::SeqCst), "Layered must forward relocate to inner");
layered.spawn(Box::pin(async {}));
layered.spawn_anywhere(Box::new(NoopTask));
layered.spawn_blocking(Box::new(|| {}));
assert!(
BLOCKING_LAYER_RAN.load(Ordering::SeqCst),
"blocking_layer must run on spawn_blocking"
);
let covered = (layered.future_layer)(Box::pin(async {}));
futures::executor::block_on(covered);
futures::executor::block_on(Box::new(NoopTask).call_once());
}
#[test]
fn layered_task_relocate_forwards_to_inner() {
static RELOCATED: AtomicBool = AtomicBool::new(false);
#[derive(Clone)]
struct Tracker(&'static AtomicBool);
impl ThreadAware for Tracker {
fn relocate(&mut self, _source: Option<Affinity>, _destination: Affinity) {
self.0.store(true, Ordering::SeqCst);
}
}
impl ThreadAwareAsyncFnOnce<()> for Tracker {
fn call_once(self: Box<Self>) -> thread_aware::closure::BoxFuture<'static, ()> {
Box::pin(async {})
}
}
let affinities = thread_aware::affinity::pinned_affinities(&[2]);
let mut task = LayeredTask {
task: Box::new(Tracker(&RELOCATED)),
layer: |task: BoxedFuture| -> BoxedFuture { task },
};
task.relocate(Some(affinities[0]), affinities[1]);
assert!(RELOCATED.load(Ordering::SeqCst), "LayeredTask must forward relocate to inner task");
let covered = (task.layer)(Box::pin(async {}));
futures::executor::block_on(covered);
let fut = task.task.call_once();
futures::executor::block_on(fut);
}
}