Skip to main content

folk_plugin_grpc/
service.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use axum::extract::State;
6use axum::response::IntoResponse;
7use bytes::Bytes;
8use folk_api::Executor;
9use http::{HeaderMap, HeaderValue, Response};
10use http_body::Frame;
11use http_body_util::BodyExt;
12use tracing::debug;
13
14#[derive(serde::Serialize, serde::Deserialize, Debug)]
15pub struct GrpcEnvelope {
16    pub service: String,
17    pub method: String,
18    pub payload: Vec<u8>,
19    pub metadata: std::collections::HashMap<String, String>,
20}
21
22#[derive(Clone)]
23pub struct GrpcState {
24    pub executor: Arc<dyn Executor>,
25}
26
27pub async fn grpc_handler(
28    State(state): State<GrpcState>,
29    req: axum::extract::Request,
30) -> impl IntoResponse {
31    let path = req.uri().path().to_string();
32
33    let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
34    let (service, method) = match parts.as_slice() {
35        [s, m] => (s.to_string(), m.to_string()),
36        _ => return grpc_response(Bytes::new(), 12, "unimplemented: bad path"),
37    };
38
39    let metadata: std::collections::HashMap<String, String> = req
40        .headers()
41        .iter()
42        .filter(|(k, _)| {
43            let k = k.as_str();
44            !k.starts_with(':')
45                && k != "content-type"
46                && k != "te"
47                && k != "user-agent"
48                && k != "grpc-timeout"
49                && k != "grpc-encoding"
50                && k != "grpc-accept-encoding"
51        })
52        .filter_map(|(k, v)| {
53            v.to_str()
54                .ok()
55                .map(|v| (k.as_str().to_string(), v.to_string()))
56        })
57        .collect();
58
59    let body_bytes = match req.into_body().collect().await {
60        Ok(collected) => collected.to_bytes(),
61        Err(e) => return grpc_response(Bytes::new(), 13, &format!("body: {e}")),
62    };
63
64    let proto_bytes = if body_bytes.len() >= 5 {
65        body_bytes.slice(5..)
66    } else {
67        body_bytes
68    };
69
70    debug!(
71        service,
72        method,
73        payload_len = proto_bytes.len(),
74        "gRPC call"
75    );
76
77    let envelope = GrpcEnvelope {
78        service,
79        method,
80        payload: proto_bytes.to_vec(),
81        metadata,
82    };
83    let encoded = match rmp_serde::to_vec_named(&envelope) {
84        Ok(v) => Bytes::from(v),
85        Err(e) => return grpc_response(Bytes::new(), 13, &format!("encode: {e}")),
86    };
87
88    let raw_response = match state.executor.execute_method("grpc.call", encoded).await {
89        Ok(v) => v,
90        Err(e) => return grpc_response(Bytes::new(), 13, &format!("worker: {e}")),
91    };
92
93    let proto_response = match rmp_serde::from_slice::<rmpv::Value>(&raw_response) {
94        Ok(rmpv::Value::Binary(b)) => b,
95        Ok(rmpv::Value::String(s)) => s.into_bytes(),
96        Ok(other) => {
97            return grpc_response(
98                Bytes::new(),
99                13,
100                &format!("unexpected response type: {other:?}"),
101            );
102        }
103        Err(_) => raw_response.to_vec(),
104    };
105
106    // Build gRPC-framed response
107    let mut framed = Vec::with_capacity(5 + proto_response.len());
108    framed.push(0u8);
109    framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
110    framed.extend_from_slice(&proto_response);
111
112    grpc_response(Bytes::from(framed), 0, "")
113}
114
115/// Build a gRPC response with proper HTTP/2 trailers.
116///
117/// gRPC requires `grpc-status` (and optionally `grpc-message`) in trailers,
118/// not in response headers. This is what grpcurl and other strict clients expect.
119fn grpc_response(data: Bytes, status: u32, message: &str) -> Response<GrpcBody> {
120    let mut trailers = HeaderMap::new();
121    trailers.insert("grpc-status", HeaderValue::from(status));
122    if !message.is_empty() {
123        if let Ok(v) = HeaderValue::from_str(message) {
124            trailers.insert("grpc-message", v);
125        }
126    }
127
128    Response::builder()
129        .status(200)
130        .header("content-type", "application/grpc")
131        .body(GrpcBody {
132            data: Some(data),
133            trailers: Some(trailers),
134        })
135        .unwrap()
136}
137
138/// Custom body type that sends data followed by gRPC trailers.
139pub struct GrpcBody {
140    data: Option<Bytes>,
141    trailers: Option<HeaderMap>,
142}
143
144impl http_body::Body for GrpcBody {
145    type Data = Bytes;
146    type Error = std::convert::Infallible;
147
148    fn poll_frame(
149        mut self: Pin<&mut Self>,
150        _cx: &mut Context<'_>,
151    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
152        if let Some(data) = self.data.take() {
153            if !data.is_empty() {
154                return Poll::Ready(Some(Ok(Frame::data(data))));
155            }
156        }
157        if let Some(trailers) = self.trailers.take() {
158            return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
159        }
160        Poll::Ready(None)
161    }
162}