#![cfg_attr(docsrs, doc(cfg(feature = "grpc")))]
pub mod health;
pub mod interceptor;
pub mod reflection;
pub mod web;
use std::convert::Infallible;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use std::time::Instant;
use bytes::Bytes;
use bytes::BytesMut;
pub const MAX_GRPC_MESSAGE_SIZE: usize = 4 * 1024 * 1024;
use futures_util::Stream;
use futures_util::StreamExt;
use http::HeaderMap;
use http::StatusCode;
use http_body::Frame;
use http_body_util::BodyExt;
use http_body_util::StreamBody;
use prost::Message;
use crate::body::TakoBody;
use crate::extractors::FromRequest;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum GrpcStatusCode {
Ok = 0,
Cancelled = 1,
Unknown = 2,
InvalidArgument = 3,
DeadlineExceeded = 4,
NotFound = 5,
AlreadyExists = 6,
PermissionDenied = 7,
ResourceExhausted = 8,
FailedPrecondition = 9,
Aborted = 10,
OutOfRange = 11,
Unimplemented = 12,
Internal = 13,
Unavailable = 14,
DataLoss = 15,
Unauthenticated = 16,
}
pub struct GrpcRequest<T: Message + Default> {
pub message: T,
}
#[derive(Debug)]
pub enum GrpcError {
InvalidContentType,
BodyReadError(String),
InvalidFrame,
MessageTooLarge,
DecodeError(String),
CompressionUnsupported,
}
impl Responder for GrpcError {
fn into_response(self) -> Response {
let (status_code, message) = match self {
GrpcError::InvalidContentType => (
GrpcStatusCode::Unimplemented,
"invalid content-type; expected application/grpc",
),
GrpcError::BodyReadError(_) => (GrpcStatusCode::Internal, "failed to read request body"),
GrpcError::InvalidFrame => (GrpcStatusCode::InvalidArgument, "malformed gRPC frame"),
GrpcError::MessageTooLarge => (
GrpcStatusCode::ResourceExhausted,
"grpc message exceeds MAX_GRPC_MESSAGE_SIZE",
),
GrpcError::DecodeError(_) => (
GrpcStatusCode::InvalidArgument,
"failed to decode protobuf message",
),
GrpcError::CompressionUnsupported => (
GrpcStatusCode::Unimplemented,
"frame is compressed but no codec is configured",
),
};
build_grpc_error_response(status_code, message)
}
}
impl<'a, T> FromRequest<'a> for GrpcRequest<T>
where
T: Message + Default + Send + 'static,
{
type Error = GrpcError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
let ct = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !ct.starts_with("application/grpc") {
return Err(GrpcError::InvalidContentType);
}
let body_bytes = req
.body_mut()
.collect()
.await
.map_err(|e| GrpcError::BodyReadError(e.to_string()))?
.to_bytes();
if body_bytes.len() < 5 {
return Err(GrpcError::InvalidFrame);
}
if body_bytes[0] != 0 {
return Err(GrpcError::CompressionUnsupported);
}
let msg_len =
u32::from_be_bytes([body_bytes[1], body_bytes[2], body_bytes[3], body_bytes[4]]) as usize;
if msg_len > MAX_GRPC_MESSAGE_SIZE {
return Err(GrpcError::MessageTooLarge);
}
if body_bytes.len() < 5 + msg_len {
return Err(GrpcError::InvalidFrame);
}
let message = T::decode(&body_bytes[5..5 + msg_len])
.map_err(|e| GrpcError::DecodeError(e.to_string()))?;
Ok(GrpcRequest { message })
}
}
}
pub struct GrpcResponse<T: Message> {
message: Option<T>,
status: GrpcStatusCode,
error_message: Option<String>,
}
impl<T: Message> GrpcResponse<T> {
pub fn ok(message: T) -> Self {
Self {
message: Some(message),
status: GrpcStatusCode::Ok,
error_message: None,
}
}
pub fn error(status: GrpcStatusCode, message: impl Into<String>) -> Self {
Self {
message: None,
status,
error_message: Some(message.into()),
}
}
}
impl<T: Message> Responder for GrpcResponse<T> {
fn into_response(self) -> Response {
if self.status != GrpcStatusCode::Ok {
return build_grpc_error_response(self.status, self.error_message.as_deref().unwrap_or(""));
}
let body_bytes = match self.message {
Some(msg) => grpc_encode(&msg),
None => Vec::new(),
};
let mut resp = Response::new(TakoBody::from(body_bytes));
*resp.status_mut() = StatusCode::OK;
resp.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/grpc"),
);
if let Ok(val) = http::HeaderValue::from_str(&(self.status as u8).to_string()) {
resp.headers_mut().insert("grpc-status", val);
}
resp
}
}
pub fn grpc_encode<T: Message>(msg: &T) -> Vec<u8> {
let msg_bytes = msg.encode_to_vec();
assert!(
u32::try_from(msg_bytes.len()).is_ok(),
"grpc_encode: message of {} bytes exceeds u32::MAX (4 GiB) — gRPC length-prefix would wrap",
msg_bytes.len()
);
let len = msg_bytes.len() as u32;
let mut frame = Vec::with_capacity(5 + msg_bytes.len());
frame.push(0); frame.extend_from_slice(&len.to_be_bytes());
frame.extend_from_slice(&msg_bytes);
frame
}
pub fn grpc_decode<T: Message + Default>(data: &[u8]) -> Result<(T, bool), GrpcError> {
if data.len() < 5 {
return Err(GrpcError::InvalidFrame);
}
let compressed = data[0] != 0;
if compressed {
return Err(GrpcError::CompressionUnsupported);
}
let msg_len = u32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
if msg_len > MAX_GRPC_MESSAGE_SIZE {
return Err(GrpcError::MessageTooLarge);
}
if data.len() < 5 + msg_len {
return Err(GrpcError::InvalidFrame);
}
let msg = T::decode(&data[5..5 + msg_len]).map_err(|e| GrpcError::DecodeError(e.to_string()))?;
Ok((msg, compressed))
}
fn percent_encode_grpc_message(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for &b in s.as_bytes() {
if (0x20..=0x7E).contains(&b) && b != b'%' {
out.push(b as char);
} else {
out.push('%');
out.push(hex_upper(b >> 4));
out.push(hex_upper(b & 0x0F));
}
}
out
}
#[inline]
fn hex_upper(n: u8) -> char {
match n {
0..=9 => (b'0' + n) as char,
10..=15 => (b'A' + n - 10) as char,
_ => unreachable!("hex_upper called with value > 15"),
}
}
fn build_grpc_error_response(status: GrpcStatusCode, message: &str) -> Response {
let mut resp = Response::new(TakoBody::empty());
*resp.status_mut() = StatusCode::OK; resp.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/grpc"),
);
if let Ok(val) = http::HeaderValue::from_str(&(status as u8).to_string()) {
resp.headers_mut().insert("grpc-status", val);
}
if !message.is_empty()
&& let Ok(val) = http::HeaderValue::from_str(&percent_encode_grpc_message(message))
{
resp.headers_mut().insert("grpc-message", val);
}
resp
}
pub struct GrpcServerStream<S, T>
where
S: Stream<Item = Result<T, GrpcStatus>> + Send + 'static,
T: Message + Send + 'static,
{
pub stream: S,
pub initial_metadata: HeaderMap,
}
#[derive(Debug, Clone)]
pub struct GrpcStatus {
pub code: GrpcStatusCode,
pub message: Option<String>,
}
impl GrpcStatus {
pub fn ok() -> Self {
Self {
code: GrpcStatusCode::Ok,
message: None,
}
}
pub fn error(code: GrpcStatusCode, message: impl Into<String>) -> Self {
Self {
code,
message: Some(message.into()),
}
}
fn write_trailers(&self) -> HeaderMap {
let mut t = HeaderMap::new();
if let Ok(v) = http::HeaderValue::from_str(&(self.code as u8).to_string()) {
t.insert("grpc-status", v);
}
if let Some(msg) = self.message.as_deref()
&& let Ok(v) = http::HeaderValue::from_str(&percent_encode_grpc_message(msg))
{
t.insert("grpc-message", v);
}
t
}
}
impl<S, T> GrpcServerStream<S, T>
where
S: Stream<Item = Result<T, GrpcStatus>> + Send + 'static,
T: Message + Send + 'static,
{
pub fn new(stream: S) -> Self {
Self {
stream,
initial_metadata: HeaderMap::new(),
}
}
pub fn with_metadata(mut self, headers: HeaderMap) -> Self {
self.initial_metadata = headers;
self
}
}
impl<S, T> Responder for GrpcServerStream<S, T>
where
S: Stream<Item = Result<T, GrpcStatus>> + Send + 'static,
T: Message + Send + 'static,
{
fn into_response(self) -> Response {
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
let error_emitted = Arc::new(AtomicBool::new(false));
let mark_err = error_emitted.clone();
let stream = self.stream.map(move |item| match item {
Ok(msg) => {
let bytes = grpc_encode(&msg);
Ok::<_, Infallible>(Frame::data(Bytes::from(bytes)))
}
Err(status) => {
mark_err.store(true, Ordering::Release);
Ok(Frame::trailers(status.write_trailers()))
}
});
let check_err = error_emitted.clone();
let mut once = false;
let trailer = futures_util::stream::iter(std::iter::from_fn(move || {
if once {
None
} else {
once = true;
if check_err.load(Ordering::Acquire) {
None
} else {
Some(Ok::<_, Infallible>(Frame::trailers(
GrpcStatus::ok().write_trailers(),
)))
}
}
}));
let combined = stream.chain(trailer);
let mut resp = http::Response::builder()
.status(StatusCode::OK)
.header(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/grpc"),
)
.body(TakoBody::new(StreamBody::new(combined)))
.expect("static headers + body construction is infallible");
let headers = resp.headers_mut();
for (k, v) in &self.initial_metadata {
headers.insert(k.clone(), v.clone());
}
resp
}
}
pub struct GrpcClientStream<T: Message + Default + Send + 'static> {
pub stream: Pin<Box<dyn Stream<Item = Result<T, GrpcError>> + Send>>,
}
impl<'a, T> FromRequest<'a> for GrpcClientStream<T>
where
T: Message + Default + Send + 'static,
{
type Error = GrpcError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
let ct = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !ct.starts_with("application/grpc") {
return Err(GrpcError::InvalidContentType);
}
let body = std::mem::take(req.body_mut());
let stream = GrpcFrameStream::new(body);
Ok(GrpcClientStream {
stream: Box::pin(stream),
})
}
}
}
struct GrpcFrameStream<T> {
body: TakoBody,
buffer: BytesMut,
finished: bool,
_marker: std::marker::PhantomData<fn() -> T>,
}
impl<T> GrpcFrameStream<T> {
fn new(body: TakoBody) -> Self {
Self {
body,
buffer: BytesMut::new(),
finished: false,
_marker: std::marker::PhantomData,
}
}
}
impl<T> Stream for GrpcFrameStream<T>
where
T: Message + Default,
{
type Item = Result<T, GrpcError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if this.buffer.len() >= 5 {
let msg_len = u32::from_be_bytes([
this.buffer[1],
this.buffer[2],
this.buffer[3],
this.buffer[4],
]) as usize;
if msg_len > MAX_GRPC_MESSAGE_SIZE {
return Poll::Ready(Some(Err(GrpcError::MessageTooLarge)));
}
if this.buffer.len() >= 5 + msg_len {
if this.buffer[0] != 0 {
return Poll::Ready(Some(Err(GrpcError::CompressionUnsupported)));
}
let payload = this.buffer.split_to(5 + msg_len);
let msg_bytes = &payload[5..5 + msg_len];
return match T::decode(msg_bytes) {
Ok(m) => Poll::Ready(Some(Ok(m))),
Err(e) => Poll::Ready(Some(Err(GrpcError::DecodeError(e.to_string())))),
};
}
}
if this.finished {
return Poll::Ready(None);
}
let mut body = Pin::new(&mut this.body);
match http_body::Body::poll_frame(body.as_mut(), cx) {
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
this.buffer.extend_from_slice(data);
}
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(GrpcError::BodyReadError(e.to_string()))));
}
Poll::Ready(None) => {
this.finished = true;
}
Poll::Pending => return Poll::Pending,
}
}
}
}
pub struct GrpcBidi<Req, Resp>
where
Req: Message + Default + Send + 'static,
Resp: Message + Send + 'static,
{
pub inbound: GrpcClientStream<Req>,
pub _phantom: std::marker::PhantomData<Resp>,
}
impl<'a, Req, Resp> FromRequest<'a> for GrpcBidi<Req, Resp>
where
Req: Message + Default + Send + 'static,
Resp: Message + Send + 'static,
{
type Error = GrpcError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
Ok(GrpcBidi {
inbound: GrpcClientStream::<Req>::from_request(req).await?,
_phantom: std::marker::PhantomData,
})
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct GrpcDeadline(pub Instant);
pub fn parse_grpc_timeout(value: &str) -> Option<Duration> {
let value = value.trim();
if value.is_empty() {
return None;
}
let (num, unit) = value.split_at(value.len() - 1);
let num: u64 = num.parse().ok()?;
let dur = match unit {
"n" => Duration::from_nanos(num),
"u" => Duration::from_micros(num),
"m" => Duration::from_millis(num),
"S" => Duration::from_secs(num),
"M" => Duration::from_secs(num.checked_mul(60)?),
"H" => Duration::from_secs(num.checked_mul(3600)?),
_ => return None,
};
Some(dur)
}
pub fn read_grpc_deadline(req: &mut Request) -> Option<GrpcDeadline> {
let raw = req
.headers()
.get("grpc-timeout")
.and_then(|v| v.to_str().ok())?;
let dur = parse_grpc_timeout(raw)?;
let deadline = GrpcDeadline(Instant::now().checked_add(dur)?);
req.extensions_mut().insert(deadline);
Some(deadline)
}