use std::io::{Read as _, Write as _};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::extract::State;
use axum::response::IntoResponse;
use base64::Engine;
use bytes::Bytes;
use folk_api::Executor;
use http::{HeaderMap, HeaderValue, Response};
use http_body::Frame;
use http_body_util::BodyExt;
use tracing::debug;
use crate::metrics::GrpcMetrics;
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct GrpcEnvelope {
pub service: String,
pub method: String,
pub payload: Vec<u8>,
pub metadata: std::collections::HashMap<String, String>,
}
const B64: base64::engine::general_purpose::GeneralPurpose =
base64::engine::general_purpose::STANDARD;
#[derive(Clone)]
pub struct GrpcState {
pub executor: Arc<dyn Executor>,
pub max_recv_message_size: usize,
pub max_send_message_size: usize,
pub compression: bool,
pub metrics: Option<GrpcMetrics>,
}
pub async fn grpc_handler(
State(state): State<GrpcState>,
req: axum::extract::Request,
) -> impl IntoResponse {
let tracker = state.metrics.as_ref().map(|m| m.track_start());
let path = req.uri().path().to_string();
let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
let (service, method) = match parts.as_slice() {
[s, m] => (s.to_string(), m.to_string()),
_ => {
if let Some(t) = tracker {
t.finish("", "", 12, 0, 0);
}
return grpc_response(Bytes::new(), 12, "unimplemented: bad path", false);
}
};
let client_accepts_gzip = state.compression
&& req
.headers()
.get("grpc-accept-encoding")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.split(',').any(|e| e.trim() == "gzip"));
let metadata: std::collections::HashMap<String, String> = req
.headers()
.iter()
.filter(|(k, _)| {
let k = k.as_str();
!k.starts_with(':')
&& k != "content-type"
&& k != "te"
&& k != "user-agent"
&& k != "grpc-timeout"
&& k != "grpc-encoding"
&& k != "grpc-accept-encoding"
})
.filter_map(|(k, v)| {
v.to_str()
.ok()
.map(|v| (k.as_str().to_string(), v.to_string()))
})
.collect();
let body_bytes = match req.into_body().collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
if let Some(t) = tracker {
t.finish(&service, &method, 13, 0, 0);
}
return grpc_response(Bytes::new(), 13, &format!("body: {e}"), false);
}
};
if body_bytes.len() < 5 {
if let Some(t) = tracker {
t.finish(&service, &method, 13, 0, 0);
}
return grpc_response(Bytes::new(), 13, "incomplete gRPC frame", false);
}
let compressed = body_bytes[0] == 1;
let raw_payload = body_bytes.slice(5..);
let proto_bytes = if compressed {
match gzip_decompress(&raw_payload, state.max_recv_message_size) {
Ok(decompressed) => Bytes::from(decompressed),
Err(e) => {
if let Some(t) = tracker {
t.finish(&service, &method, 8, raw_payload.len(), 0);
}
return grpc_response(
Bytes::new(),
8, &format!("decompress: {e}"),
false,
);
}
}
} else {
raw_payload
};
let recv_bytes = proto_bytes.len();
if recv_bytes > state.max_recv_message_size {
if let Some(t) = tracker {
t.finish(&service, &method, 8, recv_bytes, 0);
}
return grpc_response(
Bytes::new(),
8, &format!(
"received message larger than max ({recv_bytes} vs {} bytes)",
state.max_recv_message_size
),
false,
);
}
debug!(service, method, payload_len = recv_bytes, "gRPC call");
let payload_value = serde_json::json!({
"service": service,
"method": method,
"payload": B64.encode(&proto_bytes),
"metadata": metadata,
});
let response_value = match state
.executor
.execute_value("grpc.call", payload_value)
.await
{
Ok(v) => v,
Err(e) => {
if let Some(t) = tracker {
t.finish(&service, &method, 13, recv_bytes, 0);
}
return grpc_response(Bytes::new(), 13, &format!("worker: {e}"), false);
}
};
let response_str = match &response_value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Object(map) => {
if let Some(serde_json::Value::String(s)) = map.get("__result") {
s.clone()
} else if let Some(serde_json::Value::String(s)) = map.get("result") {
s.clone()
} else {
if let Some(t) = tracker {
t.finish(&service, &method, 13, recv_bytes, 0);
}
return grpc_response(
Bytes::new(),
13,
&format!("unexpected response: {response_value}"),
false,
);
}
}
other => {
if let Some(t) = tracker {
t.finish(&service, &method, 13, recv_bytes, 0);
}
return grpc_response(
Bytes::new(),
13,
&format!("unexpected response type: {other}"),
false,
);
}
};
let proto_response = match B64.decode(&response_str) {
Ok(bytes) => bytes,
Err(e) => {
tracing::warn!(error = %e, "base64 decode failed for gRPC response, returning INTERNAL");
if let Some(t) = tracker {
t.finish(&service, &method, 13, recv_bytes, 0);
}
return grpc_response(
Bytes::new(),
13, "PHP worker returned invalid base64 response",
false,
);
}
};
let sent_bytes = proto_response.len();
if sent_bytes > state.max_send_message_size {
if let Some(t) = tracker {
t.finish(&service, &method, 8, recv_bytes, sent_bytes);
}
return grpc_response(
Bytes::new(),
8, &format!(
"response larger than max ({sent_bytes} vs {} bytes)",
state.max_send_message_size
),
false,
);
}
if client_accepts_gzip {
if let Ok(compressed_data) = gzip_compress(&proto_response) {
let mut framed = Vec::with_capacity(5 + compressed_data.len());
framed.push(1u8); framed.extend_from_slice(&(compressed_data.len() as u32).to_be_bytes());
framed.extend_from_slice(&compressed_data);
if let Some(t) = tracker {
t.finish(&service, &method, 0, recv_bytes, sent_bytes);
}
return grpc_response(Bytes::from(framed), 0, "", true);
}
}
let mut framed = Vec::with_capacity(5 + proto_response.len());
framed.push(0u8);
framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
framed.extend_from_slice(&proto_response);
if let Some(t) = tracker {
t.finish(&service, &method, 0, recv_bytes, sent_bytes);
}
grpc_response(Bytes::from(framed), 0, "", false)
}
fn grpc_response(data: Bytes, status: u32, message: &str, gzip: bool) -> Response<GrpcBody> {
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", HeaderValue::from(status));
if !message.is_empty() {
if let Ok(v) = HeaderValue::from_str(message) {
trailers.insert("grpc-message", v);
}
}
let mut builder = Response::builder()
.status(200)
.header("content-type", "application/grpc");
if gzip {
builder = builder.header("grpc-encoding", "gzip");
}
builder
.body(GrpcBody {
data: Some(data),
trailers: Some(trailers),
})
.unwrap()
}
fn gzip_decompress(data: &[u8], limit: usize) -> std::io::Result<Vec<u8>> {
let decoder = flate2::read::GzDecoder::new(data);
let mut out = Vec::new();
decoder.take(limit as u64 + 1).read_to_end(&mut out)?;
if out.len() > limit {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("decompressed message exceeds limit ({limit} bytes)"),
));
}
Ok(out)
}
fn gzip_compress(data: &[u8]) -> std::io::Result<Vec<u8>> {
let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
encoder.write_all(data)?;
encoder.finish()
}
pub struct GrpcBody {
data: Option<Bytes>,
trailers: Option<HeaderMap>,
}
impl http_body::Body for GrpcBody {
type Data = Bytes;
type Error = std::convert::Infallible;
fn poll_frame(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
if let Some(data) = self.data.take() {
if !data.is_empty() {
return Poll::Ready(Some(Ok(Frame::data(data))));
}
}
if let Some(trailers) = self.trailers.take() {
return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
}
Poll::Ready(None)
}
}