Skip to main content

folk_plugin_grpc/
plugin.rs

1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use axum::Router;
7use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
8use tonic::transport::{Identity, Server, ServerTlsConfig};
9use tracing::{debug, info, warn};
10
11use crate::config::GrpcConfig;
12use crate::metrics::GrpcMetrics;
13use crate::service::{GrpcState, grpc_handler};
14
15pub struct GrpcPlugin {
16    config: GrpcConfig,
17}
18
19impl GrpcPlugin {
20    pub fn new(config: GrpcConfig) -> Self {
21        Self { config }
22    }
23}
24
25#[async_trait]
26impl ServerPlugin for GrpcPlugin {
27    fn name(&self) -> &'static str {
28        "grpc"
29    }
30
31    async fn run(&self, ctx: PluginContext) -> Result<()> {
32        let metrics = ctx
33            .metrics_registry
34            .as_ref()
35            .map(|r| GrpcMetrics::new(r.as_ref()));
36
37        if metrics.is_some() {
38            info!("gRPC metrics registered");
39        }
40
41        let state = GrpcState {
42            executor: ctx.executor.clone(),
43            max_recv_message_size: self.config.max_recv_message_size,
44            max_send_message_size: self.config.max_send_message_size,
45            compression: self.config.compression,
46            metrics,
47        };
48
49        let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
50        let mut sd = ctx.shutdown.clone();
51
52        info!(listen = %self.config.listen, "gRPC server listening");
53
54        // --- Server builder ---
55        let mut builder = if let Some(ref tls) = self.config.tls {
56            let cert = std::fs::read(&tls.cert)
57                .with_context(|| format!("read TLS cert: {}", tls.cert.display()))?;
58            let key = std::fs::read(&tls.key)
59                .with_context(|| format!("read TLS key: {}", tls.key.display()))?;
60            let identity = Identity::from_pem(cert, key);
61            let tls_config = ServerTlsConfig::new().identity(identity);
62            info!(cert = %tls.cert.display(), "gRPC TLS enabled");
63            Server::builder()
64                .tls_config(tls_config)
65                .context("TLS config")?
66        } else {
67            Server::builder()
68        };
69
70        if let Some(timeout) = self.config.timeout {
71            builder = builder.timeout(timeout);
72            info!(timeout = ?timeout, "gRPC server timeout configured");
73        }
74
75        if let Some(max_streams) = self.config.max_concurrent_streams {
76            builder = builder.max_concurrent_streams(max_streams);
77        }
78
79        if let Some(ref ka) = self.config.keepalive {
80            builder = builder
81                .http2_keepalive_interval(Some(ka.interval))
82                .http2_keepalive_timeout(Some(ka.timeout));
83            info!(interval = ?ka.interval, timeout = ?ka.timeout, "gRPC keepalive configured");
84        }
85
86        // --- Accumulate services ---
87        let mut router = builder.add_routes(tonic::service::Routes::from(router));
88
89        // Health checking (grpc.health.v1)
90        let (_health_reporter, health_service) = tonic_health::server::health_reporter();
91        router = router.add_service(health_service);
92        info!("gRPC health checking enabled");
93
94        // Reflection (always includes health service descriptor)
95        {
96            let proto_fds = if !self.config.proto.is_empty() {
97                match build_reflection_descriptors(&self.config.proto) {
98                    Ok(fds) => {
99                        info!(
100                            proto_files = self.config.proto.len(),
101                            "gRPC reflection enabled"
102                        );
103                        Some(fds)
104                    }
105                    Err(e) => {
106                        warn!(error = %e, "failed to build proto reflection; continuing without it");
107                        None
108                    }
109                }
110            } else {
111                None
112            };
113
114            let mut reflection_builder = tonic_reflection::server::Builder::configure()
115                .register_encoded_file_descriptor_set(tonic_health::pb::FILE_DESCRIPTOR_SET);
116
117            if let Some(ref fds) = proto_fds {
118                reflection_builder = reflection_builder.register_encoded_file_descriptor_set(fds);
119            }
120
121            match reflection_builder.build_v1() {
122                Ok(reflection) => {
123                    router = router.add_service(reflection);
124                }
125                Err(e) => {
126                    warn!(error = %e, "failed to build reflection service");
127                }
128            }
129        }
130
131        // --- Serve ---
132        router
133            .serve_with_shutdown(self.config.listen, async move {
134                sd.changed().await.ok();
135            })
136            .await?;
137
138        Ok(())
139    }
140
141    fn rpc_methods(&self) -> Vec<RpcMethodDef> {
142        vec![]
143    }
144}
145
146/// Compile proto files and return encoded FileDescriptorSet.
147///
148/// Import paths are resolved automatically:
149/// - Parses `import` statements from each proto file
150/// - Searches for imported files within the project root (cwd)
151/// - Never searches outside the working directory
152fn build_reflection_descriptors(proto_files: &[String]) -> Result<Vec<u8>> {
153    let cwd = std::env::current_dir().context("get cwd")?;
154    let mut include_paths = BTreeSet::new();
155
156    // Resolve imports starting from each proto file
157    let mut visited = BTreeSet::new();
158    for proto in proto_files {
159        let abs = cwd.join(proto);
160        resolve_imports(&abs, &cwd, &mut include_paths, &mut visited);
161    }
162
163    // Always include the directory of each proto file itself
164    for proto in proto_files {
165        let abs = cwd.join(proto);
166        if let Some(parent) = abs.parent() {
167            include_paths.insert(parent.to_path_buf());
168        }
169    }
170
171    debug!(paths = include_paths.len(), "resolved proto include paths");
172
173    let paths: Vec<_> = include_paths.iter().collect();
174    let mut compiler = protox::Compiler::new(paths.iter().map(|p| p.as_path()))
175        .context("create protox compiler")?;
176    compiler.include_imports(true);
177
178    for proto in proto_files {
179        let path = Path::new(proto);
180        let file_name = path
181            .file_name()
182            .and_then(|n| n.to_str())
183            .context("invalid proto file name")?;
184        compiler
185            .open_file(file_name)
186            .with_context(|| format!("compile {proto}"))?;
187    }
188
189    let fds = compiler.file_descriptor_set();
190    Ok(prost::Message::encode_to_vec(&fds))
191}
192
193/// Parse a proto file for `import` statements and find each imported file
194/// within the project root. Adds the containing directory to include_paths.
195/// Recurses into found imports.
196fn resolve_imports(
197    proto_path: &Path,
198    project_root: &Path,
199    include_paths: &mut BTreeSet<PathBuf>,
200    visited: &mut BTreeSet<PathBuf>,
201) {
202    if !proto_path.is_file() || !visited.insert(proto_path.to_path_buf()) {
203        return;
204    }
205
206    let content = match std::fs::read_to_string(proto_path) {
207        Ok(c) => c,
208        Err(_) => return,
209    };
210
211    for line in content.lines() {
212        let line = line.trim();
213        // Match: import "path/to/file.proto";
214        // Match: import public "path/to/file.proto";
215        if !line.starts_with("import ") {
216            continue;
217        }
218        let Some(start) = line.find('"') else {
219            continue;
220        };
221        let Some(end) = line[start + 1..].find('"') else {
222            continue;
223        };
224        let import_path = &line[start + 1..start + 1 + end];
225
226        // Skip well-known google types — protox has them built-in
227        if import_path.starts_with("google/protobuf/") {
228            continue;
229        }
230
231        // Find the best match for this import within the project
232        if let Some(found) = find_best_match(project_root, import_path) {
233            // Compute the include base: found = base / import_path
234            if let Some(base) = found
235                .to_str()
236                .and_then(|f| f.strip_suffix(import_path))
237                .map(PathBuf::from)
238            {
239                if !base.as_os_str().is_empty() {
240                    include_paths.insert(base);
241                }
242            } else if let Some(parent) = found.parent() {
243                include_paths.insert(parent.to_path_buf());
244            }
245
246            resolve_imports(&found, project_root, include_paths, visited);
247        }
248    }
249}
250
251/// Find the best match for a proto import within the project tree.
252/// If multiple files match, picks the largest one (avoids stubs/overlays).
253fn find_best_match(root: &Path, relative: &str) -> Option<PathBuf> {
254    let mut candidates = Vec::new();
255
256    // Direct from root
257    let direct = root.join(relative);
258    if direct.is_file() {
259        candidates.push(direct);
260    }
261
262    find_all_recursive(root, relative, &mut candidates);
263
264    if candidates.len() <= 1 {
265        return candidates.into_iter().next();
266    }
267
268    // Multiple matches — pick the largest file (real source, not a stub)
269    candidates
270        .into_iter()
271        .max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
272}
273
274fn find_all_recursive(dir: &Path, relative: &str, results: &mut Vec<PathBuf>) {
275    let candidate = dir.join(relative);
276    if candidate.is_file() && !results.contains(&candidate) {
277        results.push(candidate);
278    }
279
280    let entries = match std::fs::read_dir(dir) {
281        Ok(e) => e,
282        Err(_) => return,
283    };
284
285    for entry in entries.flatten() {
286        let path = entry.path();
287        if path.is_dir() {
288            // Only skip hidden directories
289            let name = entry.file_name();
290            if name.to_str().is_some_and(|n| n.starts_with('.')) {
291                continue;
292            }
293            find_all_recursive(&path, relative, results);
294        }
295    }
296}