use std::collections::BTreeMap;
use std::time::Duration;
use async_trait::async_trait;
use prost::Message;
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
use url::Url;
use crate::error::{Result, ZeraError};
use crate::types::{ResolvedRpcEndpoint, RpcConfig};
const GRPC_WEB_CONTENT_TYPE: &str = "application/grpc-web+proto";
const GRPC_WEB_HEADER: &str = "x-grpc-web";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TransportResponse {
pub status: u16,
pub headers: BTreeMap<String, String>,
pub body: Vec<u8>,
}
impl TransportResponse {
pub fn ok(body: Vec<u8>) -> Self {
Self {
status: 200,
headers: BTreeMap::new(),
body,
}
}
}
#[async_trait]
pub trait UnaryTransport: Clone + Send + Sync {
async fn unary_bytes(&self, path: &str, framed_request: Vec<u8>) -> Result<TransportResponse>;
}
#[derive(Clone)]
pub struct GrpcWebTransport {
client: reqwest::Client,
config: RpcConfig,
endpoint: ResolvedRpcEndpoint,
}
impl GrpcWebTransport {
pub fn new(config: RpcConfig) -> Result<Self> {
let endpoint = config.resolve_endpoint()?;
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(config.timeout_ms))
.build()?;
Ok(Self {
client,
config,
endpoint,
})
}
pub fn endpoint(&self) -> &ResolvedRpcEndpoint {
&self.endpoint
}
fn url_for(&self, path: &str) -> String {
let path = if path.starts_with('/') {
path.to_string()
} else {
format!("/{path}")
};
format!("{}{}", self.endpoint.base_url, path)
}
fn fallback_url(&self, original_url: &str) -> Result<Option<String>> {
if !self.config.fallback_to_http {
return Ok(None);
}
let mut url = Url::parse(original_url).map_err(|error| {
ZeraError::Transport(format!(
"Unable to parse transport URL \"{original_url}\": {error}"
))
})?;
if url.scheme() != "https" || url.host_str() != Some(self.endpoint.hostname.as_str()) {
return Ok(None);
}
url.set_scheme("http")
.map_err(|_| ZeraError::Transport("Unable to rewrite HTTPS URL to HTTP".to_string()))?;
url.set_port(Some(self.config.fallback_port)).map_err(|_| {
ZeraError::Transport(format!(
"Unable to apply fallback port {} to transport URL",
self.config.fallback_port
))
})?;
Ok(Some(url.to_string()))
}
async fn send_to(&self, url: String, framed_request: &[u8]) -> Result<TransportResponse> {
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static(GRPC_WEB_CONTENT_TYPE),
);
headers.insert(ACCEPT, HeaderValue::from_static(GRPC_WEB_CONTENT_TYPE));
headers.insert(GRPC_WEB_HEADER, HeaderValue::from_static("1"));
let response = self
.client
.post(url)
.headers(headers)
.body(framed_request.to_vec())
.send()
.await?;
let status = response.status().as_u16();
let headers = normalize_headers(response.headers());
let body = response.bytes().await?.to_vec();
Ok(TransportResponse {
status,
headers,
body,
})
}
}
#[async_trait]
impl UnaryTransport for GrpcWebTransport {
async fn unary_bytes(&self, path: &str, framed_request: Vec<u8>) -> Result<TransportResponse> {
let url = self.url_for(path);
match self.send_to(url.clone(), &framed_request).await {
Ok(response) => Ok(response),
Err(primary_error) => {
if let Some(fallback_url) = self.fallback_url(&url)? {
self.send_to(fallback_url, &framed_request)
.await
.map_err(|_| primary_error)
} else {
Err(primary_error)
}
}
}
}
}
pub async fn unary<Req, Res, T>(transport: &T, path: &str, request: &Req) -> Result<Res>
where
Req: Message,
Res: Message + Default,
T: UnaryTransport,
{
let framed_request = frame_grpc_web_message(request)?;
let response = transport.unary_bytes(path, framed_request).await?;
decode_grpc_web_response(&response)
}
pub fn frame_grpc_web_message<MessageType: Message>(message: &MessageType) -> Result<Vec<u8>> {
let payload = message.encode_to_vec();
let mut framed = Vec::with_capacity(payload.len() + 5);
framed.push(0);
framed.extend_from_slice(&(payload.len() as u32).to_be_bytes());
framed.extend_from_slice(&payload);
Ok(framed)
}
pub fn decode_grpc_web_response<MessageType: Message + Default>(
response: &TransportResponse,
) -> Result<MessageType> {
let (message_frame, trailers) = split_frames(&response.body)?;
let grpc_status = trailers
.get("grpc-status")
.or_else(|| response.headers.get("grpc-status"))
.cloned()
.unwrap_or_else(|| "0".to_string());
let grpc_message = trailers
.get("grpc-message")
.or_else(|| response.headers.get("grpc-message"))
.cloned();
if response.status >= 400 && grpc_status == "0" {
return Err(ZeraError::Rpc(format!(
"HTTP {} with no grpc-status trailer",
response.status
)));
}
if grpc_status != "0" {
let detail = grpc_message.unwrap_or_else(|| "unknown gRPC error".to_string());
return Err(ZeraError::Rpc(format!("[{}] {}", grpc_status, detail)));
}
if let Some(frame) = message_frame {
return Ok(MessageType::decode(frame.as_slice())?);
}
Ok(MessageType::default())
}
type GrpcWebFrameSplit = (Option<Vec<u8>>, BTreeMap<String, String>);
fn split_frames(body: &[u8]) -> Result<GrpcWebFrameSplit> {
let mut cursor = 0usize;
let mut message_frame = None;
let mut trailers = BTreeMap::new();
while cursor < body.len() {
if cursor + 5 > body.len() {
return Err(ZeraError::Serialization(
"Malformed gRPC-Web frame header".to_string(),
));
}
let flag = body[cursor];
let length = u32::from_be_bytes([
body[cursor + 1],
body[cursor + 2],
body[cursor + 3],
body[cursor + 4],
]) as usize;
cursor += 5;
if cursor + length > body.len() {
return Err(ZeraError::Serialization(
"Malformed gRPC-Web frame length".to_string(),
));
}
let frame = &body[cursor..cursor + length];
cursor += length;
match flag {
0 => message_frame = Some(frame.to_vec()),
0x80 => {
for (key, value) in parse_trailer_frame(frame) {
trailers.insert(key, value);
}
}
_ => {
return Err(ZeraError::Serialization(format!(
"Unsupported gRPC-Web frame flag: {flag}"
)))
}
}
}
Ok((message_frame, trailers))
}
fn parse_trailer_frame(frame: &[u8]) -> BTreeMap<String, String> {
let mut trailers = BTreeMap::new();
let text = String::from_utf8_lossy(frame);
for line in text.split("\r\n") {
if let Some((key, value)) = line.split_once(':') {
trailers.insert(key.trim().to_ascii_lowercase(), value.trim().to_string());
}
}
trailers
}
fn normalize_headers(headers: &HeaderMap) -> BTreeMap<String, String> {
headers
.iter()
.filter_map(|(key, value)| {
value
.to_str()
.ok()
.map(|value| (key.as_str().to_ascii_lowercase(), value.to_string()))
})
.collect()
}