use rama_core::error::BoxErrorExt as _;
use std::any::TypeId;
use std::fmt;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use parking_lot::Mutex;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::oneshot;
use rama_core::bytes::Bytes;
use rama_core::error::BoxError;
use rama_core::extensions::Extension;
use rama_core::extensions::Extensions;
use rama_core::extensions::ExtensionsRef;
use rama_core::io::Io;
use rama_core::io::rewind::Rewind;
use rama_core::telemetry::tracing::trace;
use rama_net::extensions::StreamTransformed;
use rama_utils::macros::generate_set_and_with;
pub struct Upgraded {
io: Rewind<Box<dyn UpgradeIo>>,
extensions: Extensions,
}
#[derive(Clone, Extension)]
#[extension(tags(http))]
pub struct OnUpgrade {
rx: Arc<Mutex<oneshot::Receiver<Result<Upgraded, BoxError>>>>,
}
#[derive(Debug)]
#[non_exhaustive]
pub struct Parts<T> {
pub io: T,
pub read_buf: Bytes,
pub extensions: Extensions,
}
pub fn handle_upgrade<T: ExtensionsRef>(
msg: T,
) -> impl Future<Output = Result<Upgraded, BoxError>> + 'static {
let msg_ext = msg.extensions().clone();
let on_upgrade = match msg_ext.get_ref::<OnUpgrade>().cloned() {
Some(on_upgrade) => {
trace!("upgrading this: {:?}", on_upgrade);
if on_upgrade.has_handled_upgrade() {
Err(BoxError::from_static_str(
"upgraded has already been handled",
))
} else {
Ok(on_upgrade)
}
}
None => Err(BoxError::from_static_str("no pending update found")),
};
async move {
let upgraded = match on_upgrade {
Ok(on_upgrade) => on_upgrade.await?,
Err(err) => return Err(err),
};
Ok(upgraded)
}
}
pub struct Pending {
tx: oneshot::Sender<Result<Upgraded, BoxError>>,
}
#[must_use]
pub fn pending() -> (Pending, OnUpgrade) {
let (tx, rx) = oneshot::channel();
(
Pending { tx },
OnUpgrade {
rx: Arc::new(Mutex::new(rx)),
},
)
}
impl Upgraded {
pub fn new<T>(io: T, read_buf: Bytes) -> Self
where
T: Io + Unpin + ExtensionsRef,
{
let extensions = io.extensions().clone();
extensions.insert(StreamTransformed {
by: "rama-http::Upgraded",
});
Self {
extensions,
io: Rewind::new_buffered(Box::new(io), read_buf),
}
}
generate_set_and_with! {
pub fn extensions(mut self, extensions: Extensions) -> Self {
self.extensions = extensions;
self
}
}
pub fn downcast<T: Io + Unpin>(self) -> Result<Parts<T>, Self> {
let (io, buf) = self.io.into_inner();
match io.__downcast() {
Ok(t) => Ok(Parts {
io: *t,
read_buf: buf,
extensions: self.extensions,
}),
Err(io) => Err(Self {
io: Rewind::new_buffered(io, buf),
extensions: self.extensions,
}),
}
}
}
trait UpgradeIo: Io + Unpin {
fn __type_id(&self) -> TypeId {
TypeId::of::<Self>()
}
}
impl<T: Io + Unpin> UpgradeIo for T {}
impl dyn UpgradeIo {
fn __is<T: UpgradeIo>(&self) -> bool {
let t = TypeId::of::<T>();
self.__type_id() == t
}
fn __downcast<T: UpgradeIo>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
if self.__is::<T>() {
unsafe {
let raw: *mut dyn UpgradeIo = Box::into_raw(self);
Ok(Box::from_raw(raw.cast()))
}
} else {
Err(self)
}
}
}
impl ExtensionsRef for Upgraded {
fn extensions(&self) -> &Extensions {
&self.extensions
}
}
#[warn(clippy::missing_trait_methods)]
impl AsyncRead for Upgraded {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_read(cx, buf)
}
}
#[warn(clippy::missing_trait_methods)]
impl AsyncWrite for Upgraded {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write(cx, buf)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_shutdown(cx)
}
fn is_write_vectored(&self) -> bool {
self.io.is_write_vectored()
}
}
impl fmt::Debug for Upgraded {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Upgraded").finish()
}
}
impl fmt::Debug for Pending {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Pending").finish()
}
}
impl OnUpgrade {
#[must_use]
pub fn has_handled_upgrade(&self) -> bool {
self.rx.lock().is_terminated()
}
}
impl Future for OnUpgrade {
type Output = Result<Upgraded, BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut *self.rx.lock())
.poll(cx)
.map(|res| match res {
Ok(Ok(upgraded)) => Ok(upgraded),
Ok(Err(err)) => Err(err),
Err(_oneshot_canceled) => Err(BoxError::from_static_str(
"OnUpgrade: cancelled while expecting upgrade",
)),
})
}
}
impl fmt::Debug for OnUpgrade {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OnUpgrade").field("rx", &self.rx).finish()
}
}
impl Pending {
pub fn fulfill(self, upgraded: Upgraded) {
trace!("pending upgrade fulfill");
_ = self.tx.send(Ok(upgraded));
}
pub fn manual(self) {
trace!("pending upgrade handled manually");
_ = self.tx.send(Err(BoxError::from_static_str(
"OnUpgrade: manual upgrade failed",
)));
}
}
#[cfg(test)]
mod tests {
use rama_core::ServiceInput;
use tokio_test::io::{Builder, Mock};
use super::*;
#[test]
fn upgraded_downcast() {
let io = Builder::default().build();
let io = ServiceInput::new(io);
let upgraded = Upgraded::new(io, Bytes::new());
let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
upgraded.downcast::<ServiceInput<Mock>>().unwrap();
}
#[test]
fn upgraded_carries_stream_transformed_marker() {
let io = ServiceInput::new(Builder::default().build());
let upgraded = Upgraded::new(io, Bytes::new());
let marker = upgraded
.extensions()
.get_ref::<StreamTransformed>()
.expect("Upgraded::new must insert the StreamTransformed marker");
assert_eq!(marker.by, "rama-http::Upgraded");
}
}