use std::fmt::Debug;
use crate::Spawner;
use crate::custom::{BoxedFuture, SpawnCustom};
use thread_aware::ThreadAware;
use thread_aware::affinity::Affinity;
use thread_aware::closure::ThreadAwareAsyncFnOnce;
struct Layered<F, S> {
layer: F,
inner: S,
}
impl<F: Clone, S: Clone> Clone for Layered<F, S> {
fn clone(&self) -> Self {
Self {
layer: self.layer.clone(),
inner: self.inner.clone(),
}
}
}
impl<F: Send, S: ThreadAware> ThreadAware for Layered<F, S> {
fn relocate(&mut self, source: Option<Affinity>, destination: Affinity) {
self.inner.relocate(source, destination);
}
}
impl<F, S> SpawnCustom for Layered<F, S>
where
F: Fn(BoxedFuture) -> BoxedFuture + Clone + Send + Sync + 'static,
S: SpawnCustom + Clone,
{
fn spawn(&self, task: BoxedFuture) {
self.inner.spawn((self.layer)(task));
}
fn spawn_anywhere(&self, task: Box<dyn ThreadAwareAsyncFnOnce<()>>) {
let layered = Box::new(LayeredTask {
task,
layer: self.layer.clone(),
});
self.inner.spawn_anywhere(layered);
}
}
struct LayeredTask<F> {
task: Box<dyn ThreadAwareAsyncFnOnce<()>>,
layer: F,
}
impl<F: Send> ThreadAware for LayeredTask<F> {
fn relocate(&mut self, source: Option<Affinity>, destination: Affinity) {
self.task.relocate(source, destination);
}
}
impl<F> ThreadAwareAsyncFnOnce<()> for LayeredTask<F>
where
F: Fn(BoxedFuture) -> BoxedFuture + Send + 'static,
{
fn call_once(self: Box<Self>) -> thread_aware::closure::BoxFuture<'static, ()> {
let future = self.task.call_once();
(self.layer)(future)
}
}
#[cfg(feature = "tokio")]
#[derive(Clone)]
struct TokioSpawner(Option<::tokio::runtime::Handle>);
#[cfg(feature = "tokio")]
impl ThreadAware for TokioSpawner {
fn relocate(&mut self, _source: Option<Affinity>, _destination: Affinity) {}
}
#[cfg(feature = "tokio")]
impl SpawnCustom for TokioSpawner {
fn spawn(&self, task: BoxedFuture) {
match &self.0 {
Some(h) => {
h.spawn(task);
}
None => {
::tokio::spawn(task);
}
}
}
fn spawn_anywhere(&self, task: Box<dyn ThreadAwareAsyncFnOnce<()>>) {
self.spawn(task.call_once());
}
}
pub struct CustomSpawnerBuilder<S> {
spawner: S,
name: &'static str,
}
impl CustomSpawnerBuilder<()> {
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[must_use]
pub fn tokio() -> CustomSpawnerBuilder<impl SpawnCustom + Clone> {
CustomSpawnerBuilder {
spawner: TokioSpawner(None),
name: "tokio",
}
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[must_use]
pub fn tokio_with_handle(handle: ::tokio::runtime::Handle) -> CustomSpawnerBuilder<impl SpawnCustom + Clone> {
CustomSpawnerBuilder {
spawner: TokioSpawner(Some(handle)),
name: "tokio",
}
}
pub fn new<S: SpawnCustom + Clone>(base: S) -> CustomSpawnerBuilder<S> {
CustomSpawnerBuilder {
spawner: base,
name: "custom",
}
}
}
impl<S: SpawnCustom + Clone> CustomSpawnerBuilder<S> {
#[must_use]
pub fn name(mut self, name: &'static str) -> Self {
self.name = name;
self
}
pub fn layer<F>(self, layer: F) -> CustomSpawnerBuilder<impl SpawnCustom + Clone>
where
F: Fn(BoxedFuture) -> BoxedFuture + Clone + Send + Sync + 'static,
{
CustomSpawnerBuilder {
spawner: Layered {
layer,
inner: self.spawner,
},
name: self.name,
}
}
pub fn build(self) -> Spawner {
Spawner::new_custom(self.name, self.spawner)
}
}
#[expect(clippy::missing_fields_in_debug, reason = "spawner is opaque and not useful in debug output")]
impl<S> Debug for CustomSpawnerBuilder<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("CustomSpawnerBuilder");
s.field("name", &self.name);
s.finish()
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use super::*;
#[derive(Clone)]
struct TrackingSpawner {
relocated: &'static AtomicBool,
}
impl ThreadAware for TrackingSpawner {
fn relocate(&mut self, _source: Option<Affinity>, _destination: Affinity) {
self.relocated.store(true, Ordering::SeqCst);
}
}
impl SpawnCustom for TrackingSpawner {
fn spawn(&self, _task: BoxedFuture) {}
fn spawn_anywhere(&self, mut task: Box<dyn ThreadAwareAsyncFnOnce<()>>) {
let affinities = thread_aware::affinity::pinned_affinities(&[2]);
task.relocate(Some(affinities[0]), affinities[1]);
}
}
struct NoopTask;
impl ThreadAware for NoopTask {
fn relocate(&mut self, _source: Option<Affinity>, _destination: Affinity) {}
}
impl ThreadAwareAsyncFnOnce<()> for NoopTask {
fn call_once(self: Box<Self>) -> thread_aware::closure::BoxFuture<'static, ()> {
Box::pin(async {})
}
}
#[test]
fn layered_relocate_forwards_to_inner() {
static RELOCATED: AtomicBool = AtomicBool::new(false);
let affinities = thread_aware::affinity::pinned_affinities(&[2]);
let mut layered = Layered {
layer: |task: BoxedFuture| -> BoxedFuture { task },
inner: TrackingSpawner { relocated: &RELOCATED },
};
layered.relocate(Some(affinities[0]), affinities[1]);
assert!(RELOCATED.load(Ordering::SeqCst), "Layered must forward relocate to inner");
layered.inner.spawn(Box::pin(async {}));
layered.inner.spawn_anywhere(Box::new(NoopTask));
let covered = (layered.layer)(Box::pin(async {}));
futures::executor::block_on(covered);
futures::executor::block_on(Box::new(NoopTask).call_once());
}
#[test]
fn layered_task_relocate_forwards_to_inner() {
static RELOCATED: AtomicBool = AtomicBool::new(false);
#[derive(Clone)]
struct Tracker(&'static AtomicBool);
impl ThreadAware for Tracker {
fn relocate(&mut self, _source: Option<Affinity>, _destination: Affinity) {
self.0.store(true, Ordering::SeqCst);
}
}
impl ThreadAwareAsyncFnOnce<()> for Tracker {
fn call_once(self: Box<Self>) -> thread_aware::closure::BoxFuture<'static, ()> {
Box::pin(async {})
}
}
let affinities = thread_aware::affinity::pinned_affinities(&[2]);
let mut task = LayeredTask {
task: Box::new(Tracker(&RELOCATED)),
layer: |task: BoxedFuture| -> BoxedFuture { task },
};
task.relocate(Some(affinities[0]), affinities[1]);
assert!(RELOCATED.load(Ordering::SeqCst), "LayeredTask must forward relocate to inner task");
let covered = (task.layer)(Box::pin(async {}));
futures::executor::block_on(covered);
let fut = task.task.call_once();
futures::executor::block_on(fut);
}
}