Skip to main content

folk_plugin_grpc/
service.rs

1use std::io::{Read as _, Write as _};
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::extract::State;
7use axum::response::IntoResponse;
8use bytes::Bytes;
9use folk_api::Executor;
10use http::{HeaderMap, HeaderValue, Response};
11use http_body::Frame;
12use http_body_util::BodyExt;
13use tracing::debug;
14
15use crate::metrics::GrpcMetrics;
16
17#[derive(serde::Serialize, serde::Deserialize, Debug)]
18pub struct GrpcEnvelope {
19    pub service: String,
20    pub method: String,
21    pub payload: Vec<u8>,
22    pub metadata: std::collections::HashMap<String, String>,
23}
24
25#[derive(Clone)]
26pub struct GrpcState {
27    pub executor: Arc<dyn Executor>,
28    pub max_recv_message_size: usize,
29    pub max_send_message_size: usize,
30    pub compression: bool,
31    pub metrics: Option<GrpcMetrics>,
32}
33
34pub async fn grpc_handler(
35    State(state): State<GrpcState>,
36    req: axum::extract::Request,
37) -> impl IntoResponse {
38    let tracker = state.metrics.as_ref().map(|m| m.track_start());
39
40    let path = req.uri().path().to_string();
41
42    let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
43    let (service, method) = match parts.as_slice() {
44        [s, m] => (s.to_string(), m.to_string()),
45        _ => {
46            if let Some(t) = tracker {
47                t.finish("", "", 12, 0, 0);
48            }
49            return grpc_response(Bytes::new(), 12, "unimplemented: bad path", false);
50        }
51    };
52
53    // Check if client accepts gzip
54    let client_accepts_gzip = state.compression
55        && req
56            .headers()
57            .get("grpc-accept-encoding")
58            .and_then(|v| v.to_str().ok())
59            .is_some_and(|v| v.split(',').any(|e| e.trim() == "gzip"));
60
61    let metadata: std::collections::HashMap<String, String> = req
62        .headers()
63        .iter()
64        .filter(|(k, _)| {
65            let k = k.as_str();
66            !k.starts_with(':')
67                && k != "content-type"
68                && k != "te"
69                && k != "user-agent"
70                && k != "grpc-timeout"
71                && k != "grpc-encoding"
72                && k != "grpc-accept-encoding"
73        })
74        .filter_map(|(k, v)| {
75            v.to_str()
76                .ok()
77                .map(|v| (k.as_str().to_string(), v.to_string()))
78        })
79        .collect();
80
81    let body_bytes = match req.into_body().collect().await {
82        Ok(collected) => collected.to_bytes(),
83        Err(e) => {
84            if let Some(t) = tracker {
85                t.finish(&service, &method, 13, 0, 0);
86            }
87            return grpc_response(Bytes::new(), 13, &format!("body: {e}"), false);
88        }
89    };
90
91    if body_bytes.len() < 5 {
92        if let Some(t) = tracker {
93            t.finish(&service, &method, 13, 0, 0);
94        }
95        return grpc_response(Bytes::new(), 13, "incomplete gRPC frame", false);
96    }
97
98    // Parse gRPC frame: [compression_flag(1)][length(4)][payload]
99    let compressed = body_bytes[0] == 1;
100    let raw_payload = body_bytes.slice(5..);
101
102    // Decompress incoming if needed
103    let proto_bytes = if compressed {
104        match gzip_decompress(&raw_payload) {
105            Ok(decompressed) => Bytes::from(decompressed),
106            Err(e) => {
107                if let Some(t) = tracker {
108                    t.finish(&service, &method, 13, raw_payload.len(), 0);
109                }
110                return grpc_response(Bytes::new(), 13, &format!("decompress: {e}"), false);
111            }
112        }
113    } else {
114        raw_payload
115    };
116
117    let recv_bytes = proto_bytes.len();
118
119    // Enforce max incoming message size (after decompression)
120    if recv_bytes > state.max_recv_message_size {
121        if let Some(t) = tracker {
122            t.finish(&service, &method, 11, recv_bytes, 0);
123        }
124        return grpc_response(
125            Bytes::new(),
126            11, // RESOURCE_EXHAUSTED
127            &format!(
128                "received message larger than max ({recv_bytes} vs {} bytes)",
129                state.max_recv_message_size
130            ),
131            false,
132        );
133    }
134
135    debug!(service, method, payload_len = recv_bytes, "gRPC call");
136
137    let envelope = GrpcEnvelope {
138        service: service.clone(),
139        method: method.clone(),
140        payload: proto_bytes.to_vec(),
141        metadata,
142    };
143    let encoded = match rmp_serde::to_vec_named(&envelope) {
144        Ok(v) => Bytes::from(v),
145        Err(e) => {
146            if let Some(t) = tracker {
147                t.finish(&service, &method, 13, recv_bytes, 0);
148            }
149            return grpc_response(Bytes::new(), 13, &format!("encode: {e}"), false);
150        }
151    };
152
153    let raw_response = match state.executor.execute_method("grpc.call", encoded).await {
154        Ok(v) => v,
155        Err(e) => {
156            if let Some(t) = tracker {
157                t.finish(&service, &method, 13, recv_bytes, 0);
158            }
159            return grpc_response(Bytes::new(), 13, &format!("worker: {e}"), false);
160        }
161    };
162
163    let proto_response = match rmp_serde::from_slice::<rmpv::Value>(&raw_response) {
164        Ok(rmpv::Value::Binary(b)) => b,
165        Ok(rmpv::Value::String(s)) => s.into_bytes(),
166        Ok(other) => {
167            if let Some(t) = tracker {
168                t.finish(&service, &method, 13, recv_bytes, 0);
169            }
170            return grpc_response(
171                Bytes::new(),
172                13,
173                &format!("unexpected response type: {other:?}"),
174                false,
175            );
176        }
177        Err(_) => raw_response.to_vec(),
178    };
179
180    let sent_bytes = proto_response.len();
181
182    // Enforce max outgoing message size
183    if sent_bytes > state.max_send_message_size {
184        if let Some(t) = tracker {
185            t.finish(&service, &method, 11, recv_bytes, sent_bytes);
186        }
187        return grpc_response(
188            Bytes::new(),
189            11, // RESOURCE_EXHAUSTED
190            &format!(
191                "response larger than max ({sent_bytes} vs {} bytes)",
192                state.max_send_message_size
193            ),
194            false,
195        );
196    }
197
198    // Build gRPC-framed response (optionally compressed)
199    if client_accepts_gzip {
200        if let Ok(compressed_data) = gzip_compress(&proto_response) {
201            let mut framed = Vec::with_capacity(5 + compressed_data.len());
202            framed.push(1u8); // compression flag
203            framed.extend_from_slice(&(compressed_data.len() as u32).to_be_bytes());
204            framed.extend_from_slice(&compressed_data);
205            if let Some(t) = tracker {
206                t.finish(&service, &method, 0, recv_bytes, sent_bytes);
207            }
208            return grpc_response(Bytes::from(framed), 0, "", true);
209        }
210    }
211
212    let mut framed = Vec::with_capacity(5 + proto_response.len());
213    framed.push(0u8);
214    framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
215    framed.extend_from_slice(&proto_response);
216
217    if let Some(t) = tracker {
218        t.finish(&service, &method, 0, recv_bytes, sent_bytes);
219    }
220    grpc_response(Bytes::from(framed), 0, "", false)
221}
222
223fn grpc_response(data: Bytes, status: u32, message: &str, gzip: bool) -> Response<GrpcBody> {
224    let mut trailers = HeaderMap::new();
225    trailers.insert("grpc-status", HeaderValue::from(status));
226    if !message.is_empty() {
227        if let Ok(v) = HeaderValue::from_str(message) {
228            trailers.insert("grpc-message", v);
229        }
230    }
231
232    let mut builder = Response::builder()
233        .status(200)
234        .header("content-type", "application/grpc");
235
236    if gzip {
237        builder = builder.header("grpc-encoding", "gzip");
238    }
239
240    builder
241        .body(GrpcBody {
242            data: Some(data),
243            trailers: Some(trailers),
244        })
245        .unwrap()
246}
247
248fn gzip_decompress(data: &[u8]) -> std::io::Result<Vec<u8>> {
249    let mut decoder = flate2::read::GzDecoder::new(data);
250    let mut out = Vec::new();
251    decoder.read_to_end(&mut out)?;
252    Ok(out)
253}
254
255fn gzip_compress(data: &[u8]) -> std::io::Result<Vec<u8>> {
256    let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
257    encoder.write_all(data)?;
258    encoder.finish()
259}
260
261/// Custom body type that sends data followed by gRPC trailers.
262pub struct GrpcBody {
263    data: Option<Bytes>,
264    trailers: Option<HeaderMap>,
265}
266
267impl http_body::Body for GrpcBody {
268    type Data = Bytes;
269    type Error = std::convert::Infallible;
270
271    fn poll_frame(
272        mut self: Pin<&mut Self>,
273        _cx: &mut Context<'_>,
274    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
275        if let Some(data) = self.data.take() {
276            if !data.is_empty() {
277                return Poll::Ready(Some(Ok(Frame::data(data))));
278            }
279        }
280        if let Some(trailers) = self.trailers.take() {
281            return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
282        }
283        Poll::Ready(None)
284    }
285}