folk_plugin_grpc/
plugin.rs1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use axum::Router;
7use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
8use tonic::transport::{Identity, Server, ServerTlsConfig};
9use tracing::{debug, info, warn};
10
11use crate::config::GrpcConfig;
12use crate::metrics::GrpcMetrics;
13use crate::service::{GrpcState, grpc_handler};
14
15pub struct GrpcPlugin {
16 config: GrpcConfig,
17}
18
19impl GrpcPlugin {
20 pub fn new(config: GrpcConfig) -> Self {
21 Self { config }
22 }
23}
24
25#[async_trait]
26impl ServerPlugin for GrpcPlugin {
27 fn name(&self) -> &'static str {
28 "grpc"
29 }
30
31 async fn run(&self, ctx: PluginContext) -> Result<()> {
32 let metrics = ctx
33 .metrics_registry
34 .as_ref()
35 .map(|r| GrpcMetrics::new(r.as_ref()));
36
37 if metrics.is_some() {
38 info!("gRPC metrics registered");
39 }
40
41 let state = GrpcState {
42 executor: ctx.executor.clone(),
43 max_recv_message_size: self.config.max_recv_message_size,
44 max_send_message_size: self.config.max_send_message_size,
45 compression: self.config.compression,
46 metrics,
47 };
48
49 let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
50 let mut sd = ctx.shutdown.clone();
51
52 info!(listen = %self.config.listen, "gRPC server listening");
53
54 let mut builder = if let Some(ref tls) = self.config.tls {
56 let cert = std::fs::read(&tls.cert)
57 .with_context(|| format!("read TLS cert: {}", tls.cert.display()))?;
58 let key = std::fs::read(&tls.key)
59 .with_context(|| format!("read TLS key: {}", tls.key.display()))?;
60 let identity = Identity::from_pem(cert, key);
61 let tls_config = ServerTlsConfig::new().identity(identity);
62 info!(cert = %tls.cert.display(), "gRPC TLS enabled");
63 Server::builder()
64 .tls_config(tls_config)
65 .context("TLS config")?
66 } else {
67 Server::builder()
68 };
69
70 if let Some(timeout) = self.config.timeout {
71 builder = builder.timeout(timeout);
72 info!(timeout = ?timeout, "gRPC server timeout configured");
73 }
74
75 if let Some(max_streams) = self.config.max_concurrent_streams {
76 builder = builder.max_concurrent_streams(max_streams);
77 }
78
79 if let Some(ref ka) = self.config.keepalive {
80 builder = builder
81 .http2_keepalive_interval(Some(ka.interval))
82 .http2_keepalive_timeout(Some(ka.timeout));
83 info!(interval = ?ka.interval, timeout = ?ka.timeout, "gRPC keepalive configured");
84 }
85
86 let mut router = builder.add_routes(tonic::service::Routes::from(router));
88
89 let (_health_reporter, health_service) = tonic_health::server::health_reporter();
91 router = router.add_service(health_service);
92 info!("gRPC health checking enabled");
93
94 {
96 let proto_fds = if !self.config.proto.is_empty() {
97 match build_reflection_descriptors(&self.config.proto) {
98 Ok(fds) => {
99 info!(
100 proto_files = self.config.proto.len(),
101 "gRPC reflection enabled"
102 );
103 Some(fds)
104 }
105 Err(e) => {
106 warn!(error = %e, "failed to build proto reflection; continuing without it");
107 None
108 }
109 }
110 } else {
111 None
112 };
113
114 let mut reflection_builder = tonic_reflection::server::Builder::configure()
115 .register_encoded_file_descriptor_set(tonic_health::pb::FILE_DESCRIPTOR_SET);
116
117 if let Some(ref fds) = proto_fds {
118 reflection_builder = reflection_builder.register_encoded_file_descriptor_set(fds);
119 }
120
121 match reflection_builder.build_v1() {
122 Ok(reflection) => {
123 router = router.add_service(reflection);
124 }
125 Err(e) => {
126 warn!(error = %e, "failed to build reflection service");
127 }
128 }
129 }
130
131 router
133 .serve_with_shutdown(self.config.listen, async move {
134 sd.changed().await.ok();
135 })
136 .await?;
137
138 Ok(())
139 }
140
141 fn rpc_methods(&self) -> Vec<RpcMethodDef> {
142 vec![]
143 }
144}
145
146fn build_reflection_descriptors(proto_files: &[String]) -> Result<Vec<u8>> {
153 let cwd = std::env::current_dir().context("get cwd")?;
154 let mut include_paths = BTreeSet::new();
155
156 let mut visited = BTreeSet::new();
158 for proto in proto_files {
159 let abs = cwd.join(proto);
160 resolve_imports(&abs, &cwd, &mut include_paths, &mut visited);
161 }
162
163 for proto in proto_files {
165 let abs = cwd.join(proto);
166 if let Some(parent) = abs.parent() {
167 include_paths.insert(parent.to_path_buf());
168 }
169 }
170
171 debug!(paths = include_paths.len(), "resolved proto include paths");
172
173 let paths: Vec<_> = include_paths.iter().collect();
174 let mut compiler = protox::Compiler::new(paths.iter().map(|p| p.as_path()))
175 .context("create protox compiler")?;
176 compiler.include_imports(true);
177
178 for proto in proto_files {
179 let path = Path::new(proto);
180 let file_name = path
181 .file_name()
182 .and_then(|n| n.to_str())
183 .context("invalid proto file name")?;
184 compiler
185 .open_file(file_name)
186 .with_context(|| format!("compile {proto}"))?;
187 }
188
189 let fds = compiler.file_descriptor_set();
190 Ok(prost::Message::encode_to_vec(&fds))
191}
192
193fn resolve_imports(
197 proto_path: &Path,
198 project_root: &Path,
199 include_paths: &mut BTreeSet<PathBuf>,
200 visited: &mut BTreeSet<PathBuf>,
201) {
202 if !proto_path.is_file() || !visited.insert(proto_path.to_path_buf()) {
203 return;
204 }
205
206 let content = match std::fs::read_to_string(proto_path) {
207 Ok(c) => c,
208 Err(_) => return,
209 };
210
211 for line in content.lines() {
212 let line = line.trim();
213 if !line.starts_with("import ") {
216 continue;
217 }
218 let Some(start) = line.find('"') else {
219 continue;
220 };
221 let Some(end) = line[start + 1..].find('"') else {
222 continue;
223 };
224 let import_path = &line[start + 1..start + 1 + end];
225
226 if import_path.starts_with("google/protobuf/") {
228 continue;
229 }
230
231 if let Some(found) = find_best_match(project_root, import_path) {
233 if let Some(base) = found
235 .to_str()
236 .and_then(|f| f.strip_suffix(import_path))
237 .map(PathBuf::from)
238 {
239 if !base.as_os_str().is_empty() {
240 include_paths.insert(base);
241 }
242 } else if let Some(parent) = found.parent() {
243 include_paths.insert(parent.to_path_buf());
244 }
245
246 resolve_imports(&found, project_root, include_paths, visited);
247 }
248 }
249}
250
251fn find_best_match(root: &Path, relative: &str) -> Option<PathBuf> {
254 let mut candidates = Vec::new();
255
256 let direct = root.join(relative);
258 if direct.is_file() {
259 candidates.push(direct);
260 }
261
262 find_all_recursive(root, relative, &mut candidates);
263
264 if candidates.len() <= 1 {
265 return candidates.into_iter().next();
266 }
267
268 candidates
270 .into_iter()
271 .max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
272}
273
274fn find_all_recursive(dir: &Path, relative: &str, results: &mut Vec<PathBuf>) {
275 let candidate = dir.join(relative);
276 if candidate.is_file() && !results.contains(&candidate) {
277 results.push(candidate);
278 }
279
280 let entries = match std::fs::read_dir(dir) {
281 Ok(e) => e,
282 Err(_) => return,
283 };
284
285 for entry in entries.flatten() {
286 let path = entry.path();
287 if path.is_dir() {
288 let name = entry.file_name();
290 if name.to_str().is_some_and(|n| n.starts_with('.')) {
291 continue;
292 }
293 find_all_recursive(&path, relative, results);
294 }
295 }
296}