use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::oneshot;
use crate::SessionError;
#[derive(Debug)]
enum CompletionFuture {
OneshotReceiver(oneshot::Receiver<Result<(), SessionError>>),
JoinHandle(tokio::task::JoinHandle<()>),
}
#[derive(Debug)]
pub struct CompletionHandle {
inner: CompletionFuture,
}
impl CompletionHandle {
pub fn from_oneshot_receiver(receiver: oneshot::Receiver<Result<(), SessionError>>) -> Self {
Self {
inner: CompletionFuture::OneshotReceiver(receiver),
}
}
pub fn from_join_handle(handle: tokio::task::JoinHandle<()>) -> Self {
Self {
inner: CompletionFuture::JoinHandle(handle),
}
}
}
impl Future for CompletionHandle {
type Output = Result<(), SessionError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match &mut this.inner {
CompletionFuture::OneshotReceiver(receiver) => match Pin::new(receiver).poll(cx) {
Poll::Ready(Ok(result)) => Poll::Ready(result),
Poll::Ready(Err(e)) => Poll::Ready(Err(SessionError::AckReception(e.to_string()))),
Poll::Pending => Poll::Pending,
},
CompletionFuture::JoinHandle(handle) => match Pin::new(handle).poll(cx) {
Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
Poll::Ready(Err(e)) => Poll::Ready(Err(SessionError::AckReception(format!(
"Join handle error: {}",
e
)))),
Poll::Pending => Poll::Pending,
},
}
}
}