Skip to main content

folk_plugin_grpc/
plugin.rs

1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use axum::Router;
7use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
8use tonic::transport::Server;
9use tracing::{debug, info, warn};
10
11use crate::config::GrpcConfig;
12use crate::service::{GrpcState, grpc_handler};
13
14pub struct GrpcPlugin {
15    config: GrpcConfig,
16}
17
18impl GrpcPlugin {
19    pub fn new(config: GrpcConfig) -> Self {
20        Self { config }
21    }
22}
23
24#[async_trait]
25impl ServerPlugin for GrpcPlugin {
26    fn name(&self) -> &'static str {
27        "grpc"
28    }
29
30    async fn run(&self, ctx: PluginContext) -> Result<()> {
31        let state = GrpcState {
32            executor: ctx.executor.clone(),
33        };
34
35        let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
36        let mut sd = ctx.shutdown.clone();
37
38        info!(listen = %self.config.listen, "gRPC server listening");
39
40        let routes = tonic::service::Routes::from(router);
41
42        if !self.config.proto.is_empty() {
43            match build_reflection_descriptors(&self.config.proto) {
44                Ok(encoded_fds) => {
45                    info!(
46                        proto_files = self.config.proto.len(),
47                        "gRPC reflection enabled"
48                    );
49                    let reflection = tonic_reflection::server::Builder::configure()
50                        .register_encoded_file_descriptor_set(&encoded_fds)
51                        .build_v1()
52                        .context("build reflection service")?;
53                    Server::builder()
54                        .add_routes(routes)
55                        .add_service(reflection)
56                        .serve_with_shutdown(self.config.listen, async move {
57                            sd.changed().await.ok();
58                        })
59                        .await?;
60                    return Ok(());
61                }
62                Err(e) => {
63                    warn!(error = %e, "failed to build reflection; starting without it");
64                }
65            }
66        }
67
68        Server::builder()
69            .add_routes(routes)
70            .serve_with_shutdown(self.config.listen, async move {
71                sd.changed().await.ok();
72            })
73            .await?;
74
75        Ok(())
76    }
77
78    fn rpc_methods(&self) -> Vec<RpcMethodDef> {
79        vec![RpcMethodDef::new(
80            "grpc.services",
81            "list registered gRPC service names",
82        )]
83    }
84}
85
86/// Compile proto files and return encoded FileDescriptorSet.
87///
88/// Import paths are resolved automatically:
89/// - Parses `import` statements from each proto file
90/// - Searches for imported files within the project root (cwd)
91/// - Never searches outside the working directory
92fn build_reflection_descriptors(proto_files: &[String]) -> Result<Vec<u8>> {
93    let cwd = std::env::current_dir().context("get cwd")?;
94    let mut include_paths = BTreeSet::new();
95
96    // Resolve imports starting from each proto file
97    let mut visited = BTreeSet::new();
98    for proto in proto_files {
99        let abs = cwd.join(proto);
100        resolve_imports(&abs, &cwd, &mut include_paths, &mut visited);
101    }
102
103    // Always include the directory of each proto file itself
104    for proto in proto_files {
105        let abs = cwd.join(proto);
106        if let Some(parent) = abs.parent() {
107            include_paths.insert(parent.to_path_buf());
108        }
109    }
110
111    debug!(paths = include_paths.len(), "resolved proto include paths");
112
113    let paths: Vec<_> = include_paths.iter().collect();
114    let mut compiler = protox::Compiler::new(paths.iter().map(|p| p.as_path()))
115        .context("create protox compiler")?;
116    compiler.include_imports(true);
117
118    for proto in proto_files {
119        let path = Path::new(proto);
120        let file_name = path
121            .file_name()
122            .and_then(|n| n.to_str())
123            .context("invalid proto file name")?;
124        compiler
125            .open_file(file_name)
126            .with_context(|| format!("compile {proto}"))?;
127    }
128
129    let fds = compiler.file_descriptor_set();
130    Ok(prost::Message::encode_to_vec(&fds))
131}
132
133/// Parse a proto file for `import` statements and find each imported file
134/// within the project root. Adds the containing directory to include_paths.
135/// Recurses into found imports.
136fn resolve_imports(
137    proto_path: &Path,
138    project_root: &Path,
139    include_paths: &mut BTreeSet<PathBuf>,
140    visited: &mut BTreeSet<PathBuf>,
141) {
142    if !proto_path.is_file() || !visited.insert(proto_path.to_path_buf()) {
143        return;
144    }
145
146    let content = match std::fs::read_to_string(proto_path) {
147        Ok(c) => c,
148        Err(_) => return,
149    };
150
151    for line in content.lines() {
152        let line = line.trim();
153        // Match: import "path/to/file.proto";
154        // Match: import public "path/to/file.proto";
155        if !line.starts_with("import ") {
156            continue;
157        }
158        let Some(start) = line.find('"') else {
159            continue;
160        };
161        let Some(end) = line[start + 1..].find('"') else {
162            continue;
163        };
164        let import_path = &line[start + 1..start + 1 + end];
165
166        // Skip well-known google types — protox has them built-in
167        if import_path.starts_with("google/protobuf/") {
168            continue;
169        }
170
171        // Find the best match for this import within the project
172        if let Some(found) = find_best_match(project_root, import_path) {
173            // Compute the include base: found = base / import_path
174            if let Some(base) = found
175                .to_str()
176                .and_then(|f| f.strip_suffix(import_path))
177                .map(PathBuf::from)
178            {
179                if !base.as_os_str().is_empty() {
180                    include_paths.insert(base);
181                }
182            } else if let Some(parent) = found.parent() {
183                include_paths.insert(parent.to_path_buf());
184            }
185
186            resolve_imports(&found, project_root, include_paths, visited);
187        }
188    }
189}
190
191/// Find the best match for a proto import within the project tree.
192/// If multiple files match, picks the largest one (avoids stubs/overlays).
193fn find_best_match(root: &Path, relative: &str) -> Option<PathBuf> {
194    let mut candidates = Vec::new();
195
196    // Direct from root
197    let direct = root.join(relative);
198    if direct.is_file() {
199        candidates.push(direct);
200    }
201
202    find_all_recursive(root, relative, &mut candidates);
203
204    if candidates.len() <= 1 {
205        return candidates.into_iter().next();
206    }
207
208    // Multiple matches — pick the largest file (real source, not a stub)
209    candidates
210        .into_iter()
211        .max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
212}
213
214fn find_all_recursive(dir: &Path, relative: &str, results: &mut Vec<PathBuf>) {
215    let candidate = dir.join(relative);
216    if candidate.is_file() && !results.contains(&candidate) {
217        results.push(candidate);
218    }
219
220    let entries = match std::fs::read_dir(dir) {
221        Ok(e) => e,
222        Err(_) => return,
223    };
224
225    for entry in entries.flatten() {
226        let path = entry.path();
227        if path.is_dir() {
228            // Only skip hidden directories
229            let name = entry.file_name();
230            if name.to_str().is_some_and(|n| n.starts_with('.')) {
231                continue;
232            }
233            find_all_recursive(&path, relative, results);
234        }
235    }
236}