use bytes::{BufMut, Bytes, BytesMut};
use http::{
header::{CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING},
HeaderMap,
};
use pingora_error::{ErrorType::ReadError, OrErr, Result};
use pingora_http::{RequestHeader, ResponseHeader};
#[derive(Default, PartialEq, Debug)]
pub enum GrpcWebCtx {
#[default]
Disabled,
Init,
Upgrade,
Trailers,
Done,
}
const GRPC: &str = "application/grpc";
const GRPC_WEB: &str = "application/grpc-web";
impl GrpcWebCtx {
pub fn init(&mut self) {
*self = Self::Init;
}
pub fn request_header_filter(&mut self, req: &mut RequestHeader) {
if *self != Self::Init {
return;
}
let content_type = req
.headers
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or_default();
if !(content_type.len() >= GRPC_WEB.len()
&& content_type[..GRPC_WEB.len()].eq_ignore_ascii_case(GRPC_WEB))
{
return;
}
let ct = content_type.to_lowercase().replace(GRPC_WEB, GRPC);
req.insert_header(CONTENT_TYPE, ct).expect("insert header");
req.insert_header("te", "trailers").expect("insert header");
req.set_send_end_stream(false);
*self = Self::Upgrade
}
pub fn response_header_filter(&mut self, resp: &mut ResponseHeader) {
if *self != Self::Upgrade {
return;
}
if resp.status.is_informational() {
return;
}
let content_type = resp
.headers
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or_default();
if !content_type.starts_with(GRPC) {
*self = Self::Disabled;
return;
}
let ct = content_type.replace(GRPC, GRPC_WEB);
resp.insert_header(CONTENT_TYPE, ct).expect("insert header");
resp.remove_header(&CONTENT_LENGTH);
resp.insert_header(TRANSFER_ENCODING, "chunked")
.expect("insert header");
*self = Self::Trailers
}
pub fn response_trailer_filter(
&mut self,
resp_trailers: &mut HeaderMap,
) -> Result<Option<Bytes>> {
const GRPC_WEB_TRAILER: u8 = 0x80;
const GRPC_TRAILER_HEADER_LEN: usize = 5;
const DEFAULT_TRAILER_BUFFER_SIZE: usize = 256;
if *self != Self::Trailers {
*self = Self::Disabled;
return Ok(None);
}
let mut buf = BytesMut::with_capacity(DEFAULT_TRAILER_BUFFER_SIZE);
let mut trailers = buf.split_off(GRPC_TRAILER_HEADER_LEN);
for (key, value) in resp_trailers.iter() {
trailers.put_slice(key.as_ref());
trailers.put_slice(b":");
trailers.put_slice(value.as_ref());
trailers.put_slice(b"\r\n");
}
let len = trailers.len().try_into().or_err_with(ReadError, || {
format!("invalid gRPC trailer length: {}", trailers.len())
})?;
buf.put_u8(GRPC_WEB_TRAILER);
buf.put_u32(len);
buf.unsplit(trailers);
*self = Self::Done;
Ok(Some(buf.freeze()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::{request::Request, response::Response, Version};
#[test]
fn non_grpc_web_request_ignored() {
let request = Request::get("https://pingora.dev/")
.header(CONTENT_TYPE, "application/grpc-we")
.version(Version::HTTP_2) .body(())
.unwrap();
let mut request = request.into_parts().0.into();
let mut filter = GrpcWebCtx::default();
filter.init();
filter.request_header_filter(&mut request);
assert_eq!(filter, GrpcWebCtx::Init);
let headers = &request.headers;
assert_eq!(headers.get("te"), None);
assert_eq!(headers.get("application/grpc"), None);
assert_eq!(request.send_end_stream(), Some(true));
}
#[test]
fn grpc_web_request_module_disabled_ignored() {
let request = Request::get("https://pingora.dev/")
.header(CONTENT_TYPE, "application/grpc-web")
.version(Version::HTTP_2) .body(())
.unwrap();
let mut request = request.into_parts().0.into();
let mut filter = GrpcWebCtx::default();
filter.request_header_filter(&mut request);
assert_eq!(filter, GrpcWebCtx::Disabled);
let headers = &request.headers;
assert_eq!(headers.get("te"), None);
assert_eq!(headers.get(CONTENT_TYPE).unwrap(), "application/grpc-web");
assert_eq!(request.send_end_stream(), Some(true));
}
#[test]
fn grpc_web_request_upgrade() {
let request = Request::get("https://pingora.org/")
.header(CONTENT_TYPE, "application/gRPC-web+thrift")
.version(Version::HTTP_2) .body(())
.unwrap();
let mut request = request.into_parts().0.into();
let mut filter = GrpcWebCtx::default();
filter.init();
filter.request_header_filter(&mut request);
assert_eq!(filter, GrpcWebCtx::Upgrade);
let headers = &request.headers;
assert_eq!(headers.get("te").unwrap(), "trailers");
assert_eq!(
headers.get(CONTENT_TYPE).unwrap(),
"application/grpc+thrift"
);
assert_eq!(request.send_end_stream(), Some(false));
}
#[test]
fn non_grpc_response_ignored() {
let response = Response::builder()
.header(CONTENT_TYPE, "text/html")
.header(CONTENT_LENGTH, "10")
.body(())
.unwrap();
let mut response = response.into_parts().0.into();
let mut filter = GrpcWebCtx::Upgrade;
filter.response_header_filter(&mut response);
assert_eq!(filter, GrpcWebCtx::Disabled);
let headers = &response.headers;
assert_eq!(headers.get(CONTENT_TYPE).unwrap(), "text/html");
assert_eq!(headers.get(CONTENT_LENGTH).unwrap(), "10");
}
#[test]
fn grpc_response_module_disabled_ignored() {
let response = Response::builder()
.header(CONTENT_TYPE, "application/grpc")
.body(())
.unwrap();
let mut response = response.into_parts().0.into();
let mut filter = GrpcWebCtx::default();
filter.response_header_filter(&mut response);
assert_eq!(filter, GrpcWebCtx::Disabled);
let headers = &response.headers;
assert_eq!(headers.get(CONTENT_TYPE).unwrap(), "application/grpc");
}
#[test]
fn grpc_response_upgrade() {
let response = Response::builder()
.header(CONTENT_TYPE, "application/grpc+proto")
.header(CONTENT_LENGTH, "0")
.body(())
.unwrap();
let mut response = response.into_parts().0.into();
let mut filter = GrpcWebCtx::Upgrade;
filter.response_header_filter(&mut response);
assert_eq!(filter, GrpcWebCtx::Trailers);
let headers = &response.headers;
assert_eq!(
headers.get(CONTENT_TYPE).unwrap(),
"application/grpc-web+proto"
);
assert_eq!(headers.get(TRANSFER_ENCODING).unwrap(), "chunked");
assert!(headers.get(CONTENT_LENGTH).is_none());
}
#[test]
fn grpc_response_informational_proxied() {
let response = Response::builder().status(100).body(()).unwrap();
let mut response = response.into_parts().0.into();
let mut filter = GrpcWebCtx::Upgrade;
filter.response_header_filter(&mut response);
assert_eq!(filter, GrpcWebCtx::Upgrade); }
#[test]
fn grpc_response_trailer_headers_convert_to_byte_buf() {
let mut response = Response::builder()
.header("grpc-status", "0")
.header("grpc-message", "OK")
.body(())
.unwrap();
let response = response.headers_mut();
let mut filter = GrpcWebCtx::Trailers;
let buf = filter.response_trailer_filter(response).unwrap().unwrap();
assert_eq!(filter, GrpcWebCtx::Done);
let expected = b"grpc-status:0\r\ngrpc-message:OK\r\n";
let expected_len: u32 = expected.len() as u32;
assert_eq!(0x80, buf[0]); assert_eq!(expected_len.to_be_bytes(), buf[1..5]); assert_eq!(expected[..15], buf[5..20]); assert_eq!(expected[15..], buf[20..]); }
}