palladium_http/
listener.rs1use crate::error::HttpError;
2use crate::message::{HttpRequest, HttpResponse};
3use axum::{
4 body::Body,
5 extract::{Request, State},
6 response::Response,
7 Router,
8};
9use http::StatusCode;
10use palladium_actor::{AskError, SendError, StableAddr};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::oneshot;
15
16pub struct ListenerConfig {
18 pub request_timeout: Duration,
21 pub max_body_bytes: usize,
23}
24
25impl Default for ListenerConfig {
26 fn default() -> Self {
27 Self {
28 request_timeout: Duration::from_secs(30),
29 max_body_bytes: 4 * 1024 * 1024,
30 }
31 }
32}
33
34pub struct HttpListener {
43 local_addr: SocketAddr,
44 shutdown_tx: oneshot::Sender<()>,
45}
46
47impl HttpListener {
48 pub async fn bind(
52 addr: SocketAddr,
53 handler: StableAddr<HttpRequest>,
54 config: ListenerConfig,
55 ) -> Result<Self, HttpError> {
56 Self::bind_inner(addr, handler, config).await
57 }
58
59 pub async fn bind_plain(
64 addr: SocketAddr,
65 handler: palladium_actor::Addr<HttpRequest>,
66 config: ListenerConfig,
67 ) -> Result<Self, HttpError> {
68 Self::bind_inner(addr, StableAddr::from_addr(handler), config).await
69 }
70
71 pub fn local_addr(&self) -> SocketAddr {
73 self.local_addr
74 }
75
76 pub async fn shutdown(self) {
80 let _ = self.shutdown_tx.send(());
81 }
82
83 async fn bind_inner(
84 addr: SocketAddr,
85 handler: StableAddr<HttpRequest>,
86 config: ListenerConfig,
87 ) -> Result<Self, HttpError> {
88 let state = ListenerState {
89 handler,
90 config: Arc::new(config),
91 };
92
93 let app = Router::new().fallback(catch_all).with_state(state);
94
95 let listener = tokio::net::TcpListener::bind(addr)
96 .await
97 .map_err(HttpError::Bind)?;
98 let local_addr = listener.local_addr().map_err(HttpError::Bind)?;
99 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
100
101 tokio::spawn(async move {
102 axum::serve(listener, app)
103 .with_graceful_shutdown(async {
104 shutdown_rx.await.ok();
105 })
106 .await
107 .ok();
108 });
109
110 Ok(Self {
111 local_addr,
112 shutdown_tx,
113 })
114 }
115}
116
117#[derive(Clone)]
118struct ListenerState {
119 handler: StableAddr<HttpRequest>,
120 config: Arc<ListenerConfig>,
121}
122
123async fn catch_all(State(state): State<ListenerState>, req: Request) -> Response {
124 let (parts, body) = req.into_parts();
125
126 let body_bytes = match axum::body::to_bytes(body, state.config.max_body_bytes).await {
127 Ok(b) => b,
128 Err(_) => return status_response(StatusCode::PAYLOAD_TOO_LARGE),
129 };
130
131 let http_req = HttpRequest {
132 method: parts.method,
133 uri: parts.uri,
134 headers: parts.headers,
135 body: body_bytes,
136 };
137
138 let ask_result =
139 tokio::time::timeout(state.config.request_timeout, state.handler.ask(http_req)).await;
140
141 match ask_result {
142 Err(_elapsed) => status_response(StatusCode::GATEWAY_TIMEOUT),
143 Ok(Err(AskError::Timeout)) => status_response(StatusCode::GATEWAY_TIMEOUT),
144 Ok(Err(AskError::Send(SendError::MailboxFull))) => {
145 status_response(StatusCode::TOO_MANY_REQUESTS)
146 }
147 Ok(Err(_)) => status_response(StatusCode::SERVICE_UNAVAILABLE),
148 Ok(Ok(http_resp)) => build_response(http_resp),
149 }
150}
151
152fn build_response(resp: HttpResponse) -> Response {
153 let (mut parts, _) = http::Response::new(()).into_parts();
154 parts.status = resp.status;
155 parts.headers = resp.headers;
156 http::Response::from_parts(parts, Body::from(resp.body))
157}
158
159fn status_response(status: StatusCode) -> Response {
160 http::Response::builder()
161 .status(status)
162 .body(Body::empty())
163 .expect("status response is always valid")
164}