pub mod hello;
use std::sync::Arc;
use async_trait::async_trait;
use rsketch_base::readable_size::ReadableSize;
use rsketch_error::{ParseAddressSnafu, Result};
use serde::{Deserialize, Serialize};
use smart_default::SmartDefault;
use snafu::ResultExt;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tonic::{service::RoutesBuilder, transport::Server};
use tonic_health::server::HealthReporter;
use tonic_reflection::server::v1::{ServerReflection, ServerReflectionServer};
use tonic_tracing_opentelemetry::middleware::server::OtelGrpcLayer;
use tracing::info;
use crate::ServiceHandler;
pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, SmartDefault, bon::Builder)]
pub struct GrpcServerConfig {
#[default = "127.0.0.1:50051"]
pub bind_address: String,
#[default = "127.0.0.1:50051"]
pub server_address: String,
#[default(DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE)]
pub max_recv_message_size: ReadableSize,
#[default(DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE)]
pub max_send_message_size: ReadableSize,
}
#[async_trait]
pub trait GrpcServiceHandler: Send + Sync + 'static {
fn service_name(&self) -> &'static str;
fn file_descriptor_set(&self) -> &'static [u8];
fn register_service(self: &Arc<Self>, builder: &mut RoutesBuilder);
async fn readiness_reporting(
self: &Arc<Self>,
_cancellation_token: CancellationToken,
health_reporter: HealthReporter,
) {
health_reporter
.set_service_status("", tonic_health::ServingStatus::Serving)
.await;
}
}
pub fn start_grpc_server(
config: &GrpcServerConfig,
services: &[Arc<impl GrpcServiceHandler>],
) -> Result<ServiceHandler> {
let bind_addr = config
.bind_address
.parse::<std::net::SocketAddr>()
.context(ParseAddressSnafu {
addr: config.bind_address.clone(),
})?;
let reflection_service = {
let mut file_descriptor_sets = Vec::new();
for service in services {
file_descriptor_sets.push(service.file_descriptor_set());
}
file_descriptor_sets.push(tonic_reflection::pb::v1::FILE_DESCRIPTOR_SET);
build_reflection_service(&file_descriptor_sets)
};
let (reporter, health_service) = tonic_health::server::health_reporter();
let mut routes_builder = RoutesBuilder::default();
routes_builder
.add_service(health_service)
.add_service(reflection_service);
for service in services {
let service = service.clone();
service.register_service(&mut routes_builder);
}
let cancellation_token = CancellationToken::new();
let (join_handle, started_rx) = {
let (started_tx, started_rx) = oneshot::channel::<()>();
let cancellation_token_clone = cancellation_token.clone();
let join_handle = tokio::spawn(async move {
let result = Server::builder()
.layer(OtelGrpcLayer::default())
.accept_http1(true)
.add_routes(routes_builder.routes())
.serve_with_shutdown(bind_addr, async move {
info!("gRPC server (on {}) starting", bind_addr);
let _ = started_tx.send(());
info!("gRPC server (on {}) started", bind_addr);
cancellation_token_clone.cancelled().await;
info!("gRPC server (on {}) received shutdown signal", bind_addr);
})
.await;
info!(
"gRPC server (on {}) task completed: {:?}",
bind_addr, result
);
});
(join_handle, started_rx)
};
let reporter_handlers = {
let mut handlers = Vec::new();
for service in services {
info!(
"spawning readiness reporting task for {}",
service.service_name()
);
let service = service.clone();
let reporter = reporter.clone();
let cancellation_token_clone = cancellation_token.clone();
let handle = tokio::spawn(async move {
service
.readiness_reporting(cancellation_token_clone, reporter)
.await;
info!(
"readiness reporting task for {} completed",
service.service_name()
);
});
handlers.push(handle);
}
handlers
};
let handle = ServiceHandler {
join_handle,
cancellation_token,
started_rx: Some(started_rx),
reporter_handles: reporter_handlers,
};
Ok(handle)
}
fn build_reflection_service(
file_descriptor_sets: &[&[u8]],
) -> ServerReflectionServer<impl ServerReflection> {
let mut builder = tonic_reflection::server::Builder::configure();
for file_descriptor_set in file_descriptor_sets {
builder = builder.register_encoded_file_descriptor_set(file_descriptor_set);
}
builder
.build_v1()
.expect("failed to build reflection service")
}