use crate::{
Encoding, Status,
server::{dispatch::Cancellation, streaming::Channel},
};
use bytes::Bytes;
use std::{future::Future, pin::Pin, time::Instant};
use sync_wrapper::SyncWrapper;
use trillium::{Headers, Upgrade};
pub trait BidiResponder<Req, Resp>: Send + 'static {
fn respond(
self,
channel: Channel<'_, Req, Resp>,
) -> impl Future<Output = Result<(), Status>> + Send;
}
trait ErasedResponder<Req, Resp>: Send {
fn respond_boxed<'a>(
self: Box<Self>,
channel: Channel<'a, Req, Resp>,
) -> Pin<Box<dyn Future<Output = Result<(), Status>> + Send + 'a>>
where
Req: 'a,
Resp: 'a;
}
impl<Req, Resp, R> ErasedResponder<Req, Resp> for R
where
R: BidiResponder<Req, Resp>,
{
fn respond_boxed<'a>(
self: Box<Self>,
channel: Channel<'a, Req, Resp>,
) -> Pin<Box<dyn Future<Output = Result<(), Status>> + Send + 'a>>
where
Req: 'a,
Resp: 'a,
{
Box::pin(async move { (*self).respond(channel).await })
}
}
trait BidiDriver: Send {
fn drive(self: Box<Self>, upgrade: Upgrade) -> Pin<Box<dyn Future<Output = ()> + Send>>;
}
struct BidiState<Req, Resp> {
responder: Box<dyn ErasedResponder<Req, Resp>>,
base_trailers: Headers,
decode: fn(&[u8]) -> Result<Req, Status>,
encode: fn(&Resp) -> Result<Bytes, Status>,
request_encoding: Encoding,
response_encoding: Encoding,
deadline: Option<Instant>,
}
impl<Req, Resp> BidiDriver for BidiState<Req, Resp>
where
Req: Send + 'static,
Resp: Send + 'static,
{
fn drive(self: Box<Self>, mut upgrade: Upgrade) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async move {
let BidiState {
responder,
mut base_trailers,
decode,
encode,
request_encoding,
response_encoding,
deadline,
} = *self;
let cancellation = Cancellation::for_upgrade(&upgrade, deadline);
let result = cancellation
.race(async {
let channel = Channel::new(
&mut upgrade,
&mut base_trailers,
decode,
encode,
request_encoding,
response_encoding,
);
responder.respond_boxed(channel).await
})
.await;
match result {
Ok(()) => Status::ok().write_into(&mut base_trailers),
Err(status) => status.write_into(&mut base_trailers),
}
if let Err(e) = upgrade.send_trailers(base_trailers).await {
log::warn!("trillium-grpc: send_trailers failed: {e}");
}
})
}
}
pub(crate) struct BidiUpgrade(SyncWrapper<Box<dyn BidiDriver>>);
impl BidiUpgrade {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new<Req, Resp, R>(
responder: R,
base_trailers: Headers,
decode: fn(&[u8]) -> Result<Req, Status>,
encode: fn(&Resp) -> Result<Bytes, Status>,
request_encoding: Encoding,
response_encoding: Encoding,
deadline: Option<Instant>,
) -> Self
where
R: BidiResponder<Req, Resp>,
Req: Send + 'static,
Resp: Send + 'static,
{
let state = BidiState {
responder: Box::new(responder),
base_trailers,
decode,
encode,
request_encoding,
response_encoding,
deadline,
};
Self(SyncWrapper::new(Box::new(state)))
}
}
pub fn has_bidi_upgrade(upgrade: &Upgrade) -> bool {
upgrade.state().get::<BidiUpgrade>().is_some()
}
pub async fn drive_bidi_upgrade(mut upgrade: Upgrade) {
let Some(state) = upgrade.state_mut().take::<BidiUpgrade>() else {
return;
};
state.0.into_inner().drive(upgrade).await;
}