use http::HeaderMap;
use http::header::CONTENT_TYPE;
use crate::codec::CodecFormat;
pub(crate) mod hdr {
use http::HeaderName;
pub static CONNECT_CONTENT_ENCODING: HeaderName =
HeaderName::from_static("connect-content-encoding");
pub static CONNECT_ACCEPT_ENCODING: HeaderName =
HeaderName::from_static("connect-accept-encoding");
pub static GRPC_ENCODING: HeaderName = HeaderName::from_static("grpc-encoding");
pub static GRPC_ACCEPT_ENCODING: HeaderName = HeaderName::from_static("grpc-accept-encoding");
pub static GRPC_STATUS: HeaderName = HeaderName::from_static("grpc-status");
pub static GRPC_MESSAGE: HeaderName = HeaderName::from_static("grpc-message");
pub static GRPC_STATUS_DETAILS_BIN: HeaderName =
HeaderName::from_static("grpc-status-details-bin");
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Protocol {
Connect,
Grpc,
GrpcWeb,
}
#[derive(Debug, Clone, Copy)]
pub struct RequestProtocol {
pub protocol: Protocol,
pub codec_format: CodecFormat,
pub is_streaming: bool,
pub is_text_mode: bool,
}
impl Protocol {
pub fn detect(headers: &HeaderMap) -> Option<RequestProtocol> {
let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
Self::detect_from_content_type(content_type)
}
pub fn detect_from_content_type(content_type: &str) -> Option<RequestProtocol> {
let content_type = content_type
.split(';')
.next()
.unwrap_or(content_type)
.trim();
if let Some(rest) = content_type.strip_prefix("application/grpc-web-text") {
let codec_format = match rest {
"" | "+proto" => CodecFormat::Proto,
_ => return None,
};
return Some(RequestProtocol {
protocol: Protocol::GrpcWeb,
codec_format,
is_streaming: true,
is_text_mode: true,
});
}
if let Some(rest) = content_type.strip_prefix("application/grpc-web") {
let codec_format = Self::grpc_subtype_to_codec(rest)?;
return Some(RequestProtocol {
protocol: Protocol::GrpcWeb,
codec_format,
is_streaming: true,
is_text_mode: false,
});
}
if let Some(rest) = content_type.strip_prefix("application/grpc") {
let codec_format = Self::grpc_subtype_to_codec(rest)?;
return Some(RequestProtocol {
protocol: Protocol::Grpc,
codec_format,
is_streaming: true,
is_text_mode: false,
});
}
if let Some(rest) = content_type.strip_prefix("application/connect+") {
let codec_format = match rest {
"proto" => CodecFormat::Proto,
"json" => CodecFormat::Json,
_ => return None,
};
return Some(RequestProtocol {
protocol: Protocol::Connect,
codec_format,
is_streaming: true,
is_text_mode: false,
});
}
match content_type {
"application/proto" => Some(RequestProtocol {
protocol: Protocol::Connect,
codec_format: CodecFormat::Proto,
is_streaming: false,
is_text_mode: false,
}),
"application/json" => Some(RequestProtocol {
protocol: Protocol::Connect,
codec_format: CodecFormat::Json,
is_streaming: false,
is_text_mode: false,
}),
_ => None,
}
}
fn grpc_subtype_to_codec(suffix: &str) -> Option<CodecFormat> {
match suffix {
"" => Some(CodecFormat::Proto),
"+proto" => Some(CodecFormat::Proto),
"+json" => Some(CodecFormat::Json),
_ => None,
}
}
#[inline]
pub fn response_content_type(&self, format: CodecFormat, is_streaming: bool) -> &'static str {
match (self, format, is_streaming) {
(Protocol::Connect, CodecFormat::Proto, false) => "application/proto",
(Protocol::Connect, CodecFormat::Json, false) => "application/json",
(Protocol::Connect, CodecFormat::Proto, true) => "application/connect+proto",
(Protocol::Connect, CodecFormat::Json, true) => "application/connect+json",
(Protocol::Grpc, CodecFormat::Proto, _) => "application/grpc+proto",
(Protocol::Grpc, CodecFormat::Json, _) => "application/grpc+json",
(Protocol::GrpcWeb, CodecFormat::Proto, _) => "application/grpc-web+proto",
(Protocol::GrpcWeb, CodecFormat::Json, _) => "application/grpc-web+json",
}
}
#[inline]
pub fn timeout_header(&self) -> &'static str {
match self {
Protocol::Connect => "connect-timeout-ms",
Protocol::Grpc | Protocol::GrpcWeb => "grpc-timeout",
}
}
#[inline]
pub fn content_encoding_header(&self) -> &'static http::HeaderName {
match self {
Protocol::Connect => &hdr::CONNECT_CONTENT_ENCODING,
Protocol::Grpc | Protocol::GrpcWeb => &hdr::GRPC_ENCODING,
}
}
#[inline]
pub fn accept_encoding_header(&self) -> &'static http::HeaderName {
match self {
Protocol::Connect => &hdr::CONNECT_ACCEPT_ENCODING,
Protocol::Grpc | Protocol::GrpcWeb => &hdr::GRPC_ACCEPT_ENCODING,
}
}
#[inline]
pub fn uses_http_status_codes(&self) -> bool {
matches!(self, Protocol::Connect)
}
#[inline]
pub fn requires_http2(&self) -> bool {
matches!(self, Protocol::Grpc)
}
#[inline]
pub fn uses_http_trailers(&self) -> bool {
matches!(self, Protocol::Grpc)
}
}
impl std::fmt::Display for Protocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Protocol::Connect => write!(f, "connect"),
Protocol::Grpc => write!(f, "grpc"),
Protocol::GrpcWeb => write!(f, "grpc-web"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_connect_unary_proto() {
let result = Protocol::detect_from_content_type("application/proto").unwrap();
assert_eq!(result.protocol, Protocol::Connect);
assert_eq!(result.codec_format, CodecFormat::Proto);
assert!(!result.is_streaming);
}
#[test]
fn test_detect_connect_unary_json() {
let result = Protocol::detect_from_content_type("application/json").unwrap();
assert_eq!(result.protocol, Protocol::Connect);
assert_eq!(result.codec_format, CodecFormat::Json);
assert!(!result.is_streaming);
}
#[test]
fn test_detect_connect_streaming_proto() {
let result = Protocol::detect_from_content_type("application/connect+proto").unwrap();
assert_eq!(result.protocol, Protocol::Connect);
assert_eq!(result.codec_format, CodecFormat::Proto);
assert!(result.is_streaming);
}
#[test]
fn test_detect_connect_streaming_json() {
let result = Protocol::detect_from_content_type("application/connect+json").unwrap();
assert_eq!(result.protocol, Protocol::Connect);
assert_eq!(result.codec_format, CodecFormat::Json);
assert!(result.is_streaming);
}
#[test]
fn test_detect_grpc_default() {
let result = Protocol::detect_from_content_type("application/grpc").unwrap();
assert_eq!(result.protocol, Protocol::Grpc);
assert_eq!(result.codec_format, CodecFormat::Proto);
assert!(result.is_streaming);
}
#[test]
fn test_detect_grpc_proto() {
let result = Protocol::detect_from_content_type("application/grpc+proto").unwrap();
assert_eq!(result.protocol, Protocol::Grpc);
assert_eq!(result.codec_format, CodecFormat::Proto);
}
#[test]
fn test_detect_grpc_json() {
let result = Protocol::detect_from_content_type("application/grpc+json").unwrap();
assert_eq!(result.protocol, Protocol::Grpc);
assert_eq!(result.codec_format, CodecFormat::Json);
}
#[test]
fn test_detect_grpc_web_default() {
let result = Protocol::detect_from_content_type("application/grpc-web").unwrap();
assert_eq!(result.protocol, Protocol::GrpcWeb);
assert_eq!(result.codec_format, CodecFormat::Proto);
assert!(!result.is_text_mode);
}
#[test]
fn test_detect_grpc_web_proto() {
let result = Protocol::detect_from_content_type("application/grpc-web+proto").unwrap();
assert_eq!(result.protocol, Protocol::GrpcWeb);
assert_eq!(result.codec_format, CodecFormat::Proto);
}
#[test]
fn test_detect_grpc_web_json() {
let result = Protocol::detect_from_content_type("application/grpc-web+json").unwrap();
assert_eq!(result.protocol, Protocol::GrpcWeb);
assert_eq!(result.codec_format, CodecFormat::Json);
}
#[test]
fn test_detect_grpc_web_text() {
let result = Protocol::detect_from_content_type("application/grpc-web-text").unwrap();
assert_eq!(result.protocol, Protocol::GrpcWeb);
assert_eq!(result.codec_format, CodecFormat::Proto);
assert!(result.is_text_mode);
}
#[test]
fn test_detect_grpc_web_text_proto() {
let result = Protocol::detect_from_content_type("application/grpc-web-text+proto").unwrap();
assert_eq!(result.protocol, Protocol::GrpcWeb);
assert_eq!(result.codec_format, CodecFormat::Proto);
assert!(result.is_text_mode);
}
#[test]
fn test_detect_unknown() {
assert!(Protocol::detect_from_content_type("text/html").is_none());
assert!(Protocol::detect_from_content_type("application/xml").is_none());
}
#[test]
fn test_detect_grpc_web_text_json_rejected() {
assert!(Protocol::detect_from_content_type("application/grpc-web-text+json").is_none());
}
#[test]
fn test_detect_with_charset_parameter() {
let result = Protocol::detect_from_content_type("application/json; charset=utf-8").unwrap();
assert_eq!(result.protocol, Protocol::Connect);
assert_eq!(result.codec_format, CodecFormat::Json);
}
#[test]
fn test_detect_grpc_not_confused_with_grpc_web() {
let result = Protocol::detect_from_content_type("application/grpc").unwrap();
assert_eq!(result.protocol, Protocol::Grpc);
let result = Protocol::detect_from_content_type("application/grpc-web").unwrap();
assert_eq!(result.protocol, Protocol::GrpcWeb);
}
#[test]
fn test_response_content_types() {
assert_eq!(
Protocol::Connect.response_content_type(CodecFormat::Proto, false),
"application/proto"
);
assert_eq!(
Protocol::Connect.response_content_type(CodecFormat::Json, true),
"application/connect+json"
);
assert_eq!(
Protocol::Grpc.response_content_type(CodecFormat::Proto, true),
"application/grpc+proto"
);
assert_eq!(
Protocol::GrpcWeb.response_content_type(CodecFormat::Json, false),
"application/grpc-web+json"
);
}
#[test]
fn test_protocol_properties() {
assert!(Protocol::Connect.uses_http_status_codes());
assert!(!Protocol::Grpc.uses_http_status_codes());
assert!(!Protocol::GrpcWeb.uses_http_status_codes());
assert!(!Protocol::Connect.requires_http2());
assert!(Protocol::Grpc.requires_http2());
assert!(!Protocol::GrpcWeb.requires_http2());
assert!(!Protocol::Connect.uses_http_trailers());
assert!(Protocol::Grpc.uses_http_trailers());
assert!(!Protocol::GrpcWeb.uses_http_trailers());
}
#[test]
fn test_header_names() {
assert_eq!(Protocol::Connect.timeout_header(), "connect-timeout-ms");
assert_eq!(Protocol::Grpc.timeout_header(), "grpc-timeout");
assert_eq!(Protocol::GrpcWeb.timeout_header(), "grpc-timeout");
assert_eq!(
Protocol::Connect.content_encoding_header().as_str(),
"connect-content-encoding"
);
assert_eq!(
Protocol::Grpc.content_encoding_header().as_str(),
"grpc-encoding"
);
assert_eq!(
Protocol::Connect.accept_encoding_header().as_str(),
"connect-accept-encoding"
);
assert_eq!(
Protocol::Grpc.accept_encoding_header().as_str(),
"grpc-accept-encoding"
);
}
}