amareleo_node_rest/
lib.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#![forbid(unsafe_code)]
17
18#[macro_use]
19extern crate tracing;
20
21mod helpers;
22pub use helpers::*;
23
24mod routes;
25
26use amareleo_node_consensus::Consensus;
27use snarkvm::{
28    console::{program::ProgramID, types::Field},
29    prelude::{Ledger, Network, cfg_into_iter, store::ConsensusStorage},
30};
31
32use anyhow::Result;
33use axum::{
34    Json,
35    body::Body,
36    extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State},
37    http::{Method, Request, StatusCode, header::CONTENT_TYPE},
38    middleware,
39    middleware::Next,
40    response::Response,
41    routing::{get, post},
42};
43use axum_extra::response::ErasedJson;
44use parking_lot::Mutex;
45use std::{net::SocketAddr, sync::Arc};
46use tokio::{net::TcpListener, task::JoinHandle};
47use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
48use tower_http::{
49    cors::{Any, CorsLayer},
50    trace::TraceLayer,
51};
52
53/// A REST API server for the ledger.
54#[derive(Clone)]
55pub struct Rest<N: Network, C: ConsensusStorage<N>> {
56    /// The consensus module.
57    consensus: Option<Consensus<N>>,
58    /// The ledger.
59    ledger: Ledger<N, C>,
60    /// The server handles.
61    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
62}
63
64impl<N: Network, C: 'static + ConsensusStorage<N>> Rest<N, C> {
65    /// Initializes a new instance of the server.
66    pub async fn start(
67        rest_ip: SocketAddr,
68        rest_rps: u32,
69        consensus: Option<Consensus<N>>,
70        ledger: Ledger<N, C>,
71    ) -> Result<Self> {
72        // Initialize the server.
73        let mut server = Self { consensus, ledger, handles: Default::default() };
74        // Spawn the server.
75        server.spawn_server(rest_ip, rest_rps).await;
76        // Return the server.
77        Ok(server)
78    }
79}
80
81impl<N: Network, C: ConsensusStorage<N>> Rest<N, C> {
82    /// Returns the ledger.
83    pub const fn ledger(&self) -> &Ledger<N, C> {
84        &self.ledger
85    }
86
87    /// Returns the handles.
88    pub const fn handles(&self) -> &Arc<Mutex<Vec<JoinHandle<()>>>> {
89        &self.handles
90    }
91}
92
93impl<N: Network, C: ConsensusStorage<N>> Rest<N, C> {
94    async fn spawn_server(&mut self, rest_ip: SocketAddr, rest_rps: u32) {
95        let cors = CorsLayer::new()
96            .allow_origin(Any)
97            .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
98            .allow_headers([CONTENT_TYPE]);
99
100        // Log the REST rate limit per IP.
101        debug!("REST rate limit per IP - {rest_rps} RPS");
102
103        // Prepare the rate limiting setup.
104        let governor_config = Box::new(
105            GovernorConfigBuilder::default()
106                .per_nanosecond((1_000_000_000 / rest_rps) as u64)
107                .burst_size(rest_rps)
108                .error_handler(|error| {
109                    // Properly return a 429 Too Many Requests error
110                    let error_message = error.to_string();
111                    let mut response = Response::new(error_message.clone().into());
112                    *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
113                    if error_message.contains("Too Many Requests") {
114                        *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
115                    }
116                    response
117                })
118                .finish()
119                .expect("Couldn't set up rate limiting for the REST server!"),
120        );
121
122        // Get the network being used.
123        let network = match N::ID {
124            snarkvm::console::network::MainnetV0::ID => "mainnet",
125            snarkvm::console::network::TestnetV0::ID => "testnet",
126            snarkvm::console::network::CanaryV0::ID => "canary",
127            unknown_id => {
128                eprintln!("Unknown network ID ({unknown_id})");
129                return;
130            }
131        };
132
133        let router = {
134            let routes = axum::Router::new()
135                // All the endpoints before the call to `route_layer` are protected with JWT auth.
136                .route(
137                    &format!("/{network}/node/address"),
138                    get(Self::get_node_address),
139                )
140                .route(
141                    &format!("/{network}/program/:id/mapping/:name"),
142                    get(Self::get_mapping_values),
143                )
144                .route_layer(middleware::from_fn(auth_middleware))
145                // GET ../block/..
146                .route(
147                    &format!("/{network}/block/height/latest"),
148                    get(Self::get_block_height_latest),
149                )
150                .route(
151                    &format!("/{network}/block/hash/latest"),
152                    get(Self::get_block_hash_latest),
153                )
154                .route(
155                    &format!("/{network}/block/latest"),
156                    get(Self::get_block_latest),
157                )
158                .route(
159                    &format!("/{network}/block/:height_or_hash"),
160                    get(Self::get_block),
161                )
162                // The path param here is actually only the height, but the name must match the route
163                // above, otherwise there'll be a conflict at runtime.
164                .route(
165                    &format!("/{network}/block/:height_or_hash/transactions"),
166                    get(Self::get_block_transactions),
167                )
168                // GET and POST ../transaction/..
169                .route(
170                    &format!("/{network}/transaction/:id"),
171                    get(Self::get_transaction),
172                )
173                .route(
174                    &format!("/{network}/transaction/confirmed/:id"),
175                    get(Self::get_confirmed_transaction),
176                )
177                .route(
178                    &format!("/{network}/transaction/broadcast"),
179                    post(Self::transaction_broadcast),
180                )
181                // POST ../solution/broadcast
182                .route(
183                    &format!("/{network}/solution/broadcast"),
184                    post(Self::solution_broadcast),
185                )
186                // GET ../find/..
187                .route(
188                    &format!("/{network}/find/blockHash/:tx_id"),
189                    get(Self::find_block_hash),
190                )
191                .route(
192                    &format!("/{network}/find/blockHeight/:state_root"),
193                    get(Self::find_block_height_from_state_root),
194                )
195                .route(
196                    &format!("/{network}/find/transactionID/deployment/:program_id"),
197                    get(Self::find_transaction_id_from_program_id),
198                )
199                .route(
200                    &format!("/{network}/find/transactionID/:transition_id"),
201                    get(Self::find_transaction_id_from_transition_id),
202                )
203                .route(
204                    &format!("/{network}/find/transitionID/:input_or_output_id"),
205                    get(Self::find_transition_id),
206                )
207                // GET ../peers/..
208                .route(
209                    &format!("/{network}/peers/count"),
210                    get(Self::get_peers_count),
211                )
212                .route(&format!("/{network}/peers/all"), get(Self::get_peers_all))
213                .route(
214                    &format!("/{network}/peers/all/metrics"),
215                    get(Self::get_peers_all_metrics),
216                )
217                // GET ../program/..
218                .route(&format!("/{network}/program/:id"), get(Self::get_program))
219                .route(
220                    &format!("/{network}/program/:id/mappings"),
221                    get(Self::get_mapping_names),
222                )
223                .route(
224                    &format!("/{network}/program/:id/mapping/:name/:key"),
225                    get(Self::get_mapping_value),
226                )
227                // GET misc endpoints.
228                .route(&format!("/{network}/blocks"), get(Self::get_blocks))
229                .route(&format!("/{network}/height/:hash"), get(Self::get_height))
230                .route(
231                    &format!("/{network}/memoryPool/transmissions"),
232                    get(Self::get_memory_pool_transmissions),
233                )
234                .route(
235                    &format!("/{network}/memoryPool/solutions"),
236                    get(Self::get_memory_pool_solutions),
237                )
238                .route(
239                    &format!("/{network}/memoryPool/transactions"),
240                    get(Self::get_memory_pool_transactions),
241                )
242                .route(
243                    &format!("/{network}/statePath/:commitment"),
244                    get(Self::get_state_path_for_commitment),
245                )
246                .route(
247                    &format!("/{network}/stateRoot/latest"),
248                    get(Self::get_state_root_latest),
249                )
250                .route(
251                    &format!("/{network}/stateRoot/:height"),
252                    get(Self::get_state_root),
253                )
254                .route(
255                    &format!("/{network}/committee/latest"),
256                    get(Self::get_committee_latest),
257                )
258                .route(
259                    &format!("/{network}/committee/:height"),
260                    get(Self::get_committee),
261                )
262                .route(
263                    &format!("/{network}/delegators/:validator"),
264                    get(Self::get_delegators_for_validator),
265                );
266
267            // If the `history` feature is enabled, enable the additional endpoint.
268            #[cfg(feature = "history")]
269            let routes =
270                routes.route(&format!("/{network}/block/:blockHeight/history/:mapping"), get(Self::get_history));
271
272            routes
273                // Pass in `Rest` to make things convenient.
274                .with_state(self.clone())
275                // Enable tower-http tracing.
276                .layer(TraceLayer::new_for_http())
277                // Custom logging.
278                .layer(middleware::from_fn(log_middleware))
279                // Enable CORS.
280                .layer(cors)
281                // Cap body size at 512KiB.
282                .layer(DefaultBodyLimit::max(512 * 1024))
283                .layer(GovernorLayer {
284                    // We can leak this because it is created only once and it persists.
285                    config: Box::leak(governor_config),
286                })
287        };
288
289        let rest_listener = TcpListener::bind(rest_ip).await.unwrap();
290        self.handles.lock().push(tokio::spawn(async move {
291            axum::serve(rest_listener, router.into_make_service_with_connect_info::<SocketAddr>())
292                .await
293                .expect("couldn't start rest server");
294        }))
295    }
296}
297
298async fn log_middleware(
299    ConnectInfo(addr): ConnectInfo<SocketAddr>,
300    request: Request<Body>,
301    next: Next,
302) -> Result<Response, StatusCode> {
303    info!("Received '{} {}' from '{addr}'", request.method(), request.uri());
304
305    Ok(next.run(request).await)
306}
307
308/// Formats an ID into a truncated identifier (for logging purposes).
309pub fn fmt_id(id: impl ToString) -> String {
310    let id = id.to_string();
311    let mut formatted_id = id.chars().take(16).collect::<String>();
312    if id.chars().count() > 16 {
313        formatted_id.push_str("..");
314    }
315    formatted_id
316}