starknet_devnet_server/
server.rs1use std::time::Duration;
2
3use axum::Router;
4use axum::body::{Body, Bytes};
5use axum::extract::{DefaultBodyLimit, Request};
6use axum::http::{HeaderValue, StatusCode};
7use axum::middleware::Next;
8use axum::response::{IntoResponse, Response};
9use axum::routing::{IntoMakeService, get, post};
10use http_body_util::BodyExt;
11use reqwest::{Method, header};
12use tokio::net::TcpListener;
13use tower_http::cors::CorsLayer;
14use tower_http::timeout::TimeoutLayer;
15use tower_http::trace::TraceLayer;
16
17use crate::api::JsonRpcHandler;
18use crate::rpc_handler::RpcHandler;
19use crate::{ServerConfig, rpc_handler};
20pub type StarknetDevnetServer = axum::serve::Serve<TcpListener, IntoMakeService<Router>, Router>;
21
22fn json_rpc_routes<TJsonRpcHandler: RpcHandler>(json_rpc_handler: TJsonRpcHandler) -> Router {
23 Router::new()
24 .route("/", post(rpc_handler::handle::<TJsonRpcHandler>))
25 .route("/rpc", post(rpc_handler::handle::<TJsonRpcHandler>))
26 .route("/ws", get(rpc_handler::handle_socket::<TJsonRpcHandler>))
27 .with_state(json_rpc_handler)
28}
29
30pub async fn serve_http_json_rpc(
32 tcp_listener: TcpListener,
33 server_config: &ServerConfig,
34 json_rpc_handler: JsonRpcHandler,
35) -> StarknetDevnetServer {
36 let mut routes = Router::new()
37 .route("/is_alive", get(|| async { "Alive!!!" })) .merge(json_rpc_routes(json_rpc_handler.clone()))
39 .layer(TraceLayer::new_for_http());
40
41 if server_config.log_response {
42 routes = routes.layer(axum::middleware::from_fn(response_logging_middleware));
43 };
44
45 routes = routes
46 .layer(TimeoutLayer::with_status_code(
47 StatusCode::REQUEST_TIMEOUT,
48 Duration::from_secs(server_config.timeout.into()),
49 ))
50 .layer(DefaultBodyLimit::disable())
51 .layer(
52 CorsLayer::new()
54 .allow_origin(HeaderValue::from_static("*"))
55 .allow_headers(vec![header::CONTENT_TYPE])
56 .allow_methods(vec![Method::GET, Method::POST]),
57 );
58
59 if server_config.log_request {
60 routes = routes.layer(axum::middleware::from_fn(request_logging_middleware));
61 }
62
63 axum::serve(tcp_listener, routes.into_make_service())
64}
65
66async fn log_body_and_path<T>(
67 body: T,
68 uri_option: Option<axum::http::Uri>,
69) -> Result<axum::body::Body, (StatusCode, String)>
70where
71 T: axum::body::HttpBody<Data = Bytes>,
72 T::Error: std::fmt::Display,
73{
74 let bytes = match body.collect().await {
75 Ok(collected) => collected.to_bytes(),
76 Err(err) => {
77 return Err((StatusCode::INTERNAL_SERVER_ERROR, err.to_string()));
78 }
79 };
80
81 if let Ok(body_str) = std::str::from_utf8(&bytes) {
82 if let Some(uri) = uri_option {
83 tracing::info!("{} {}", uri, body_str);
84 } else {
85 tracing::info!("{}", body_str);
86 }
87 } else {
88 tracing::error!("Failed to convert body to string");
89 }
90
91 Ok(Body::from(bytes))
92}
93
94async fn request_logging_middleware(
95 request: Request,
96 next: Next,
97) -> Result<impl IntoResponse, (StatusCode, String)> {
98 let (parts, body) = request.into_parts();
99
100 let body = log_body_and_path(body, Some(parts.uri.clone())).await?;
101 Ok(next.run(Request::from_parts(parts, body)).await)
102}
103
104async fn response_logging_middleware(
105 request: Request,
106 next: Next,
107) -> Result<impl IntoResponse, (StatusCode, String)> {
108 let response = next.run(request).await;
109
110 let (parts, body) = response.into_parts();
111
112 let body = log_body_and_path(body, None).await?;
113
114 let response = Response::from_parts(parts, body);
115 Ok(response)
116}