1use std::{convert::Infallible, future::IntoFuture, net::SocketAddr};
2
3use http::{Request, Response};
4use prover_telemetry::ServerBuilder as MetricsBuilder;
5use tokio::{net::TcpListener, runtime::Runtime};
6use tokio_util::sync::CancellationToken;
7use tonic::{
8 body::{boxed, BoxBody},
9 server::NamedService,
10};
11use tower::{Service, ServiceExt};
12use tracing::{debug, info};
13
14pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
15
16pub struct ProverEngine {
17 rpc_server: axum::Router,
18 rpc_runtime: Option<Runtime>,
19 metrics_runtime: Option<Runtime>,
20 reflection: Vec<&'static [u8]>,
21 healthy_service: Vec<&'static str>,
22 cancellation_token: Option<CancellationToken>,
23 metric_socket_addr: Option<SocketAddr>,
24 rpc_socket_addr: Option<SocketAddr>,
25}
26
27impl ProverEngine {
28 pub fn builder() -> Self {
29 Self {
30 rpc_server: axum::Router::new(),
31 reflection: vec![tonic_health::pb::FILE_DESCRIPTOR_SET],
32 healthy_service: vec![],
33 rpc_runtime: None,
34 metrics_runtime: None,
35 cancellation_token: None,
36 metric_socket_addr: None,
37 rpc_socket_addr: None,
38 }
39 }
40
41 pub fn set_rpc_runtime(mut self, rpc_runtime: Runtime) -> Self {
42 self.rpc_runtime = Some(rpc_runtime);
43
44 self
45 }
46
47 pub fn set_metrics_runtime(mut self, metrics_runtime: Runtime) -> Self {
48 self.metrics_runtime = Some(metrics_runtime);
49
50 self
51 }
52
53 pub fn set_metric_socket_addr(mut self, metric_socket_addr: SocketAddr) -> Self {
54 self.metric_socket_addr = Some(metric_socket_addr);
55
56 self
57 }
58
59 pub fn set_rpc_socket_addr(mut self, rpc_socket_addr: SocketAddr) -> Self {
60 self.rpc_socket_addr = Some(rpc_socket_addr);
61
62 self
63 }
64
65 pub fn set_cancellation_token(mut self, cancellation_token: CancellationToken) -> Self {
66 self.cancellation_token = Some(cancellation_token);
67 self
68 }
69
70 pub fn add_rpc_reflection(mut self, reflection: &'static [u8]) -> Self {
71 self.reflection.push(reflection);
72
73 self
74 }
75 pub fn add_rpc_service<S>(mut self, rpc_service: S) -> Self
76 where
77 S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
78 + NamedService
79 + Clone
80 + Sync
81 + Send
82 + 'static,
83 S::Future: Send + 'static,
84 S::Error: Into<BoxError> + Send,
85 {
86 self.rpc_server = add_rpc_service(self.rpc_server, rpc_service);
87 self.healthy_service.push(S::NAME);
88
89 self
90 }
91
92 pub fn add_reflection_service(mut self, descriptor: &'static [u8]) -> Self {
93 self.reflection.push(descriptor);
94
95 self
96 }
97
98 pub fn start(mut self) -> anyhow::Result<()> {
99 info!("Starting the prover engine");
100 let cancellation_token = self.cancellation_token.take().unwrap_or_default();
101
102 let metrics_runtime = self
103 .metrics_runtime
104 .take()
105 .map(Result::Ok)
106 .unwrap_or_else(|| {
107 tokio::runtime::Builder::new_multi_thread()
108 .thread_name("metrics-runtime")
109 .worker_threads(2)
110 .enable_all()
111 .build()
112 })?;
113
114 let prover_runtime = self.rpc_runtime.take().map(Result::Ok).unwrap_or_else(|| {
115 tokio::runtime::Builder::new_multi_thread()
116 .thread_name("prover-runtime")
117 .enable_all()
118 .build()
119 })?;
120
121 let addr = self.rpc_socket_addr.take().unwrap_or_else(|| {
122 "[::1]:10000"
123 .parse()
124 .expect("Unable to parse the RPC socket address")
125 });
126 let telemetry_addr = self.metric_socket_addr.take().unwrap_or_else(|| {
127 "[::1]:10001"
128 .parse()
129 .expect("Unable to parse the telemetry socket address")
130 });
131
132 debug!("Starting the metrics server..");
133 let metric_server = metrics_runtime.block_on(
135 MetricsBuilder::builder()
136 .addr(telemetry_addr)
137 .cancellation_token(cancellation_token.clone())
138 .build(),
139 )?;
140
141 let metrics_handle = {
143 let _guard = metrics_runtime.enter();
148 metrics_runtime.spawn(metric_server.into_future())
150 };
151 let tcp_listener = prover_runtime.block_on(TcpListener::bind(addr))?;
152
153 let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
154
155 let (reflection_v1, reflection_v1alpha) = self.reflection.iter().fold(
156 (
157 tonic_reflection::server::Builder::configure(),
158 tonic_reflection::server::Builder::configure(),
159 ),
160 |(v1, v1alpha), descriptor| {
161 (
162 v1.register_encoded_file_descriptor_set(descriptor),
163 v1alpha.register_encoded_file_descriptor_set(descriptor),
164 )
165 },
166 );
167
168 let (reflection_v1, reflection_v1alpha) = self.reflection.iter().fold(
169 (reflection_v1, reflection_v1alpha),
170 |(reflection_v1, reflection_v1alpha), descriptor| {
171 (
172 reflection_v1.register_encoded_file_descriptor_set(descriptor),
173 reflection_v1alpha.register_encoded_file_descriptor_set(descriptor),
174 )
175 },
176 );
177
178 let reflection_v1 = reflection_v1.build_v1().unwrap();
179 let reflection_v1alpha = reflection_v1alpha.build_v1alpha().unwrap();
180
181 debug!("Setting the health status of the services to healthy");
182 prover_runtime.block_on(async {
183 for service_name in self.healthy_service.iter() {
184 health_reporter
185 .set_service_status(service_name, tonic_health::ServingStatus::Serving)
186 .await;
187 }
188 });
189
190 debug!("Adding the reflection and health services to the RPC server");
191 let rpc_server = add_rpc_service(self.rpc_server, reflection_v1);
193 let rpc_server = add_rpc_service(rpc_server, reflection_v1alpha);
194 let rpc_server = add_rpc_service(rpc_server, health_service);
195
196 let token = cancellation_token.clone();
197 let prover_handle = prover_runtime.spawn(
198 axum::serve(tcp_listener, rpc_server)
199 .with_graceful_shutdown(async move { token.cancelled().await })
200 .into_future(),
201 );
202
203 info!("Metrics server started on {}", telemetry_addr);
204 info!("RPC server started on {}", addr);
205 let terminate_signal = async {
206 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
207 .expect("Fail to setup SIGTERM signal")
208 .recv()
209 .await;
210 };
211
212 tokio::runtime::Builder::new_current_thread()
213 .enable_all()
214 .build()?
215 .block_on(async {
216 tokio::select! {
217 _ = terminate_signal => {
218 info!("Received SIGTERM, shutting down...");
219 cancellation_token.cancel();
221 _ = prover_handle.await;
223 _ = metrics_handle.await;
225 }
226 _ = tokio::signal::ctrl_c() => {
227 info!("Received SIGINT (ctrl-c), shutting down...");
228 cancellation_token.cancel();
230 _ = prover_handle.await;
232 _ = metrics_handle.await;
234 }
235 }
236 });
237
238 Ok(())
242 }
243}
244
245fn add_rpc_service<S>(rpc_server: axum::Router, rpc_service: S) -> axum::Router
246where
247 S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
248 + NamedService
249 + Clone
250 + Sync
251 + Send
252 + 'static,
253 S::Future: Send + 'static,
254 S::Error: Into<BoxError> + Send,
255{
256 rpc_server.route_service(
257 &format!("/{}/{{*rest}}", S::NAME),
258 rpc_service.map_request(|r: Request<axum::body::Body>| r.map(boxed)),
259 )
260}