Skip to main content

folk_plugin_grpc/
service.rs

1use std::sync::Arc;
2
3use axum::body::Body;
4use axum::extract::State;
5use axum::response::IntoResponse;
6use bytes::Bytes;
7use folk_api::Executor;
8use http::Response;
9use http_body_util::BodyExt;
10use tracing::debug;
11
12#[derive(serde::Serialize, serde::Deserialize, Debug)]
13pub struct GrpcEnvelope {
14    pub service: String,
15    pub method: String,
16    pub payload: Vec<u8>,
17    pub metadata: std::collections::HashMap<String, String>,
18}
19
20#[derive(Clone)]
21pub struct GrpcState {
22    pub executor: Arc<dyn Executor>,
23}
24
25pub async fn grpc_handler(
26    State(state): State<GrpcState>,
27    req: axum::extract::Request,
28) -> impl IntoResponse {
29    let path = req.uri().path().to_string();
30
31    let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
32    let (service, method) = match parts.as_slice() {
33        [s, m] => (s.to_string(), m.to_string()),
34        _ => return grpc_error(12, "unimplemented: bad path"),
35    };
36
37    // Extract gRPC metadata from headers.
38    // gRPC metadata = all headers except standard HTTP/2 and gRPC internal ones.
39    let metadata: std::collections::HashMap<String, String> = req
40        .headers()
41        .iter()
42        .filter(|(k, _)| {
43            let k = k.as_str();
44            !k.starts_with(':')
45                && k != "content-type"
46                && k != "te"
47                && k != "user-agent"
48                && k != "grpc-timeout"
49                && k != "grpc-encoding"
50                && k != "grpc-accept-encoding"
51        })
52        .filter_map(|(k, v)| {
53            v.to_str()
54                .ok()
55                .map(|v| (k.as_str().to_string(), v.to_string()))
56        })
57        .collect();
58
59    let body_bytes = match req.into_body().collect().await {
60        Ok(collected) => collected.to_bytes(),
61        Err(e) => return grpc_error(13, &format!("body: {e}")),
62    };
63
64    // Strip 5-byte gRPC framing (compression flag + 4-byte length)
65    let proto_bytes = if body_bytes.len() >= 5 {
66        body_bytes.slice(5..)
67    } else {
68        body_bytes
69    };
70
71    debug!(
72        service,
73        method,
74        payload_len = proto_bytes.len(),
75        "gRPC call"
76    );
77
78    let envelope = GrpcEnvelope {
79        service,
80        method,
81        payload: proto_bytes.to_vec(),
82        metadata,
83    };
84    let encoded = match rmp_serde::to_vec_named(&envelope) {
85        Ok(v) => Bytes::from(v),
86        Err(e) => return grpc_error(13, &format!("encode: {e}")),
87    };
88
89    let raw_response = match state.executor.execute_method("grpc.call", encoded).await {
90        Ok(v) => v,
91        Err(e) => return grpc_error(13, &format!("worker: {e}")),
92    };
93
94    // Worker returns MessagePack-encoded result. Extract raw bytes.
95    let proto_response = match rmp_serde::from_slice::<rmpv::Value>(&raw_response) {
96        Ok(rmpv::Value::Binary(b)) => b,
97        Ok(rmpv::Value::String(s)) => s.into_bytes(),
98        Ok(other) => {
99            return grpc_error(13, &format!("unexpected response type: {other:?}"));
100        }
101        Err(_) => raw_response.to_vec(),
102    };
103
104    // Re-add 5-byte gRPC framing to response
105    let mut framed = Vec::with_capacity(5 + proto_response.len());
106    framed.push(0u8); // no compression
107    framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
108    framed.extend_from_slice(&proto_response);
109
110    Response::builder()
111        .status(200)
112        .header("content-type", "application/grpc")
113        .header("grpc-status", "0")
114        .body(Body::from(framed))
115        .unwrap()
116}
117
118fn grpc_error(code: u32, msg: &str) -> Response<Body> {
119    Response::builder()
120        .status(200)
121        .header("content-type", "application/grpc")
122        .header("grpc-status", code.to_string())
123        .header("grpc-message", msg)
124        .body(Body::empty())
125        .unwrap()
126}