amareleo_node_rest/
lib.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#![forbid(unsafe_code)]
17
18#[macro_use]
19extern crate tracing;
20
21mod helpers;
22pub use helpers::*;
23
24mod routes;
25
26use amareleo_chain_tracing::TracingHandler;
27use amareleo_node_consensus::Consensus;
28use snarkvm::{
29    console::{program::ProgramID, types::Field},
30    prelude::{Ledger, Network, cfg_into_iter, store::ConsensusStorage},
31};
32
33use anyhow::{Result, anyhow, bail};
34use axum::{
35    Json,
36    body::Body,
37    extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State},
38    http::{Method, Request, StatusCode, header::CONTENT_TYPE},
39    middleware,
40    middleware::Next,
41    response::Response,
42    routing::{get, post},
43};
44use axum_extra::response::ErasedJson;
45use parking_lot::Mutex;
46use std::{net::SocketAddr, sync::Arc};
47use tokio::{
48    net::TcpListener,
49    sync::{
50        oneshot,
51        oneshot::{Receiver, Sender},
52    },
53    task::JoinHandle,
54};
55use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
56use tower_http::{
57    cors::{Any, CorsLayer},
58    trace::TraceLayer,
59};
60use tracing::subscriber::DefaultGuard;
61
62/// A REST API server for the ledger.
63#[derive(Clone)]
64pub struct Rest<N: Network, C: ConsensusStorage<N>> {
65    /// The consensus module.
66    consensus: Option<Consensus<N>>,
67    /// The ledger.
68    ledger: Ledger<N, C>,
69    /// Tracing handle
70    tracing: Option<TracingHandler>,
71    /// shutdown signal
72    shutdown_tx: Arc<Mutex<Option<Sender<()>>>>,
73    /// The server handles.
74    rest_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
75}
76
77impl<N: Network, C: 'static + ConsensusStorage<N>> Rest<N, C> {
78    /// Initializes a new instance of the server.
79    pub async fn start(
80        rest_ip: SocketAddr,
81        rest_rps: u32,
82        consensus: Option<Consensus<N>>,
83        ledger: Ledger<N, C>,
84        tracing: Option<TracingHandler>,
85    ) -> Result<Self> {
86        // Initialize the server.
87        let mut server = Self {
88            consensus,
89            ledger,
90            tracing,
91            shutdown_tx: Arc::new(Mutex::new(None)),
92            rest_handle: Arc::new(Mutex::new(None)),
93        };
94        // Spawn the server.
95        server.spawn_server(rest_ip, rest_rps).await?;
96        // Return the server.
97        Ok(server)
98    }
99
100    pub async fn shut_down(&self) {
101        // Extract and replace with None
102        let shutdown_option = self.shutdown_tx.lock().take();
103        if let Some(tx) = shutdown_option {
104            let _ = tx.send(()); // Send shutdown signal
105        }
106
107        // Extract and replace with None
108        let handle_option = self.rest_handle.lock().take();
109        if let Some(handle) = handle_option {
110            let _ = handle.await;
111        }
112
113        info!("REST server shutdown completed.");
114    }
115}
116
117impl<N: Network, C: ConsensusStorage<N>> Rest<N, C> {
118    /// Returns the ledger.
119    pub const fn ledger(&self) -> &Ledger<N, C> {
120        &self.ledger
121    }
122
123    /// Retruns tracing guard
124    pub fn get_tracing_guard(&self) -> Option<DefaultGuard> {
125        self.tracing.clone().map(|trace_handle| trace_handle.subscribe_thread())
126    }
127}
128
129impl<N: Network, C: ConsensusStorage<N>> Rest<N, C> {
130    async fn spawn_server(&mut self, rest_ip: SocketAddr, rest_rps: u32) -> Result<()> {
131        let _guard = self.get_tracing_guard();
132
133        let cors = CorsLayer::new()
134            .allow_origin(Any)
135            .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
136            .allow_headers([CONTENT_TYPE]);
137
138        // Log the REST rate limit per IP.
139        debug!("REST rate limit per IP - {rest_rps} RPS");
140
141        // Prepare the rate limiting setup.
142        let governor_config = Box::new(
143            GovernorConfigBuilder::default()
144                .per_nanosecond((1_000_000_000 / rest_rps) as u64)
145                .burst_size(rest_rps)
146                .error_handler(|error| {
147                    // Properly return a 429 Too Many Requests error
148                    let error_message = error.to_string();
149                    let mut response = Response::new(error_message.clone().into());
150                    *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
151                    if error_message.contains("Too Many Requests") {
152                        *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
153                    }
154                    response
155                })
156                .finish()
157                .expect("Couldn't set up rate limiting for the REST server!"),
158        );
159
160        // Get the network being used.
161        let network = match N::ID {
162            snarkvm::console::network::MainnetV0::ID => "mainnet",
163            snarkvm::console::network::TestnetV0::ID => "testnet",
164            snarkvm::console::network::CanaryV0::ID => "canary",
165            unknown_id => bail!("Unknown network ID ({unknown_id})"),
166        };
167
168        let router = {
169            let routes = axum::Router::new()
170                // All the endpoints before the call to `route_layer` are protected with JWT auth.
171                .route(
172                    &format!("/{network}/node/address"),
173                    get(Self::get_node_address),
174                )
175                .route(
176                    &format!("/{network}/program/:id/mapping/:name"),
177                    get(Self::get_mapping_values),
178                )
179                .route_layer(middleware::from_fn(auth_middleware))
180                // GET ../block/..
181                .route(
182                    &format!("/{network}/block/height/latest"),
183                    get(Self::get_block_height_latest),
184                )
185                .route(
186                    &format!("/{network}/block/hash/latest"),
187                    get(Self::get_block_hash_latest),
188                )
189                .route(
190                    &format!("/{network}/block/latest"),
191                    get(Self::get_block_latest),
192                )
193                .route(
194                    &format!("/{network}/block/:height_or_hash"),
195                    get(Self::get_block),
196                )
197                // The path param here is actually only the height, but the name must match the route
198                // above, otherwise there'll be a conflict at runtime.
199                .route(&format!("/{network}/block/:height_or_hash/header"), get(Self::get_block_header))
200                .route(&format!("/{network}/block/:height_or_hash/transactions"), get(Self::get_block_transactions))
201                // GET and POST ../transaction/..
202                .route(
203                    &format!("/{network}/transaction/:id"),
204                    get(Self::get_transaction),
205                )
206                .route(
207                    &format!("/{network}/transaction/confirmed/:id"),
208                    get(Self::get_confirmed_transaction),
209                )
210                .route(
211                    &format!("/{network}/transaction/broadcast"),
212                    post(Self::transaction_broadcast),
213                )
214                // POST ../solution/broadcast
215                .route(
216                    &format!("/{network}/solution/broadcast"),
217                    post(Self::solution_broadcast),
218                )
219                // GET ../find/..
220                .route(
221                    &format!("/{network}/find/blockHash/:tx_id"),
222                    get(Self::find_block_hash),
223                )
224                .route(
225                    &format!("/{network}/find/blockHeight/:state_root"),
226                    get(Self::find_block_height_from_state_root),
227                )
228                .route(
229                    &format!("/{network}/find/transactionID/deployment/:program_id"),
230                    get(Self::find_transaction_id_from_program_id),
231                )
232                .route(
233                    &format!("/{network}/find/transactionID/:transition_id"),
234                    get(Self::find_transaction_id_from_transition_id),
235                )
236                .route(
237                    &format!("/{network}/find/transitionID/:input_or_output_id"),
238                    get(Self::find_transition_id),
239                )
240                // GET ../peers/..
241                .route(
242                    &format!("/{network}/peers/count"),
243                    get(Self::get_peers_count),
244                )
245                .route(&format!("/{network}/peers/all"), get(Self::get_peers_all))
246                .route(
247                    &format!("/{network}/peers/all/metrics"),
248                    get(Self::get_peers_all_metrics),
249                )
250                // GET ../program/..
251                .route(&format!("/{network}/program/:id"), get(Self::get_program))
252                .route(
253                    &format!("/{network}/program/:id/mappings"),
254                    get(Self::get_mapping_names),
255                )
256                .route(
257                    &format!("/{network}/program/:id/mapping/:name/:key"),
258                    get(Self::get_mapping_value),
259                )
260                // GET misc endpoints.
261                .route(&format!("/{network}/blocks"), get(Self::get_blocks))
262                .route(&format!("/{network}/height/:hash"), get(Self::get_height))
263                .route(
264                    &format!("/{network}/memoryPool/transmissions"),
265                    get(Self::get_memory_pool_transmissions),
266                )
267                .route(
268                    &format!("/{network}/memoryPool/solutions"),
269                    get(Self::get_memory_pool_solutions),
270                )
271                .route(
272                    &format!("/{network}/memoryPool/transactions"),
273                    get(Self::get_memory_pool_transactions),
274                )
275                .route(
276                    &format!("/{network}/statePath/:commitment"),
277                    get(Self::get_state_path_for_commitment),
278                )
279                .route(
280                    &format!("/{network}/stateRoot/latest"),
281                    get(Self::get_state_root_latest),
282                )
283                .route(
284                    &format!("/{network}/stateRoot/:height"),
285                    get(Self::get_state_root),
286                )
287                .route(
288                    &format!("/{network}/committee/latest"),
289                    get(Self::get_committee_latest),
290                )
291                .route(
292                    &format!("/{network}/committee/:height"),
293                    get(Self::get_committee),
294                )
295                .route(
296                    &format!("/{network}/delegators/:validator"),
297                    get(Self::get_delegators_for_validator),
298                );
299
300            // If the `history` feature is enabled, enable the additional endpoint.
301            #[cfg(feature = "history")]
302            let routes =
303                routes.route(&format!("/{network}/block/:blockHeight/history/:mapping"), get(Self::get_history));
304
305            routes
306                // Pass in `Rest` to make things convenient.
307                .with_state(self.clone())
308                // Enable tower-http tracing.
309                .layer(TraceLayer::new_for_http())
310                // Custom logging.
311                .layer(middleware::from_fn_with_state(self.clone(), Self::log_middleware))
312                // Enable CORS.
313                .layer(cors)
314                // Cap body size at 512KiB.
315                .layer(DefaultBodyLimit::max(512 * 1024))
316                .layer(GovernorLayer {
317                    // We can leak this because it is created only once and it persists.
318                    config: Box::leak(governor_config),
319                })
320        };
321
322        // Create a oneshot channel to signal when the server is done
323        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
324
325        let rest_listener =
326            TcpListener::bind(rest_ip).await.map_err(|err| anyhow!("Failed to bind to {}: {}", rest_ip, err))?;
327
328        let serve_handle = tokio::spawn(async move {
329            axum::serve(rest_listener, router.into_make_service_with_connect_info::<SocketAddr>())
330                .with_graceful_shutdown(Self::shutdown_wait(shutdown_rx))
331                .await
332                .expect("couldn't start rest server");
333        });
334
335        *self.rest_handle.lock() = Some(serve_handle);
336        *self.shutdown_tx.lock() = Some(shutdown_tx);
337
338        Ok(())
339    }
340
341    async fn log_middleware(
342        State(rest): State<Self>,
343        ConnectInfo(addr): ConnectInfo<SocketAddr>,
344        request: Request<Body>,
345        next: Next,
346    ) -> Result<Response, StatusCode> {
347        let _guard = rest.get_tracing_guard();
348        info!("Received '{} {}' from '{addr}'", request.method(), request.uri());
349
350        Ok(next.run(request).await)
351    }
352
353    async fn shutdown_wait(shutdown_rx: Receiver<()>) {
354        if let Err(error) = shutdown_rx.await {
355            error!("REST server shutdown signaling error: {}", error);
356        }
357
358        info!("REST server shutdown signal recieved...");
359    }
360}
361
362/// Formats an ID into a truncated identifier (for logging purposes).
363pub fn fmt_id(id: impl ToString) -> String {
364    let id = id.to_string();
365    let mut formatted_id = id.chars().take(16).collect::<String>();
366    if id.chars().count() > 16 {
367        formatted_id.push_str("..");
368    }
369    formatted_id
370}