ohttp_relay/
lib.rs

1use std::net::SocketAddr;
2use std::str::FromStr;
3use std::sync::Arc;
4
5pub(crate) use gateway_prober::Prober;
6pub use gateway_uri::GatewayUri;
7use http::uri::Authority;
8use http_body_util::combinators::BoxBody;
9use http_body_util::{BodyExt, Empty, Full};
10use hyper::body::{Bytes, Incoming};
11use hyper::header::{
12    HeaderValue, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
13    ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_LENGTH, CONTENT_TYPE,
14};
15use hyper::server::conn::http1;
16use hyper::service::service_fn;
17use hyper::{Method, Request, Response};
18use hyper_rustls::builderstates::WantsSchemes;
19use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
20use hyper_util::client::legacy::connect::HttpConnector;
21use hyper_util::client::legacy::Client;
22use hyper_util::rt::{TokioExecutor, TokioIo};
23use tokio::io::{AsyncRead, AsyncWrite};
24use tokio::net::{TcpListener, UnixListener};
25use tokio_util::net::Listener;
26use tracing::{error, info, instrument};
27
28pub mod error;
29#[cfg(not(feature = "_test-util"))]
30mod gateway_prober;
31#[cfg(feature = "_test-util")]
32pub mod gateway_prober;
33mod gateway_uri;
34use crate::error::{BoxError, Error};
35
36#[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))]
37pub mod bootstrap;
38
39pub const DEFAULT_PORT: u16 = 3000;
40pub const OHTTP_RELAY_HOST: HeaderValue = HeaderValue::from_static("0.0.0.0");
41pub const EXPECTED_MEDIA_TYPE: HeaderValue = HeaderValue::from_static("message/ohttp-req");
42
43#[instrument]
44pub async fn listen_tcp(
45    port: u16,
46    gateway_origin: GatewayUri,
47) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
48    let addr = SocketAddr::from(([0, 0, 0, 0], port));
49    let listener = TcpListener::bind(addr).await?;
50    println!("OHTTP relay listening on tcp://{}", addr);
51    ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin)).await
52}
53
54#[instrument]
55pub async fn listen_socket(
56    socket_path: &str,
57    gateway_origin: GatewayUri,
58) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
59    let listener = UnixListener::bind(socket_path)?;
60    info!("OHTTP relay listening on socket: {}", socket_path);
61    ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin)).await
62}
63
64#[cfg(feature = "_test-util")]
65pub async fn listen_tcp_on_free_port(
66    default_gateway: GatewayUri,
67    root_store: rustls::RootCertStore,
68) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
69    let listener = tokio::net::TcpListener::bind("[::]:0").await?;
70    let port = listener.local_addr()?.port();
71    println!("OHTTP relay binding to port {}", listener.local_addr()?);
72    let config = RelayConfig::new(default_gateway, root_store);
73    let handle = ohttp_relay(listener, config).await?;
74    Ok((port, handle))
75}
76
77#[derive(Debug)]
78struct RelayConfig {
79    default_gateway: GatewayUri,
80    client: HttpClient,
81    prober: Prober,
82}
83
84impl RelayConfig {
85    fn new_with_default_client(default_gateway: GatewayUri) -> Self {
86        Self::new(default_gateway, HttpClient::default())
87    }
88
89    fn new(default_gateway: GatewayUri, into_client: impl Into<HttpClient>) -> Self {
90        let client = into_client.into();
91        let prober = Prober::new_with_client(client.clone());
92        RelayConfig { default_gateway, client, prober }
93    }
94}
95
96#[derive(Debug, Clone)]
97pub(crate) struct HttpClient(
98    hyper_util::client::legacy::Client<HttpsConnector<HttpConnector>, BoxBody<Bytes, hyper::Error>>,
99);
100
101impl std::ops::Deref for HttpClient {
102    type Target = hyper_util::client::legacy::Client<
103        HttpsConnector<HttpConnector>,
104        BoxBody<Bytes, hyper::Error>,
105    >;
106    fn deref(&self) -> &Self::Target { &self.0 }
107}
108
109impl From<HttpsConnectorBuilder<WantsSchemes>> for HttpClient {
110    fn from(builder: HttpsConnectorBuilder<WantsSchemes>) -> Self {
111        let https = builder.https_or_http().enable_http1().build();
112        Self(Client::builder(TokioExecutor::new()).build(https))
113    }
114}
115
116impl Default for HttpClient {
117    fn default() -> Self { HttpsConnectorBuilder::new().with_webpki_roots().into() }
118}
119
120impl From<rustls::RootCertStore> for HttpClient {
121    fn from(root_store: rustls::RootCertStore) -> Self {
122        HttpsConnectorBuilder::new()
123            .with_tls_config(
124                rustls::ClientConfig::builder()
125                    .with_root_certificates(root_store)
126                    .with_no_client_auth(),
127            )
128            .into()
129    }
130}
131
132#[instrument(skip(listener))]
133async fn ohttp_relay<L>(
134    mut listener: L,
135    config: RelayConfig,
136) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError>
137where
138    L: Listener + Unpin + Send + 'static,
139    L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static,
140{
141    config.prober.assert_opt_in(&config.default_gateway).await;
142
143    let config = Arc::new(config);
144
145    let handle = tokio::spawn(async move {
146        while let Ok((stream, _)) = listener.accept().await {
147            let config = config.clone();
148            let io = TokioIo::new(stream);
149            tokio::spawn(async move {
150                if let Err(err) = http1::Builder::new()
151                    .serve_connection(io, service_fn(|req| serve_ohttp_relay(req, &config)))
152                    .with_upgrades()
153                    .await
154                {
155                    error!("Error serving connection: {:?}", err);
156                }
157            });
158        }
159        Ok(())
160    });
161
162    Ok(handle)
163}
164
165#[instrument]
166async fn serve_ohttp_relay(
167    req: Request<Incoming>,
168    config: &RelayConfig,
169) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
170    let mut res = match (req.method(), req.uri().path()) {
171        (&Method::OPTIONS, _) => Ok(handle_preflight()),
172        (&Method::GET, "/health") => Ok(health_check().await),
173        (&Method::POST, _) => match parse_gateway_uri(&req, config).await {
174            Ok(gateway_uri) => handle_ohttp_relay(req, config, gateway_uri).await,
175            Err(e) => Err(e),
176        },
177        #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))]
178        (&Method::GET, _) | (&Method::CONNECT, _) => match parse_gateway_uri(&req, config).await {
179            Ok(gateway_uri) => crate::bootstrap::handle_ohttp_keys(req, gateway_uri).await,
180            Err(e) => Err(e),
181        },
182        _ => Err(Error::NotFound),
183    }
184    .unwrap_or_else(|e| e.to_response());
185    res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
186    Ok(res)
187}
188
189async fn parse_gateway_uri(
190    req: &Request<Incoming>,
191    config: &RelayConfig,
192) -> Result<GatewayUri, Error> {
193    // for POST and GET (websockets), the gateway URI is provided in the path
194    // for CONNECT requests, just an authority is provided, and we assume HTTPS
195    let gateway_uri = match req.method() {
196        &Method::CONNECT => req.uri().authority().cloned().map(GatewayUri::from),
197        _ => parse_gateway_uri_from_path(req.uri().path(), &config.default_gateway).ok(),
198    }
199    .ok_or_else(|| Error::BadRequest("Invalid gateway".to_string()))?;
200
201    let policy = match config.prober.check_opt_in(&gateway_uri).await {
202        Some(policy) => Ok(policy),
203        None => Err(Error::Unavailable(config.prober.unavailable_for().await)),
204    }?;
205
206    if policy.bip77_allowed {
207        Ok(gateway_uri)
208    } else {
209        // TODO Cache-Control header for error based on policy.expires
210        // is not found the right error? maybe forbidden or bad gateway?
211        // prober policy judgement can be an enum instead of a bool to
212        // distinguish 4xx vs. 5xx failures, 4xx being an explicit opt out and
213        // 5xx for IO errors etc
214        Err(Error::NotFound)
215    }
216}
217
218fn parse_gateway_uri_from_path(path: &str, default: &GatewayUri) -> Result<GatewayUri, BoxError> {
219    if path.is_empty() || path == "/" {
220        return Ok(default.clone());
221    }
222
223    let path = &path[1..];
224
225    if "http://" == &path[..7] || "https://" == &path[..8] {
226        GatewayUri::from_str(path)
227    } else {
228        Ok(Authority::from_str(path)?.into())
229    }
230}
231
232fn handle_preflight() -> Response<BoxBody<Bytes, hyper::Error>> {
233    let mut res = Response::new(empty());
234    *res.status_mut() = hyper::StatusCode::NO_CONTENT;
235    res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
236    res.headers_mut().insert(
237        ACCESS_CONTROL_ALLOW_METHODS,
238        HeaderValue::from_static("CONNECT, GET, OPTIONS, POST"),
239    );
240    res.headers_mut().insert(
241        ACCESS_CONTROL_ALLOW_HEADERS,
242        HeaderValue::from_static("Content-Type, Content-Length"),
243    );
244    res
245}
246
247async fn health_check() -> Response<BoxBody<Bytes, hyper::Error>> { Response::new(empty()) }
248
249#[instrument]
250async fn handle_ohttp_relay(
251    req: Request<Incoming>,
252    config: &RelayConfig,
253    gateway: GatewayUri,
254) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
255    let fwd_req = into_forward_req(req, gateway)?;
256    forward_request(fwd_req, config).await.map(|res| {
257        let (parts, body) = res.into_parts();
258        let boxed_body = BoxBody::new(body);
259        Response::from_parts(parts, boxed_body)
260    })
261}
262
263/// Convert an incoming request into a request to forward to the target gateway server.
264#[instrument]
265fn into_forward_req(
266    req: Request<Incoming>,
267    gateway_origin: GatewayUri,
268) -> Result<Request<BoxBody<Bytes, hyper::Error>>, Error> {
269    let (head, body) = req.into_parts();
270
271    if head.method != hyper::Method::POST {
272        return Err(Error::MethodNotAllowed);
273    }
274
275    if head.headers.get(CONTENT_TYPE) != Some(&EXPECTED_MEDIA_TYPE) {
276        return Err(Error::UnsupportedMediaType);
277    }
278
279    let mut builder = Request::builder()
280        .method(hyper::Method::POST)
281        .uri(gateway_origin.rfc_9540_url())
282        .header(CONTENT_TYPE, EXPECTED_MEDIA_TYPE);
283
284    if let Some(content_length) = head.headers.get(CONTENT_LENGTH) {
285        builder = builder.header(CONTENT_LENGTH, content_length);
286    }
287
288    builder.body(BoxBody::new(body)).map_err(|e| Error::InternalServerError(Box::new(e)))
289}
290
291#[instrument]
292async fn forward_request(
293    req: Request<BoxBody<Bytes, hyper::Error>>,
294    config: &RelayConfig,
295) -> Result<Response<Incoming>, Error> {
296    config.client.request(req).await.map_err(|_| Error::BadGateway)
297}
298
299pub(crate) fn empty() -> BoxBody<Bytes, hyper::Error> {
300    Empty::<Bytes>::new().map_err(|never| match never {}).boxed()
301}
302
303pub(crate) fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
304    Full::new(chunk.into()).map_err(|never| match never {}).boxed()
305}