use std::{fmt, future::Future, io, pin::Pin, sync::mpsc, thread};
use tokio::sync::mpsc as tokio_mpsc;
pub struct RuntimeBridge {
tx: tokio_mpsc::UnboundedSender<BridgedTask>,
_worker: thread::JoinHandle<()>,
}
type BridgedTask = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
#[derive(Debug)]
pub enum BridgeError {
WorkerDead,
ResponseLost,
}
impl fmt::Display for BridgeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::WorkerDead => {
f.write_str("runtime bridge: worker thread terminated; dispatch channel closed")
}
Self::ResponseLost => f.write_str(
"runtime bridge: task accepted but no reply received (task panicked \
or runtime shut down mid-flight)",
),
}
}
}
impl std::error::Error for BridgeError {}
impl RuntimeBridge {
pub fn new() -> io::Result<Self> {
Self::with_thread_name("heddle-runtime-bridge")
}
pub fn with_thread_name(thread_name: impl Into<String>) -> io::Result<Self> {
let (tx, mut rx) = tokio_mpsc::unbounded_channel::<BridgedTask>();
let worker = thread::Builder::new()
.name(thread_name.into())
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("build heddle-runtime-bridge worker runtime");
runtime.block_on(async move {
while let Some(task) = rx.recv().await {
tokio::spawn(task);
}
});
})?;
Ok(Self {
tx,
_worker: worker,
})
}
pub fn block_on<F, T>(&self, future: F) -> Result<T, BridgeError>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let (reply_tx, reply_rx) = mpsc::sync_channel::<T>(1);
let task: BridgedTask = Box::pin(async move {
let value = future.await;
let _ = reply_tx.send(value);
});
self.tx.send(task).map_err(|_| BridgeError::WorkerDead)?;
reply_rx.recv().map_err(|_| BridgeError::ResponseLost)
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
Arc, Barrier,
atomic::{AtomicUsize, Ordering},
},
time::{Duration, Instant},
};
use super::*;
#[test]
fn block_on_from_non_tokio_thread() {
let bridge = RuntimeBridge::new().expect("spawn bridge");
let value = bridge.block_on(async { 1 + 2 }).expect("ok");
assert_eq!(value, 3);
}
#[tokio::test(flavor = "current_thread")]
async fn block_on_from_current_thread_runtime() {
let bridge = RuntimeBridge::new().expect("spawn bridge");
let value = bridge.block_on(async { "ok".to_string() }).expect("ok");
assert_eq!(value, "ok");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn block_on_from_multi_thread_runtime() {
let bridge = RuntimeBridge::new().expect("spawn bridge");
let value = bridge.block_on(async { 42_u64 }).expect("ok");
assert_eq!(value, 42);
}
#[test]
fn block_on_sequential_calls() {
let bridge = Arc::new(RuntimeBridge::new().expect("spawn bridge"));
for i in 0..5 {
let got: u32 = bridge.block_on(async move { i * 2 }).expect("ok");
assert_eq!(got, i * 2);
}
}
#[test]
fn drop_shuts_down_worker() {
let bridge = RuntimeBridge::new().expect("spawn bridge");
let _ = bridge.block_on(async { 1 });
drop(bridge);
}
#[test]
fn concurrent_callers_run_in_parallel() {
const N: usize = 4;
const DELAY: Duration = Duration::from_millis(250);
let bridge = Arc::new(RuntimeBridge::new().expect("spawn bridge"));
let barrier = Arc::new(Barrier::new(N));
let started = Instant::now();
let mut handles = Vec::with_capacity(N);
for _ in 0..N {
let bridge = Arc::clone(&bridge);
let barrier = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
barrier.wait();
bridge
.block_on(async move {
tokio::time::sleep(DELAY).await;
})
.expect("ok")
}));
}
for h in handles {
h.join().expect("worker thread joined");
}
let elapsed = started.elapsed();
let serial_floor = DELAY * (N as u32);
assert!(
elapsed < serial_floor / 2,
"concurrent dispatch regressed to serial behaviour: \
elapsed {elapsed:?} >= serial_floor/2 ({:?}); \
N={N}, per-call delay={DELAY:?}",
serial_floor / 2,
);
}
#[test]
fn panicking_task_returns_response_lost_and_bridge_stays_alive() {
let bridge = RuntimeBridge::new().expect("spawn bridge");
let result: Result<(), _> = bridge.block_on(async {
panic!("simulated task panic");
});
assert!(
matches!(result, Err(BridgeError::ResponseLost)),
"panicking task must return ResponseLost; got {result:?}",
);
let next: u64 = bridge.block_on(async { 7 }).expect("ok");
assert_eq!(next, 7, "bridge must keep serving after a task panic");
}
#[test]
fn worker_runtime_runs_inner_spawns() {
let bridge = RuntimeBridge::new().expect("spawn bridge");
let counter = Arc::new(AtomicUsize::new(0));
let counter_for_task = Arc::clone(&counter);
bridge
.block_on(async move {
let handles: Vec<_> = (0..8)
.map(|_| {
let c = Arc::clone(&counter_for_task);
tokio::spawn(async move {
c.fetch_add(1, Ordering::SeqCst);
})
})
.collect();
for h in handles {
h.await.expect("inner spawn joined");
}
})
.expect("ok");
assert_eq!(counter.load(Ordering::SeqCst), 8);
}
}