use crate::Error;
use futures_util::{pin_mut, task::ArcWake};
use std::future::Future;
use std::net::{SocketAddr, UdpSocket};
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::thread::{self, Thread};
pub(crate) trait Join: Future {
fn join(self) -> <Self as Future>::Output;
}
impl<F: Future> Join for F {
fn join(self) -> <Self as Future>::Output {
struct ThreadWaker(Thread);
impl ArcWake for ThreadWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.0.unpark();
}
}
let future = self;
pin_mut!(future);
let waker = Arc::new(ThreadWaker(thread::current())).into_waker();
let mut context = Context::from_waker(&waker);
loop {
match future.as_mut().poll(&mut context) {
Poll::Ready(output) => return output,
Poll::Pending => thread::park(),
}
}
}
}
fn waker_fn(f: impl Fn() + Send + Sync + 'static) -> Waker {
struct Impl<F>(F);
impl<F: Fn() + Send + Sync + 'static> ArcWake for Impl<F> {
fn wake_by_ref(arc_self: &Arc<Self>) {
(&arc_self.0)()
}
}
Arc::new(Impl(f)).into_waker()
}
pub(crate) trait WakerExt {
fn chain(&self, f: impl Fn(&Waker) + Send + Sync + 'static) -> Waker;
}
impl WakerExt for Waker {
fn chain(&self, f: impl Fn(&Waker) + Send + Sync + 'static) -> Waker {
let inner = self.clone();
waker_fn(move || (f)(&inner))
}
}
pub(crate) struct UdpWaker {
socket: UdpSocket,
}
impl UdpWaker {
pub(crate) fn connect(addr: SocketAddr) -> Result<Self, Error> {
let socket = UdpSocket::bind("127.0.0.1:0")?;
socket.connect(addr)?;
Ok(Self { socket })
}
}
impl ArcWake for UdpWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
if let Err(e) = arc_self.socket.send(&[1]) {
log::debug!("agent waker produced an error: {}", e);
}
}
}