use std::sync::Arc;
use http_body_util::{BodyExt, Full, Limited};
use hyper::body::{Bytes, Incoming};
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as ConnBuilder;
use rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;
use crate::auth::{self, AuthConfig};
use crate::exposed::fast_routes::FastRoutes;
use crate::exposed::tls::leaf_common_name;
pub struct GatewaySettings {
pub routes: FastRoutes,
pub auth_config: Arc<AuthConfig>,
pub tls: Option<Arc<ServerConfig>>,
pub max_body_bytes: usize,
}
pub fn start_http_gateway(
addr: std::net::SocketAddr,
settings: Arc<GatewaySettings>,
) -> anyhow::Result<()> {
tokio::spawn(async move {
let listener = match tokio::net::TcpListener::bind(addr).await {
Ok(l) => l,
Err(e) => {
eprintln!("HTTP/OCI gateway failed to bind {}: {}", addr, e);
return;
}
};
let scheme = if settings.tls.is_some() { "https" } else { "http (cleartext)" };
println!("HTTP/OCI gateway listening on {} [{}]", addr, scheme);
if settings.tls.is_none() {
log::warn!(
"SECURITY: HTTP/OCI gateway on {} runs WITHOUT TLS — credentials and \
artifacts travel in cleartext. Set ron_tls for any non-loopback use.",
addr
);
}
let acceptor = settings.tls.clone().map(TlsAcceptor::from);
loop {
let (stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
log::debug!("gateway accept error: {}", e);
continue;
}
};
let settings = settings.clone();
let acceptor = acceptor.clone();
tokio::spawn(async move {
match acceptor {
Some(acceptor) => {
let tls_stream = match acceptor.accept(stream).await {
Ok(s) => s,
Err(e) => {
log::debug!("TLS handshake failed: {}", e);
return;
}
};
let client_cn = tls_stream
.get_ref()
.1
.peer_certificates()
.and_then(leaf_common_name);
serve(TokioIo::new(tls_stream), settings, client_cn).await;
}
None => serve(TokioIo::new(stream), settings, None).await,
}
});
}
});
Ok(())
}
async fn serve<I>(io: TokioIo<I>, settings: Arc<GatewaySettings>, client_cn: Option<String>)
where
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
{
let svc = service_fn(move |req| {
let settings = settings.clone();
let client_cn = client_cn.clone();
async move { handle(req, settings, client_cn).await }
});
if let Err(e) = ConnBuilder::new(TokioExecutor::new())
.serve_connection(io, svc)
.await
{
log::debug!("gateway connection error: {}", e);
}
}
async fn handle(
req: Request<Incoming>,
settings: Arc<GatewaySettings>,
client_cn: Option<String>,
) -> Result<Response<Full<Bytes>>, std::convert::Infallible> {
let method = req.method().as_str().to_owned();
let path = req.uri().path().to_owned();
let path_and_query = req
.uri()
.path_and_query()
.map(|pq| pq.as_str().to_owned())
.unwrap_or_else(|| path.clone());
if method == "GET" {
if path == "/healthz" || path == "/readyz" {
return Ok(text(200, "ok"));
}
if path == "/v2" || path == "/v2/" {
return Ok(Response::builder()
.status(200)
.header("Docker-Distribution-Api-Version", "registry/2.0")
.body(Full::new(Bytes::from_static(b"{}")))
.expect("static response"));
}
}
if method != "GET" && method != "HEAD" {
let bearer = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if auth::validate_request(&settings.auth_config, bearer, client_cn.as_deref())
.await
.is_err()
{
return Ok(text(401, "Unauthorized"));
}
}
let backend = match routes_lookup(&settings.routes, &path) {
Some(b) => b,
None => return Ok(text(404, "Unknown repository")),
};
let body = match Limited::new(req.into_body(), settings.max_body_bytes)
.collect()
.await
{
Ok(collected) => collected.to_bytes(),
Err(_) => return Ok(text(413, "Payload too large")),
};
let dispatch = tokio::task::spawn_blocking(move || {
backend.handle_http2_request(&method, &path_and_query, &body)
})
.await;
match dispatch {
Ok(Ok((status, headers, body))) => {
let mut builder = Response::builder().status(status);
for (k, v) in headers {
builder = builder.header(k, v);
}
Ok(builder
.body(Full::new(Bytes::from(body)))
.unwrap_or_else(|_| text(500, "Internal error")))
}
Ok(Err(e)) => {
log::warn!("backend error for {}: {:#}", path, e);
Ok(text(500, "Internal error"))
}
Err(_) => Ok(text(500, "Internal error")),
}
}
fn routes_lookup(
routes: &FastRoutes,
path: &str,
) -> Option<std::sync::Arc<dyn traits::RepositoryBackendTrait>> {
let segs: Vec<&str> = path.trim_start_matches('/').split('/').filter(|s| !s.is_empty()).collect();
let key = match segs.as_slice() {
["v2", name, ..] => *name,
[first, ..] => *first,
[] => return None,
};
routes.lookup(key).cloned()
}
fn text(status: u16, msg: &str) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.header("Content-Type", "text/plain")
.body(Full::new(Bytes::from(msg.to_owned())))
.expect("static response")
}