1pub mod hello;
16
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use rsketch_base::readable_size::ReadableSize;
21use rsketch_error::{ParseAddressSnafu, Result};
22use serde::{Deserialize, Serialize};
23use smart_default::SmartDefault;
24use snafu::ResultExt;
25use tokio::sync::oneshot;
26use tokio_util::sync::CancellationToken;
27use tonic::{service::RoutesBuilder, transport::Server};
28use tonic_health::server::HealthReporter;
29use tonic_reflection::server::v1::{ServerReflection, ServerReflectionServer};
30use tonic_tracing_opentelemetry::middleware::server::OtelGrpcLayer;
31use tracing::info;
32
33use crate::ServiceHandler;
34
35pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
37pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
39
40#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, SmartDefault, bon::Builder)]
42pub struct GrpcServerConfig {
43 #[default = "127.0.0.1:50051"]
45 pub bind_address: String,
46 #[default = "127.0.0.1:50051"]
48 pub server_address: String,
49 #[default(DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE)]
51 pub max_recv_message_size: ReadableSize,
52 #[default(DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE)]
54 pub max_send_message_size: ReadableSize,
55}
56
57#[async_trait]
74pub trait GrpcServiceHandler: Send + Sync + 'static {
75 fn service_name(&self) -> &'static str;
77 fn file_descriptor_set(&self) -> &'static [u8];
80 fn register_service(self: &Arc<Self>, builder: &mut RoutesBuilder);
84 async fn readiness_reporting(
87 self: &Arc<Self>,
88 _cancellation_token: CancellationToken,
89 health_reporter: HealthReporter,
90 ) {
91 health_reporter
94 .set_service_status("", tonic_health::ServingStatus::Serving)
95 .await;
96 }
97}
98
99pub fn start_grpc_server(
116 config: &GrpcServerConfig,
117 services: &[Arc<impl GrpcServiceHandler>],
118) -> Result<ServiceHandler> {
119 let bind_addr = config
121 .bind_address
122 .parse::<std::net::SocketAddr>()
123 .context(ParseAddressSnafu {
124 addr: config.bind_address.clone(),
125 })?;
126
127 let reflection_service = {
128 let mut file_descriptor_sets = Vec::new();
129 for service in services {
130 file_descriptor_sets.push(service.file_descriptor_set());
131 }
132 file_descriptor_sets.push(tonic_reflection::pb::v1::FILE_DESCRIPTOR_SET);
133 build_reflection_service(&file_descriptor_sets)
134 };
135
136 let (reporter, health_service) = tonic_health::server::health_reporter();
137 let mut routes_builder = RoutesBuilder::default();
138 routes_builder
139 .add_service(health_service)
140 .add_service(reflection_service);
141
142 for service in services {
144 let service = service.clone();
145 service.register_service(&mut routes_builder);
146 }
147
148 let cancellation_token = CancellationToken::new();
150 let (join_handle, started_rx) = {
151 let (started_tx, started_rx) = oneshot::channel::<()>();
152 let cancellation_token_clone = cancellation_token.clone();
153 let join_handle = tokio::spawn(async move {
154 let result = Server::builder()
155 .layer(OtelGrpcLayer::default())
156 .accept_http1(true)
157 .add_routes(routes_builder.routes())
158 .serve_with_shutdown(bind_addr, async move {
159 info!("gRPC server (on {}) starting", bind_addr);
160 let _ = started_tx.send(());
161 info!("gRPC server (on {}) started", bind_addr);
162 cancellation_token_clone.cancelled().await;
163 info!("gRPC server (on {}) received shutdown signal", bind_addr);
164 })
165 .await;
166
167 info!(
168 "gRPC server (on {}) task completed: {:?}",
169 bind_addr, result
170 );
171 });
172 (join_handle, started_rx)
173 };
174
175 let reporter_handlers = {
176 let mut handlers = Vec::new();
177 for service in services {
178 info!(
179 "spawning readiness reporting task for {}",
180 service.service_name()
181 );
182 let service = service.clone();
183 let reporter = reporter.clone();
184 let cancellation_token_clone = cancellation_token.clone();
185 let handle = tokio::spawn(async move {
186 service
187 .readiness_reporting(cancellation_token_clone, reporter)
188 .await;
189 info!(
190 "readiness reporting task for {} completed",
191 service.service_name()
192 );
193 });
194 handlers.push(handle);
195 }
196 handlers
197 };
198
199 let handle = ServiceHandler {
200 join_handle,
201 cancellation_token,
202 started_rx: Some(started_rx),
203 reporter_handles: reporter_handlers,
204 };
205 Ok(handle)
206}
207
208fn build_reflection_service(
209 file_descriptor_sets: &[&[u8]],
210) -> ServerReflectionServer<impl ServerReflection> {
211 let mut builder = tonic_reflection::server::Builder::configure();
212
213 for file_descriptor_set in file_descriptor_sets {
214 builder = builder.register_encoded_file_descriptor_set(file_descriptor_set);
215 }
216 builder
217 .build_v1()
218 .expect("failed to build reflection service")
219}