essential_rest_server/
lib.rs

1#![deny(missing_docs)]
2//! # Server
3//!
4//! A simple REST server for the Essential platform.
5
6use anyhow::anyhow;
7use axum::{
8    extract::{Path, Query, State},
9    response::{
10        sse::{Event, KeepAlive},
11        IntoResponse, Sse,
12    },
13    routing::{get, post},
14    Json, Router,
15};
16use essential_server::{CheckSolutionOutput, Essential, SolutionOutcome, StateRead, Storage};
17use essential_server_types::{CheckSolution, QueryStateReads, QueryStateReadsOutput};
18use essential_types::{
19    contract::{Contract, SignedContract},
20    convert::word_from_bytes,
21    predicate::Predicate,
22    solution::Solution,
23    Block, ContentAddress, PredicateAddress, Word,
24};
25use futures::{Stream, StreamExt};
26use hyper::body::Incoming;
27use hyper_util::rt::{TokioExecutor, TokioIo};
28use serde::Deserialize;
29use std::{net::SocketAddr, time::Duration};
30use tokio::{
31    net::{TcpListener, ToSocketAddrs},
32    sync::oneshot,
33    task::JoinSet,
34};
35use tower::Service;
36use tower_http::cors::CorsLayer;
37
38const MAX_CONNECTIONS: usize = 2000;
39
40#[derive(Debug, Clone)]
41/// Server configuration.
42pub struct Config {
43    /// Whether the rest server should build blocks
44    /// or just serve requests.
45    /// Default is `true`.
46    pub build_blocks: bool,
47    /// Essential server configuration.
48    pub server_config: essential_server::Config,
49}
50
51#[derive(Deserialize)]
52/// Type to deserialize a time range query parameters.
53struct TimeRange {
54    /// Start of the time range in seconds.
55    start: u64,
56    /// End of the time range in seconds.
57    end: u64,
58}
59
60#[derive(Deserialize)]
61/// Type to deserialize a time query parameters.
62struct Time {
63    /// Time in seconds.
64    time: u64,
65}
66
67#[derive(Deserialize)]
68/// Type to deserialize a page query parameter.
69struct Page {
70    /// The page number to start from.
71    page: u64,
72}
73
74#[derive(Deserialize)]
75/// Type to deserialize a block number query parameter.
76struct BlockNumber {
77    /// The block number to start from.
78    block: u64,
79}
80
81/// Run the server.
82///
83/// - Takes the essential library to run it.
84/// - Address to bind to.
85/// - A channel that returns the actual chosen local address.
86/// - An optional channel that can be used to shutdown the server.
87pub async fn run<S, A>(
88    essential: Essential<S>,
89    addr: A,
90    local_addr: oneshot::Sender<SocketAddr>,
91    shutdown_rx: Option<oneshot::Receiver<()>>,
92    config: Config,
93) -> anyhow::Result<()>
94where
95    A: ToSocketAddrs,
96    S: Storage + StateRead + Clone + Send + Sync + 'static,
97    <S as StateRead>::Future: Send,
98    <S as StateRead>::Error: Send,
99{
100    // Spawn essential and get the handle.
101    let handle = if config.build_blocks {
102        Some(essential.clone().spawn(config.server_config)?)
103    } else {
104        None
105    };
106
107    let cors = CorsLayer::new()
108        .allow_origin(tower_http::cors::Any)
109        .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
110        .allow_headers([http::header::CONTENT_TYPE]);
111
112    // Create all the endpoints.
113    let app = Router::new()
114        .route("/", get(health_check))
115        .route("/deploy-contract", post(deploy_contract))
116        .route("/get-contract/:address", get(get_contract))
117        .route("/get-predicate/:contract/:address", get(get_predicate))
118        .route("/list-contracts", get(list_contracts))
119        .route("/subscribe-contracts", get(subscribe_contracts))
120        .route("/submit-solution", post(submit_solution))
121        .route("/list-solutions-pool", get(list_solutions_pool))
122        .route("/query-state/:address/:key", get(query_state))
123        .route("/list-blocks", get(list_blocks))
124        .route("/subscribe-blocks", get(subscribe_blocks))
125        .route("/solution-outcome/:hash", get(solution_outcome))
126        .route("/check-solution", post(check_solution))
127        .route(
128            "/check-solution-with-contracts",
129            post(check_solution_with_contracts),
130        )
131        .route("/query-state-reads", post(query_state_reads))
132        .layer(cors)
133        .with_state(essential.clone());
134
135    // Bind to the address.
136    let listener = TcpListener::bind(addr).await?;
137
138    // Send the local address to the caller.
139    // This is useful when the address or port is chosen by the OS.
140    let addr = listener.local_addr()?;
141    local_addr
142        .send(addr)
143        .map_err(|_| anyhow::anyhow!("Failed to send local address"))?;
144
145    // Serve the app.
146    serve(app, listener, shutdown_rx).await;
147
148    // After the server is done, shutdown essential.
149    if let Some(handle) = handle {
150        handle.shutdown().await?;
151    }
152
153    Ok(())
154}
155
156async fn serve(app: Router, listener: TcpListener, shutdown_rx: Option<oneshot::Receiver<()>>) {
157    let shut = shutdown(shutdown_rx);
158    tokio::pin!(shut);
159
160    let mut conn_contract = JoinSet::new();
161    // Continuously accept new connections up to max connections.
162    loop {
163        // Accept a new connection or wait for a shutdown signal.
164        let (socket, remote_addr) = tokio::select! {
165            _ = &mut shut => {
166                break;
167            }
168            v = listener.accept() => {
169                match v {
170                    Ok(v) => v,
171                    Err(err) => {
172                        #[cfg(feature = "tracing")]
173                        tracing::trace!("Failed to accept connection {}", err);
174                        continue;
175                    }
176                }
177            }
178        };
179
180        #[cfg(feature = "tracing")]
181        tracing::trace!("Accepted new connection from: {}", remote_addr);
182
183        // We don't need to call `poll_ready` because `Router` is always ready.
184        let tower_service = app.clone();
185
186        // Spawn a task to handle the connection. That way we can handle multiple connections
187        // concurrently.
188
189        conn_contract.spawn(async move {
190            // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
191            // `TokioIo` converts between them.
192            let socket = TokioIo::new(socket);
193
194            // Hyper also has its own `Service` trait and doesn't use tower. We can use
195            // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
196            // `tower::Service::call`.
197            let hyper_service =
198                hyper::service::service_fn(move |request: axum::extract::Request<Incoming>| {
199                    // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
200                    // tower's `Service` requires `&mut self`.
201                    //
202                    // We don't need to call `poll_ready` since `Router` is always ready.
203                    tower_service.clone().call(request)
204                });
205
206            // `TokioExecutor` tells hyper to use `tokio::spawn` to spawn tasks.
207            let conn =
208                hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()).http2_only();
209            let conn = conn.serve_connection(socket, hyper_service);
210            let _ = conn.await;
211        });
212
213        // Wait for existing connection to close or wait for a shutdown signal.
214        if conn_contract.len() > MAX_CONNECTIONS {
215            #[cfg(feature = "tracing")]
216            tracing::info!("Max number of connections reached: {}", MAX_CONNECTIONS);
217            tokio::select! {
218                _ = &mut shut => {
219                    break;
220                }
221                _ = conn_contract.join_next() => {},
222
223            }
224        }
225    }
226}
227
228/// The return a health check response.
229async fn health_check() {}
230
231/// The deploy contract post endpoint.
232///
233/// Takes a signed vector of contract as a json payload.
234async fn deploy_contract<S>(
235    State(essential): State<Essential<S>>,
236    Json(payload): Json<SignedContract>,
237) -> Result<Json<ContentAddress>, Error>
238where
239    S: Storage + StateRead + Clone + Send + Sync + 'static,
240    <S as StateRead>::Future: Send,
241    <S as StateRead>::Error: Send,
242{
243    let address = essential.deploy_contract(payload).await?;
244    Ok(Json(address))
245}
246
247/// The submit solution post endpoint.
248///
249/// Takes a signed solution as a json payload.
250async fn submit_solution<S>(
251    State(essential): State<Essential<S>>,
252    Json(payload): Json<Solution>,
253) -> Result<Json<ContentAddress>, Error>
254where
255    S: Storage + StateRead + Clone + Send + Sync + 'static,
256    <S as StateRead>::Future: Send,
257    <S as StateRead>::Error: Send,
258{
259    let hash = essential.submit_solution(payload).await?;
260    Ok(Json(hash))
261}
262
263/// The get contract get endpoint.
264///
265/// Takes a content address (encoded as hex) as a path parameter.
266async fn get_contract<S>(
267    State(essential): State<Essential<S>>,
268    Path(address): Path<String>,
269) -> Result<Json<Option<SignedContract>>, Error>
270where
271    S: Storage + StateRead + Clone + Send + Sync + 'static,
272    <S as StateRead>::Future: Send,
273    <S as StateRead>::Error: Send,
274{
275    let address: ContentAddress = address
276        .parse()
277        .map_err(|e| anyhow!("failed to parse contract content address: {e}"))?;
278    let contract = essential.get_contract(&address).await?;
279    Ok(Json(contract))
280}
281
282/// The get predicate get endpoint.
283///
284/// Takes a contract content address and a predicate content address as path parameters.
285/// Both are encoded as hex.
286async fn get_predicate<S>(
287    State(essential): State<Essential<S>>,
288    Path((contract, address)): Path<(String, String)>,
289) -> Result<Json<Option<Predicate>>, Error>
290where
291    S: Storage + StateRead + Clone + Send + Sync + 'static,
292    <S as StateRead>::Future: Send,
293    <S as StateRead>::Error: Send,
294{
295    let contract: ContentAddress = contract
296        .parse()
297        .map_err(|e| anyhow!("failed to parse contract content address: {e}"))?;
298    let predicate: ContentAddress = address
299        .parse()
300        .map_err(|e| anyhow!("failed to parse predicate content address: {e}"))?;
301    let predicate = essential
302        .get_predicate(&PredicateAddress {
303            contract,
304            predicate,
305        })
306        .await?;
307    Ok(Json(predicate))
308}
309
310/// The list contracts get endpoint.
311///
312/// Takes optional time range and page as query parameters.
313async fn list_contracts<S>(
314    State(essential): State<Essential<S>>,
315    time_range: Option<Query<TimeRange>>,
316    page: Option<Query<Page>>,
317) -> Result<Json<Vec<Contract>>, Error>
318where
319    S: Storage + StateRead + Clone + Send + Sync + 'static,
320    <S as StateRead>::Future: Send,
321    <S as StateRead>::Error: Send,
322{
323    let time_range =
324        time_range.map(|range| Duration::from_secs(range.start)..Duration::from_secs(range.end));
325
326    let contracts = essential
327        .list_contracts(time_range, page.map(|p| p.page as usize))
328        .await?;
329    Ok(Json(contracts))
330}
331
332/// The subscribe contracts get endpoint.
333///
334/// Takes optional time and page as query parameters.
335async fn subscribe_contracts<S>(
336    State(essential): State<Essential<S>>,
337    time: Option<Query<Time>>,
338    page: Option<Query<Page>>,
339) -> Sse<impl Stream<Item = Result<Event, StdError>>>
340where
341    S: Storage + StateRead + Clone + Send + Sync + 'static,
342    <S as StateRead>::Future: Send,
343    <S as StateRead>::Error: Send,
344{
345    let time = time.map(|t| Duration::from_secs(t.time));
346
347    let contracts = essential.subscribe_contracts(time, page.map(|p| p.page as usize));
348    Sse::new(
349        contracts
350            .map::<Result<_, Error>, _>(|contract| Ok(Event::default().json_data(contract?)?))
351            .map(|r| r.map_err(StdError)),
352    )
353    .keep_alive(KeepAlive::default())
354}
355
356/// The list blocks get endpoint.
357///
358/// Takes optional time range and page as query parameters.
359async fn list_blocks<S>(
360    State(essential): State<Essential<S>>,
361    time_range: Option<Query<TimeRange>>,
362    block: Option<Query<BlockNumber>>,
363    page: Option<Query<Page>>,
364) -> Result<Json<Vec<Block>>, Error>
365where
366    S: Storage + StateRead + Clone + Send + Sync + 'static,
367    <S as StateRead>::Future: Send,
368    <S as StateRead>::Error: Send,
369{
370    let time_range =
371        time_range.map(|range| Duration::from_secs(range.start)..Duration::from_secs(range.end));
372
373    let blocks = essential
374        .list_blocks(
375            time_range,
376            block.map(|b| b.block),
377            page.map(|p| p.page as usize),
378        )
379        .await?;
380    Ok(Json(blocks))
381}
382
383/// The subscribe blocks get endpoint.
384///
385/// Takes optional time and page as query parameters.
386async fn subscribe_blocks<S>(
387    State(essential): State<Essential<S>>,
388    time: Option<Query<Time>>,
389    block: Option<Query<BlockNumber>>,
390    page: Option<Query<Page>>,
391) -> Sse<impl Stream<Item = Result<Event, StdError>>>
392where
393    S: Storage + StateRead + Clone + Send + Sync + 'static,
394    <S as StateRead>::Future: Send,
395    <S as StateRead>::Error: Send,
396{
397    let time = time.map(|time| Duration::from_secs(time.time));
398
399    let blocks =
400        essential.subscribe_blocks(time, block.map(|b| b.block), page.map(|p| p.page as usize));
401    Sse::new(
402        blocks
403            .map::<Result<_, Error>, _>(|block| Ok(Event::default().json_data(block?)?))
404            .map(|r| r.map_err(StdError)),
405    )
406    .keep_alive(KeepAlive::default())
407}
408
409/// The list solutions pool get endpoint.
410async fn list_solutions_pool<S>(
411    State(essential): State<Essential<S>>,
412    page: Option<Query<Page>>,
413) -> Result<Json<Vec<Solution>>, Error>
414where
415    S: Storage + StateRead + Clone + Send + Sync + 'static,
416    <S as StateRead>::Future: Send,
417    <S as StateRead>::Error: Send,
418{
419    let solutions = essential
420        .list_solutions_pool(page.map(|p| p.page as usize))
421        .await?;
422    Ok(Json(solutions))
423}
424
425/// The query state get endpoint.
426///
427/// Takes a content address and a byte array key as path parameters.
428/// Both are encoded as hex.
429async fn query_state<S>(
430    State(essential): State<Essential<S>>,
431    Path((address, key)): Path<(String, String)>,
432) -> Result<Json<Vec<Word>>, Error>
433where
434    S: Storage + StateRead + Clone + Send + Sync + 'static,
435    <S as StateRead>::Future: Send,
436    <S as StateRead>::Error: Send,
437{
438    let address: ContentAddress = address
439        .parse()
440        .map_err(|e| anyhow!("failed to parse contract content address: {e}"))?;
441    let key: Vec<u8> = hex::decode(key).map_err(|e| anyhow!("failed to decode key: {e}"))?;
442
443    // Convert the key to words.
444    let key = key
445        .chunks_exact(8)
446        .map(|chunk| word_from_bytes(chunk.try_into().expect("Safe due to chunk size")))
447        .collect::<Vec<_>>();
448
449    let state = essential.query_state(&address, &key).await?;
450    Ok(Json(state))
451}
452
453/// The solution outcome get endpoint.
454///
455/// Takes a solution content address as a path parameter encoded hex.
456async fn solution_outcome<S>(
457    State(essential): State<Essential<S>>,
458    Path(address): Path<String>,
459) -> Result<Json<Vec<SolutionOutcome>>, Error>
460where
461    S: Storage + StateRead + Clone + Send + Sync + 'static,
462    <S as StateRead>::Future: Send,
463    <S as StateRead>::Error: Send,
464{
465    let address: ContentAddress = address
466        .parse()
467        .map_err(|e| anyhow!("failed to parse solution content address: {e}"))?;
468    let outcome = essential.solution_outcome(&address.0).await?;
469    Ok(Json(outcome))
470}
471
472/// The check solution post endpoint.
473///
474/// Takes a signed solution as a json payload.
475async fn check_solution<S>(
476    State(essential): State<Essential<S>>,
477    Json(payload): Json<Solution>,
478) -> Result<Json<CheckSolutionOutput>, Error>
479where
480    S: Storage + StateRead + Clone + Send + Sync + 'static,
481    <S as StateRead>::Future: Send,
482    <S as StateRead>::Error: Send,
483{
484    let outcome = essential.check_solution(payload).await?;
485    Ok(Json(outcome))
486}
487
488/// The check solution with data post endpoint.
489///
490/// Takes a signed solution and a list of contract as a json payload.
491async fn check_solution_with_contracts<S>(
492    State(essential): State<Essential<S>>,
493    Json(payload): Json<CheckSolution>,
494) -> Result<Json<CheckSolutionOutput>, Error>
495where
496    S: Storage + StateRead + Clone + Send + Sync + 'static,
497    <S as StateRead>::Future: Send,
498    <S as StateRead>::Error: Send,
499{
500    let outcome = essential
501        .check_solution_with_contracts(payload.solution, payload.contracts)
502        .await?;
503    Ok(Json(outcome))
504}
505
506/// The query state reads post endpoint.
507///
508/// Takes a json state read query and returns the outcome
509async fn query_state_reads<S>(
510    State(essential): State<Essential<S>>,
511    Json(payload): Json<QueryStateReads>,
512) -> Result<Json<QueryStateReadsOutput>, Error>
513where
514    S: Storage + StateRead + Clone + Send + Sync + 'static,
515    <S as StateRead>::Future: Send,
516    <S as StateRead>::Error: Send,
517{
518    let out = essential.query_state_reads(payload).await?;
519    Ok(Json(out))
520}
521
522/// Shutdown the server manually or on ctrl-c.
523async fn shutdown(rx: Option<oneshot::Receiver<()>>) {
524    // The manual signal is used to shutdown the server.
525    let manual = async {
526        match rx {
527            Some(rx) => {
528                rx.await.ok();
529            }
530            None => futures::future::pending().await,
531        }
532    };
533
534    // The ctrl-c signal is used to shutdown the server.
535    let ctrl_c = async {
536        tokio::signal::ctrl_c()
537            .await
538            .expect("Failed to listen for ctrl-c");
539    };
540
541    // Wait for either signal.
542    tokio::select! {
543        _ = manual => {},
544        _ = ctrl_c => {},
545    }
546}
547
548#[derive(Debug)]
549struct Error(anyhow::Error);
550
551#[derive(Debug)]
552struct StdError(Error);
553
554impl IntoResponse for Error {
555    fn into_response(self) -> axum::response::Response {
556        // Return an internal server error with the error message.
557        (
558            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
559            format!("{}", self.0),
560        )
561            .into_response()
562    }
563}
564
565impl<E> From<E> for Error
566where
567    E: Into<anyhow::Error>,
568{
569    fn from(err: E) -> Self {
570        Self(err.into())
571    }
572}
573
574impl From<Error> for StdError {
575    fn from(err: Error) -> Self {
576        Self(err)
577    }
578}
579
580impl std::error::Error for StdError {}
581
582impl std::fmt::Display for Error {
583    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584        self.0.fmt(f)
585    }
586}
587
588impl std::fmt::Display for StdError {
589    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
590        self.0.fmt(f)
591    }
592}
593
594impl Default for Config {
595    fn default() -> Self {
596        Self {
597            build_blocks: true,
598            server_config: Default::default(),
599        }
600    }
601}