use std::{net::SocketAddr, sync::Arc, thread};
use axum::{body::Body, routing::RouterIntoService};
use tokio::runtime::Runtime;
use tower_http::trace::TraceLayer;
use super::{
auth::{self, Permissions},
ContainerRegistry, ContainerRegistryBuilder,
};
pub struct TestingContainerRegistry {
pub registry: Arc<ContainerRegistry>,
pub temp_storage: Option<tempdir::TempDir>,
pub body_limit: usize,
pub bind_addr: SocketAddr,
}
pub struct RunningRegistry {
bound_addr: SocketAddr,
join_handle: Option<thread::JoinHandle<()>>,
_temp_storage: Option<tempdir::TempDir>,
shutdown: Option<tokio::sync::mpsc::Sender<()>>,
}
impl RunningRegistry {
pub fn bound_addr(&self) -> SocketAddr {
self.bound_addr
}
}
impl Drop for RunningRegistry {
fn drop(&mut self) {
drop(self.shutdown.take());
if let Some(join_handle) = self.join_handle.take() {
join_handle.join().expect("failed to join");
}
}
}
impl TestingContainerRegistry {
pub fn make_service(&self) -> RouterIntoService<Body> {
self.registry
.clone()
.make_router()
.layer(TraceLayer::new_for_http())
.into_service::<Body>()
}
pub fn bind(&mut self, addr: SocketAddr) -> &mut Self {
self.bind_addr = addr;
self
}
pub fn body_limit(&mut self, body_limit: usize) -> &mut Self {
self.body_limit = body_limit;
self
}
pub fn run_in_background(mut self) -> RunningRegistry {
let app = axum::Router::new()
.merge(self.registry.clone().make_router())
.layer(axum::extract::DefaultBodyLimit::max(self.body_limit));
let listener =
std::net::TcpListener::bind(self.bind_addr).expect("could not bind listener");
listener
.set_nonblocking(true)
.expect("could not set listener to non-blocking");
let bound_addr = listener.local_addr().expect("failed to get local address");
let (shutdown_sender, mut shutdown_receiver) = tokio::sync::mpsc::channel::<()>(1);
let rt = Runtime::new().expect("could not create tokio runtime");
let join_handle = thread::spawn(move || {
rt.block_on(async move {
let listener = tokio::net::TcpListener::from_std(listener)
.expect("could not create tokio listener");
axum::serve(listener, app)
.with_graceful_shutdown(async move {
shutdown_receiver.recv().await;
})
.await
.expect("axum io error");
})
});
RunningRegistry {
bound_addr,
join_handle: Some(join_handle),
shutdown: Some(shutdown_sender),
_temp_storage: self.temp_storage.take(),
}
}
pub fn registry(&self) -> &ContainerRegistry {
&self.registry
}
}
impl ContainerRegistryBuilder {
pub fn build_for_testing(mut self) -> TestingContainerRegistry {
let temp_storage = if self.storage.is_none() {
let temp_storage = tempdir::TempDir::new("container-registry-for-testing").expect(
"could not create temporary directory to host testing container registry instance",
);
self = self.storage(temp_storage.path());
Some(temp_storage)
} else {
None
};
if self.auth_provider.is_none() {
self = self.auth_provider(Arc::new(auth::Anonymous::new(
Permissions::ReadWrite,
Permissions::ReadWrite,
)));
}
let registry = self.build().expect("could not create registry");
TestingContainerRegistry {
registry,
temp_storage,
bind_addr: ([127, 0, 0, 1], 0).into(),
body_limit: 100 * 1024 * 1024,
}
}
}