use std::sync::Arc;
use axum::body::Body;
use axum::extract::State;
use axum::response::IntoResponse;
use bytes::Bytes;
use folk_api::Executor;
use http::Response;
use http_body_util::BodyExt;
use tracing::debug;
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct GrpcEnvelope {
pub service: String,
pub method: String,
pub payload: Vec<u8>,
pub metadata: std::collections::HashMap<String, String>,
}
#[derive(Clone)]
pub struct GrpcState {
pub executor: Arc<dyn Executor>,
}
pub async fn grpc_handler(
State(state): State<GrpcState>,
req: axum::extract::Request,
) -> impl IntoResponse {
let path = req.uri().path().to_string();
let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
let (service, method) = match parts.as_slice() {
[s, m] => (s.to_string(), m.to_string()),
_ => return grpc_error(12, "unimplemented: bad path"),
};
let metadata: std::collections::HashMap<String, String> = req
.headers()
.iter()
.filter(|(k, _)| {
let k = k.as_str();
!k.starts_with(':')
&& k != "content-type"
&& k != "te"
&& k != "user-agent"
&& k != "grpc-timeout"
&& k != "grpc-encoding"
&& k != "grpc-accept-encoding"
})
.filter_map(|(k, v)| {
v.to_str()
.ok()
.map(|v| (k.as_str().to_string(), v.to_string()))
})
.collect();
let body_bytes = match req.into_body().collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => return grpc_error(13, &format!("body: {e}")),
};
let proto_bytes = if body_bytes.len() >= 5 {
body_bytes.slice(5..)
} else {
body_bytes
};
debug!(
service,
method,
payload_len = proto_bytes.len(),
"gRPC call"
);
let envelope = GrpcEnvelope {
service,
method,
payload: proto_bytes.to_vec(),
metadata,
};
let encoded = match rmp_serde::to_vec_named(&envelope) {
Ok(v) => Bytes::from(v),
Err(e) => return grpc_error(13, &format!("encode: {e}")),
};
let raw_response = match state.executor.execute_method("grpc.call", encoded).await {
Ok(v) => v,
Err(e) => return grpc_error(13, &format!("worker: {e}")),
};
let proto_response = match rmp_serde::from_slice::<rmpv::Value>(&raw_response) {
Ok(rmpv::Value::Binary(b)) => b,
Ok(rmpv::Value::String(s)) => s.into_bytes(),
Ok(other) => {
return grpc_error(13, &format!("unexpected response type: {other:?}"));
}
Err(_) => raw_response.to_vec(),
};
let mut framed = Vec::with_capacity(5 + proto_response.len());
framed.push(0u8); framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
framed.extend_from_slice(&proto_response);
Response::builder()
.status(200)
.header("content-type", "application/grpc")
.header("grpc-status", "0")
.body(Body::from(framed))
.unwrap()
}
fn grpc_error(code: u32, msg: &str) -> Response<Body> {
Response::builder()
.status(200)
.header("content-type", "application/grpc")
.header("grpc-status", code.to_string())
.header("grpc-message", msg)
.body(Body::empty())
.unwrap()
}