Skip to main content

folk_plugin_grpc/
service.rs

1use std::sync::Arc;
2
3use axum::body::Body;
4use axum::extract::State;
5use axum::response::IntoResponse;
6use bytes::Bytes;
7use folk_api::Executor;
8use http::Response;
9use http_body_util::BodyExt;
10use tracing::debug;
11
12#[derive(serde::Serialize, serde::Deserialize, Debug)]
13pub struct GrpcEnvelope {
14    pub service: String,
15    pub method: String,
16    pub payload: Vec<u8>,
17}
18
19#[derive(Clone)]
20pub struct GrpcState {
21    pub executor: Arc<dyn Executor>,
22}
23
24pub async fn grpc_handler(
25    State(state): State<GrpcState>,
26    req: axum::extract::Request,
27) -> impl IntoResponse {
28    let path = req.uri().path().to_string();
29
30    let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
31    let (service, method) = match parts.as_slice() {
32        [s, m] => (s.to_string(), m.to_string()),
33        _ => return grpc_error(12, "unimplemented: bad path"),
34    };
35
36    let body_bytes = match req.into_body().collect().await {
37        Ok(collected) => collected.to_bytes(),
38        Err(e) => return grpc_error(13, &format!("body: {e}")),
39    };
40
41    // Strip 5-byte gRPC framing (compression flag + 4-byte length)
42    let proto_bytes = if body_bytes.len() >= 5 {
43        body_bytes.slice(5..)
44    } else {
45        body_bytes
46    };
47
48    debug!(
49        service,
50        method,
51        payload_len = proto_bytes.len(),
52        "gRPC call"
53    );
54
55    let envelope = GrpcEnvelope {
56        service,
57        method,
58        payload: proto_bytes.to_vec(),
59    };
60    let encoded = match rmp_serde::to_vec_named(&envelope) {
61        Ok(v) => Bytes::from(v),
62        Err(e) => return grpc_error(13, &format!("encode: {e}")),
63    };
64
65    let raw_response = match state.executor.execute_method("grpc.call", encoded).await {
66        Ok(v) => v,
67        Err(e) => return grpc_error(13, &format!("worker: {e}")),
68    };
69
70    // Worker returns MessagePack-encoded result. Extract raw bytes.
71    let proto_response = match rmp_serde::from_slice::<rmpv::Value>(&raw_response) {
72        Ok(rmpv::Value::Binary(b)) => b,
73        Ok(rmpv::Value::String(s)) => s.into_bytes(),
74        Ok(other) => {
75            return grpc_error(13, &format!("unexpected response type: {other:?}"));
76        }
77        Err(_) => raw_response.to_vec(),
78    };
79
80    // Re-add 5-byte gRPC framing to response
81    let mut framed = Vec::with_capacity(5 + proto_response.len());
82    framed.push(0u8); // no compression
83    framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
84    framed.extend_from_slice(&proto_response);
85
86    Response::builder()
87        .status(200)
88        .header("content-type", "application/grpc")
89        .header("grpc-status", "0")
90        .body(Body::from(framed))
91        .unwrap()
92}
93
94fn grpc_error(code: u32, msg: &str) -> Response<Body> {
95    Response::builder()
96        .status(200)
97        .header("content-type", "application/grpc")
98        .header("grpc-status", code.to_string())
99        .header("grpc-message", msg)
100        .body(Body::empty())
101        .unwrap()
102}