Skip to main content

folk_plugin_grpc/
service.rs

1use std::io::{Read as _, Write as _};
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::extract::State;
7use axum::response::IntoResponse;
8use base64::Engine;
9use bytes::Bytes;
10use folk_api::Executor;
11use http::{HeaderMap, HeaderValue, Response};
12use http_body::Frame;
13use http_body_util::BodyExt;
14use tracing::debug;
15
16use crate::metrics::GrpcMetrics;
17
18#[derive(serde::Serialize, serde::Deserialize, Debug)]
19pub struct GrpcEnvelope {
20    pub service: String,
21    pub method: String,
22    pub payload: Vec<u8>,
23    pub metadata: std::collections::HashMap<String, String>,
24}
25
26const B64: base64::engine::general_purpose::GeneralPurpose =
27    base64::engine::general_purpose::STANDARD;
28
29#[derive(Clone)]
30pub struct GrpcState {
31    pub executor: Arc<dyn Executor>,
32    pub max_recv_message_size: usize,
33    pub max_send_message_size: usize,
34    pub compression: bool,
35    pub metrics: Option<GrpcMetrics>,
36}
37
38pub async fn grpc_handler(
39    State(state): State<GrpcState>,
40    req: axum::extract::Request,
41) -> impl IntoResponse {
42    let tracker = state.metrics.as_ref().map(|m| m.track_start());
43
44    let path = req.uri().path().to_string();
45
46    let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
47    let (service, method) = match parts.as_slice() {
48        [s, m] => (s.to_string(), m.to_string()),
49        _ => {
50            if let Some(t) = tracker {
51                t.finish("", "", 12, 0, 0);
52            }
53            return grpc_response(Bytes::new(), 12, "unimplemented: bad path", false);
54        }
55    };
56
57    // Check if client accepts gzip
58    let client_accepts_gzip = state.compression
59        && req
60            .headers()
61            .get("grpc-accept-encoding")
62            .and_then(|v| v.to_str().ok())
63            .is_some_and(|v| v.split(',').any(|e| e.trim() == "gzip"));
64
65    let metadata: std::collections::HashMap<String, String> = req
66        .headers()
67        .iter()
68        .filter(|(k, _)| {
69            let k = k.as_str();
70            !k.starts_with(':')
71                && k != "content-type"
72                && k != "te"
73                && k != "user-agent"
74                && k != "grpc-timeout"
75                && k != "grpc-encoding"
76                && k != "grpc-accept-encoding"
77        })
78        .filter_map(|(k, v)| {
79            v.to_str()
80                .ok()
81                .map(|v| (k.as_str().to_string(), v.to_string()))
82        })
83        .collect();
84
85    let body_bytes = match req.into_body().collect().await {
86        Ok(collected) => collected.to_bytes(),
87        Err(e) => {
88            if let Some(t) = tracker {
89                t.finish(&service, &method, 13, 0, 0);
90            }
91            return grpc_response(Bytes::new(), 13, &format!("body: {e}"), false);
92        }
93    };
94
95    if body_bytes.len() < 5 {
96        if let Some(t) = tracker {
97            t.finish(&service, &method, 13, 0, 0);
98        }
99        return grpc_response(Bytes::new(), 13, "incomplete gRPC frame", false);
100    }
101
102    // Parse gRPC frame: [compression_flag(1)][length(4)][payload]
103    let compressed = body_bytes[0] == 1;
104    let raw_payload = body_bytes.slice(5..);
105
106    // Decompress incoming if needed (size-limited to prevent zip bombs)
107    let proto_bytes = if compressed {
108        match gzip_decompress(&raw_payload, state.max_recv_message_size) {
109            Ok(decompressed) => Bytes::from(decompressed),
110            Err(e) => {
111                if let Some(t) = tracker {
112                    t.finish(&service, &method, 8, raw_payload.len(), 0);
113                }
114                return grpc_response(
115                    Bytes::new(),
116                    8, // RESOURCE_EXHAUSTED
117                    &format!("decompress: {e}"),
118                    false,
119                );
120            }
121        }
122    } else {
123        raw_payload
124    };
125
126    let recv_bytes = proto_bytes.len();
127
128    // Enforce max incoming message size (uncompressed payloads)
129    if recv_bytes > state.max_recv_message_size {
130        if let Some(t) = tracker {
131            t.finish(&service, &method, 8, recv_bytes, 0);
132        }
133        return grpc_response(
134            Bytes::new(),
135            8, // RESOURCE_EXHAUSTED
136            &format!(
137                "received message larger than max ({recv_bytes} vs {} bytes)",
138                state.max_recv_message_size
139            ),
140            false,
141        );
142    }
143
144    debug!(service, method, payload_len = recv_bytes, "gRPC call");
145
146    // Build payload as serde_json::Value (base64 for binary protobuf data)
147    let payload_value = serde_json::json!({
148        "service": service,
149        "method": method,
150        "payload": B64.encode(&proto_bytes),
151        "metadata": metadata,
152    });
153
154    let response_value = match state
155        .executor
156        .execute_value("grpc.call", payload_value)
157        .await
158    {
159        Ok(v) => v,
160        Err(e) => {
161            if let Some(t) = tracker {
162                t.finish(&service, &method, 13, recv_bytes, 0);
163            }
164            return grpc_response(Bytes::new(), 13, &format!("worker: {e}"), false);
165        }
166    };
167
168    // Extract response: may be a string directly, or wrapped in {"__result": "..."}
169    let response_str = match &response_value {
170        serde_json::Value::String(s) => s.clone(),
171        serde_json::Value::Object(map) => {
172            if let Some(serde_json::Value::String(s)) = map.get("__result") {
173                s.clone()
174            } else if let Some(serde_json::Value::String(s)) = map.get("result") {
175                s.clone()
176            } else {
177                if let Some(t) = tracker {
178                    t.finish(&service, &method, 13, recv_bytes, 0);
179                }
180                return grpc_response(
181                    Bytes::new(),
182                    13,
183                    &format!("unexpected response: {response_value}"),
184                    false,
185                );
186            }
187        }
188        other => {
189            if let Some(t) = tracker {
190                t.finish(&service, &method, 13, recv_bytes, 0);
191            }
192            return grpc_response(
193                Bytes::new(),
194                13,
195                &format!("unexpected response type: {other}"),
196                false,
197            );
198        }
199    };
200
201    // Decode base64 protobuf bytes
202    let proto_response = match B64.decode(&response_str) {
203        Ok(bytes) => bytes,
204        Err(e) => {
205            tracing::warn!(error = %e, "base64 decode failed for gRPC response, returning INTERNAL");
206            if let Some(t) = tracker {
207                t.finish(&service, &method, 13, recv_bytes, 0);
208            }
209            return grpc_response(
210                Bytes::new(),
211                13, // INTERNAL
212                "PHP worker returned invalid base64 response",
213                false,
214            );
215        }
216    };
217
218    let sent_bytes = proto_response.len();
219
220    // Enforce max outgoing message size
221    if sent_bytes > state.max_send_message_size {
222        if let Some(t) = tracker {
223            t.finish(&service, &method, 8, recv_bytes, sent_bytes);
224        }
225        return grpc_response(
226            Bytes::new(),
227            8, // RESOURCE_EXHAUSTED
228            &format!(
229                "response larger than max ({sent_bytes} vs {} bytes)",
230                state.max_send_message_size
231            ),
232            false,
233        );
234    }
235
236    // Build gRPC-framed response (optionally compressed)
237    if client_accepts_gzip {
238        if let Ok(compressed_data) = gzip_compress(&proto_response) {
239            let mut framed = Vec::with_capacity(5 + compressed_data.len());
240            framed.push(1u8); // compression flag
241            framed.extend_from_slice(&(compressed_data.len() as u32).to_be_bytes());
242            framed.extend_from_slice(&compressed_data);
243            if let Some(t) = tracker {
244                t.finish(&service, &method, 0, recv_bytes, sent_bytes);
245            }
246            return grpc_response(Bytes::from(framed), 0, "", true);
247        }
248    }
249
250    let mut framed = Vec::with_capacity(5 + proto_response.len());
251    framed.push(0u8);
252    framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
253    framed.extend_from_slice(&proto_response);
254
255    if let Some(t) = tracker {
256        t.finish(&service, &method, 0, recv_bytes, sent_bytes);
257    }
258    grpc_response(Bytes::from(framed), 0, "", false)
259}
260
261fn grpc_response(data: Bytes, status: u32, message: &str, gzip: bool) -> Response<GrpcBody> {
262    let mut trailers = HeaderMap::new();
263    trailers.insert("grpc-status", HeaderValue::from(status));
264    if !message.is_empty() {
265        if let Ok(v) = HeaderValue::from_str(message) {
266            trailers.insert("grpc-message", v);
267        }
268    }
269
270    let mut builder = Response::builder()
271        .status(200)
272        .header("content-type", "application/grpc");
273
274    if gzip {
275        builder = builder.header("grpc-encoding", "gzip");
276    }
277
278    builder
279        .body(GrpcBody {
280            data: Some(data),
281            trailers: Some(trailers),
282        })
283        .unwrap()
284}
285
286fn gzip_decompress(data: &[u8], limit: usize) -> std::io::Result<Vec<u8>> {
287    let decoder = flate2::read::GzDecoder::new(data);
288    let mut out = Vec::new();
289    decoder.take(limit as u64 + 1).read_to_end(&mut out)?;
290    if out.len() > limit {
291        return Err(std::io::Error::new(
292            std::io::ErrorKind::InvalidData,
293            format!("decompressed message exceeds limit ({limit} bytes)"),
294        ));
295    }
296    Ok(out)
297}
298
299fn gzip_compress(data: &[u8]) -> std::io::Result<Vec<u8>> {
300    let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
301    encoder.write_all(data)?;
302    encoder.finish()
303}
304
305/// Custom body type that sends data followed by gRPC trailers.
306pub struct GrpcBody {
307    data: Option<Bytes>,
308    trailers: Option<HeaderMap>,
309}
310
311impl http_body::Body for GrpcBody {
312    type Data = Bytes;
313    type Error = std::convert::Infallible;
314
315    fn poll_frame(
316        mut self: Pin<&mut Self>,
317        _cx: &mut Context<'_>,
318    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
319        if let Some(data) = self.data.take() {
320            if !data.is_empty() {
321                return Poll::Ready(Some(Ok(Frame::data(data))));
322            }
323        }
324        if let Some(trailers) = self.trailers.take() {
325            return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
326        }
327        Poll::Ready(None)
328    }
329}