folk_plugin_grpc/
plugin.rs1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use axum::Router;
7use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
8use tonic::transport::{Identity, Server, ServerTlsConfig};
9use tracing::{debug, info, warn};
10
11use crate::config::GrpcConfig;
12use crate::metrics::GrpcMetrics;
13use crate::service::{GrpcState, grpc_handler};
14
15pub struct GrpcPlugin {
16 config: GrpcConfig,
17}
18
19impl GrpcPlugin {
20 pub fn new(config: GrpcConfig) -> Self {
21 Self { config }
22 }
23}
24
25#[async_trait]
26impl ServerPlugin for GrpcPlugin {
27 fn name(&self) -> &'static str {
28 "grpc"
29 }
30
31 async fn run(&self, ctx: PluginContext) -> Result<()> {
32 let metrics = ctx
33 .metrics_registry
34 .as_ref()
35 .map(|r| GrpcMetrics::new(r.as_ref()));
36
37 if metrics.is_some() {
38 info!("gRPC metrics registered");
39 }
40
41 let state = GrpcState {
42 executor: ctx.executor.clone(),
43 max_recv_message_size: self.config.max_recv_message_size,
44 max_send_message_size: self.config.max_send_message_size,
45 compression: self.config.compression,
46 metrics,
47 };
48
49 let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
50 let mut sd = ctx.shutdown.clone();
51
52 info!(listen = %self.config.listen, "gRPC server listening");
53
54 let mut builder = if let Some(ref tls) = self.config.tls {
56 let cert = std::fs::read(&tls.cert)
57 .with_context(|| format!("read TLS cert: {}", tls.cert.display()))?;
58 let key = std::fs::read(&tls.key)
59 .with_context(|| format!("read TLS key: {}", tls.key.display()))?;
60 let identity = Identity::from_pem(cert, key);
61 let tls_config = ServerTlsConfig::new().identity(identity);
62 info!(cert = %tls.cert.display(), "gRPC TLS enabled");
63 Server::builder()
64 .tls_config(tls_config)
65 .context("TLS config")?
66 } else {
67 Server::builder()
68 };
69
70 if let Some(timeout) = self.config.timeout {
71 builder = builder.timeout(timeout);
72 info!(timeout = ?timeout, "gRPC server timeout configured");
73 }
74
75 if let Some(max_streams) = self.config.max_concurrent_streams {
76 builder = builder.max_concurrent_streams(max_streams);
77 }
78
79 if let Some(ref ka) = self.config.keepalive {
80 builder = builder
81 .http2_keepalive_interval(Some(ka.interval))
82 .http2_keepalive_timeout(Some(ka.timeout));
83 info!(interval = ?ka.interval, timeout = ?ka.timeout, "gRPC keepalive configured");
84 }
85
86 let mut router = builder.add_routes(tonic::service::Routes::from(router));
88
89 let (health_reporter, health_service) = tonic_health::server::health_reporter();
91 health_reporter
92 .set_service_status("", tonic_health::ServingStatus::Serving)
93 .await;
94 router = router.add_service(health_service);
95 info!("gRPC health checking enabled");
96
97 {
99 let proto_fds = if !self.config.proto.is_empty() {
100 match build_reflection_descriptors(&self.config.proto) {
101 Ok(fds) => {
102 info!(
103 proto_files = self.config.proto.len(),
104 "gRPC reflection enabled"
105 );
106 Some(fds)
107 }
108 Err(e) => {
109 warn!(error = %e, "failed to build proto reflection; continuing without it");
110 None
111 }
112 }
113 } else {
114 None
115 };
116
117 let mut reflection_builder = tonic_reflection::server::Builder::configure()
118 .register_encoded_file_descriptor_set(tonic_health::pb::FILE_DESCRIPTOR_SET);
119
120 if let Some(ref fds) = proto_fds {
121 reflection_builder = reflection_builder.register_encoded_file_descriptor_set(fds);
122 }
123
124 match reflection_builder.build_v1() {
125 Ok(reflection) => {
126 router = router.add_service(reflection);
127 }
128 Err(e) => {
129 warn!(error = %e, "failed to build reflection service");
130 }
131 }
132 }
133
134 router
136 .serve_with_shutdown(self.config.listen, async move {
137 sd.changed().await.ok();
138 })
139 .await?;
140
141 Ok(())
142 }
143
144 fn rpc_methods(&self) -> Vec<RpcMethodDef> {
145 vec![]
146 }
147}
148
149fn build_reflection_descriptors(proto_files: &[String]) -> Result<Vec<u8>> {
156 let cwd = std::env::current_dir().context("get cwd")?;
157 let mut include_paths = BTreeSet::new();
158
159 let mut visited = BTreeSet::new();
161 for proto in proto_files {
162 let abs = cwd.join(proto);
163 resolve_imports(&abs, &cwd, &mut include_paths, &mut visited);
164 }
165
166 for proto in proto_files {
168 let abs = cwd.join(proto);
169 if let Some(parent) = abs.parent() {
170 include_paths.insert(parent.to_path_buf());
171 }
172 }
173
174 debug!(paths = include_paths.len(), "resolved proto include paths");
175
176 let paths: Vec<_> = include_paths.iter().collect();
177 let mut compiler = protox::Compiler::new(paths.iter().map(|p| p.as_path()))
178 .context("create protox compiler")?;
179 compiler.include_imports(true);
180
181 for proto in proto_files {
182 let path = Path::new(proto);
183 let file_name = path
184 .file_name()
185 .and_then(|n| n.to_str())
186 .context("invalid proto file name")?;
187 compiler
188 .open_file(file_name)
189 .with_context(|| format!("compile {proto}"))?;
190 }
191
192 let fds = compiler.file_descriptor_set();
193 Ok(prost::Message::encode_to_vec(&fds))
194}
195
196fn resolve_imports(
200 proto_path: &Path,
201 project_root: &Path,
202 include_paths: &mut BTreeSet<PathBuf>,
203 visited: &mut BTreeSet<PathBuf>,
204) {
205 if !proto_path.is_file() || !visited.insert(proto_path.to_path_buf()) {
206 return;
207 }
208
209 let content = match std::fs::read_to_string(proto_path) {
210 Ok(c) => c,
211 Err(_) => return,
212 };
213
214 for line in content.lines() {
215 let line = line.trim();
216 if !line.starts_with("import ") {
219 continue;
220 }
221 let Some(start) = line.find('"') else {
222 continue;
223 };
224 let Some(end) = line[start + 1..].find('"') else {
225 continue;
226 };
227 let import_path = &line[start + 1..start + 1 + end];
228
229 if import_path.starts_with("google/protobuf/") {
231 continue;
232 }
233
234 if let Some(found) = find_best_match(project_root, import_path) {
236 if let Some(base) = found
238 .to_str()
239 .and_then(|f| f.strip_suffix(import_path))
240 .map(PathBuf::from)
241 {
242 if !base.as_os_str().is_empty() {
243 include_paths.insert(base);
244 }
245 } else if let Some(parent) = found.parent() {
246 include_paths.insert(parent.to_path_buf());
247 }
248
249 resolve_imports(&found, project_root, include_paths, visited);
250 }
251 }
252}
253
254fn find_best_match(root: &Path, relative: &str) -> Option<PathBuf> {
257 let mut candidates = Vec::new();
258
259 let direct = root.join(relative);
261 if direct.is_file() {
262 candidates.push(direct);
263 }
264
265 find_all_recursive(root, relative, &mut candidates);
266
267 if candidates.len() <= 1 {
268 return candidates.into_iter().next();
269 }
270
271 candidates
273 .into_iter()
274 .max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
275}
276
277fn find_all_recursive(dir: &Path, relative: &str, results: &mut Vec<PathBuf>) {
278 let candidate = dir.join(relative);
279 if candidate.is_file() && !results.contains(&candidate) {
280 results.push(candidate);
281 }
282
283 let entries = match std::fs::read_dir(dir) {
284 Ok(e) => e,
285 Err(_) => return,
286 };
287
288 for entry in entries.flatten() {
289 let path = entry.path();
290 if path.is_dir() {
291 let name = entry.file_name();
293 if name.to_str().is_some_and(|n| n.starts_with('.')) {
294 continue;
295 }
296 find_all_recursive(&path, relative, results);
297 }
298 }
299}