amareleo_node_rest/
lib.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#![forbid(unsafe_code)]
17
18#[macro_use]
19extern crate tracing;
20
21#[macro_use]
22extern crate amareleo_chain_tracing;
23
24mod helpers;
25pub use helpers::*;
26
27mod routes;
28
29use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
30use amareleo_node_consensus::Consensus;
31use snarkvm::{
32    console::{program::ProgramID, types::Field},
33    prelude::{Ledger, Network, cfg_into_iter, store::ConsensusStorage},
34};
35
36use anyhow::{Result, anyhow, bail};
37use axum::{
38    Json,
39    body::Body,
40    extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State},
41    http::{Method, Request, StatusCode, header::CONTENT_TYPE},
42    middleware,
43    middleware::Next,
44    response::Response,
45    routing::{get, post},
46};
47use axum_extra::response::ErasedJson;
48use parking_lot::Mutex;
49use std::{net::SocketAddr, sync::Arc};
50use tokio::{
51    net::TcpListener,
52    sync::{oneshot, watch},
53    task::JoinHandle,
54};
55use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
56use tower_http::{
57    cors::{Any, CorsLayer},
58    trace::TraceLayer,
59};
60use tracing::subscriber::DefaultGuard;
61
62/// A REST API server for the ledger.
63#[derive(Clone)]
64pub struct Rest<N: Network, C: ConsensusStorage<N>> {
65    /// The consensus module.
66    consensus: Option<Consensus<N>>,
67    /// The ledger.
68    ledger: Ledger<N, C>,
69    /// Tracing handle
70    tracing: Option<TracingHandler>,
71    /// signal to initiate shutdown
72    shutdown_trigger_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
73    /// signal of completed shutdown
74    shutdown_complete_rx: Arc<Mutex<Option<watch::Receiver<bool>>>>,
75    /// The server handles.
76    rest_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
77}
78
79impl<N: Network, C: 'static + ConsensusStorage<N>> Rest<N, C> {
80    /// Initializes a new instance of the server.
81    pub async fn start(
82        rest_ip: SocketAddr,
83        rest_rps: u32,
84        consensus: Option<Consensus<N>>,
85        ledger: Ledger<N, C>,
86        tracing: Option<TracingHandler>,
87    ) -> Result<Self> {
88        // Initialize the server.
89        let mut server = Self {
90            consensus,
91            ledger,
92            tracing,
93            shutdown_trigger_tx: Arc::new(Mutex::new(None)),
94            shutdown_complete_rx: Arc::new(Mutex::new(None)),
95            rest_handle: Arc::new(Mutex::new(None)),
96        };
97        // Spawn the server.
98        server.spawn_server(rest_ip, rest_rps).await?;
99        // Return the server.
100        Ok(server)
101    }
102
103    pub fn is_finished(&self) -> bool {
104        let lock = self.rest_handle.lock();
105        if let Some(handle) = lock.as_ref() { handle.is_finished() } else { true }
106    }
107
108    pub async fn wait_finish(&self) -> Result<()> {
109        if self.is_finished() {
110            guard_info!(self, "REST server already shutdown.");
111            return Ok(());
112        }
113
114        // Clone the shutdown complete signal receiver
115        let rx_option = {
116            let lock = self.shutdown_complete_rx.lock();
117            lock.as_ref().map(|opt| opt.clone())
118        };
119
120        if let Some(mut rx) = rx_option {
121            // Wait for completion
122            while !*rx.borrow() {
123                if rx.changed().await.is_err() {
124                    bail!("REST shutdown completed signal errored!");
125                }
126            }
127
128            guard_info!(self, "REST shutdown completed signal received.");
129        } else {
130            bail!("REST shutdown completed signal NOT found!");
131        }
132
133        Ok(())
134    }
135
136    pub async fn shut_down(&self) {
137        // Extract and replace with None
138        let shutdown_option = self.shutdown_trigger_tx.lock().take();
139        if let Some(tx) = shutdown_option {
140            let _ = tx.send(()); // Send shutdown signal
141        }
142
143        // Await for the server to shutdown
144        let _ = self.wait_finish().await;
145    }
146}
147
148impl<N: Network, C: ConsensusStorage<N>> Rest<N, C> {
149    /// Returns the ledger.
150    pub const fn ledger(&self) -> &Ledger<N, C> {
151        &self.ledger
152    }
153
154    /// Retruns tracing guard
155    pub fn get_tracing_guard(&self) -> Option<DefaultGuard> {
156        self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
157    }
158}
159
160impl<N: Network, C: ConsensusStorage<N>> Rest<N, C> {
161    async fn spawn_server(&mut self, rest_ip: SocketAddr, rest_rps: u32) -> Result<()> {
162        let cors = CorsLayer::new()
163            .allow_origin(Any)
164            .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
165            .allow_headers([CONTENT_TYPE]);
166
167        // Log the REST rate limit per IP.
168        guard_debug!(self, "REST rate limit per IP - {rest_rps} RPS");
169
170        // Prepare the rate limiting setup.
171        let governor_config = match GovernorConfigBuilder::default()
172            .per_nanosecond((1_000_000_000 / rest_rps) as u64)
173            .burst_size(rest_rps)
174            .error_handler(|error| {
175                // Properly return a 429 Too Many Requests error
176                let error_message = error.to_string();
177                let mut response = Response::new(error_message.clone().into());
178                *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
179                if error_message.contains("Too Many Requests") {
180                    *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
181                }
182                response
183            })
184            .finish()
185        {
186            Some(config) => Box::new(config),
187            None => bail!("Couldn't set up rate limiting for the REST server"),
188        };
189
190        // Get the network being used.
191        let network = match N::ID {
192            snarkvm::console::network::MainnetV0::ID => "mainnet",
193            snarkvm::console::network::TestnetV0::ID => "testnet",
194            snarkvm::console::network::CanaryV0::ID => "canary",
195            unknown_id => bail!("Unknown network ID ({unknown_id})"),
196        };
197
198        let router = {
199            let routes = axum::Router::new()
200                // All the endpoints before the call to `route_layer` are protected with JWT auth.
201                .route(
202                    &format!("/{network}/node/address"),
203                    get(Self::get_node_address),
204                )
205                .route(
206                    &format!("/{network}/program/:id/mapping/:name"),
207                    get(Self::get_mapping_values),
208                )
209                .route_layer(middleware::from_fn(auth_middleware))
210                // GET ../block/..
211                .route(
212                    &format!("/{network}/block/height/latest"),
213                    get(Self::get_block_height_latest),
214                )
215                .route(
216                    &format!("/{network}/block/hash/latest"),
217                    get(Self::get_block_hash_latest),
218                )
219                .route(
220                    &format!("/{network}/block/latest"),
221                    get(Self::get_block_latest),
222                )
223                .route(
224                    &format!("/{network}/block/:height_or_hash"),
225                    get(Self::get_block),
226                )
227                // The path param here is actually only the height, but the name must match the route
228                // above, otherwise there'll be a conflict at runtime.
229                .route(&format!("/{network}/block/:height_or_hash/header"), get(Self::get_block_header))
230                .route(&format!("/{network}/block/:height_or_hash/transactions"), get(Self::get_block_transactions))
231                // GET and POST ../transaction/..
232                .route(
233                    &format!("/{network}/transaction/:id"),
234                    get(Self::get_transaction),
235                )
236                .route(
237                    &format!("/{network}/transaction/confirmed/:id"),
238                    get(Self::get_confirmed_transaction),
239                )
240                .route(
241                    &format!("/{network}/transaction/broadcast"),
242                    post(Self::transaction_broadcast),
243                )
244                // POST ../solution/broadcast
245                .route(
246                    &format!("/{network}/solution/broadcast"),
247                    post(Self::solution_broadcast),
248                )
249                // GET ../find/..
250                .route(
251                    &format!("/{network}/find/blockHash/:tx_id"),
252                    get(Self::find_block_hash),
253                )
254                .route(
255                    &format!("/{network}/find/blockHeight/:state_root"),
256                    get(Self::find_block_height_from_state_root),
257                )
258                .route(
259                    &format!("/{network}/find/transactionID/deployment/:program_id"),
260                    get(Self::find_transaction_id_from_program_id),
261                )
262                .route(
263                    &format!("/{network}/find/transactionID/:transition_id"),
264                    get(Self::find_transaction_id_from_transition_id),
265                )
266                .route(
267                    &format!("/{network}/find/transitionID/:input_or_output_id"),
268                    get(Self::find_transition_id),
269                )
270                // GET ../peers/..
271                .route(
272                    &format!("/{network}/peers/count"),
273                    get(Self::get_peers_count),
274                )
275                .route(&format!("/{network}/peers/all"), get(Self::get_peers_all))
276                .route(
277                    &format!("/{network}/peers/all/metrics"),
278                    get(Self::get_peers_all_metrics),
279                )
280                // GET ../program/..
281                .route(&format!("/{network}/program/:id"), get(Self::get_program))
282                .route(
283                    &format!("/{network}/program/:id/mappings"),
284                    get(Self::get_mapping_names),
285                )
286                .route(
287                    &format!("/{network}/program/:id/mapping/:name/:key"),
288                    get(Self::get_mapping_value),
289                )
290                // GET misc endpoints.
291                .route(&format!("/{network}/blocks"), get(Self::get_blocks))
292                .route(&format!("/{network}/height/:hash"), get(Self::get_height))
293                .route(
294                    &format!("/{network}/memoryPool/transmissions"),
295                    get(Self::get_memory_pool_transmissions),
296                )
297                .route(
298                    &format!("/{network}/memoryPool/solutions"),
299                    get(Self::get_memory_pool_solutions),
300                )
301                .route(
302                    &format!("/{network}/memoryPool/transactions"),
303                    get(Self::get_memory_pool_transactions),
304                )
305                .route(
306                    &format!("/{network}/statePath/:commitment"),
307                    get(Self::get_state_path_for_commitment),
308                )
309                .route(
310                    &format!("/{network}/stateRoot/latest"),
311                    get(Self::get_state_root_latest),
312                )
313                .route(
314                    &format!("/{network}/stateRoot/:height"),
315                    get(Self::get_state_root),
316                )
317                .route(
318                    &format!("/{network}/committee/latest"),
319                    get(Self::get_committee_latest),
320                )
321                .route(
322                    &format!("/{network}/committee/:height"),
323                    get(Self::get_committee),
324                )
325                .route(
326                    &format!("/{network}/delegators/:validator"),
327                    get(Self::get_delegators_for_validator),
328                );
329
330            // If the `history` feature is enabled, enable the additional endpoint.
331            #[cfg(feature = "history")]
332            let routes =
333                routes.route(&format!("/{network}/block/:blockHeight/history/:mapping"), get(Self::get_history));
334
335            routes
336                // Pass in `Rest` to make things convenient.
337                .with_state(self.clone())
338                // Enable tower-http tracing.
339                .layer(TraceLayer::new_for_http())
340                // Custom logging.
341                .layer(middleware::from_fn_with_state(self.clone(), Self::log_middleware))
342                // Enable CORS.
343                .layer(cors)
344                // Cap body size at 512KiB.
345                .layer(DefaultBodyLimit::max(512 * 1024))
346                .layer(GovernorLayer {
347                    // We can leak this because it is created only once and it persists.
348                    config: Box::leak(governor_config),
349                })
350        };
351
352        // Create channels to signal the server to shutdown, and to signal when the server has shutdown.
353        let (shutdown_trigger_tx, shutdown_trigger_rx) = oneshot::channel::<()>();
354        let (shutdown_complete_tx, shutdown_complete_rx) = watch::channel::<bool>(false);
355        let tracing_: TracingHandler = self.tracing.clone().into();
356
357        // Bind the REST server and catch port conflict errors
358        let rest_listener =
359            TcpListener::bind(rest_ip).await.map_err(|err| anyhow!("Failed to bind to {}: {}", rest_ip, err))?;
360
361        let serve_handle = tokio::spawn(async move {
362            let result = axum::serve(rest_listener, router.into_make_service_with_connect_info::<SocketAddr>())
363                .with_graceful_shutdown(Self::shutdown_wait(tracing_.clone(), shutdown_trigger_rx))
364                .await;
365
366            if let Err(error) = result {
367                guard_error!(tracing_, "Couldn't start REST server: {}", error);
368            }
369
370            let _ = shutdown_complete_tx.send(true);
371        });
372
373        *self.rest_handle.lock() = Some(serve_handle);
374        *self.shutdown_trigger_tx.lock() = Some(shutdown_trigger_tx);
375        *self.shutdown_complete_rx.lock() = Some(shutdown_complete_rx);
376
377        Ok(())
378    }
379
380    async fn log_middleware(
381        State(rest): State<Self>,
382        ConnectInfo(addr): ConnectInfo<SocketAddr>,
383        request: Request<Body>,
384        next: Next,
385    ) -> Result<Response, StatusCode> {
386        guard_info!(rest, "Received '{} {}' from '{addr}'", request.method(), request.uri());
387        Ok(next.run(request).await)
388    }
389
390    async fn shutdown_wait(tracing: TracingHandler, shutdown_rx: oneshot::Receiver<()>) {
391        if let Err(error) = shutdown_rx.await {
392            guard_error!(tracing, "REST server shutdown signaling error: {}", error);
393        }
394
395        guard_info!(tracing, "REST server shutdown signal recieved...");
396    }
397}
398
399/// Formats an ID into a truncated identifier (for logging purposes).
400pub fn fmt_id(id: impl ToString) -> String {
401    let id = id.to_string();
402    let mut formatted_id = id.chars().take(16).collect::<String>();
403    if id.chars().count() > 16 {
404        formatted_id.push_str("..");
405    }
406    formatted_id
407}