use std::path::Path;
use anyhow::{Context, Result};
use async_trait::async_trait;
use axum::Router;
use folk_api::{PluginContext, RpcMethodDef, ServerPlugin};
use tonic::transport::Server;
use tracing::{info, warn};
use crate::config::GrpcConfig;
use crate::service::{GrpcState, grpc_handler};
pub struct GrpcPlugin {
config: GrpcConfig,
}
impl GrpcPlugin {
pub fn new(config: GrpcConfig) -> Self {
Self { config }
}
}
#[async_trait]
impl ServerPlugin for GrpcPlugin {
fn name(&self) -> &'static str {
"grpc"
}
async fn run(&self, ctx: PluginContext) -> Result<()> {
let state = GrpcState {
executor: ctx.executor.clone(),
};
let router: axum::Router = Router::new().fallback(grpc_handler).with_state(state);
let mut sd = ctx.shutdown.clone();
info!(listen = %self.config.listen, "gRPC server listening");
let mut routes = tonic::service::Routes::from(router);
if !self.config.proto.is_empty() {
match build_reflection_service(&self.config.proto) {
Ok(encoded_fds) => {
info!(proto_files = self.config.proto.len(), "gRPC reflection enabled");
let reflection = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(&encoded_fds)
.build_v1()
.context("build reflection service")?;
Server::builder()
.add_routes(routes)
.add_service(reflection)
.serve_with_shutdown(self.config.listen, async move {
sd.changed().await.ok();
})
.await?;
return Ok(());
},
Err(e) => {
warn!(error = %e, "failed to build reflection; starting without it");
},
}
}
Server::builder()
.add_routes(routes)
.serve_with_shutdown(self.config.listen, async move {
sd.changed().await.ok();
})
.await?;
Ok(())
}
fn rpc_methods(&self) -> Vec<RpcMethodDef> {
vec![RpcMethodDef::new(
"grpc.services",
"list registered gRPC service names",
)]
}
}
fn build_reflection_service(proto_files: &[String]) -> Result<Vec<u8>> {
let mut include_paths = Vec::new();
for proto in proto_files {
let path = std::path::absolute(Path::new(proto)).unwrap_or_else(|_| proto.into());
let mut dir = path.parent();
while let Some(d) = dir {
include_paths.push(d.to_path_buf());
for subdir in ["third_party", "vendor", "include"] {
let tp = d.join(subdir);
if tp.is_dir() {
include_paths.push(tp.clone());
for entry in std::fs::read_dir(&tp).into_iter().flatten().flatten() {
if entry.path().is_dir() {
include_paths.push(entry.path());
}
}
}
}
dir = d.parent();
}
}
include_paths.sort();
include_paths.dedup();
let mut compiler = protox::Compiler::new(include_paths.iter().map(|p| p.as_path()))
.context("create protox compiler")?;
for proto in proto_files {
let path = Path::new(proto);
let file_name = path.file_name().and_then(|n| n.to_str())
.context("invalid proto file name")?;
compiler.open_file(file_name)
.with_context(|| format!("compile {proto}"))?;
}
let fds = compiler.file_descriptor_set();
Ok(prost::Message::encode_to_vec(&fds))
}