1use std::net::SocketAddr;
2use std::str::FromStr;
3use std::sync::Arc;
4
5pub(crate) use gateway_prober::Prober;
6pub use gateway_uri::GatewayUri;
7use http::uri::Authority;
8use http_body_util::combinators::BoxBody;
9use http_body_util::{BodyExt, Empty, Full};
10use hyper::body::{Bytes, Incoming};
11use hyper::header::{
12 HeaderValue, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
13 ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_LENGTH, CONTENT_TYPE,
14};
15use hyper::server::conn::http1;
16use hyper::service::service_fn;
17use hyper::{Method, Request, Response};
18use hyper_rustls::builderstates::WantsSchemes;
19use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
20use hyper_util::client::legacy::connect::HttpConnector;
21use hyper_util::client::legacy::Client;
22use hyper_util::rt::{TokioExecutor, TokioIo};
23use tokio::io::{AsyncRead, AsyncWrite};
24use tokio::net::{TcpListener, UnixListener};
25use tokio_util::net::Listener;
26use tracing::{error, info, instrument};
27
28pub mod error;
29#[cfg(not(feature = "_test-util"))]
30mod gateway_prober;
31#[cfg(feature = "_test-util")]
32pub mod gateway_prober;
33mod gateway_uri;
34use crate::error::{BoxError, Error};
35
36#[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))]
37pub mod bootstrap;
38
39pub const DEFAULT_PORT: u16 = 3000;
40pub const OHTTP_RELAY_HOST: HeaderValue = HeaderValue::from_static("0.0.0.0");
41pub const EXPECTED_MEDIA_TYPE: HeaderValue = HeaderValue::from_static("message/ohttp-req");
42
43#[instrument]
44pub async fn listen_tcp(
45 port: u16,
46 gateway_origin: GatewayUri,
47) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
48 let addr = SocketAddr::from(([0, 0, 0, 0], port));
49 let listener = TcpListener::bind(addr).await?;
50 println!("OHTTP relay listening on tcp://{}", addr);
51 ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin)).await
52}
53
54#[instrument]
55pub async fn listen_socket(
56 socket_path: &str,
57 gateway_origin: GatewayUri,
58) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
59 let listener = UnixListener::bind(socket_path)?;
60 info!("OHTTP relay listening on socket: {}", socket_path);
61 ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin)).await
62}
63
64#[cfg(feature = "_test-util")]
65pub async fn listen_tcp_on_free_port(
66 default_gateway: GatewayUri,
67 root_store: rustls::RootCertStore,
68) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
69 let listener = tokio::net::TcpListener::bind("[::]:0").await?;
70 let port = listener.local_addr()?.port();
71 println!("OHTTP relay binding to port {}", listener.local_addr()?);
72 let config = RelayConfig::new(default_gateway, root_store);
73 let handle = ohttp_relay(listener, config).await?;
74 Ok((port, handle))
75}
76
77#[derive(Debug)]
78struct RelayConfig {
79 default_gateway: GatewayUri,
80 client: HttpClient,
81 prober: Prober,
82}
83
84impl RelayConfig {
85 fn new_with_default_client(default_gateway: GatewayUri) -> Self {
86 Self::new(default_gateway, HttpClient::default())
87 }
88
89 fn new(default_gateway: GatewayUri, into_client: impl Into<HttpClient>) -> Self {
90 let client = into_client.into();
91 let prober = Prober::new_with_client(client.clone());
92 RelayConfig { default_gateway, client, prober }
93 }
94}
95
96#[derive(Debug, Clone)]
97pub(crate) struct HttpClient(
98 hyper_util::client::legacy::Client<HttpsConnector<HttpConnector>, BoxBody<Bytes, hyper::Error>>,
99);
100
101impl std::ops::Deref for HttpClient {
102 type Target = hyper_util::client::legacy::Client<
103 HttpsConnector<HttpConnector>,
104 BoxBody<Bytes, hyper::Error>,
105 >;
106 fn deref(&self) -> &Self::Target { &self.0 }
107}
108
109impl From<HttpsConnectorBuilder<WantsSchemes>> for HttpClient {
110 fn from(builder: HttpsConnectorBuilder<WantsSchemes>) -> Self {
111 let https = builder.https_or_http().enable_http1().build();
112 Self(Client::builder(TokioExecutor::new()).build(https))
113 }
114}
115
116impl Default for HttpClient {
117 fn default() -> Self { HttpsConnectorBuilder::new().with_webpki_roots().into() }
118}
119
120impl From<rustls::RootCertStore> for HttpClient {
121 fn from(root_store: rustls::RootCertStore) -> Self {
122 HttpsConnectorBuilder::new()
123 .with_tls_config(
124 rustls::ClientConfig::builder()
125 .with_root_certificates(root_store)
126 .with_no_client_auth(),
127 )
128 .into()
129 }
130}
131
132#[instrument(skip(listener))]
133async fn ohttp_relay<L>(
134 mut listener: L,
135 config: RelayConfig,
136) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError>
137where
138 L: Listener + Unpin + Send + 'static,
139 L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static,
140{
141 config.prober.assert_opt_in(&config.default_gateway).await;
142
143 let config = Arc::new(config);
144
145 let handle = tokio::spawn(async move {
146 while let Ok((stream, _)) = listener.accept().await {
147 let config = config.clone();
148 let io = TokioIo::new(stream);
149 tokio::spawn(async move {
150 if let Err(err) = http1::Builder::new()
151 .serve_connection(io, service_fn(|req| serve_ohttp_relay(req, &config)))
152 .with_upgrades()
153 .await
154 {
155 error!("Error serving connection: {:?}", err);
156 }
157 });
158 }
159 Ok(())
160 });
161
162 Ok(handle)
163}
164
165#[instrument]
166async fn serve_ohttp_relay(
167 req: Request<Incoming>,
168 config: &RelayConfig,
169) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
170 let mut res = match (req.method(), req.uri().path()) {
171 (&Method::OPTIONS, _) => Ok(handle_preflight()),
172 (&Method::GET, "/health") => Ok(health_check().await),
173 (&Method::POST, _) => match parse_gateway_uri(&req, config).await {
174 Ok(gateway_uri) => handle_ohttp_relay(req, config, gateway_uri).await,
175 Err(e) => Err(e),
176 },
177 #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))]
178 (&Method::GET, _) | (&Method::CONNECT, _) => match parse_gateway_uri(&req, config).await {
179 Ok(gateway_uri) => crate::bootstrap::handle_ohttp_keys(req, gateway_uri).await,
180 Err(e) => Err(e),
181 },
182 _ => Err(Error::NotFound),
183 }
184 .unwrap_or_else(|e| e.to_response());
185 res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
186 Ok(res)
187}
188
189async fn parse_gateway_uri(
190 req: &Request<Incoming>,
191 config: &RelayConfig,
192) -> Result<GatewayUri, Error> {
193 let gateway_uri = match req.method() {
196 &Method::CONNECT => req.uri().authority().cloned().map(GatewayUri::from),
197 _ => parse_gateway_uri_from_path(req.uri().path(), &config.default_gateway).ok(),
198 }
199 .ok_or_else(|| Error::BadRequest("Invalid gateway".to_string()))?;
200
201 let policy = match config.prober.check_opt_in(&gateway_uri).await {
202 Some(policy) => Ok(policy),
203 None => Err(Error::Unavailable(config.prober.unavailable_for().await)),
204 }?;
205
206 if policy.bip77_allowed {
207 Ok(gateway_uri)
208 } else {
209 Err(Error::NotFound)
215 }
216}
217
218fn parse_gateway_uri_from_path(path: &str, default: &GatewayUri) -> Result<GatewayUri, BoxError> {
219 if path.is_empty() || path == "/" {
220 return Ok(default.clone());
221 }
222
223 let path = &path[1..];
224
225 if "http://" == &path[..7] || "https://" == &path[..8] {
226 GatewayUri::from_str(path)
227 } else {
228 Ok(Authority::from_str(path)?.into())
229 }
230}
231
232fn handle_preflight() -> Response<BoxBody<Bytes, hyper::Error>> {
233 let mut res = Response::new(empty());
234 *res.status_mut() = hyper::StatusCode::NO_CONTENT;
235 res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
236 res.headers_mut().insert(
237 ACCESS_CONTROL_ALLOW_METHODS,
238 HeaderValue::from_static("CONNECT, GET, OPTIONS, POST"),
239 );
240 res.headers_mut().insert(
241 ACCESS_CONTROL_ALLOW_HEADERS,
242 HeaderValue::from_static("Content-Type, Content-Length"),
243 );
244 res
245}
246
247async fn health_check() -> Response<BoxBody<Bytes, hyper::Error>> { Response::new(empty()) }
248
249#[instrument]
250async fn handle_ohttp_relay(
251 req: Request<Incoming>,
252 config: &RelayConfig,
253 gateway: GatewayUri,
254) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
255 let fwd_req = into_forward_req(req, gateway)?;
256 forward_request(fwd_req, config).await.map(|res| {
257 let (parts, body) = res.into_parts();
258 let boxed_body = BoxBody::new(body);
259 Response::from_parts(parts, boxed_body)
260 })
261}
262
263#[instrument]
265fn into_forward_req(
266 req: Request<Incoming>,
267 gateway_origin: GatewayUri,
268) -> Result<Request<BoxBody<Bytes, hyper::Error>>, Error> {
269 let (head, body) = req.into_parts();
270
271 if head.method != hyper::Method::POST {
272 return Err(Error::MethodNotAllowed);
273 }
274
275 if head.headers.get(CONTENT_TYPE) != Some(&EXPECTED_MEDIA_TYPE) {
276 return Err(Error::UnsupportedMediaType);
277 }
278
279 let mut builder = Request::builder()
280 .method(hyper::Method::POST)
281 .uri(gateway_origin.rfc_9540_url())
282 .header(CONTENT_TYPE, EXPECTED_MEDIA_TYPE);
283
284 if let Some(content_length) = head.headers.get(CONTENT_LENGTH) {
285 builder = builder.header(CONTENT_LENGTH, content_length);
286 }
287
288 builder.body(BoxBody::new(body)).map_err(|e| Error::InternalServerError(Box::new(e)))
289}
290
291#[instrument]
292async fn forward_request(
293 req: Request<BoxBody<Bytes, hyper::Error>>,
294 config: &RelayConfig,
295) -> Result<Response<Incoming>, Error> {
296 config.client.request(req).await.map_err(|_| Error::BadGateway)
297}
298
299pub(crate) fn empty() -> BoxBody<Bytes, hyper::Error> {
300 Empty::<Bytes>::new().map_err(|never| match never {}).boxed()
301}
302
303pub(crate) fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
304 Full::new(chunk.into()).map_err(|never| match never {}).boxed()
305}