use std::{net::SocketAddr, time::Duration};
use super::tls::TlsConfig;
use crate::shutdown::GracefulShutdown;
#[derive(Debug)]
pub enum GrpcServerError {
Bind(String),
Tls(String),
Server(String),
Config(String),
}
impl std::fmt::Display for GrpcServerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GrpcServerError::Bind(msg) => write!(f, "Failed to bind: {}", msg),
GrpcServerError::Tls(msg) => write!(f, "TLS error: {}", msg),
GrpcServerError::Server(msg) => write!(f, "Server error: {}", msg),
GrpcServerError::Config(msg) => write!(f, "Configuration error: {}", msg),
}
}
}
impl std::error::Error for GrpcServerError {}
pub struct GrpcServerBuilder {
addr: SocketAddr,
tls_config: Option<TlsConfig>,
reflection_descriptor: Option<&'static [u8]>,
health_check: bool,
shutdown: Option<GracefulShutdown>,
shutdown_timeout: Duration,
}
impl Default for GrpcServerBuilder {
fn default() -> Self {
Self::new()
}
}
impl GrpcServerBuilder {
pub fn new() -> Self {
Self {
addr: "[::1]:50051".parse().unwrap(),
tls_config: None,
reflection_descriptor: None,
health_check: false,
shutdown: None,
shutdown_timeout: Duration::from_secs(30),
}
}
pub fn addr(mut self, addr: impl Into<SocketAddrInput>) -> Self {
self.addr = addr.into().0;
self
}
pub fn port(mut self, port: u16) -> Self {
self.addr = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], port));
self
}
pub fn tls(mut self, config: TlsConfig) -> Self {
self.tls_config = Some(config);
self
}
pub fn tls_from_env(mut self) -> Self {
self.tls_config = TlsConfig::from_env();
self
}
pub fn reflection(mut self, file_descriptor_set: &'static [u8]) -> Self {
self.reflection_descriptor = Some(file_descriptor_set);
self
}
pub fn health_check(mut self) -> Self {
self.health_check = true;
self
}
pub fn graceful_shutdown(mut self, shutdown: GracefulShutdown) -> Self {
self.shutdown = Some(shutdown);
self
}
pub fn shutdown_timeout(mut self, timeout: Duration) -> Self {
self.shutdown_timeout = timeout;
self
}
pub fn get_addr(&self) -> SocketAddr {
self.addr
}
pub fn has_tls(&self) -> bool {
self.tls_config.is_some()
}
pub fn has_reflection(&self) -> bool {
self.reflection_descriptor.is_some()
}
pub fn has_health_check(&self) -> bool {
self.health_check
}
#[cfg(feature = "router-grpc")]
pub async fn serve_router(
self,
router: tonic::transport::server::Router,
) -> Result<(), GrpcServerError> {
if let Some(shutdown) = self.shutdown {
let mut token = shutdown.token();
router
.serve_with_shutdown(self.addr, async move {
token.cancelled().await;
})
.await
.map_err(|e| GrpcServerError::Server(e.to_string()))?;
} else {
router
.serve(self.addr)
.await
.map_err(|e| GrpcServerError::Server(e.to_string()))?;
}
Ok(())
}
#[cfg(feature = "router-grpc")]
pub fn server_builder(&self) -> tonic::transport::Server {
tonic::transport::Server::builder()
}
#[cfg(feature = "router-grpc")]
pub fn reflection_service(
&self,
) -> Result<
Option<
tonic_reflection::server::v1::ServerReflectionServer<
impl tonic_reflection::server::v1::ServerReflection,
>,
>,
GrpcServerError,
> {
if let Some(fds) = self.reflection_descriptor {
let service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(fds)
.build_v1()
.map_err(|e| GrpcServerError::Config(format!("Reflection setup failed: {}", e)))?;
Ok(Some(service))
} else {
Ok(None)
}
}
#[cfg(feature = "router-grpc")]
pub fn create_health_reporter(
&self,
) -> Option<(
tonic_health::server::HealthReporter,
impl tonic::codegen::Service<
hyper::Request<hyper::body::Incoming>,
Response = hyper::Response<tonic::body::Body>,
Error = std::convert::Infallible,
> + Clone
+ Send
+ 'static,
)> {
if self.health_check {
Some(tonic_health::server::health_reporter())
} else {
None
}
}
}
pub struct SocketAddrInput(SocketAddr);
impl From<SocketAddr> for SocketAddrInput {
fn from(addr: SocketAddr) -> Self {
Self(addr)
}
}
impl From<&str> for SocketAddrInput {
fn from(s: &str) -> Self {
Self(s.parse().expect("Invalid socket address"))
}
}
impl From<String> for SocketAddrInput {
fn from(s: String) -> Self {
Self(s.parse().expect("Invalid socket address"))
}
}
impl From<([u8; 4], u16)> for SocketAddrInput {
fn from((ip, port): ([u8; 4], u16)) -> Self {
Self(SocketAddr::from((ip, port)))
}
}
impl From<([u16; 8], u16)> for SocketAddrInput {
fn from((ip, port): ([u16; 8], u16)) -> Self {
Self(SocketAddr::from((ip, port)))
}
}
pub type GrpcServer = GrpcServerBuilder;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_default() {
let builder = GrpcServerBuilder::new();
assert_eq!(builder.get_addr().port(), 50051);
assert!(!builder.has_tls());
assert!(!builder.has_reflection());
assert!(!builder.has_health_check());
}
#[test]
fn test_builder_addr_string() {
let builder = GrpcServerBuilder::new().addr("127.0.0.1:9000");
assert_eq!(builder.get_addr().port(), 9000);
}
#[test]
fn test_builder_port() {
let builder = GrpcServerBuilder::new().port(8080);
assert_eq!(builder.get_addr().port(), 8080);
}
#[test]
fn test_builder_tls() {
let tls = TlsConfig::new("/path/to/cert.pem", "/path/to/key.pem");
let builder = GrpcServerBuilder::new().tls(tls);
assert!(builder.has_tls());
}
#[test]
fn test_builder_reflection() {
static FDS: &[u8] = b"fake descriptor";
let builder = GrpcServerBuilder::new().reflection(FDS);
assert!(builder.has_reflection());
}
#[test]
fn test_builder_health_check() {
let builder = GrpcServerBuilder::new().health_check();
assert!(builder.has_health_check());
}
#[test]
fn test_builder_shutdown_timeout() {
let builder = GrpcServerBuilder::new().shutdown_timeout(Duration::from_secs(60));
assert_eq!(builder.shutdown_timeout, Duration::from_secs(60));
}
#[test]
fn test_grpc_server_error_display() {
let err = GrpcServerError::Bind("address in use".to_string());
assert!(err.to_string().contains("address in use"));
let err = GrpcServerError::Tls("invalid cert".to_string());
assert!(err.to_string().contains("invalid cert"));
}
}