use crate::{
Code, Codec, Encoding, Status,
client::{ResponseStream, response_stream::race_against_deadline},
frame::{reader::MessageStream, writer::encode_frame},
server::content_type::parse_grpc_content_type,
timeout::parse_grpc_timeout,
};
use futures_lite::{AsyncWriteExt, Stream, StreamExt};
use std::time::Instant;
use trillium::{KnownHeaderName, Transport};
use trillium_client::{Conn, ConnExt, Version};
use trillium_http::{Status as HttpStatus, Upgrade as HttpUpgrade};
type Upgrade = HttpUpgrade<Box<dyn Transport>>;
#[allow(async_fn_in_trait)]
pub trait Client: Sized + 'static {
async fn unary_call<Req, Resp>(
client: &trillium_client::Client,
path: &str,
req: Req,
) -> Result<Resp, Status>
where
Self: Codec<Req> + Codec<Resp>,
Req: Send + 'static,
Resp: Send + 'static,
{
let deadline = deadline_from_client(client);
with_deadline(
client,
deadline,
unary_call_impl::<Self, Req, Resp>(client, path, req),
)
.await
}
async fn server_streaming_call<Req, Resp>(
client: &trillium_client::Client,
path: &str,
req: Req,
) -> Result<ResponseStream<Self, Resp>, Status>
where
Self: Codec<Req> + Codec<Resp>,
Req: Send + 'static,
Resp: Send + 'static,
{
let deadline = deadline_from_client(client);
with_deadline(
client,
deadline,
server_streaming_call_impl::<Self, Req, Resp>(client, path, req, deadline),
)
.await
}
async fn client_streaming_call<Req, Resp, S>(
client: &trillium_client::Client,
path: &str,
requests: S,
) -> Result<Resp, Status>
where
Self: Codec<Req> + Codec<Resp>,
Req: Send + 'static,
Resp: Send + 'static,
S: Stream<Item = Req> + Send + 'static,
{
let deadline = deadline_from_client(client);
with_deadline(
client,
deadline,
client_streaming_call_impl::<Self, Req, Resp, S>(client, path, requests),
)
.await
}
async fn bidi_call<Req, Resp, S>(
client: &trillium_client::Client,
path: &str,
requests: S,
) -> Result<ResponseStream<Self, Resp>, Status>
where
Self: Codec<Req> + Codec<Resp>,
Req: Send + 'static,
Resp: Send + 'static,
S: Stream<Item = Req> + Send + 'static,
{
let deadline = deadline_from_client(client);
with_deadline(
client,
deadline,
bidi_call_impl::<Self, Req, Resp, S>(client, path, requests, deadline),
)
.await
}
}
impl<T: Sized + 'static> Client for T {}
async fn unary_call_impl<C, Req, Resp>(
client: &trillium_client::Client,
path: &str,
req: Req,
) -> Result<Resp, Status>
where
C: Codec<Req> + Codec<Resp>,
Req: Send + 'static,
Resp: Send + 'static,
{
let outbound_encoding = outbound_encoding_from_client(client);
let frame = encode_frame::<C, Req>(&req, outbound_encoding)?;
let (mut upgrade, response_encoding) = open_upgrade::<C, Req>(client, path)
.await?
.into_streaming()?;
write_frame(&mut upgrade, &frame).await?;
close_outbound(&mut upgrade).await?;
let response = read_one_response::<C, Resp>(&mut upgrade, response_encoding).await?;
finish_with_trailers(&upgrade)?;
Ok(response)
}
async fn server_streaming_call_impl<C, Req, Resp>(
client: &trillium_client::Client,
path: &str,
req: Req,
deadline: Option<Instant>,
) -> Result<ResponseStream<C, Resp>, Status>
where
C: Codec<Req> + Codec<Resp>,
Req: Send + 'static,
Resp: Send + 'static,
{
let outbound_encoding = outbound_encoding_from_client(client);
let frame = encode_frame::<C, Req>(&req, outbound_encoding)?;
let opened = open_upgrade::<C, Req>(client, path).await?;
let response_stream = match opened {
OpenUpgrade::Streaming(mut upgrade, response_encoding) => {
write_frame(&mut upgrade, &frame).await?;
close_outbound(&mut upgrade).await?;
ResponseStream::spawn(client, upgrade, response_encoding, deadline)
}
OpenUpgrade::TrailersOnly(result) => ResponseStream::trailers_only(result),
};
Ok(response_stream)
}
async fn client_streaming_call_impl<C, Req, Resp, S>(
client: &trillium_client::Client,
path: &str,
requests: S,
) -> Result<Resp, Status>
where
C: Codec<Req> + Codec<Resp>,
Req: Send + 'static,
Resp: Send + 'static,
S: Stream<Item = Req> + Send + 'static,
{
let outbound_encoding = outbound_encoding_from_client(client);
let (mut upgrade, response_encoding) = open_upgrade::<C, Req>(client, path)
.await?
.into_streaming()?;
write_request_stream::<C, Req, S>(&mut upgrade, requests, outbound_encoding).await?;
close_outbound(&mut upgrade).await?;
let response = read_one_response::<C, Resp>(&mut upgrade, response_encoding).await?;
finish_with_trailers(&upgrade)?;
Ok(response)
}
async fn bidi_call_impl<C, Req, Resp, S>(
client: &trillium_client::Client,
path: &str,
requests: S,
deadline: Option<Instant>,
) -> Result<ResponseStream<C, Resp>, Status>
where
C: Codec<Req> + Codec<Resp>,
Req: Send + 'static,
Resp: Send + 'static,
S: Stream<Item = Req> + Send + 'static,
{
let outbound_encoding = outbound_encoding_from_client(client);
let opened = open_upgrade::<C, Req>(client, path).await?;
let response_stream = match opened {
OpenUpgrade::Streaming(mut upgrade, response_encoding) => {
write_request_stream::<C, Req, S>(&mut upgrade, requests, outbound_encoding).await?;
close_outbound(&mut upgrade).await?;
ResponseStream::spawn(client, upgrade, response_encoding, deadline)
}
OpenUpgrade::TrailersOnly(result) => ResponseStream::trailers_only(result),
};
Ok(response_stream)
}
enum OpenUpgrade {
Streaming(Upgrade, Encoding),
TrailersOnly(Result<(), Status>),
}
impl OpenUpgrade {
fn into_streaming(self) -> Result<(Upgrade, Encoding), Status> {
match self {
OpenUpgrade::Streaming(u, enc) => Ok((u, enc)),
OpenUpgrade::TrailersOnly(Ok(())) => {
Err(Status::internal("response missing message body"))
}
OpenUpgrade::TrailersOnly(Err(status)) => Err(status),
}
}
}
async fn open_upgrade<C, Req>(
client: &trillium_client::Client,
path: &str,
) -> Result<OpenUpgrade, Status>
where
C: Codec<Req>,
{
let conn = grpc_request(client, path, <C as Codec<Req>>::content_type_suffix())
.upgrade()
.await
.map_err(transport_error)?;
validate_response_headers(&conn)?;
let response_encoding = extract_response_encoding(&conn)?;
if conn.response_headers().get_str("grpc-status").is_some() {
return Ok(OpenUpgrade::TrailersOnly(Status::from_trailers(
conn.response_headers(),
)));
}
let upgrade: Upgrade = conn.into();
Ok(OpenUpgrade::Streaming(upgrade, response_encoding))
}
async fn write_frame(upgrade: &mut Upgrade, frame: &[u8]) -> Result<(), Status> {
upgrade
.write_all(frame)
.await
.map_err(|e| Status::unavailable(format!("write error: {e}")))
}
async fn close_outbound(upgrade: &mut Upgrade) -> Result<(), Status> {
upgrade
.close()
.await
.map_err(|e| Status::unavailable(format!("close error: {e}")))
}
async fn write_request_stream<C, Req, S>(
upgrade: &mut Upgrade,
requests: S,
outbound_encoding: Encoding,
) -> Result<(), Status>
where
C: Codec<Req>,
Req: Send + 'static,
S: Stream<Item = Req> + Send + 'static,
{
let mut requests = Box::pin(requests);
while let Some(req) = requests.next().await {
let frame = encode_frame::<C, Req>(&req, outbound_encoding)?;
write_frame(upgrade, &frame).await?;
}
Ok(())
}
async fn read_one_response<C, Resp>(
upgrade: &mut Upgrade,
encoding: Encoding,
) -> Result<Resp, Status>
where
C: Codec<Resp>,
Resp: Send + 'static,
{
let (first, second) = {
let mut messages = MessageStream::<Resp, _>::new(&mut *upgrade, <C as Codec<Resp>>::decode)
.with_encoding(encoding);
let first = messages.next().await;
let second = messages.next().await;
(first, second)
};
match (first, second) {
(Some(Ok(resp)), None) => Ok(resp),
(Some(Ok(_)), Some(_)) => Err(Status::internal(
"expected one response message, got multiple",
)),
(Some(Err(status)), _) | (_, Some(Err(status))) => Err(status),
(None, _) => Err(finish_with_trailers(upgrade)
.err()
.unwrap_or_else(|| Status::internal("response missing message body"))),
}
}
fn finish_with_trailers(upgrade: &Upgrade) -> Result<(), Status> {
match upgrade.received_trailers() {
Some(trailers) => Status::from_trailers(trailers),
None => Err(Status::internal("stream ended without grpc-status trailer")),
}
}
fn deadline_from_client(client: &trillium_client::Client) -> Option<Instant> {
let header = client.default_headers().get_str("grpc-timeout")?;
let duration = parse_grpc_timeout(header)?;
Some(Instant::now() + duration)
}
async fn with_deadline<T, F>(
client: &trillium_client::Client,
deadline: Option<Instant>,
work: F,
) -> Result<T, Status>
where
F: std::future::Future<Output = Result<T, Status>>,
{
match deadline {
None => work.await,
Some(deadline) => {
let runtime = client.connector().runtime();
race_against_deadline(&runtime, deadline, work).await
}
}
}
fn grpc_request(client: &trillium_client::Client, path: &str, suffix: &'static str) -> Conn {
let version = match client.default_headers().get_str("x-trillium-grpc-version") {
Some("h3") => Version::Http3,
_ => Version::Http2,
};
client
.post(path)
.with_http_version(version)
.with_request_header(
KnownHeaderName::ContentType,
format!("application/grpc+{suffix}"),
)
.with_request_header(KnownHeaderName::Te, "trailers")
.with_request_header("grpc-accept-encoding", Encoding::accepted_encodings())
}
fn extract_response_encoding(conn: &Conn) -> Result<Encoding, Status> {
match conn.response_headers().get_str("grpc-encoding") {
None => Ok(Encoding::Identity),
Some(s) => Encoding::from_grpc_encoding(s).ok_or_else(|| {
Status::internal(format!("server returned unsupported grpc-encoding {s:?}"))
}),
}
}
fn outbound_encoding_from_client(client: &trillium_client::Client) -> Encoding {
client
.default_headers()
.get_str("grpc-encoding")
.and_then(Encoding::from_grpc_encoding)
.unwrap_or(Encoding::Identity)
}
fn validate_response_headers(conn: &Conn) -> Result<(), Status> {
let http_status = conn.status();
if http_status != Some(HttpStatus::Ok) {
let n = http_status.map(|s| s as u16).unwrap_or(0);
return Err(http_to_grpc_status(n));
}
let ct = conn
.response_headers()
.get_str(KnownHeaderName::ContentType);
if ct.and_then(parse_grpc_content_type).is_none() {
return Err(Status::internal(format!(
"unexpected response content-type: {ct:?}"
)));
}
Ok(())
}
fn http_to_grpc_status(http: u16) -> Status {
let code = match http {
400 => Code::Internal,
401 => Code::Unauthenticated,
403 => Code::PermissionDenied,
404 => Code::Unimplemented,
429 | 502 | 503 | 504 => Code::Unavailable,
_ => Code::Unknown,
};
Status::new(code, format!("HTTP {http}"))
}
fn transport_error(err: trillium_client::Error) -> Status {
Status::unavailable(format!("transport error: {err}"))
}