Skip to main content

folk_plugin_grpc/
plugin.rs

1use std::path::Path;
2
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use axum::Router;
6use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
7use tonic::transport::Server;
8use tracing::{info, warn};
9
10use crate::config::GrpcConfig;
11use crate::service::{GrpcState, grpc_handler};
12
13pub struct GrpcPlugin {
14    config: GrpcConfig,
15}
16
17impl GrpcPlugin {
18    pub fn new(config: GrpcConfig) -> Self {
19        Self { config }
20    }
21}
22
23#[async_trait]
24impl ServerPlugin for GrpcPlugin {
25    fn name(&self) -> &'static str {
26        "grpc"
27    }
28
29    async fn run(&self, ctx: PluginContext) -> Result<()> {
30        let state = GrpcState {
31            executor: ctx.executor.clone(),
32        };
33
34        let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
35        let mut sd = ctx.shutdown.clone();
36
37        info!(listen = %self.config.listen, "gRPC server listening");
38
39        let mut routes = tonic::service::Routes::from(router);
40
41        // Enable reflection if proto files are specified
42        if !self.config.proto.is_empty() {
43            match build_reflection_service(&self.config.proto) {
44                Ok(encoded_fds) => {
45                    info!(proto_files = self.config.proto.len(), "gRPC reflection enabled");
46                    let reflection = tonic_reflection::server::Builder::configure()
47                        .register_encoded_file_descriptor_set(&encoded_fds)
48                        .build_v1()
49                        .context("build reflection service")?;
50                    Server::builder()
51                        .add_routes(routes)
52                        .add_service(reflection)
53                        .serve_with_shutdown(self.config.listen, async move {
54                            sd.changed().await.ok();
55                        })
56                        .await?;
57                    return Ok(());
58                },
59                Err(e) => {
60                    warn!(error = %e, "failed to build reflection; starting without it");
61                },
62            }
63        }
64
65        Server::builder()
66            .add_routes(routes)
67            .serve_with_shutdown(self.config.listen, async move {
68                sd.changed().await.ok();
69            })
70            .await?;
71
72        Ok(())
73    }
74
75    fn rpc_methods(&self) -> Vec<RpcMethodDef> {
76        vec![RpcMethodDef::new(
77            "grpc.services",
78            "list registered gRPC service names",
79        )]
80    }
81}
82
83/// Parse proto files and return encoded FileDescriptorSet for reflection.
84fn build_reflection_service(proto_files: &[String]) -> Result<Vec<u8>> {
85    let mut include_paths = Vec::new();
86    for proto in proto_files {
87        let path = std::path::absolute(Path::new(proto)).unwrap_or_else(|_| proto.into());
88        let mut dir = path.parent();
89        while let Some(d) = dir {
90            include_paths.push(d.to_path_buf());
91            for subdir in ["third_party", "vendor", "include"] {
92                let tp = d.join(subdir);
93                if tp.is_dir() {
94                    include_paths.push(tp.clone());
95                    for entry in std::fs::read_dir(&tp).into_iter().flatten().flatten() {
96                        if entry.path().is_dir() {
97                            include_paths.push(entry.path());
98                        }
99                    }
100                }
101            }
102            dir = d.parent();
103        }
104    }
105    include_paths.sort();
106    include_paths.dedup();
107
108    let mut compiler = protox::Compiler::new(include_paths.iter().map(|p| p.as_path()))
109        .context("create protox compiler")?;
110
111    for proto in proto_files {
112        let path = Path::new(proto);
113        let file_name = path.file_name().and_then(|n| n.to_str())
114            .context("invalid proto file name")?;
115        compiler.open_file(file_name)
116            .with_context(|| format!("compile {proto}"))?;
117    }
118
119    let fds = compiler.file_descriptor_set();
120    Ok(prost::Message::encode_to_vec(&fds))
121}