folk_plugin_grpc/
plugin.rs1use std::path::Path;
2
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use axum::Router;
6use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
7use tonic::transport::Server;
8use tracing::{info, warn};
9
10use crate::config::GrpcConfig;
11use crate::service::{GrpcState, grpc_handler};
12
13pub struct GrpcPlugin {
14 config: GrpcConfig,
15}
16
17impl GrpcPlugin {
18 pub fn new(config: GrpcConfig) -> Self {
19 Self { config }
20 }
21}
22
23#[async_trait]
24impl ServerPlugin for GrpcPlugin {
25 fn name(&self) -> &'static str {
26 "grpc"
27 }
28
29 async fn run(&self, ctx: PluginContext) -> Result<()> {
30 let state = GrpcState {
31 executor: ctx.executor.clone(),
32 };
33
34 let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
35 let mut sd = ctx.shutdown.clone();
36
37 info!(listen = %self.config.listen, "gRPC server listening");
38
39 let mut routes = tonic::service::Routes::from(router);
40
41 if !self.config.proto.is_empty() {
43 match build_reflection_service(&self.config.proto) {
44 Ok(encoded_fds) => {
45 info!(proto_files = self.config.proto.len(), "gRPC reflection enabled");
46 let reflection = tonic_reflection::server::Builder::configure()
47 .register_encoded_file_descriptor_set(&encoded_fds)
48 .build_v1()
49 .context("build reflection service")?;
50 Server::builder()
51 .add_routes(routes)
52 .add_service(reflection)
53 .serve_with_shutdown(self.config.listen, async move {
54 sd.changed().await.ok();
55 })
56 .await?;
57 return Ok(());
58 },
59 Err(e) => {
60 warn!(error = %e, "failed to build reflection; starting without it");
61 },
62 }
63 }
64
65 Server::builder()
66 .add_routes(routes)
67 .serve_with_shutdown(self.config.listen, async move {
68 sd.changed().await.ok();
69 })
70 .await?;
71
72 Ok(())
73 }
74
75 fn rpc_methods(&self) -> Vec<RpcMethodDef> {
76 vec![RpcMethodDef::new(
77 "grpc.services",
78 "list registered gRPC service names",
79 )]
80 }
81}
82
83fn build_reflection_service(proto_files: &[String]) -> Result<Vec<u8>> {
85 let mut include_paths = Vec::new();
86 for proto in proto_files {
87 let path = std::path::absolute(Path::new(proto)).unwrap_or_else(|_| proto.into());
88 let mut dir = path.parent();
89 while let Some(d) = dir {
90 include_paths.push(d.to_path_buf());
91 for subdir in ["third_party", "vendor", "include"] {
92 let tp = d.join(subdir);
93 if tp.is_dir() {
94 include_paths.push(tp.clone());
95 for entry in std::fs::read_dir(&tp).into_iter().flatten().flatten() {
96 if entry.path().is_dir() {
97 include_paths.push(entry.path());
98 }
99 }
100 }
101 }
102 dir = d.parent();
103 }
104 }
105 include_paths.sort();
106 include_paths.dedup();
107
108 let mut compiler = protox::Compiler::new(include_paths.iter().map(|p| p.as_path()))
109 .context("create protox compiler")?;
110
111 for proto in proto_files {
112 let path = Path::new(proto);
113 let file_name = path.file_name().and_then(|n| n.to_str())
114 .context("invalid proto file name")?;
115 compiler.open_file(file_name)
116 .with_context(|| format!("compile {proto}"))?;
117 }
118
119 let fds = compiler.file_descriptor_set();
120 Ok(prost::Message::encode_to_vec(&fds))
121}