1use std::io::{Read as _, Write as _};
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::extract::State;
7use axum::response::IntoResponse;
8use base64::Engine;
9use bytes::Bytes;
10use folk_api::Executor;
11use http::{HeaderMap, HeaderValue, Response};
12use http_body::Frame;
13use http_body_util::BodyExt;
14use tracing::debug;
15
16use crate::metrics::GrpcMetrics;
17
18#[derive(serde::Serialize, serde::Deserialize, Debug)]
19pub struct GrpcEnvelope {
20 pub service: String,
21 pub method: String,
22 pub payload: Vec<u8>,
23 pub metadata: std::collections::HashMap<String, String>,
24}
25
26const B64: base64::engine::general_purpose::GeneralPurpose =
27 base64::engine::general_purpose::STANDARD;
28
29#[derive(Clone)]
30pub struct GrpcState {
31 pub executor: Arc<dyn Executor>,
32 pub max_recv_message_size: usize,
33 pub max_send_message_size: usize,
34 pub compression: bool,
35 pub metrics: Option<GrpcMetrics>,
36}
37
38pub async fn grpc_handler(
39 State(state): State<GrpcState>,
40 req: axum::extract::Request,
41) -> impl IntoResponse {
42 let tracker = state.metrics.as_ref().map(|m| m.track_start());
43
44 let path = req.uri().path().to_string();
45
46 let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
47 let (service, method) = match parts.as_slice() {
48 [s, m] => (s.to_string(), m.to_string()),
49 _ => {
50 if let Some(t) = tracker {
51 t.finish("", "", 12, 0, 0);
52 }
53 return grpc_response(Bytes::new(), 12, "unimplemented: bad path", false);
54 }
55 };
56
57 let client_accepts_gzip = state.compression
59 && req
60 .headers()
61 .get("grpc-accept-encoding")
62 .and_then(|v| v.to_str().ok())
63 .is_some_and(|v| v.split(',').any(|e| e.trim() == "gzip"));
64
65 let metadata: std::collections::HashMap<String, String> = req
66 .headers()
67 .iter()
68 .filter(|(k, _)| {
69 let k = k.as_str();
70 !k.starts_with(':')
71 && k != "content-type"
72 && k != "te"
73 && k != "user-agent"
74 && k != "grpc-timeout"
75 && k != "grpc-encoding"
76 && k != "grpc-accept-encoding"
77 })
78 .filter_map(|(k, v)| {
79 v.to_str()
80 .ok()
81 .map(|v| (k.as_str().to_string(), v.to_string()))
82 })
83 .collect();
84
85 let body_bytes = match req.into_body().collect().await {
86 Ok(collected) => collected.to_bytes(),
87 Err(e) => {
88 if let Some(t) = tracker {
89 t.finish(&service, &method, 13, 0, 0);
90 }
91 return grpc_response(Bytes::new(), 13, &format!("body: {e}"), false);
92 }
93 };
94
95 if body_bytes.len() < 5 {
96 if let Some(t) = tracker {
97 t.finish(&service, &method, 13, 0, 0);
98 }
99 return grpc_response(Bytes::new(), 13, "incomplete gRPC frame", false);
100 }
101
102 let compressed = body_bytes[0] == 1;
104 let raw_payload = body_bytes.slice(5..);
105
106 let proto_bytes = if compressed {
108 match gzip_decompress(&raw_payload, state.max_recv_message_size) {
109 Ok(decompressed) => Bytes::from(decompressed),
110 Err(e) => {
111 if let Some(t) = tracker {
112 t.finish(&service, &method, 8, raw_payload.len(), 0);
113 }
114 return grpc_response(
115 Bytes::new(),
116 8, &format!("decompress: {e}"),
118 false,
119 );
120 }
121 }
122 } else {
123 raw_payload
124 };
125
126 let recv_bytes = proto_bytes.len();
127
128 if recv_bytes > state.max_recv_message_size {
130 if let Some(t) = tracker {
131 t.finish(&service, &method, 8, recv_bytes, 0);
132 }
133 return grpc_response(
134 Bytes::new(),
135 8, &format!(
137 "received message larger than max ({recv_bytes} vs {} bytes)",
138 state.max_recv_message_size
139 ),
140 false,
141 );
142 }
143
144 debug!(service, method, payload_len = recv_bytes, "gRPC call");
145
146 let payload_value = serde_json::json!({
148 "service": service,
149 "method": method,
150 "payload": B64.encode(&proto_bytes),
151 "metadata": metadata,
152 });
153
154 let response_value = match state
155 .executor
156 .execute_value("grpc.call", payload_value)
157 .await
158 {
159 Ok(v) => v,
160 Err(e) => {
161 if let Some(t) = tracker {
162 t.finish(&service, &method, 13, recv_bytes, 0);
163 }
164 return grpc_response(Bytes::new(), 13, &format!("worker: {e}"), false);
165 }
166 };
167
168 let response_str = match &response_value {
170 serde_json::Value::String(s) => s.clone(),
171 serde_json::Value::Object(map) => {
172 if let Some(serde_json::Value::String(s)) = map.get("__result") {
173 s.clone()
174 } else if let Some(serde_json::Value::String(s)) = map.get("result") {
175 s.clone()
176 } else {
177 if let Some(t) = tracker {
178 t.finish(&service, &method, 13, recv_bytes, 0);
179 }
180 return grpc_response(
181 Bytes::new(),
182 13,
183 &format!("unexpected response: {response_value}"),
184 false,
185 );
186 }
187 }
188 other => {
189 if let Some(t) = tracker {
190 t.finish(&service, &method, 13, recv_bytes, 0);
191 }
192 return grpc_response(
193 Bytes::new(),
194 13,
195 &format!("unexpected response type: {other}"),
196 false,
197 );
198 }
199 };
200
201 let proto_response = match B64.decode(&response_str) {
203 Ok(bytes) => bytes,
204 Err(e) => {
205 tracing::warn!(error = %e, "base64 decode failed for gRPC response, returning INTERNAL");
206 if let Some(t) = tracker {
207 t.finish(&service, &method, 13, recv_bytes, 0);
208 }
209 return grpc_response(
210 Bytes::new(),
211 13, "PHP worker returned invalid base64 response",
213 false,
214 );
215 }
216 };
217
218 let sent_bytes = proto_response.len();
219
220 if sent_bytes > state.max_send_message_size {
222 if let Some(t) = tracker {
223 t.finish(&service, &method, 8, recv_bytes, sent_bytes);
224 }
225 return grpc_response(
226 Bytes::new(),
227 8, &format!(
229 "response larger than max ({sent_bytes} vs {} bytes)",
230 state.max_send_message_size
231 ),
232 false,
233 );
234 }
235
236 if client_accepts_gzip {
238 if let Ok(compressed_data) = gzip_compress(&proto_response) {
239 let mut framed = Vec::with_capacity(5 + compressed_data.len());
240 framed.push(1u8); framed.extend_from_slice(&(compressed_data.len() as u32).to_be_bytes());
242 framed.extend_from_slice(&compressed_data);
243 if let Some(t) = tracker {
244 t.finish(&service, &method, 0, recv_bytes, sent_bytes);
245 }
246 return grpc_response(Bytes::from(framed), 0, "", true);
247 }
248 }
249
250 let mut framed = Vec::with_capacity(5 + proto_response.len());
251 framed.push(0u8);
252 framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
253 framed.extend_from_slice(&proto_response);
254
255 if let Some(t) = tracker {
256 t.finish(&service, &method, 0, recv_bytes, sent_bytes);
257 }
258 grpc_response(Bytes::from(framed), 0, "", false)
259}
260
261fn grpc_response(data: Bytes, status: u32, message: &str, gzip: bool) -> Response<GrpcBody> {
262 let mut trailers = HeaderMap::new();
263 trailers.insert("grpc-status", HeaderValue::from(status));
264 if !message.is_empty() {
265 if let Ok(v) = HeaderValue::from_str(message) {
266 trailers.insert("grpc-message", v);
267 }
268 }
269
270 let mut builder = Response::builder()
271 .status(200)
272 .header("content-type", "application/grpc");
273
274 if gzip {
275 builder = builder.header("grpc-encoding", "gzip");
276 }
277
278 builder
279 .body(GrpcBody {
280 data: Some(data),
281 trailers: Some(trailers),
282 })
283 .unwrap()
284}
285
286fn gzip_decompress(data: &[u8], limit: usize) -> std::io::Result<Vec<u8>> {
287 let decoder = flate2::read::GzDecoder::new(data);
288 let mut out = Vec::new();
289 decoder.take(limit as u64 + 1).read_to_end(&mut out)?;
290 if out.len() > limit {
291 return Err(std::io::Error::new(
292 std::io::ErrorKind::InvalidData,
293 format!("decompressed message exceeds limit ({limit} bytes)"),
294 ));
295 }
296 Ok(out)
297}
298
299fn gzip_compress(data: &[u8]) -> std::io::Result<Vec<u8>> {
300 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
301 encoder.write_all(data)?;
302 encoder.finish()
303}
304
305pub struct GrpcBody {
307 data: Option<Bytes>,
308 trailers: Option<HeaderMap>,
309}
310
311impl http_body::Body for GrpcBody {
312 type Data = Bytes;
313 type Error = std::convert::Infallible;
314
315 fn poll_frame(
316 mut self: Pin<&mut Self>,
317 _cx: &mut Context<'_>,
318 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
319 if let Some(data) = self.data.take() {
320 if !data.is_empty() {
321 return Poll::Ready(Some(Ok(Frame::data(data))));
322 }
323 }
324 if let Some(trailers) = self.trailers.take() {
325 return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
326 }
327 Poll::Ready(None)
328 }
329}