Skip to main content

folk_plugin_grpc/
plugin.rs

1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use axum::Router;
7use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
8use tonic::transport::{Identity, Server, ServerTlsConfig};
9use tracing::{debug, info, warn};
10
11use crate::config::GrpcConfig;
12use crate::metrics::GrpcMetrics;
13use crate::service::{GrpcState, grpc_handler};
14
15pub struct GrpcPlugin {
16    config: GrpcConfig,
17}
18
19impl GrpcPlugin {
20    pub fn new(config: GrpcConfig) -> Self {
21        Self { config }
22    }
23}
24
25#[async_trait]
26impl ServerPlugin for GrpcPlugin {
27    fn name(&self) -> &'static str {
28        "grpc"
29    }
30
31    async fn run(&self, ctx: PluginContext) -> Result<()> {
32        let metrics = ctx
33            .metrics_registry
34            .as_ref()
35            .map(|r| GrpcMetrics::new(r.as_ref()));
36
37        if metrics.is_some() {
38            info!("gRPC metrics registered");
39        }
40
41        let state = GrpcState {
42            executor: ctx.executor.clone(),
43            max_recv_message_size: self.config.max_recv_message_size,
44            max_send_message_size: self.config.max_send_message_size,
45            compression: self.config.compression,
46            metrics,
47        };
48
49        let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
50        let mut sd = ctx.shutdown.clone();
51
52        info!(listen = %self.config.listen, "gRPC server listening");
53
54        // --- Server builder ---
55        let mut builder = if let Some(ref tls) = self.config.tls {
56            let cert = std::fs::read(&tls.cert)
57                .with_context(|| format!("read TLS cert: {}", tls.cert.display()))?;
58            let key = std::fs::read(&tls.key)
59                .with_context(|| format!("read TLS key: {}", tls.key.display()))?;
60            let identity = Identity::from_pem(cert, key);
61            let tls_config = ServerTlsConfig::new().identity(identity);
62            info!(cert = %tls.cert.display(), "gRPC TLS enabled");
63            Server::builder()
64                .tls_config(tls_config)
65                .context("TLS config")?
66        } else {
67            Server::builder()
68        };
69
70        if let Some(timeout) = self.config.timeout {
71            builder = builder.timeout(timeout);
72            info!(timeout = ?timeout, "gRPC server timeout configured");
73        }
74
75        if let Some(max_streams) = self.config.max_concurrent_streams {
76            builder = builder.max_concurrent_streams(max_streams);
77        }
78
79        if let Some(ref ka) = self.config.keepalive {
80            builder = builder
81                .http2_keepalive_interval(Some(ka.interval))
82                .http2_keepalive_timeout(Some(ka.timeout));
83            info!(interval = ?ka.interval, timeout = ?ka.timeout, "gRPC keepalive configured");
84        }
85
86        // --- Accumulate services ---
87        let mut router = builder.add_routes(tonic::service::Routes::from(router));
88
89        // Health checking (grpc.health.v1)
90        let (health_reporter, health_service) = tonic_health::server::health_reporter();
91        health_reporter
92            .set_service_status("", tonic_health::ServingStatus::Serving)
93            .await;
94        router = router.add_service(health_service);
95        info!("gRPC health checking enabled");
96
97        // Reflection (always includes health service descriptor)
98        {
99            let proto_fds = if !self.config.proto.is_empty() {
100                match build_reflection_descriptors(&self.config.proto) {
101                    Ok(fds) => {
102                        info!(
103                            proto_files = self.config.proto.len(),
104                            "gRPC reflection enabled"
105                        );
106                        Some(fds)
107                    }
108                    Err(e) => {
109                        warn!(error = %e, "failed to build proto reflection; continuing without it");
110                        None
111                    }
112                }
113            } else {
114                None
115            };
116
117            let mut reflection_builder = tonic_reflection::server::Builder::configure()
118                .register_encoded_file_descriptor_set(tonic_health::pb::FILE_DESCRIPTOR_SET);
119
120            if let Some(ref fds) = proto_fds {
121                reflection_builder = reflection_builder.register_encoded_file_descriptor_set(fds);
122            }
123
124            match reflection_builder.build_v1() {
125                Ok(reflection) => {
126                    router = router.add_service(reflection);
127                }
128                Err(e) => {
129                    warn!(error = %e, "failed to build reflection service");
130                }
131            }
132        }
133
134        // --- Serve ---
135        router
136            .serve_with_shutdown(self.config.listen, async move {
137                sd.changed().await.ok();
138            })
139            .await?;
140
141        Ok(())
142    }
143
144    fn rpc_methods(&self) -> Vec<RpcMethodDef> {
145        vec![]
146    }
147}
148
149/// Compile proto files and return encoded FileDescriptorSet.
150///
151/// Import paths are resolved automatically:
152/// - Parses `import` statements from each proto file
153/// - Searches for imported files within the project root (cwd)
154/// - Never searches outside the working directory
155fn build_reflection_descriptors(proto_files: &[String]) -> Result<Vec<u8>> {
156    let cwd = std::env::current_dir().context("get cwd")?;
157    let mut include_paths = BTreeSet::new();
158
159    // Resolve imports starting from each proto file
160    let mut visited = BTreeSet::new();
161    for proto in proto_files {
162        let abs = cwd.join(proto);
163        resolve_imports(&abs, &cwd, &mut include_paths, &mut visited);
164    }
165
166    // Always include the directory of each proto file itself
167    for proto in proto_files {
168        let abs = cwd.join(proto);
169        if let Some(parent) = abs.parent() {
170            include_paths.insert(parent.to_path_buf());
171        }
172    }
173
174    debug!(paths = include_paths.len(), "resolved proto include paths");
175
176    let paths: Vec<_> = include_paths.iter().collect();
177    let mut compiler = protox::Compiler::new(paths.iter().map(|p| p.as_path()))
178        .context("create protox compiler")?;
179    compiler.include_imports(true);
180
181    for proto in proto_files {
182        let path = Path::new(proto);
183        let file_name = path
184            .file_name()
185            .and_then(|n| n.to_str())
186            .context("invalid proto file name")?;
187        compiler
188            .open_file(file_name)
189            .with_context(|| format!("compile {proto}"))?;
190    }
191
192    let fds = compiler.file_descriptor_set();
193    Ok(prost::Message::encode_to_vec(&fds))
194}
195
196/// Parse a proto file for `import` statements and find each imported file
197/// within the project root. Adds the containing directory to include_paths.
198/// Recurses into found imports.
199fn resolve_imports(
200    proto_path: &Path,
201    project_root: &Path,
202    include_paths: &mut BTreeSet<PathBuf>,
203    visited: &mut BTreeSet<PathBuf>,
204) {
205    if !proto_path.is_file() || !visited.insert(proto_path.to_path_buf()) {
206        return;
207    }
208
209    let content = match std::fs::read_to_string(proto_path) {
210        Ok(c) => c,
211        Err(_) => return,
212    };
213
214    for line in content.lines() {
215        let line = line.trim();
216        // Match: import "path/to/file.proto";
217        // Match: import public "path/to/file.proto";
218        if !line.starts_with("import ") {
219            continue;
220        }
221        let Some(start) = line.find('"') else {
222            continue;
223        };
224        let Some(end) = line[start + 1..].find('"') else {
225            continue;
226        };
227        let import_path = &line[start + 1..start + 1 + end];
228
229        // Skip well-known google types — protox has them built-in
230        if import_path.starts_with("google/protobuf/") {
231            continue;
232        }
233
234        // Find the best match for this import within the project
235        if let Some(found) = find_best_match(project_root, import_path) {
236            // Compute the include base: found = base / import_path
237            if let Some(base) = found
238                .to_str()
239                .and_then(|f| f.strip_suffix(import_path))
240                .map(PathBuf::from)
241            {
242                if !base.as_os_str().is_empty() {
243                    include_paths.insert(base);
244                }
245            } else if let Some(parent) = found.parent() {
246                include_paths.insert(parent.to_path_buf());
247            }
248
249            resolve_imports(&found, project_root, include_paths, visited);
250        }
251    }
252}
253
254/// Find the best match for a proto import within the project tree.
255/// If multiple files match, picks the largest one (avoids stubs/overlays).
256fn find_best_match(root: &Path, relative: &str) -> Option<PathBuf> {
257    let mut candidates = Vec::new();
258
259    // Direct from root
260    let direct = root.join(relative);
261    if direct.is_file() {
262        candidates.push(direct);
263    }
264
265    find_all_recursive(root, relative, &mut candidates);
266
267    if candidates.len() <= 1 {
268        return candidates.into_iter().next();
269    }
270
271    // Multiple matches — pick the largest file (real source, not a stub)
272    candidates
273        .into_iter()
274        .max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
275}
276
277fn find_all_recursive(dir: &Path, relative: &str, results: &mut Vec<PathBuf>) {
278    let candidate = dir.join(relative);
279    if candidate.is_file() && !results.contains(&candidate) {
280        results.push(candidate);
281    }
282
283    let entries = match std::fs::read_dir(dir) {
284        Ok(e) => e,
285        Err(_) => return,
286    };
287
288    for entry in entries.flatten() {
289        let path = entry.path();
290        if path.is_dir() {
291            // Only skip hidden directories
292            let name = entry.file_name();
293            if name.to_str().is_some_and(|n| n.starts_with('.')) {
294                continue;
295            }
296            find_all_recursive(&path, relative, results);
297        }
298    }
299}