folk-plugin-grpc 0.2.2

gRPC plugin for Folk — unary call passthrough to PHP workers via tonic
Documentation
use std::collections::BTreeSet;
use std::path::{Path, PathBuf};

use anyhow::{Context, Result};
use async_trait::async_trait;
use axum::Router;
use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
use tonic::transport::{Identity, Server, ServerTlsConfig};
use tracing::{debug, info, warn};

use crate::config::GrpcConfig;
use crate::metrics::GrpcMetrics;
use crate::service::{GrpcState, grpc_handler};

pub struct GrpcPlugin {
    config: GrpcConfig,
}

impl GrpcPlugin {
    pub fn new(config: GrpcConfig) -> Self {
        Self { config }
    }
}

#[async_trait]
impl ServerPlugin for GrpcPlugin {
    fn name(&self) -> &'static str {
        "grpc"
    }

    async fn run(&self, ctx: PluginContext) -> Result<()> {
        let metrics = ctx
            .metrics_registry
            .as_ref()
            .map(|r| GrpcMetrics::new(r.as_ref()));

        if metrics.is_some() {
            info!("gRPC metrics registered");
        }

        let state = GrpcState {
            executor: ctx.executor.clone(),
            max_recv_message_size: self.config.max_recv_message_size,
            max_send_message_size: self.config.max_send_message_size,
            compression: self.config.compression,
            metrics,
        };

        let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
        let mut sd = ctx.shutdown.clone();

        info!(listen = %self.config.listen, "gRPC server listening");

        // --- Server builder ---
        let mut builder = if let Some(ref tls) = self.config.tls {
            let cert = std::fs::read(&tls.cert)
                .with_context(|| format!("read TLS cert: {}", tls.cert.display()))?;
            let key = std::fs::read(&tls.key)
                .with_context(|| format!("read TLS key: {}", tls.key.display()))?;
            let identity = Identity::from_pem(cert, key);
            let tls_config = ServerTlsConfig::new().identity(identity);
            info!(cert = %tls.cert.display(), "gRPC TLS enabled");
            Server::builder()
                .tls_config(tls_config)
                .context("TLS config")?
        } else {
            Server::builder()
        };

        if let Some(timeout) = self.config.timeout {
            builder = builder.timeout(timeout);
            info!(timeout = ?timeout, "gRPC server timeout configured");
        }

        if let Some(max_streams) = self.config.max_concurrent_streams {
            builder = builder.max_concurrent_streams(max_streams);
        }

        if let Some(ref ka) = self.config.keepalive {
            builder = builder
                .http2_keepalive_interval(Some(ka.interval))
                .http2_keepalive_timeout(Some(ka.timeout));
            info!(interval = ?ka.interval, timeout = ?ka.timeout, "gRPC keepalive configured");
        }

        // --- Accumulate services ---
        let mut router = builder.add_routes(tonic::service::Routes::from(router));

        // Health checking (grpc.health.v1)
        let (_health_reporter, health_service) = tonic_health::server::health_reporter();
        router = router.add_service(health_service);
        info!("gRPC health checking enabled");

        // Reflection (always includes health service descriptor)
        {
            let proto_fds = if !self.config.proto.is_empty() {
                match build_reflection_descriptors(&self.config.proto) {
                    Ok(fds) => {
                        info!(
                            proto_files = self.config.proto.len(),
                            "gRPC reflection enabled"
                        );
                        Some(fds)
                    }
                    Err(e) => {
                        warn!(error = %e, "failed to build proto reflection; continuing without it");
                        None
                    }
                }
            } else {
                None
            };

            let mut reflection_builder = tonic_reflection::server::Builder::configure()
                .register_encoded_file_descriptor_set(tonic_health::pb::FILE_DESCRIPTOR_SET);

            if let Some(ref fds) = proto_fds {
                reflection_builder = reflection_builder.register_encoded_file_descriptor_set(fds);
            }

            match reflection_builder.build_v1() {
                Ok(reflection) => {
                    router = router.add_service(reflection);
                }
                Err(e) => {
                    warn!(error = %e, "failed to build reflection service");
                }
            }
        }

        // --- Serve ---
        router
            .serve_with_shutdown(self.config.listen, async move {
                sd.changed().await.ok();
            })
            .await?;

        Ok(())
    }

    fn rpc_methods(&self) -> Vec<RpcMethodDef> {
        vec![]
    }
}

/// Compile proto files and return encoded FileDescriptorSet.
///
/// Import paths are resolved automatically:
/// - Parses `import` statements from each proto file
/// - Searches for imported files within the project root (cwd)
/// - Never searches outside the working directory
fn build_reflection_descriptors(proto_files: &[String]) -> Result<Vec<u8>> {
    let cwd = std::env::current_dir().context("get cwd")?;
    let mut include_paths = BTreeSet::new();

    // Resolve imports starting from each proto file
    let mut visited = BTreeSet::new();
    for proto in proto_files {
        let abs = cwd.join(proto);
        resolve_imports(&abs, &cwd, &mut include_paths, &mut visited);
    }

    // Always include the directory of each proto file itself
    for proto in proto_files {
        let abs = cwd.join(proto);
        if let Some(parent) = abs.parent() {
            include_paths.insert(parent.to_path_buf());
        }
    }

    debug!(paths = include_paths.len(), "resolved proto include paths");

    let paths: Vec<_> = include_paths.iter().collect();
    let mut compiler = protox::Compiler::new(paths.iter().map(|p| p.as_path()))
        .context("create protox compiler")?;
    compiler.include_imports(true);

    for proto in proto_files {
        let path = Path::new(proto);
        let file_name = path
            .file_name()
            .and_then(|n| n.to_str())
            .context("invalid proto file name")?;
        compiler
            .open_file(file_name)
            .with_context(|| format!("compile {proto}"))?;
    }

    let fds = compiler.file_descriptor_set();
    Ok(prost::Message::encode_to_vec(&fds))
}

/// Parse a proto file for `import` statements and find each imported file
/// within the project root. Adds the containing directory to include_paths.
/// Recurses into found imports.
fn resolve_imports(
    proto_path: &Path,
    project_root: &Path,
    include_paths: &mut BTreeSet<PathBuf>,
    visited: &mut BTreeSet<PathBuf>,
) {
    if !proto_path.is_file() || !visited.insert(proto_path.to_path_buf()) {
        return;
    }

    let content = match std::fs::read_to_string(proto_path) {
        Ok(c) => c,
        Err(_) => return,
    };

    for line in content.lines() {
        let line = line.trim();
        // Match: import "path/to/file.proto";
        // Match: import public "path/to/file.proto";
        if !line.starts_with("import ") {
            continue;
        }
        let Some(start) = line.find('"') else {
            continue;
        };
        let Some(end) = line[start + 1..].find('"') else {
            continue;
        };
        let import_path = &line[start + 1..start + 1 + end];

        // Skip well-known google types — protox has them built-in
        if import_path.starts_with("google/protobuf/") {
            continue;
        }

        // Find the best match for this import within the project
        if let Some(found) = find_best_match(project_root, import_path) {
            // Compute the include base: found = base / import_path
            if let Some(base) = found
                .to_str()
                .and_then(|f| f.strip_suffix(import_path))
                .map(PathBuf::from)
            {
                if !base.as_os_str().is_empty() {
                    include_paths.insert(base);
                }
            } else if let Some(parent) = found.parent() {
                include_paths.insert(parent.to_path_buf());
            }

            resolve_imports(&found, project_root, include_paths, visited);
        }
    }
}

/// Find the best match for a proto import within the project tree.
/// If multiple files match, picks the largest one (avoids stubs/overlays).
fn find_best_match(root: &Path, relative: &str) -> Option<PathBuf> {
    let mut candidates = Vec::new();

    // Direct from root
    let direct = root.join(relative);
    if direct.is_file() {
        candidates.push(direct);
    }

    find_all_recursive(root, relative, &mut candidates);

    if candidates.len() <= 1 {
        return candidates.into_iter().next();
    }

    // Multiple matches — pick the largest file (real source, not a stub)
    candidates
        .into_iter()
        .max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
}

fn find_all_recursive(dir: &Path, relative: &str, results: &mut Vec<PathBuf>) {
    let candidate = dir.join(relative);
    if candidate.is_file() && !results.contains(&candidate) {
        results.push(candidate);
    }

    let entries = match std::fs::read_dir(dir) {
        Ok(e) => e,
        Err(_) => return,
    };

    for entry in entries.flatten() {
        let path = entry.path();
        if path.is_dir() {
            // Only skip hidden directories
            let name = entry.file_name();
            if name.to_str().is_some_and(|n| n.starts_with('.')) {
                continue;
            }
            find_all_recursive(&path, relative, results);
        }
    }
}