folk-plugin-grpc 0.2.3

gRPC plugin for Folk — unary call passthrough to PHP workers via tonic
Documentation
use std::io::{Read as _, Write as _};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use axum::extract::State;
use axum::response::IntoResponse;
use base64::Engine;
use bytes::Bytes;
use folk_api::Executor;
use http::{HeaderMap, HeaderValue, Response};
use http_body::Frame;
use http_body_util::BodyExt;
use tracing::debug;

use crate::metrics::GrpcMetrics;

#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct GrpcEnvelope {
    pub service: String,
    pub method: String,
    pub payload: Vec<u8>,
    pub metadata: std::collections::HashMap<String, String>,
}

const B64: base64::engine::general_purpose::GeneralPurpose =
    base64::engine::general_purpose::STANDARD;

#[derive(Clone)]
pub struct GrpcState {
    pub executor: Arc<dyn Executor>,
    pub max_recv_message_size: usize,
    pub max_send_message_size: usize,
    pub compression: bool,
    pub metrics: Option<GrpcMetrics>,
}

pub async fn grpc_handler(
    State(state): State<GrpcState>,
    req: axum::extract::Request,
) -> impl IntoResponse {
    let tracker = state.metrics.as_ref().map(|m| m.track_start());

    let path = req.uri().path().to_string();

    let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
    let (service, method) = match parts.as_slice() {
        [s, m] => (s.to_string(), m.to_string()),
        _ => {
            if let Some(t) = tracker {
                t.finish("", "", 12, 0, 0);
            }
            return grpc_response(Bytes::new(), 12, "unimplemented: bad path", false);
        }
    };

    // Check if client accepts gzip
    let client_accepts_gzip = state.compression
        && req
            .headers()
            .get("grpc-accept-encoding")
            .and_then(|v| v.to_str().ok())
            .is_some_and(|v| v.split(',').any(|e| e.trim() == "gzip"));

    let metadata: std::collections::HashMap<String, String> = req
        .headers()
        .iter()
        .filter(|(k, _)| {
            let k = k.as_str();
            !k.starts_with(':')
                && k != "content-type"
                && k != "te"
                && k != "user-agent"
                && k != "grpc-timeout"
                && k != "grpc-encoding"
                && k != "grpc-accept-encoding"
        })
        .filter_map(|(k, v)| {
            v.to_str()
                .ok()
                .map(|v| (k.as_str().to_string(), v.to_string()))
        })
        .collect();

    let body_bytes = match req.into_body().collect().await {
        Ok(collected) => collected.to_bytes(),
        Err(e) => {
            if let Some(t) = tracker {
                t.finish(&service, &method, 13, 0, 0);
            }
            return grpc_response(Bytes::new(), 13, &format!("body: {e}"), false);
        }
    };

    if body_bytes.len() < 5 {
        if let Some(t) = tracker {
            t.finish(&service, &method, 13, 0, 0);
        }
        return grpc_response(Bytes::new(), 13, "incomplete gRPC frame", false);
    }

    // Parse gRPC frame: [compression_flag(1)][length(4)][payload]
    let compressed = body_bytes[0] == 1;
    let raw_payload = body_bytes.slice(5..);

    // Decompress incoming if needed (size-limited to prevent zip bombs)
    let proto_bytes = if compressed {
        match gzip_decompress(&raw_payload, state.max_recv_message_size) {
            Ok(decompressed) => Bytes::from(decompressed),
            Err(e) => {
                if let Some(t) = tracker {
                    t.finish(&service, &method, 8, raw_payload.len(), 0);
                }
                return grpc_response(
                    Bytes::new(),
                    8, // RESOURCE_EXHAUSTED
                    &format!("decompress: {e}"),
                    false,
                );
            }
        }
    } else {
        raw_payload
    };

    let recv_bytes = proto_bytes.len();

    // Enforce max incoming message size (uncompressed payloads)
    if recv_bytes > state.max_recv_message_size {
        if let Some(t) = tracker {
            t.finish(&service, &method, 8, recv_bytes, 0);
        }
        return grpc_response(
            Bytes::new(),
            8, // RESOURCE_EXHAUSTED
            &format!(
                "received message larger than max ({recv_bytes} vs {} bytes)",
                state.max_recv_message_size
            ),
            false,
        );
    }

    debug!(service, method, payload_len = recv_bytes, "gRPC call");

    // Build payload as serde_json::Value (base64 for binary protobuf data)
    let payload_value = serde_json::json!({
        "service": service,
        "method": method,
        "payload": B64.encode(&proto_bytes),
        "metadata": metadata,
    });

    let response_value = match state
        .executor
        .execute_value("grpc.call", payload_value)
        .await
    {
        Ok(v) => v,
        Err(e) => {
            if let Some(t) = tracker {
                t.finish(&service, &method, 13, recv_bytes, 0);
            }
            return grpc_response(Bytes::new(), 13, &format!("worker: {e}"), false);
        }
    };

    // Extract response: may be a string directly, or wrapped in {"__result": "..."}
    let response_str = match &response_value {
        serde_json::Value::String(s) => s.clone(),
        serde_json::Value::Object(map) => {
            if let Some(serde_json::Value::String(s)) = map.get("__result") {
                s.clone()
            } else if let Some(serde_json::Value::String(s)) = map.get("result") {
                s.clone()
            } else {
                if let Some(t) = tracker {
                    t.finish(&service, &method, 13, recv_bytes, 0);
                }
                return grpc_response(
                    Bytes::new(),
                    13,
                    &format!("unexpected response: {response_value}"),
                    false,
                );
            }
        }
        other => {
            if let Some(t) = tracker {
                t.finish(&service, &method, 13, recv_bytes, 0);
            }
            return grpc_response(
                Bytes::new(),
                13,
                &format!("unexpected response type: {other}"),
                false,
            );
        }
    };

    // Decode base64 protobuf bytes
    let proto_response = match B64.decode(&response_str) {
        Ok(bytes) => bytes,
        Err(_) => response_str.into_bytes(),
    };

    let sent_bytes = proto_response.len();

    // Enforce max outgoing message size
    if sent_bytes > state.max_send_message_size {
        if let Some(t) = tracker {
            t.finish(&service, &method, 8, recv_bytes, sent_bytes);
        }
        return grpc_response(
            Bytes::new(),
            8, // RESOURCE_EXHAUSTED
            &format!(
                "response larger than max ({sent_bytes} vs {} bytes)",
                state.max_send_message_size
            ),
            false,
        );
    }

    // Build gRPC-framed response (optionally compressed)
    if client_accepts_gzip {
        if let Ok(compressed_data) = gzip_compress(&proto_response) {
            let mut framed = Vec::with_capacity(5 + compressed_data.len());
            framed.push(1u8); // compression flag
            framed.extend_from_slice(&(compressed_data.len() as u32).to_be_bytes());
            framed.extend_from_slice(&compressed_data);
            if let Some(t) = tracker {
                t.finish(&service, &method, 0, recv_bytes, sent_bytes);
            }
            return grpc_response(Bytes::from(framed), 0, "", true);
        }
    }

    let mut framed = Vec::with_capacity(5 + proto_response.len());
    framed.push(0u8);
    framed.extend_from_slice(&(proto_response.len() as u32).to_be_bytes());
    framed.extend_from_slice(&proto_response);

    if let Some(t) = tracker {
        t.finish(&service, &method, 0, recv_bytes, sent_bytes);
    }
    grpc_response(Bytes::from(framed), 0, "", false)
}

fn grpc_response(data: Bytes, status: u32, message: &str, gzip: bool) -> Response<GrpcBody> {
    let mut trailers = HeaderMap::new();
    trailers.insert("grpc-status", HeaderValue::from(status));
    if !message.is_empty() {
        if let Ok(v) = HeaderValue::from_str(message) {
            trailers.insert("grpc-message", v);
        }
    }

    let mut builder = Response::builder()
        .status(200)
        .header("content-type", "application/grpc");

    if gzip {
        builder = builder.header("grpc-encoding", "gzip");
    }

    builder
        .body(GrpcBody {
            data: Some(data),
            trailers: Some(trailers),
        })
        .unwrap()
}

fn gzip_decompress(data: &[u8], limit: usize) -> std::io::Result<Vec<u8>> {
    let decoder = flate2::read::GzDecoder::new(data);
    let mut out = Vec::new();
    decoder.take(limit as u64 + 1).read_to_end(&mut out)?;
    if out.len() > limit {
        return Err(std::io::Error::new(
            std::io::ErrorKind::InvalidData,
            format!("decompressed message exceeds limit ({limit} bytes)"),
        ));
    }
    Ok(out)
}

fn gzip_compress(data: &[u8]) -> std::io::Result<Vec<u8>> {
    let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
    encoder.write_all(data)?;
    encoder.finish()
}

/// Custom body type that sends data followed by gRPC trailers.
pub struct GrpcBody {
    data: Option<Bytes>,
    trailers: Option<HeaderMap>,
}

impl http_body::Body for GrpcBody {
    type Data = Bytes;
    type Error = std::convert::Infallible;

    fn poll_frame(
        mut self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
        if let Some(data) = self.data.take() {
            if !data.is_empty() {
                return Poll::Ready(Some(Ok(Frame::data(data))));
            }
        }
        if let Some(trailers) = self.trailers.take() {
            return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
        }
        Poll::Ready(None)
    }
}