1#![deny(missing_docs)]
5#![deny(rustdoc::broken_intra_doc_links)]
6#![cfg_attr(any(), deny(clippy::unwrap_used))]
7
8use std::{
9 collections::HashMap,
10 net::{SocketAddr, TcpListener},
11 sync::Arc,
12 time::Duration,
13};
14
15use anyhow::Result;
16
17use axum::{
18 body::Bytes,
19 extract::{Path, State},
20 response::IntoResponse,
21 routing::get,
22 Router,
23};
24use axum_server::Handle;
25use tokio::sync::{oneshot, Mutex};
26
27use tower_http::{cors::CorsLayer, trace::TraceLayer};
28use url::Url;
29
30type SharedState = Arc<Mutex<HashMap<String, ChannelState>>>;
32
33enum ChannelState {
34 ProducerWaiting {
35 body: Bytes,
36 completion: oneshot::Sender<()>,
37 },
38 ConsumerWaiting {
39 message_sender: oneshot::Sender<Bytes>,
40 },
41}
42
43#[derive(Debug, Default)]
44struct Config {
45 pub http_port: u16,
46}
47
48#[derive(Debug, Default)]
49pub struct HttpRelayBuilder(Config);
51
52impl HttpRelayBuilder {
53 pub fn http_port(mut self, port: u16) -> Self {
55 self.0.http_port = port;
56
57 self
58 }
59
60 pub async fn run(self) -> Result<HttpRelay> {
62 HttpRelay::start(self.0).await
63 }
64}
65
66pub struct HttpRelay {
68 pub(crate) http_handle: Handle,
69 http_address: SocketAddr,
70}
71
72impl HttpRelay {
73 async fn start(config: Config) -> Result<Self> {
74 let shared_state: SharedState = Arc::new(Mutex::new(HashMap::new()));
75
76 let app = Router::new()
77 .route("/link/{id}", get(link::get).post(link::post))
78 .layer(CorsLayer::very_permissive())
79 .layer(TraceLayer::new_for_http())
80 .with_state(shared_state);
81
82 let http_handle = Handle::new();
83 let shutdown_handle = http_handle.clone();
84
85 let http_listener = TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], config.http_port)))?;
86 let http_address = http_listener.local_addr()?;
87
88 tokio::spawn(async move {
89 axum_server::from_tcp(http_listener)
90 .handle(http_handle.clone())
91 .serve(app.into_make_service())
92 .await
93 .map_err(|error| tracing::error!(?error, "HttpRelay http server error"))
94 });
95
96 Ok(Self {
97 http_handle: shutdown_handle,
98 http_address,
99 })
100 }
101
102 pub fn builder() -> HttpRelayBuilder {
104 HttpRelayBuilder::default()
105 }
106
107 pub fn http_address(&self) -> SocketAddr {
109 self.http_address
110 }
111
112 pub fn local_url(&self) -> Url {
114 Url::parse(&format!("http://localhost:{}", self.http_address.port()))
115 .expect("local_url should be formatted fine")
116 }
117
118 pub fn local_link_url(&self) -> Url {
120 let mut url = self.local_url();
121
122 let mut segments = url
123 .path_segments_mut()
124 .expect("HttpRelay::local_link_url path_segments_mut");
125
126 segments.push("link");
127
128 drop(segments);
129
130 url
131 }
132
133 pub async fn shutdown(self) -> anyhow::Result<()> {
135 self.http_handle
136 .graceful_shutdown(Some(Duration::from_secs(1)));
137 Ok(())
138 }
139}
140
141impl Drop for HttpRelay {
142 fn drop(&mut self) {
143 self.http_handle.shutdown();
144 }
145}
146
147mod link {
148 use axum::http::StatusCode;
149
150 use super::*;
151
152 pub async fn get(
153 Path(id): Path<String>,
154 State(state): State<SharedState>,
155 ) -> impl IntoResponse {
156 let mut channels = state.lock().await;
157
158 match channels.remove(&id) {
159 Some(ChannelState::ProducerWaiting { body, completion }) => {
160 let _ = completion.send(());
161
162 (StatusCode::OK, body)
163 }
164 _ => {
165 let (message_sender, message_receiver) = oneshot::channel();
166 channels.insert(id, ChannelState::ConsumerWaiting { message_sender });
167 drop(channels);
168
169 match message_receiver.await {
170 Ok(message) => (StatusCode::OK, message),
171 Err(_) => (StatusCode::NOT_FOUND, "Not Found".into()),
172 }
173 }
174 }
175 }
176
177 pub async fn post(
178 Path(channel): Path<String>,
179 State(state): State<SharedState>,
180 body: Bytes,
181 ) -> impl IntoResponse {
182 let mut channels = state.lock().await;
183
184 match channels.remove(&channel) {
185 Some(ChannelState::ConsumerWaiting { message_sender }) => {
186 let _ = message_sender.send(body);
187 (StatusCode::OK, ())
188 }
189 _ => {
190 let (completion_sender, completion_receiver) = oneshot::channel();
191 channels.insert(
192 channel,
193 ChannelState::ProducerWaiting {
194 body,
195 completion: completion_sender,
196 },
197 );
198 drop(channels);
199 let _ = completion_receiver.await;
200 (StatusCode::OK, ())
201 }
202 }
203 }
204}