http_relay/
lib.rs

1//! A Rust implementation of _some_ of [Http relay spec](https://httprelay.io/).
2//!
3
4#![deny(missing_docs)]
5#![deny(rustdoc::broken_intra_doc_links)]
6#![cfg_attr(any(), deny(clippy::unwrap_used))]
7
8use std::{
9    collections::HashMap,
10    net::{SocketAddr, TcpListener},
11    sync::Arc,
12    time::Duration,
13};
14
15use anyhow::Result;
16
17use axum::{
18    body::Bytes,
19    extract::{Path, State},
20    response::IntoResponse,
21    routing::get,
22    Router,
23};
24use axum_server::Handle;
25use tokio::sync::{oneshot, Mutex};
26
27use tower_http::{cors::CorsLayer, trace::TraceLayer};
28use url::Url;
29
30// Shared state to store GET requests and their notifications
31type SharedState = Arc<Mutex<HashMap<String, ChannelState>>>;
32
33enum ChannelState {
34    ProducerWaiting {
35        body: Bytes,
36        completion: oneshot::Sender<()>,
37    },
38    ConsumerWaiting {
39        message_sender: oneshot::Sender<Bytes>,
40    },
41}
42
43#[derive(Debug, Default)]
44struct Config {
45    pub http_port: u16,
46}
47
48#[derive(Debug, Default)]
49/// Builder for [HttpRelay].
50pub struct HttpRelayBuilder(Config);
51
52impl HttpRelayBuilder {
53    /// Configure the port used for HTTP server.
54    pub fn http_port(mut self, port: u16) -> Self {
55        self.0.http_port = port;
56
57        self
58    }
59
60    /// Start running an HTTP relay.
61    pub async fn run(self) -> Result<HttpRelay> {
62        HttpRelay::start(self.0).await
63    }
64}
65
66/// An implementation of _some_ of [Http relay spec](https://httprelay.io/).
67pub struct HttpRelay {
68    pub(crate) http_handle: Handle,
69    http_address: SocketAddr,
70}
71
72impl HttpRelay {
73    async fn start(config: Config) -> Result<Self> {
74        let shared_state: SharedState = Arc::new(Mutex::new(HashMap::new()));
75
76        let app = Router::new()
77            .route("/link/{id}", get(link::get).post(link::post))
78            .layer(CorsLayer::very_permissive())
79            .layer(TraceLayer::new_for_http())
80            .with_state(shared_state);
81
82        let http_handle = Handle::new();
83        let shutdown_handle = http_handle.clone();
84
85        let http_listener = TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], config.http_port)))?;
86        let http_address = http_listener.local_addr()?;
87
88        tokio::spawn(async move {
89            axum_server::from_tcp(http_listener)
90                .handle(http_handle.clone())
91                .serve(app.into_make_service())
92                .await
93                .map_err(|error| tracing::error!(?error, "HttpRelay http server error"))
94        });
95
96        Ok(Self {
97            http_handle: shutdown_handle,
98            http_address,
99        })
100    }
101
102    /// Create [HttpRelayBuilder].
103    pub fn builder() -> HttpRelayBuilder {
104        HttpRelayBuilder::default()
105    }
106
107    /// Returns the HTTP address of this http relay.
108    pub fn http_address(&self) -> SocketAddr {
109        self.http_address
110    }
111
112    /// Returns the localhost Url of this server.
113    pub fn local_url(&self) -> Url {
114        Url::parse(&format!("http://localhost:{}", self.http_address.port()))
115            .expect("local_url should be formatted fine")
116    }
117
118    /// Returns the localhost URL of Link endpoints
119    pub fn local_link_url(&self) -> Url {
120        let mut url = self.local_url();
121
122        let mut segments = url
123            .path_segments_mut()
124            .expect("HttpRelay::local_link_url path_segments_mut");
125
126        segments.push("link");
127
128        drop(segments);
129
130        url
131    }
132
133    /// Gracefully shuts down the HTTP relay.
134    pub async fn shutdown(self) -> anyhow::Result<()> {
135        self.http_handle
136            .graceful_shutdown(Some(Duration::from_secs(1)));
137        Ok(())
138    }
139}
140
141impl Drop for HttpRelay {
142    fn drop(&mut self) {
143        self.http_handle.shutdown();
144    }
145}
146
147mod link {
148    use axum::http::StatusCode;
149
150    use super::*;
151
152    pub async fn get(
153        Path(id): Path<String>,
154        State(state): State<SharedState>,
155    ) -> impl IntoResponse {
156        let mut channels = state.lock().await;
157
158        match channels.remove(&id) {
159            Some(ChannelState::ProducerWaiting { body, completion }) => {
160                let _ = completion.send(());
161
162                (StatusCode::OK, body)
163            }
164            _ => {
165                let (message_sender, message_receiver) = oneshot::channel();
166                channels.insert(id, ChannelState::ConsumerWaiting { message_sender });
167                drop(channels);
168
169                match message_receiver.await {
170                    Ok(message) => (StatusCode::OK, message),
171                    Err(_) => (StatusCode::NOT_FOUND, "Not Found".into()),
172                }
173            }
174        }
175    }
176
177    pub async fn post(
178        Path(channel): Path<String>,
179        State(state): State<SharedState>,
180        body: Bytes,
181    ) -> impl IntoResponse {
182        let mut channels = state.lock().await;
183
184        match channels.remove(&channel) {
185            Some(ChannelState::ConsumerWaiting { message_sender }) => {
186                let _ = message_sender.send(body);
187                (StatusCode::OK, ())
188            }
189            _ => {
190                let (completion_sender, completion_receiver) = oneshot::channel();
191                channels.insert(
192                    channel,
193                    ChannelState::ProducerWaiting {
194                        body,
195                        completion: completion_sender,
196                    },
197                );
198                drop(channels);
199                let _ = completion_receiver.await;
200                (StatusCode::OK, ())
201            }
202        }
203    }
204}