amareleo_node_rest/
lib.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#![forbid(unsafe_code)]
17
18#[macro_use]
19extern crate tracing;
20
21#[macro_use]
22extern crate amareleo_chain_tracing;
23
24mod helpers;
25pub use helpers::*;
26
27mod routes;
28
29use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
30use amareleo_node_consensus::Consensus;
31use snarkvm::{
32    console::{program::ProgramID, types::Field},
33    prelude::{Ledger, Network, cfg_into_iter, store::ConsensusStorage},
34};
35
36use anyhow::{Result, anyhow, bail};
37use axum::{
38    Json,
39    body::Body,
40    extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State},
41    http::{Method, Request, StatusCode, header::CONTENT_TYPE},
42    middleware,
43    middleware::Next,
44    response::Response,
45    routing::{get, post},
46};
47use axum_extra::response::ErasedJson;
48#[cfg(feature = "locktick")]
49use locktick::parking_lot::Mutex;
50#[cfg(not(feature = "locktick"))]
51use parking_lot::Mutex;
52use std::{net::SocketAddr, sync::Arc};
53use tokio::{
54    net::TcpListener,
55    sync::{oneshot, watch},
56    task::JoinHandle,
57};
58use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
59use tower_http::{
60    cors::{Any, CorsLayer},
61    trace::TraceLayer,
62};
63use tracing::subscriber::DefaultGuard;
64
65/// A REST API server for the ledger.
66#[derive(Clone)]
67pub struct Rest<N: Network, C: ConsensusStorage<N>> {
68    /// The consensus module.
69    consensus: Option<Consensus<N>>,
70    /// The ledger.
71    ledger: Ledger<N, C>,
72    /// Tracing handle
73    tracing: Option<TracingHandler>,
74    /// signal to initiate shutdown
75    shutdown_trigger_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
76    /// signal of completed shutdown
77    shutdown_complete_rx: Arc<Mutex<Option<watch::Receiver<bool>>>>,
78    /// The server handles.
79    rest_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
80}
81
82impl<N: Network, C: 'static + ConsensusStorage<N>> Rest<N, C> {
83    /// Initializes a new instance of the server.
84    pub async fn start(
85        rest_ip: SocketAddr,
86        rest_rps: u32,
87        consensus: Option<Consensus<N>>,
88        ledger: Ledger<N, C>,
89        tracing: Option<TracingHandler>,
90    ) -> Result<Self> {
91        // Initialize the server.
92        let mut server = Self {
93            consensus,
94            ledger,
95            tracing,
96            shutdown_trigger_tx: Arc::new(Mutex::new(None)),
97            shutdown_complete_rx: Arc::new(Mutex::new(None)),
98            rest_handle: Arc::new(Mutex::new(None)),
99        };
100        // Spawn the server.
101        server.spawn_server(rest_ip, rest_rps).await?;
102        // Return the server.
103        Ok(server)
104    }
105
106    pub fn is_finished(&self) -> bool {
107        let lock = self.rest_handle.lock();
108        if let Some(handle) = lock.as_ref() { handle.is_finished() } else { true }
109    }
110
111    pub async fn wait_finish(&self) -> Result<()> {
112        if self.is_finished() {
113            guard_info!(self, "REST server already shutdown.");
114            return Ok(());
115        }
116
117        // Clone the shutdown complete signal receiver
118        let rx_option = {
119            let lock = self.shutdown_complete_rx.lock();
120            lock.as_ref().map(|opt| opt.clone())
121        };
122
123        if let Some(mut rx) = rx_option {
124            // Wait for completion
125            while !*rx.borrow() {
126                if rx.changed().await.is_err() {
127                    bail!("REST shutdown completed signal errored!");
128                }
129            }
130
131            guard_info!(self, "REST shutdown completed signal received.");
132        } else {
133            bail!("REST shutdown completed signal NOT found!");
134        }
135
136        Ok(())
137    }
138
139    pub async fn shut_down(&self) {
140        // Extract and replace with None
141        let shutdown_option = self.shutdown_trigger_tx.lock().take();
142        if let Some(tx) = shutdown_option {
143            let _ = tx.send(()); // Send shutdown signal
144        }
145
146        // Await for the server to shutdown
147        let _ = self.wait_finish().await;
148    }
149}
150
151impl<N: Network, C: ConsensusStorage<N>> Rest<N, C> {
152    /// Returns the ledger.
153    pub const fn ledger(&self) -> &Ledger<N, C> {
154        &self.ledger
155    }
156
157    /// Retruns tracing guard
158    pub fn get_tracing_guard(&self) -> Option<DefaultGuard> {
159        self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
160    }
161}
162
163impl<N: Network, C: ConsensusStorage<N>> Rest<N, C> {
164    async fn spawn_server(&mut self, rest_ip: SocketAddr, rest_rps: u32) -> Result<()> {
165        let cors = CorsLayer::new()
166            .allow_origin(Any)
167            .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
168            .allow_headers([CONTENT_TYPE]);
169
170        // Log the REST rate limit per IP.
171        guard_debug!(self, "REST rate limit per IP - {rest_rps} RPS");
172
173        // Prepare the rate limiting setup.
174        let governor_config = match GovernorConfigBuilder::default()
175            .per_nanosecond((1_000_000_000 / rest_rps) as u64)
176            .burst_size(rest_rps)
177            .error_handler(|error| {
178                // Properly return a 429 Too Many Requests error
179                let error_message = error.to_string();
180                let mut response = Response::new(error_message.clone().into());
181                *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
182                if error_message.contains("Too Many Requests") {
183                    *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
184                }
185                response
186            })
187            .finish()
188        {
189            Some(config) => Box::new(config),
190            None => bail!("Couldn't set up rate limiting for the REST server"),
191        };
192
193        // Get the network being used.
194        let network = match N::ID {
195            snarkvm::console::network::MainnetV0::ID => "mainnet",
196            snarkvm::console::network::TestnetV0::ID => "testnet",
197            snarkvm::console::network::CanaryV0::ID => "canary",
198            unknown_id => bail!("Unknown network ID ({unknown_id})"),
199        };
200
201        let router = {
202            let routes = axum::Router::new()
203                // All the endpoints before the call to `route_layer` are protected with JWT auth.
204                .route(
205                    &format!("/{network}/node/address"),
206                    get(Self::get_node_address),
207                )
208                .route(
209                    &format!("/{network}/program/:id/mapping/:name"),
210                    get(Self::get_mapping_values),
211                )
212                .route_layer(middleware::from_fn(auth_middleware))
213                // GET ../block/..
214                .route(
215                    &format!("/{network}/block/height/latest"),
216                    get(Self::get_block_height_latest),
217                )
218                .route(
219                    &format!("/{network}/block/hash/latest"),
220                    get(Self::get_block_hash_latest),
221                )
222                .route(
223                    &format!("/{network}/block/latest"),
224                    get(Self::get_block_latest),
225                )
226                .route(
227                    &format!("/{network}/block/:height_or_hash"),
228                    get(Self::get_block),
229                )
230                // The path param here is actually only the height, but the name must match the route
231                // above, otherwise there'll be a conflict at runtime.
232                .route(&format!("/{network}/block/:height_or_hash/header"), get(Self::get_block_header))
233                .route(&format!("/{network}/block/:height_or_hash/transactions"), get(Self::get_block_transactions))
234                // GET and POST ../transaction/..
235                .route(
236                    &format!("/{network}/transaction/:id"),
237                    get(Self::get_transaction),
238                )
239                .route(
240                    &format!("/{network}/transaction/confirmed/:id"),
241                    get(Self::get_confirmed_transaction),
242                )
243                .route(
244                    &format!("/{network}/transaction/broadcast"),
245                    post(Self::transaction_broadcast),
246                )
247                // POST ../solution/broadcast
248                .route(
249                    &format!("/{network}/solution/broadcast"),
250                    post(Self::solution_broadcast),
251                )
252                // GET ../find/..
253                .route(
254                    &format!("/{network}/find/blockHash/:tx_id"),
255                    get(Self::find_block_hash),
256                )
257                .route(
258                    &format!("/{network}/find/blockHeight/:state_root"),
259                    get(Self::find_block_height_from_state_root),
260                )
261                .route(
262                    &format!("/{network}/find/transactionID/deployment/:program_id"),
263                    get(Self::find_transaction_id_from_program_id),
264                )
265                .route(
266                    &format!("/{network}/find/transactionID/:transition_id"),
267                    get(Self::find_transaction_id_from_transition_id),
268                )
269                .route(
270                    &format!("/{network}/find/transitionID/:input_or_output_id"),
271                    get(Self::find_transition_id),
272                )
273                // GET ../peers/..
274                .route(
275                    &format!("/{network}/peers/count"),
276                    get(Self::get_peers_count),
277                )
278                .route(&format!("/{network}/peers/all"), get(Self::get_peers_all))
279                .route(
280                    &format!("/{network}/peers/all/metrics"),
281                    get(Self::get_peers_all_metrics),
282                )
283                // GET ../program/..
284                .route(&format!("/{network}/program/:id"), get(Self::get_program))
285                .route(
286                    &format!("/{network}/program/:id/mappings"),
287                    get(Self::get_mapping_names),
288                )
289                .route(
290                    &format!("/{network}/program/:id/mapping/:name/:key"),
291                    get(Self::get_mapping_value),
292                )
293                // GET misc endpoints.
294                .route(&format!("/{network}/blocks"), get(Self::get_blocks))
295                .route(&format!("/{network}/height/:hash"), get(Self::get_height))
296                .route(
297                    &format!("/{network}/memoryPool/transmissions"),
298                    get(Self::get_memory_pool_transmissions),
299                )
300                .route(
301                    &format!("/{network}/memoryPool/solutions"),
302                    get(Self::get_memory_pool_solutions),
303                )
304                .route(
305                    &format!("/{network}/memoryPool/transactions"),
306                    get(Self::get_memory_pool_transactions),
307                )
308                .route(
309                    &format!("/{network}/statePath/:commitment"),
310                    get(Self::get_state_path_for_commitment),
311                )
312                .route(
313                    &format!("/{network}/stateRoot/latest"),
314                    get(Self::get_state_root_latest),
315                )
316                .route(
317                    &format!("/{network}/stateRoot/:height"),
318                    get(Self::get_state_root),
319                )
320                .route(
321                    &format!("/{network}/committee/latest"),
322                    get(Self::get_committee_latest),
323                )
324                .route(
325                    &format!("/{network}/committee/:height"),
326                    get(Self::get_committee),
327                )
328                .route(
329                    &format!("/{network}/delegators/:validator"),
330                    get(Self::get_delegators_for_validator),
331                );
332
333            // If the `history` feature is enabled, enable the additional endpoint.
334            #[cfg(feature = "history")]
335            let routes =
336                routes.route(&format!("/{network}/block/:blockHeight/history/:mapping"), get(Self::get_history));
337
338            routes
339                // Pass in `Rest` to make things convenient.
340                .with_state(self.clone())
341                // Enable tower-http tracing.
342                .layer(TraceLayer::new_for_http())
343                // Custom logging.
344                .layer(middleware::from_fn_with_state(self.clone(), Self::log_middleware))
345                // Enable CORS.
346                .layer(cors)
347                // Cap body size at 512KiB.
348                .layer(DefaultBodyLimit::max(512 * 1024))
349                .layer(GovernorLayer {
350                    // We can leak this because it is created only once and it persists.
351                    config: Box::leak(governor_config),
352                })
353        };
354
355        // Create channels to signal the server to shutdown, and to signal when the server has shutdown.
356        let (shutdown_trigger_tx, shutdown_trigger_rx) = oneshot::channel::<()>();
357        let (shutdown_complete_tx, shutdown_complete_rx) = watch::channel::<bool>(false);
358        let tracing_: TracingHandler = self.tracing.clone().into();
359
360        // Bind the REST server and catch port conflict errors
361        let rest_listener =
362            TcpListener::bind(rest_ip).await.map_err(|err| anyhow!("Failed to bind to {}: {}", rest_ip, err))?;
363
364        let serve_handle = tokio::spawn(async move {
365            let result = axum::serve(rest_listener, router.into_make_service_with_connect_info::<SocketAddr>())
366                .with_graceful_shutdown(Self::shutdown_wait(tracing_.clone(), shutdown_trigger_rx))
367                .await;
368
369            if let Err(error) = result {
370                guard_error!(tracing_, "Couldn't start REST server: {}", error);
371            }
372
373            let _ = shutdown_complete_tx.send(true);
374        });
375
376        *self.rest_handle.lock() = Some(serve_handle);
377        *self.shutdown_trigger_tx.lock() = Some(shutdown_trigger_tx);
378        *self.shutdown_complete_rx.lock() = Some(shutdown_complete_rx);
379
380        Ok(())
381    }
382
383    async fn log_middleware(
384        State(rest): State<Self>,
385        ConnectInfo(addr): ConnectInfo<SocketAddr>,
386        request: Request<Body>,
387        next: Next,
388    ) -> Result<Response, StatusCode> {
389        guard_info!(rest, "Received '{} {}' from '{addr}'", request.method(), request.uri());
390        Ok(next.run(request).await)
391    }
392
393    async fn shutdown_wait(tracing: TracingHandler, shutdown_rx: oneshot::Receiver<()>) {
394        if let Err(error) = shutdown_rx.await {
395            guard_error!(tracing, "REST server shutdown signaling error: {}", error);
396        }
397
398        guard_info!(tracing, "REST server shutdown signal recieved...");
399    }
400}
401
402/// Formats an ID into a truncated identifier (for logging purposes).
403pub fn fmt_id(id: impl ToString) -> String {
404    let id = id.to_string();
405    let mut formatted_id = id.chars().take(16).collect::<String>();
406    if id.chars().count() > 16 {
407        formatted_id.push_str("..");
408    }
409    formatted_id
410}