prover_engine/
lib.rs

1use std::{convert::Infallible, future::IntoFuture, net::SocketAddr};
2
3use http::{Request, Response};
4use prover_telemetry::ServerBuilder as MetricsBuilder;
5use tokio::{net::TcpListener, runtime::Runtime};
6use tokio_util::sync::CancellationToken;
7use tonic::{
8    body::{boxed, BoxBody},
9    server::NamedService,
10};
11use tower::{Service, ServiceExt};
12use tracing::{debug, info};
13
14pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
15
16pub struct ProverEngine {
17    rpc_server: axum::Router,
18    rpc_runtime: Option<Runtime>,
19    metrics_runtime: Option<Runtime>,
20    reflection: Vec<&'static [u8]>,
21    healthy_service: Vec<&'static str>,
22    cancellation_token: Option<CancellationToken>,
23    metric_socket_addr: Option<SocketAddr>,
24    rpc_socket_addr: Option<SocketAddr>,
25}
26
27impl ProverEngine {
28    pub fn builder() -> Self {
29        Self {
30            rpc_server: axum::Router::new(),
31            reflection: vec![tonic_health::pb::FILE_DESCRIPTOR_SET],
32            healthy_service: vec![],
33            rpc_runtime: None,
34            metrics_runtime: None,
35            cancellation_token: None,
36            metric_socket_addr: None,
37            rpc_socket_addr: None,
38        }
39    }
40
41    pub fn set_rpc_runtime(mut self, rpc_runtime: Runtime) -> Self {
42        self.rpc_runtime = Some(rpc_runtime);
43
44        self
45    }
46
47    pub fn set_metrics_runtime(mut self, metrics_runtime: Runtime) -> Self {
48        self.metrics_runtime = Some(metrics_runtime);
49
50        self
51    }
52
53    pub fn set_metric_socket_addr(mut self, metric_socket_addr: SocketAddr) -> Self {
54        self.metric_socket_addr = Some(metric_socket_addr);
55
56        self
57    }
58
59    pub fn set_rpc_socket_addr(mut self, rpc_socket_addr: SocketAddr) -> Self {
60        self.rpc_socket_addr = Some(rpc_socket_addr);
61
62        self
63    }
64
65    pub fn set_cancellation_token(mut self, cancellation_token: CancellationToken) -> Self {
66        self.cancellation_token = Some(cancellation_token);
67        self
68    }
69
70    pub fn add_rpc_reflection(mut self, reflection: &'static [u8]) -> Self {
71        self.reflection.push(reflection);
72
73        self
74    }
75    pub fn add_rpc_service<S>(mut self, rpc_service: S) -> Self
76    where
77        S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
78            + NamedService
79            + Clone
80            + Sync
81            + Send
82            + 'static,
83        S::Future: Send + 'static,
84        S::Error: Into<BoxError> + Send,
85    {
86        self.rpc_server = add_rpc_service(self.rpc_server, rpc_service);
87        self.healthy_service.push(S::NAME);
88
89        self
90    }
91
92    pub fn add_reflection_service(mut self, descriptor: &'static [u8]) -> Self {
93        self.reflection.push(descriptor);
94
95        self
96    }
97
98    pub fn start(mut self) -> anyhow::Result<()> {
99        info!("Starting the prover engine");
100        let cancellation_token = self.cancellation_token.take().unwrap_or_default();
101
102        let metrics_runtime = self
103            .metrics_runtime
104            .take()
105            .map(Result::Ok)
106            .unwrap_or_else(|| {
107                tokio::runtime::Builder::new_multi_thread()
108                    .thread_name("metrics-runtime")
109                    .worker_threads(2)
110                    .enable_all()
111                    .build()
112            })?;
113
114        let prover_runtime = self.rpc_runtime.take().map(Result::Ok).unwrap_or_else(|| {
115            tokio::runtime::Builder::new_multi_thread()
116                .thread_name("prover-runtime")
117                .enable_all()
118                .build()
119        })?;
120
121        let addr = self.rpc_socket_addr.take().unwrap_or_else(|| {
122            "[::1]:10000"
123                .parse()
124                .expect("Unable to parse the RPC socket address")
125        });
126        let telemetry_addr = self.metric_socket_addr.take().unwrap_or_else(|| {
127            "[::1]:10001"
128                .parse()
129                .expect("Unable to parse the telemetry socket address")
130        });
131
132        debug!("Starting the metrics server..");
133        // Create the metrics server.
134        let metric_server = metrics_runtime.block_on(
135            MetricsBuilder::builder()
136                .addr(telemetry_addr)
137                .cancellation_token(cancellation_token.clone())
138                .build(),
139        )?;
140
141        // Spawn the metrics server into the metrics runtime.
142        let metrics_handle = {
143            // This guard is used to ensure that the metrics runtime is entered
144            // before the server is spawned. This is necessary because the `into_future`
145            // of `WithGracefulShutdown` is spawning various tasks before returning the
146            // actual server instance to spawn.
147            let _guard = metrics_runtime.enter();
148            // Spawn the metrics server
149            metrics_runtime.spawn(metric_server.into_future())
150        };
151        let tcp_listener = prover_runtime.block_on(TcpListener::bind(addr))?;
152
153        let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
154
155        let (reflection_v1, reflection_v1alpha) = self.reflection.iter().fold(
156            (
157                tonic_reflection::server::Builder::configure(),
158                tonic_reflection::server::Builder::configure(),
159            ),
160            |(v1, v1alpha), descriptor| {
161                (
162                    v1.register_encoded_file_descriptor_set(descriptor),
163                    v1alpha.register_encoded_file_descriptor_set(descriptor),
164                )
165            },
166        );
167
168        let (reflection_v1, reflection_v1alpha) = self.reflection.iter().fold(
169            (reflection_v1, reflection_v1alpha),
170            |(reflection_v1, reflection_v1alpha), descriptor| {
171                (
172                    reflection_v1.register_encoded_file_descriptor_set(descriptor),
173                    reflection_v1alpha.register_encoded_file_descriptor_set(descriptor),
174                )
175            },
176        );
177
178        let reflection_v1 = reflection_v1.build_v1().unwrap();
179        let reflection_v1alpha = reflection_v1alpha.build_v1alpha().unwrap();
180
181        debug!("Setting the health status of the services to healthy");
182        prover_runtime.block_on(async {
183            for service_name in self.healthy_service.iter() {
184                health_reporter
185                    .set_service_status(service_name, tonic_health::ServingStatus::Serving)
186                    .await;
187            }
188        });
189
190        debug!("Adding the reflection and health services to the RPC server");
191        // Adding the reflection and health services to the RPC server
192        let rpc_server = add_rpc_service(self.rpc_server, reflection_v1);
193        let rpc_server = add_rpc_service(rpc_server, reflection_v1alpha);
194        let rpc_server = add_rpc_service(rpc_server, health_service);
195
196        let token = cancellation_token.clone();
197        let prover_handle = prover_runtime.spawn(
198            axum::serve(tcp_listener, rpc_server)
199                .with_graceful_shutdown(async move { token.cancelled().await })
200                .into_future(),
201        );
202
203        info!("Metrics server started on {}", telemetry_addr);
204        info!("RPC server started on {}", addr);
205        let terminate_signal = async {
206            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
207                .expect("Fail to setup SIGTERM signal")
208                .recv()
209                .await;
210        };
211
212        tokio::runtime::Builder::new_current_thread()
213            .enable_all()
214            .build()?
215            .block_on(async {
216                tokio::select! {
217                    _ = terminate_signal => {
218                        info!("Received SIGTERM, shutting down...");
219                        // Cancel the global cancellation token to start the shutdown process.
220                        cancellation_token.cancel();
221                        // Wait for the prover to shutdown.
222                        _ = prover_handle.await;
223                        // Wait for the metrics server to shutdown.
224                        _ = metrics_handle.await;
225                    }
226                    _ = tokio::signal::ctrl_c() => {
227                        info!("Received SIGINT (ctrl-c), shutting down...");
228                        // Cancel the global cancellation token to start the shutdown process.
229                        cancellation_token.cancel();
230                        // Wait for the prover to shutdown.
231                        _ = prover_handle.await;
232                        // Wait for the metrics server to shutdown.
233                        _ = metrics_handle.await;
234                    }
235                }
236            });
237
238        // prover_runtime.shutdown_timeout(config.shutdown.runtime_timeout);
239        // metrics_runtime.shutdown_timeout(config.shutdown.runtime_timeout);
240
241        Ok(())
242    }
243}
244
245fn add_rpc_service<S>(rpc_server: axum::Router, rpc_service: S) -> axum::Router
246where
247    S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
248        + NamedService
249        + Clone
250        + Sync
251        + Send
252        + 'static,
253    S::Future: Send + 'static,
254    S::Error: Into<BoxError> + Send,
255{
256    rpc_server.route_service(
257        &format!("/{}/{{*rest}}", S::NAME),
258        rpc_service.map_request(|r: Request<axum::body::Body>| r.map(boxed)),
259    )
260}