use crate::{Codec, Encoding, Status, frame::reader::MessageStream};
use async_channel::{Receiver, Sender, bounded};
use futures_lite::{Stream, StreamExt};
use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
time::Instant,
};
use trillium::Transport;
use trillium_http::Upgrade as HttpUpgrade;
type Upgrade = HttpUpgrade<Box<dyn Transport>>;
use trillium_server_common::Runtime;
pub struct ResponseStream<C, T> {
rx: Pin<Box<Receiver<Result<T, Status>>>>,
_marker: PhantomData<fn() -> C>,
}
impl<C, T> ResponseStream<C, T>
where
C: Codec<T>,
T: Send + 'static,
{
pub(crate) fn spawn(
client: &trillium_client::Client,
upgrade: Upgrade,
encoding: Encoding,
deadline: Option<Instant>,
) -> Self {
let (tx, rx) = bounded(8);
let runtime = client.connector().runtime();
let _detach = runtime
.clone()
.spawn(read_loop::<C, T>(upgrade, tx, encoding, runtime, deadline));
Self {
rx: Box::pin(rx),
_marker: PhantomData,
}
}
pub(crate) fn trailers_only(result: Result<(), Status>) -> Self {
let (tx, rx) = bounded(1);
if let Err(status) = result {
let _ = tx.try_send(Err(status));
}
Self {
rx: Box::pin(rx),
_marker: PhantomData,
}
}
}
async fn read_loop<C, T>(
mut upgrade: Upgrade,
tx: Sender<Result<T, Status>>,
encoding: Encoding,
runtime: Runtime,
deadline: Option<Instant>,
) where
C: Codec<T>,
T: Send + 'static,
{
{
let mut messages = MessageStream::<T, _>::new(&mut upgrade, <C as Codec<T>>::decode)
.with_encoding(encoding);
loop {
let next = match next_with_deadline(&runtime, deadline, messages.next()).await {
Ok(opt) => opt,
Err(status) => {
let _ = tx.send(Err(status)).await;
return;
}
};
let Some(item) = next else { break };
let stop = item.is_err();
if tx.send(item).await.is_err() {
return;
}
if stop {
return;
}
}
}
if let Some(trailers) = upgrade.received_trailers()
&& let Err(status) = Status::from_trailers(trailers)
{
let _ = tx.send(Err(status)).await;
}
}
async fn next_with_deadline<F, T>(
runtime: &Runtime,
deadline: Option<Instant>,
fut: F,
) -> Result<Option<T>, Status>
where
F: std::future::Future<Output = Option<T>>,
{
let Some(deadline) = deadline else {
return Ok(fut.await);
};
let Some(remaining) = deadline.checked_duration_since(Instant::now()) else {
return Err(Status::deadline_exceeded("deadline elapsed"));
};
let runtime = runtime.clone();
let timer = async move {
runtime.delay(remaining).await;
Err(Status::deadline_exceeded("deadline elapsed"))
};
futures_lite::future::or(async move { Ok(fut.await) }, timer).await
}
pub(crate) async fn race_against_deadline<T, F>(
runtime: &Runtime,
deadline: Instant,
fut: F,
) -> Result<T, Status>
where
F: std::future::Future<Output = Result<T, Status>>,
{
let Some(remaining) = deadline.checked_duration_since(Instant::now()) else {
return Err(Status::deadline_exceeded("deadline elapsed"));
};
let runtime = runtime.clone();
let timer = async move {
runtime.delay(remaining).await;
Err(Status::deadline_exceeded("deadline elapsed"))
};
futures_lite::future::or(fut, timer).await
}
impl<C, T> Stream for ResponseStream<C, T>
where
C: Codec<T>,
T: Send + 'static,
{
type Item = Result<T, Status>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().rx.as_mut().poll_next(cx)
}
}