folk_plugin_grpc/
plugin.rs1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use axum::Router;
7use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
8use tonic::transport::Server;
9use tracing::{debug, info, warn};
10
11use crate::config::GrpcConfig;
12use crate::service::{GrpcState, grpc_handler};
13
14pub struct GrpcPlugin {
15 config: GrpcConfig,
16}
17
18impl GrpcPlugin {
19 pub fn new(config: GrpcConfig) -> Self {
20 Self { config }
21 }
22}
23
24#[async_trait]
25impl ServerPlugin for GrpcPlugin {
26 fn name(&self) -> &'static str {
27 "grpc"
28 }
29
30 async fn run(&self, ctx: PluginContext) -> Result<()> {
31 let state = GrpcState {
32 executor: ctx.executor.clone(),
33 };
34
35 let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
36 let mut sd = ctx.shutdown.clone();
37
38 info!(listen = %self.config.listen, "gRPC server listening");
39
40 let routes = tonic::service::Routes::from(router);
41
42 if !self.config.proto.is_empty() {
43 match build_reflection_descriptors(&self.config.proto) {
44 Ok(encoded_fds) => {
45 info!(
46 proto_files = self.config.proto.len(),
47 "gRPC reflection enabled"
48 );
49 let reflection = tonic_reflection::server::Builder::configure()
50 .register_encoded_file_descriptor_set(&encoded_fds)
51 .build_v1()
52 .context("build reflection service")?;
53 Server::builder()
54 .add_routes(routes)
55 .add_service(reflection)
56 .serve_with_shutdown(self.config.listen, async move {
57 sd.changed().await.ok();
58 })
59 .await?;
60 return Ok(());
61 }
62 Err(e) => {
63 warn!(error = %e, "failed to build reflection; starting without it");
64 }
65 }
66 }
67
68 Server::builder()
69 .add_routes(routes)
70 .serve_with_shutdown(self.config.listen, async move {
71 sd.changed().await.ok();
72 })
73 .await?;
74
75 Ok(())
76 }
77
78 fn rpc_methods(&self) -> Vec<RpcMethodDef> {
79 vec![RpcMethodDef::new(
80 "grpc.services",
81 "list registered gRPC service names",
82 )]
83 }
84}
85
86fn build_reflection_descriptors(proto_files: &[String]) -> Result<Vec<u8>> {
93 let cwd = std::env::current_dir().context("get cwd")?;
94 let mut include_paths = BTreeSet::new();
95
96 let mut visited = BTreeSet::new();
98 for proto in proto_files {
99 let abs = cwd.join(proto);
100 resolve_imports(&abs, &cwd, &mut include_paths, &mut visited);
101 }
102
103 for proto in proto_files {
105 let abs = cwd.join(proto);
106 if let Some(parent) = abs.parent() {
107 include_paths.insert(parent.to_path_buf());
108 }
109 }
110
111 debug!(paths = include_paths.len(), "resolved proto include paths");
112
113 let paths: Vec<_> = include_paths.iter().collect();
114 let mut compiler = protox::Compiler::new(paths.iter().map(|p| p.as_path()))
115 .context("create protox compiler")?;
116 compiler.include_imports(true);
117
118 for proto in proto_files {
119 let path = Path::new(proto);
120 let file_name = path
121 .file_name()
122 .and_then(|n| n.to_str())
123 .context("invalid proto file name")?;
124 compiler
125 .open_file(file_name)
126 .with_context(|| format!("compile {proto}"))?;
127 }
128
129 let fds = compiler.file_descriptor_set();
130 Ok(prost::Message::encode_to_vec(&fds))
131}
132
133fn resolve_imports(
137 proto_path: &Path,
138 project_root: &Path,
139 include_paths: &mut BTreeSet<PathBuf>,
140 visited: &mut BTreeSet<PathBuf>,
141) {
142 if !proto_path.is_file() || !visited.insert(proto_path.to_path_buf()) {
143 return;
144 }
145
146 let content = match std::fs::read_to_string(proto_path) {
147 Ok(c) => c,
148 Err(_) => return,
149 };
150
151 for line in content.lines() {
152 let line = line.trim();
153 if !line.starts_with("import ") {
156 continue;
157 }
158 let Some(start) = line.find('"') else {
159 continue;
160 };
161 let Some(end) = line[start + 1..].find('"') else {
162 continue;
163 };
164 let import_path = &line[start + 1..start + 1 + end];
165
166 if import_path.starts_with("google/protobuf/") {
168 continue;
169 }
170
171 if let Some(found) = find_best_match(project_root, import_path) {
173 if let Some(base) = found
175 .to_str()
176 .and_then(|f| f.strip_suffix(import_path))
177 .map(PathBuf::from)
178 {
179 if !base.as_os_str().is_empty() {
180 include_paths.insert(base);
181 }
182 } else if let Some(parent) = found.parent() {
183 include_paths.insert(parent.to_path_buf());
184 }
185
186 resolve_imports(&found, project_root, include_paths, visited);
187 }
188 }
189}
190
191fn find_best_match(root: &Path, relative: &str) -> Option<PathBuf> {
194 let mut candidates = Vec::new();
195
196 let direct = root.join(relative);
198 if direct.is_file() {
199 candidates.push(direct);
200 }
201
202 find_all_recursive(root, relative, &mut candidates);
203
204 if candidates.len() <= 1 {
205 return candidates.into_iter().next();
206 }
207
208 candidates
210 .into_iter()
211 .max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
212}
213
214fn find_all_recursive(dir: &Path, relative: &str, results: &mut Vec<PathBuf>) {
215 let candidate = dir.join(relative);
216 if candidate.is_file() && !results.contains(&candidate) {
217 results.push(candidate);
218 }
219
220 let entries = match std::fs::read_dir(dir) {
221 Ok(e) => e,
222 Err(_) => return,
223 };
224
225 for entry in entries.flatten() {
226 let path = entry.path();
227 if path.is_dir() {
228 let name = entry.file_name();
230 if name.to_str().is_some_and(|n| n.starts_with('.')) {
231 continue;
232 }
233 find_all_recursive(&path, relative, results);
234 }
235 }
236}