use crate::{
Code, Encoding, Status,
encoding::DEFAULT_MAX_MESSAGE_SIZE,
frame::{
reader::{ReadState, poll_read_message},
writer::encode_payload,
},
metadata::Metadata,
timeout::format_grpc_timeout,
};
use bytes::Bytes;
use futures_lite::{AsyncWriteExt, future::poll_fn};
use std::{
future::Future,
marker::PhantomData,
pin::Pin,
task::Poll,
time::{Duration, Instant},
};
use trillium::{Headers, KnownHeaderName, Status as HttpStatus, Transport};
use trillium_client::{Body, Client, Conn, ConnExt, Version};
use trillium_http::Upgrade as HttpUpgrade;
use trillium_server_common::Runtime;
type Upgrade = HttpUpgrade<Box<dyn Transport>>;
struct Pending {
client: Client,
path: String,
content_type: String,
request_metadata: Metadata,
body: Vec<u8>,
send_closed: bool,
}
struct Live<R> {
reader: R,
response_headers: Headers,
read_state: ReadState,
response_encoding: Encoding,
head_status: Option<Result<(), Status>>,
}
enum Inner {
Pending(Pending),
Reading(Box<Live<Conn>>),
Duplex(Box<Live<Upgrade>>),
Done {
response_headers: Headers,
trailers: Headers,
},
Failed(Status),
}
#[derive(Clone)]
pub struct CancelHandle(async_channel::Sender<()>);
impl CancelHandle {
pub fn cancel(&self) {
self.0.close();
}
}
pub struct GrpcClientConn<Req, Resp> {
inner: Inner,
decode: fn(&[u8]) -> Result<Resp, Status>,
encode: fn(&Req) -> Result<Bytes, Status>,
outbound_encoding: Encoding,
max_message_size: usize,
full_duplex: bool,
deadline: Option<Instant>,
runtime: Runtime,
cancel_rx: async_channel::Receiver<()>,
cancel_tx: async_channel::Sender<()>,
init_error: Option<Status>,
_marker: PhantomData<fn() -> (Req, Resp)>,
}
impl<Req, Resp> GrpcClientConn<Req, Resp>
where
Req: Send + 'static,
Resp: Send + 'static,
{
#[allow(clippy::too_many_arguments)]
pub(crate) fn open(
client: &Client,
path: &str,
content_type: String,
request_metadata: Metadata,
timeout: Option<Duration>,
encode: fn(&Req) -> Result<Bytes, Status>,
decode: fn(&[u8]) -> Result<Resp, Status>,
outbound_encoding: Encoding,
full_duplex: bool,
) -> Self {
let (cancel_tx, cancel_rx) = async_channel::bounded(1);
Self {
inner: Inner::Pending(Pending {
client: client.clone(),
path: path.to_string(),
content_type,
request_metadata,
body: Vec::new(),
send_closed: false,
}),
decode,
encode,
outbound_encoding,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
full_duplex,
deadline: timeout.map(|d| Instant::now() + d),
runtime: client.connector().runtime(),
cancel_rx,
cancel_tx,
init_error: None,
_marker: PhantomData,
}
}
pub(crate) fn add_ascii_metadata(&mut self, key: &str, value: &str) {
let result = match &mut self.inner {
Inner::Pending(pending) => pending.request_metadata.insert_ascii(key, value),
_ => Ok(()),
};
if let Err(e) = result {
self.init_error
.get_or_insert_with(|| Status::invalid_argument(format!("invalid metadata: {e}")));
}
}
pub(crate) fn add_binary_metadata(&mut self, key: &str, value: Vec<u8>) {
let result = match &mut self.inner {
Inner::Pending(pending) => pending.request_metadata.insert_binary(key, value),
_ => Ok(()),
};
if let Err(e) = result {
self.init_error
.get_or_insert_with(|| Status::invalid_argument(format!("invalid metadata: {e}")));
}
}
pub(crate) fn set_deadline_from_now(&mut self, timeout: Duration) {
self.deadline = Some(Instant::now() + timeout);
}
pub(crate) fn buffer_request(&mut self, message: Req) {
match (self.encode)(&message)
.and_then(|payload| encode_payload(&payload, self.outbound_encoding))
{
Ok(frame) => {
if let Inner::Pending(pending) = &mut self.inner {
pending.body.extend_from_slice(&frame);
}
}
Err(status) => {
self.init_error.get_or_insert(status);
}
}
}
pub(crate) async fn open_head(&mut self) -> Result<(), Status> {
if matches!(self.inner, Inner::Pending(_)) {
if self.full_duplex {
self.materialize_duplex().await
} else {
self.materialize_reading().await
}
} else {
Ok(())
}
}
pub fn new<C>(
client: &Client,
path: &str,
metadata: Metadata,
timeout: Option<Duration>,
full_duplex: bool,
) -> Self
where
C: crate::Codec<Req> + crate::Codec<Resp>,
{
let content_type = format!(
"application/grpc+{}",
<C as crate::Codec<Req>>::content_type_suffix()
);
let outbound_encoding = client
.default_headers()
.get_str("grpc-encoding")
.and_then(Encoding::from_grpc_encoding)
.unwrap_or(Encoding::Identity);
Self::open(
client,
path,
content_type,
metadata,
timeout,
<C as crate::Codec<Req>>::encode,
decode_response::<C, Resp>,
outbound_encoding,
full_duplex,
)
}
pub fn cancel_handle(&self) -> CancelHandle {
CancelHandle(self.cancel_tx.clone())
}
pub fn headers(&self) -> Option<&Headers> {
match &self.inner {
Inner::Reading(live) => Some(&live.response_headers),
Inner::Duplex(live) => Some(&live.response_headers),
Inner::Done {
response_headers, ..
} => Some(response_headers),
Inner::Pending(_) | Inner::Failed(_) => None,
}
}
pub fn trailers(&self) -> Option<&Headers> {
match &self.inner {
Inner::Done { trailers, .. } => Some(trailers),
_ => None,
}
}
pub async fn send(&mut self, message: Req) -> Result<(), Status> {
let frame = (self.encode)(&message)
.and_then(|payload| encode_payload(&payload, self.outbound_encoding))?;
match &mut self.inner {
Inner::Pending(pending) => {
if pending.send_closed {
return Err(Status::internal("send after close_send"));
}
pending.body.extend_from_slice(&frame);
Ok(())
}
Inner::Duplex(live) => {
let (deadline, runtime, cancel_rx) =
(self.deadline, self.runtime.clone(), self.cancel_rx.clone());
let write = async {
live.reader
.write_all(&frame)
.await
.map_err(|e| Status::unavailable(format!("write error: {e}")))
};
race(deadline, &runtime, &cancel_rx, write).await
}
_ => Err(Status::internal("send after response started")),
}
}
pub async fn close_send(&mut self) -> Result<(), Status> {
match &mut self.inner {
Inner::Pending(pending) => {
pending.send_closed = true;
if self.full_duplex {
self.materialize_duplex().await?;
if let Inner::Duplex(live) = &mut self.inner {
live.reader
.close()
.await
.map_err(|e| Status::unavailable(format!("close error: {e}")))?;
}
Ok(())
} else {
self.materialize_reading().await
}
}
Inner::Duplex(live) => live
.reader
.close()
.await
.map_err(|e| Status::unavailable(format!("close error: {e}"))),
_ => Ok(()),
}
}
pub async fn recv(&mut self) -> Result<Option<Resp>, Status> {
if matches!(self.inner, Inner::Pending(_)) {
if self.full_duplex {
self.materialize_duplex().await?;
} else {
self.materialize_reading().await?;
}
}
match &self.inner {
Inner::Reading(_) | Inner::Duplex(_) => self.read_one().await,
Inner::Done { .. } => Ok(None),
Inner::Failed(_) => {
let Inner::Failed(status) = std::mem::replace(
&mut self.inner,
Inner::Done {
response_headers: Headers::new(),
trailers: Headers::new(),
},
) else {
unreachable!()
};
Err(status)
}
Inner::Pending(_) => unreachable!("materialized above"),
}
}
async fn read_one(&mut self) -> Result<Option<Resp>, Status> {
let decode = self.decode;
let max = self.max_message_size;
let (deadline, runtime, cancel_rx) =
(self.deadline, self.runtime.clone(), self.cancel_rx.clone());
let read = poll_fn(|cx| match &mut self.inner {
Inner::Reading(live) => {
let enc = live.response_encoding;
let mut body = live.reader.response_body();
poll_read_message(
Pin::new(&mut body),
&mut live.read_state,
cx,
decode,
enc,
max,
)
}
Inner::Duplex(live) => poll_read_message(
Pin::new(&mut live.reader),
&mut live.read_state,
cx,
decode,
live.response_encoding,
max,
),
_ => Poll::Ready(None),
});
match race(deadline, &runtime, &cancel_rx, async { Ok(read.await) }).await {
Ok(Some(Ok(msg))) => {
if self.is_trailers_only() {
let _ = self.finish_from_trailers();
Err(Status::internal(
"trailers-only response (grpc-status in headers) carried a message body",
))
} else {
Ok(Some(msg))
}
}
Ok(Some(Err(status))) => {
let _ = self.finish_from_trailers();
Err(status)
}
Ok(None) => self.finish_from_trailers().map(|()| None),
Err(status) => Err(status), }
}
fn is_trailers_only(&self) -> bool {
match &self.inner {
Inner::Reading(live) => live.head_status.is_some(),
Inner::Duplex(live) => live.head_status.is_some(),
_ => false,
}
}
fn finish_from_trailers(&mut self) -> Result<(), Status> {
let (response_headers, head_status, mut trailers) = match &mut self.inner {
Inner::Reading(live) => (
live.response_headers.clone(),
live.head_status.clone(),
live.reader.response_trailers().cloned().unwrap_or_default(),
),
Inner::Duplex(live) => (
live.response_headers.clone(),
live.head_status.clone(),
live.reader.received_trailers().cloned().unwrap_or_default(),
),
_ => (Headers::new(), None, Headers::new()),
};
let status = if trailers.get_str("grpc-status").is_some() {
Status::from_trailers(&trailers)
} else if let Some(head_status) = head_status {
trailers = response_headers.clone();
head_status
} else {
Status::from_trailers(&trailers)
};
self.inner = Inner::Done {
response_headers,
trailers,
};
status
}
async fn materialize_reading(&mut self) -> Result<(), Status> {
if let Some(status) = self.init_error.take() {
return self.fail(status);
}
let body = self.take_body();
let request = self.build_request(Body::from(body));
let (deadline, runtime, cancel_rx) =
(self.deadline, self.runtime.clone(), self.cancel_rx.clone());
let conn = match race(deadline, &runtime, &cancel_rx, async {
request.await.map_err(transport_error)
})
.await
{
Ok(conn) => conn,
Err(status) => return self.fail(status),
};
match process_head(&conn) {
Ok(Head {
response_headers,
response_encoding,
head_status,
}) => {
self.inner = Inner::Reading(Box::new(Live {
reader: conn,
response_headers,
read_state: ReadState::new(),
response_encoding,
head_status,
}));
Ok(())
}
Err(status) => self.fail(status),
}
}
fn fail(&mut self, status: Status) -> Result<(), Status> {
self.inner = Inner::Failed(status.clone());
Err(status)
}
async fn materialize_duplex(&mut self) -> Result<(), Status> {
if let Some(status) = self.init_error.take() {
return self.fail(status);
}
let body = self.take_body();
let request = self.build_request(Body::from(body));
let (deadline, runtime, cancel_rx) =
(self.deadline, self.runtime.clone(), self.cancel_rx.clone());
let conn = match race(deadline, &runtime, &cancel_rx, async {
request.upgrade().await.map_err(transport_error)
})
.await
{
Ok(conn) => conn,
Err(status) => return self.fail(status),
};
match process_head(&conn) {
Ok(Head {
response_headers,
response_encoding,
head_status,
}) => {
self.inner = Inner::Duplex(Box::new(Live {
reader: conn.into(),
response_headers,
read_state: ReadState::new(),
response_encoding,
head_status,
}));
Ok(())
}
Err(status) => self.fail(status),
}
}
fn take_body(&mut self) -> Vec<u8> {
match &mut self.inner {
Inner::Pending(pending) => std::mem::take(&mut pending.body),
_ => Vec::new(),
}
}
fn build_request(&self, body: Body) -> Conn {
let Inner::Pending(pending) = &self.inner else {
unreachable!("build_request requires Pending");
};
let mut conn = pending
.client
.post(pending.path.as_str())
.with_http_version(Version::Http2)
.with_request_header(KnownHeaderName::ContentType, pending.content_type.clone())
.with_request_header(KnownHeaderName::Te, "trailers")
.with_request_header("grpc-accept-encoding", Encoding::accepted_encodings());
if !matches!(self.outbound_encoding, Encoding::Identity) {
conn.request_headers_mut()
.insert("grpc-encoding", self.outbound_encoding.as_grpc_encoding());
}
if let Some(deadline) = self.deadline {
let remaining = deadline.saturating_duration_since(Instant::now());
conn.request_headers_mut()
.insert("grpc-timeout", format_grpc_timeout(remaining));
}
pending
.request_metadata
.write_into(conn.request_headers_mut());
conn.with_body(body)
}
}
async fn race<T, F>(
deadline: Option<Instant>,
runtime: &Runtime,
cancel_rx: &async_channel::Receiver<()>,
fut: F,
) -> Result<T, Status>
where
F: Future<Output = Result<T, Status>>,
{
let cancel = {
let rx = cancel_rx.clone();
async move {
let _ = rx.recv().await;
Err(Status::cancelled("call cancelled"))
}
};
match deadline {
None => futures_lite::future::or(cancel, fut).await,
Some(deadline) => {
let Some(remaining) = deadline.checked_duration_since(Instant::now()) else {
return Err(Status::deadline_exceeded("deadline elapsed"));
};
let runtime = runtime.clone();
let timer = async move {
runtime.delay(remaining).await;
Err(Status::deadline_exceeded("deadline elapsed"))
};
futures_lite::future::or(futures_lite::future::or(cancel, timer), fut).await
}
}
}
struct Head {
response_headers: Headers,
response_encoding: Encoding,
head_status: Option<Result<(), Status>>,
}
fn process_head(conn: &Conn) -> Result<Head, Status> {
let http_status = conn.status();
if http_status != Some(HttpStatus::Ok) {
return Err(http_to_grpc_status(
http_status.map(|s| s as u16).unwrap_or(0),
));
}
let ct = conn
.response_headers()
.get_str(KnownHeaderName::ContentType);
if ct
.and_then(crate::server::content_type::parse_grpc_content_type)
.is_none()
{
return Err(Status::unknown(format!(
"unexpected response content-type: {ct:?}"
)));
}
let response_encoding = match conn.response_headers().get_str("grpc-encoding") {
None => Encoding::Identity,
Some(s) => Encoding::from_grpc_encoding(s)
.ok_or_else(|| Status::internal(format!("unsupported grpc-encoding {s:?}")))?,
};
let response_headers = conn.response_headers().clone();
let head_status = response_headers
.get_str("grpc-status")
.is_some()
.then(|| Status::from_trailers(&response_headers));
Ok(Head {
response_headers,
response_encoding,
head_status,
})
}
fn transport_error(err: trillium_client::Error) -> Status {
Status::unavailable(format!("transport error: {err}"))
}
fn decode_response<C, Resp>(bytes: &[u8]) -> Result<Resp, Status>
where
C: crate::Codec<Resp>,
{
<C as crate::Codec<Resp>>::decode(bytes)
.map_err(|status| Status::new(Code::Internal, status.message))
}
fn http_to_grpc_status(http: u16) -> Status {
let code = match http {
400 => Code::Internal,
401 => Code::Unauthenticated,
403 => Code::PermissionDenied,
404 => Code::Unimplemented,
429 | 502 | 503 | 504 => Code::Unavailable,
_ => Code::Unknown,
};
Status::new(code, format!("HTTP {http}"))
}