use std::collections::BTreeSet;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use async_trait::async_trait;
use axum::Router;
use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
use tonic::transport::{Identity, Server, ServerTlsConfig};
use tracing::{debug, info, warn};
use crate::config::GrpcConfig;
use crate::metrics::GrpcMetrics;
use crate::service::{GrpcState, grpc_handler};
pub struct GrpcPlugin {
config: GrpcConfig,
}
impl GrpcPlugin {
pub fn new(config: GrpcConfig) -> Self {
Self { config }
}
}
#[async_trait]
impl ServerPlugin for GrpcPlugin {
fn name(&self) -> &'static str {
"grpc"
}
async fn run(&self, ctx: PluginContext) -> Result<()> {
let metrics = ctx
.metrics_registry
.as_ref()
.map(|r| GrpcMetrics::new(r.as_ref()));
if metrics.is_some() {
info!("gRPC metrics registered");
}
let state = GrpcState {
executor: ctx.executor.clone(),
max_recv_message_size: self.config.max_recv_message_size,
max_send_message_size: self.config.max_send_message_size,
compression: self.config.compression,
metrics,
};
let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
let mut sd = ctx.shutdown.clone();
info!(listen = %self.config.listen, "gRPC server listening");
let mut builder = if let Some(ref tls) = self.config.tls {
let cert = std::fs::read(&tls.cert)
.with_context(|| format!("read TLS cert: {}", tls.cert.display()))?;
let key = std::fs::read(&tls.key)
.with_context(|| format!("read TLS key: {}", tls.key.display()))?;
let identity = Identity::from_pem(cert, key);
let tls_config = ServerTlsConfig::new().identity(identity);
info!(cert = %tls.cert.display(), "gRPC TLS enabled");
Server::builder()
.tls_config(tls_config)
.context("TLS config")?
} else {
Server::builder()
};
if let Some(timeout) = self.config.timeout {
builder = builder.timeout(timeout);
info!(timeout = ?timeout, "gRPC server timeout configured");
}
if let Some(max_streams) = self.config.max_concurrent_streams {
builder = builder.max_concurrent_streams(max_streams);
}
if let Some(ref ka) = self.config.keepalive {
builder = builder
.http2_keepalive_interval(Some(ka.interval))
.http2_keepalive_timeout(Some(ka.timeout));
info!(interval = ?ka.interval, timeout = ?ka.timeout, "gRPC keepalive configured");
}
let mut router = builder.add_routes(tonic::service::Routes::from(router));
let (_health_reporter, health_service) = tonic_health::server::health_reporter();
router = router.add_service(health_service);
info!("gRPC health checking enabled");
{
let proto_fds = if !self.config.proto.is_empty() {
match build_reflection_descriptors(&self.config.proto) {
Ok(fds) => {
info!(
proto_files = self.config.proto.len(),
"gRPC reflection enabled"
);
Some(fds)
}
Err(e) => {
warn!(error = %e, "failed to build proto reflection; continuing without it");
None
}
}
} else {
None
};
let mut reflection_builder = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(tonic_health::pb::FILE_DESCRIPTOR_SET);
if let Some(ref fds) = proto_fds {
reflection_builder = reflection_builder.register_encoded_file_descriptor_set(fds);
}
match reflection_builder.build_v1() {
Ok(reflection) => {
router = router.add_service(reflection);
}
Err(e) => {
warn!(error = %e, "failed to build reflection service");
}
}
}
router
.serve_with_shutdown(self.config.listen, async move {
sd.changed().await.ok();
})
.await?;
Ok(())
}
fn rpc_methods(&self) -> Vec<RpcMethodDef> {
vec![]
}
}
fn build_reflection_descriptors(proto_files: &[String]) -> Result<Vec<u8>> {
let cwd = std::env::current_dir().context("get cwd")?;
let mut include_paths = BTreeSet::new();
let mut visited = BTreeSet::new();
for proto in proto_files {
let abs = cwd.join(proto);
resolve_imports(&abs, &cwd, &mut include_paths, &mut visited);
}
for proto in proto_files {
let abs = cwd.join(proto);
if let Some(parent) = abs.parent() {
include_paths.insert(parent.to_path_buf());
}
}
debug!(paths = include_paths.len(), "resolved proto include paths");
let paths: Vec<_> = include_paths.iter().collect();
let mut compiler = protox::Compiler::new(paths.iter().map(|p| p.as_path()))
.context("create protox compiler")?;
compiler.include_imports(true);
for proto in proto_files {
let path = Path::new(proto);
let file_name = path
.file_name()
.and_then(|n| n.to_str())
.context("invalid proto file name")?;
compiler
.open_file(file_name)
.with_context(|| format!("compile {proto}"))?;
}
let fds = compiler.file_descriptor_set();
Ok(prost::Message::encode_to_vec(&fds))
}
fn resolve_imports(
proto_path: &Path,
project_root: &Path,
include_paths: &mut BTreeSet<PathBuf>,
visited: &mut BTreeSet<PathBuf>,
) {
if !proto_path.is_file() || !visited.insert(proto_path.to_path_buf()) {
return;
}
let content = match std::fs::read_to_string(proto_path) {
Ok(c) => c,
Err(_) => return,
};
for line in content.lines() {
let line = line.trim();
if !line.starts_with("import ") {
continue;
}
let Some(start) = line.find('"') else {
continue;
};
let Some(end) = line[start + 1..].find('"') else {
continue;
};
let import_path = &line[start + 1..start + 1 + end];
if import_path.starts_with("google/protobuf/") {
continue;
}
if let Some(found) = find_best_match(project_root, import_path) {
if let Some(base) = found
.to_str()
.and_then(|f| f.strip_suffix(import_path))
.map(PathBuf::from)
{
if !base.as_os_str().is_empty() {
include_paths.insert(base);
}
} else if let Some(parent) = found.parent() {
include_paths.insert(parent.to_path_buf());
}
resolve_imports(&found, project_root, include_paths, visited);
}
}
}
fn find_best_match(root: &Path, relative: &str) -> Option<PathBuf> {
let mut candidates = Vec::new();
let direct = root.join(relative);
if direct.is_file() {
candidates.push(direct);
}
find_all_recursive(root, relative, &mut candidates);
if candidates.len() <= 1 {
return candidates.into_iter().next();
}
candidates
.into_iter()
.max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
}
fn find_all_recursive(dir: &Path, relative: &str, results: &mut Vec<PathBuf>) {
let candidate = dir.join(relative);
if candidate.is_file() && !results.contains(&candidate) {
results.push(candidate);
}
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return,
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
let name = entry.file_name();
if name.to_str().is_some_and(|n| n.starts_with('.')) {
continue;
}
find_all_recursive(&path, relative, results);
}
}
}