Skip to main content

tycho_rpc/endpoint/
mod.rs

1use std::time::Duration;
2
3use anyhow::Result;
4use axum::RequestExt;
5use axum::extract::{DefaultBodyLimit, FromRef, Request, State};
6use axum::http::StatusCode;
7use axum::response::{IntoResponse, Response};
8use axum::routing::{get, post};
9use tokio::net::TcpListener;
10
11pub use self::jrpc::JrpcEndpointCache;
12pub use self::proto::ProtoEndpointCache;
13use crate::state::RpcState;
14use crate::util::mime::{APPLICATION_JSON, APPLICATION_PROTOBUF, get_mime_type};
15
16pub mod jrpc;
17pub mod proto;
18
19pub struct RpcEndpointBuilder<C = ()> {
20    common: RpcEndpointBuilderCommon,
21    custom_routes: C,
22}
23
24impl Default for RpcEndpointBuilder {
25    #[inline]
26    fn default() -> Self {
27        Self {
28            common: Default::default(),
29            custom_routes: (),
30        }
31    }
32}
33
34impl RpcEndpointBuilder<()> {
35    pub fn empty() -> Self {
36        Self {
37            common: RpcEndpointBuilderCommon::empty(),
38            custom_routes: (),
39        }
40    }
41
42    pub fn with_custom_routes<S>(
43        self,
44        routes: axum::Router<S>,
45    ) -> RpcEndpointBuilder<axum::Router<S>>
46    where
47        RpcState: FromRef<S>,
48        S: Send + Sync,
49    {
50        RpcEndpointBuilder {
51            common: self.common,
52            custom_routes: routes,
53        }
54    }
55
56    pub async fn bind(self, state: RpcState) -> Result<RpcEndpoint> {
57        let listener = state.bind_socket().await?;
58        Ok(RpcEndpoint::from_parts(
59            listener,
60            self.common.build(),
61            state,
62        ))
63    }
64}
65
66impl<C> RpcEndpointBuilder<C> {
67    pub fn with_healthcheck_route<T: Into<String>>(mut self, route: T) -> Self {
68        self.common.healthcheck_route = Some(route.into());
69        self
70    }
71
72    pub fn with_base_routes<I, T>(mut self, routes: I) -> Self
73    where
74        I: IntoIterator<Item = T>,
75        T: Into<String>,
76    {
77        self.common.base_routes = routes.into_iter().map(Into::into).collect();
78        self
79    }
80}
81
82impl<S> RpcEndpointBuilder<axum::Router<S>>
83where
84    RpcState: FromRef<S>,
85    S: Send + Sync + Clone + 'static,
86{
87    pub async fn bind(self, state: S) -> Result<RpcEndpoint> {
88        let listener = RpcState::from_ref(&state).bind_socket().await?;
89        Ok(RpcEndpoint::from_parts(
90            listener,
91            self.common.build::<S>().merge(self.custom_routes),
92            state,
93        ))
94    }
95}
96
97struct RpcEndpointBuilderCommon {
98    healthcheck_route: Option<String>,
99    base_routes: Vec<String>,
100}
101
102impl Default for RpcEndpointBuilderCommon {
103    fn default() -> Self {
104        Self {
105            healthcheck_route: Some("/".to_owned()),
106            base_routes: vec!["/".to_owned(), "/rpc".to_owned(), "/proto".to_owned()],
107        }
108    }
109}
110
111impl RpcEndpointBuilderCommon {
112    pub fn empty() -> Self {
113        Self {
114            healthcheck_route: None,
115            base_routes: Vec::new(),
116        }
117    }
118
119    fn build<S>(self) -> axum::Router<S>
120    where
121        RpcState: FromRef<S>,
122        S: Clone + Send + Sync + 'static,
123    {
124        let mut router = axum::Router::new();
125
126        if let Some(route) = self.healthcheck_route {
127            router = router.route(&route, get(health_check));
128        }
129        for route in self.base_routes {
130            router = router.route(&route, post(common_route));
131        }
132        router = router.merge(jrpc::stream_router::<S>());
133
134        router
135    }
136}
137
138pub struct RpcEndpoint {
139    listener: TcpListener,
140    router: axum::Router<()>,
141}
142
143impl RpcEndpoint {
144    pub fn builder() -> RpcEndpointBuilder {
145        RpcEndpointBuilder::default()
146    }
147
148    pub fn empty_builder() -> RpcEndpointBuilder {
149        RpcEndpointBuilder::empty()
150    }
151
152    pub fn from_parts<S>(listener: TcpListener, router: axum::Router<S>, state: S) -> Self
153    where
154        S: Clone + Send + Sync + 'static,
155    {
156        use tower::ServiceBuilder;
157        use tower_http::cors::CorsLayer;
158        use tower_http::timeout::TimeoutLayer;
159
160        // Prepare middleware
161        let service = ServiceBuilder::new()
162            .layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
163            .layer(CorsLayer::permissive())
164            .layer(TimeoutLayer::with_status_code(
165                StatusCode::REQUEST_TIMEOUT,
166                Duration::from_secs(25),
167            ));
168
169        #[cfg(feature = "compression")]
170        let service = service.layer(tower_http::compression::CompressionLayer::new().gzip(true));
171
172        // Prepare routes
173        let router = router.layer(service).with_state(state);
174
175        // Done
176        Self { listener, router }
177    }
178
179    pub async fn serve(self) -> std::io::Result<()> {
180        axum::serve(self.listener, self.router).await
181    }
182}
183
184fn health_check() -> futures_util::future::Ready<impl IntoResponse> {
185    futures_util::future::ready(
186        std::time::SystemTime::now()
187            .duration_since(std::time::UNIX_EPOCH)
188            .expect("system time before Unix epoch")
189            .as_millis()
190            .to_string(),
191    )
192}
193
194async fn common_route(state: State<RpcState>, req: Request) -> Response {
195    use axum::http::StatusCode;
196
197    match get_mime_type(&req) {
198        Some(mime) if mime.starts_with(APPLICATION_JSON) => match req.extract().await {
199            Ok(method) => jrpc::route(state, method).await,
200            Err(e) => e.into_response(),
201        },
202        Some(mime) if mime.starts_with(APPLICATION_PROTOBUF) => match req.extract().await {
203            Ok(request) => proto::route(state, request).await,
204            Err(e) => e.into_response(),
205        },
206        _ => StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response(),
207    }
208}
209
210const MAX_REQUEST_SIZE: usize = 2 << 17; // 256kb