use std::net::SocketAddr;
use std::path::PathBuf;
use std::time::Duration;
use super::routes::{
dav_handler, delete_entry,
disable_users::{disable_user, enable_user},
generate_signup_token, info, root, user_quota,
};
use super::trace::with_trace_layer;
use super::{app_state::AppState, auth_middleware::AdminAuthLayer};
use crate::AppContext;
#[cfg(any(test, feature = "testing"))]
use crate::MockDataDir;
use crate::{AppContextConversionError, PersistentDataDir};
use axum::routing::{any, delete, post};
use axum::{routing::get, Router};
use axum_server::Handle;
use tokio::task::JoinHandle;
use tower_http::cors::CorsLayer;
fn create_protected_router(password: &str) -> Router<AppState> {
Router::new()
.route(
"/generate_signup_token",
get(generate_signup_token::generate_signup_token)
.post(generate_signup_token::generate_signup_token_with_limits),
)
.route("/info", get(info::info))
.route("/webdav/{*entry_path}", delete(delete_entry::delete_entry))
.route("/users/{pubkey}/disable", post(disable_user))
.route("/users/{pubkey}/enable", post(enable_user))
.route(
"/users/{pubkey}/quota",
get(user_quota::get_user_quota).patch(user_quota::patch_user_quota),
)
.layer(AdminAuthLayer::new(password.to_string()))
}
fn create_public_router() -> Router<AppState> {
Router::new().route("/", get(root::handler))
}
pub(crate) fn create_app(
state: AppState,
password: &str,
) -> axum::routing::IntoMakeService<Router> {
let admin_router = create_protected_router(password);
let public_router = create_public_router();
let app = Router::new()
.merge(admin_router)
.merge(public_router)
.route("/dav{*path}", any(dav_handler::dav_handler))
.with_state(state)
.layer(CorsLayer::very_permissive());
with_trace_layer(app).into_make_service()
}
#[derive(thiserror::Error, Debug)]
pub enum AdminServerBuildError {
#[error("Failed to create admin server: {0}")]
Server(anyhow::Error),
#[error("Failed to boostrap from the data directory: {0}")]
DataDir(AppContextConversionError),
}
pub struct AdminServer {
http_handle: Handle<SocketAddr>,
join_handle: JoinHandle<()>,
socket: SocketAddr,
password: String,
}
impl AdminServer {
pub async fn from_data_dir(data_dir: PersistentDataDir) -> Result<Self, AdminServerBuildError> {
let context = AppContext::read_from(data_dir)
.await
.map_err(AdminServerBuildError::DataDir)?;
Self::start(&context).await
}
pub async fn from_data_dir_path(data_dir_path: PathBuf) -> Result<Self, AdminServerBuildError> {
let data_dir = PersistentDataDir::new(data_dir_path);
Self::from_data_dir(data_dir).await
}
#[cfg(any(test, feature = "testing"))]
pub async fn from_mock_dir(mock_dir: MockDataDir) -> Result<Self, AdminServerBuildError> {
let context = AppContext::read_from(mock_dir)
.await
.map_err(AdminServerBuildError::DataDir)?;
Self::start(&context).await
}
pub async fn start(context: &AppContext) -> Result<Self, AdminServerBuildError> {
let password = context.config_toml.admin.admin_password.clone();
let state = AppState::new(
context.sql_db.clone(),
context.file_service.clone(),
&password,
context.user_service.clone(),
)
.with_metadata_from_config(
context.keypair.public_key().z32(),
&context.config_toml,
env!("CARGO_PKG_VERSION"),
);
let socket = context.config_toml.admin.listen_socket;
let app = create_app(state, password.as_str());
let listener = std::net::TcpListener::bind(socket)
.map_err(|e| AdminServerBuildError::Server(e.into()))?;
listener
.set_nonblocking(true)
.map_err(|e| AdminServerBuildError::Server(e.into()))?;
let socket = listener
.local_addr()
.map_err(|e| AdminServerBuildError::Server(e.into()))?;
let http_handle = Handle::new();
let inner_http_handle = http_handle.clone();
let server =
axum_server::from_tcp(listener).map_err(|e| AdminServerBuildError::Server(e.into()))?;
let join_handle = tokio::spawn(async move {
server
.handle(inner_http_handle)
.serve(app)
.await
.unwrap_or_else(|e| tracing::error!("Admin server error: {}", e));
});
Ok(Self {
http_handle,
socket,
join_handle,
password,
})
}
pub fn listen_socket(&self) -> SocketAddr {
self.socket
}
pub async fn create_signup_token(&self) -> anyhow::Result<String> {
let admin_socket = self.listen_socket();
let url = format!("http://{}/generate_signup_token", admin_socket);
let response = reqwest::Client::new()
.get(url)
.header("X-Admin-Password", &self.password)
.send()
.await?;
let response = response.error_for_status()?;
let body = response.text().await?;
Ok(body)
}
}
impl Drop for AdminServer {
fn drop(&mut self) {
self.http_handle
.graceful_shutdown(Some(Duration::from_secs(5)));
self.join_handle.abort();
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use axum::http::Method;
use axum_test::TestServer;
use base64::Engine;
use crate::data_directory::quota_config::BandwidthQuota;
use crate::persistence::files::FileService;
use super::*;
fn bw(s: &str) -> BandwidthQuota {
BandwidthQuota::from_str(s).unwrap()
}
fn create_test_server(context: &AppContext) -> TestServer {
TestServer::new(create_app(
AppState::new(
context.sql_db.clone(),
FileService::new_from_context(context).unwrap(),
"",
context.user_service.clone(),
),
"test",
))
.unwrap()
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_root() {
let context = AppContext::test().await;
let server = create_test_server(&context);
let response = server.get("/").expect_success().await;
response.assert_status_ok();
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_generate_signup_token_fail() {
let context = AppContext::test().await;
let server = create_test_server(&context);
let response = server.get("/generate_signup_token").expect_failure().await;
response.assert_status_unauthorized();
let response = server
.get("/generate_signup_token")
.add_header("X-Admin-Password", "wrongpassword")
.expect_failure()
.await;
response.assert_status_unauthorized();
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_generate_signup_token_success() {
let context = AppContext::test().await;
let server = create_test_server(&context);
let response = server
.get("/generate_signup_token")
.add_header("X-Admin-Password", "test")
.expect_success()
.await;
response.assert_status_ok();
}
fn auth_header() -> String {
let auth = base64::engine::general_purpose::STANDARD.encode("admin:");
format!("Basic {auth}")
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_dav_root_propfind_and_get() {
let context = AppContext::test().await;
let server = create_test_server(&context);
let auth_value = auth_header();
let propfind = Method::from_bytes(b"PROPFIND").unwrap();
let response = server
.method(propfind, "/dav/")
.add_header("Authorization", auth_value.as_str())
.add_header("Depth", "1")
.expect_success()
.await;
response.assert_status(axum::http::StatusCode::MULTI_STATUS);
let response = server
.get("/dav/")
.add_header("Authorization", auth_value.as_str())
.expect_success()
.await;
response.assert_status_ok();
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_dav_put_get_delete_file() {
use crate::persistence::sql::user::UserRepository;
use pubky_common::crypto::Keypair;
let context = AppContext::test().await;
let server = create_test_server(&context);
let auth_value = auth_header();
let keypair = Keypair::from_secret(&[0; 32]);
let pubkey = keypair.public_key();
UserRepository::create(&pubkey, &mut context.sql_db.pool().into())
.await
.unwrap();
let file_content = b"hello webdav";
let file_url = format!("/dav/{}/pub/test.txt", pubkey.z32());
let response = server
.put(&file_url)
.add_header("Authorization", auth_value.as_str())
.bytes(file_content.to_vec().into())
.expect_success()
.await;
response.assert_status(axum::http::StatusCode::CREATED);
let response = server
.get(&file_url)
.add_header("Authorization", auth_value.as_str())
.expect_success()
.await;
response.assert_status_ok();
assert_eq!(response.as_bytes().as_ref(), file_content);
let propfind = Method::from_bytes(b"PROPFIND").unwrap();
let dir_url = format!("/dav/{}/pub/", pubkey.z32());
let response = server
.method(propfind, &dir_url)
.add_header("Authorization", auth_value.as_str())
.add_header("Depth", "1")
.expect_success()
.await;
response.assert_status(axum::http::StatusCode::MULTI_STATUS);
let body = response.text();
assert!(body.contains("test.txt"), "PROPFIND should list the file");
let response = server
.delete(&file_url)
.add_header("Authorization", auth_value.as_str())
.expect_success()
.await;
response.assert_status(axum::http::StatusCode::NO_CONTENT);
let response = server
.get(&file_url)
.add_header("Authorization", auth_value.as_str())
.expect_failure()
.await;
response.assert_status(axum::http::StatusCode::NOT_FOUND);
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_dav_put_quota_overflow_returns_500() {
use crate::persistence::sql::user::UserRepository;
use pubky_common::crypto::Keypair;
let mut context = AppContext::test().await;
context.config_toml.storage.default_quota_mb = Some(1);
let server = create_test_server(&context);
let auth_value = auth_header();
let keypair = Keypair::from_secret(&[0; 32]);
let pubkey = keypair.public_key();
UserRepository::create(&pubkey, &mut context.sql_db.pool().into())
.await
.unwrap();
let pubkey = keypair.public_key().z32();
let file1_url = format!("/dav/{pubkey}/pub/one.bin");
let file2_url = format!("/dav/{pubkey}/pub/two.bin");
let file_content = vec![0u8; 600_000];
let response = server
.put(&file1_url)
.add_header("Authorization", auth_value.as_str())
.bytes(file_content.clone().into())
.expect_success()
.await;
response.assert_status(axum::http::StatusCode::CREATED);
let response = server
.put(&file2_url)
.add_header("Authorization", auth_value.as_str())
.bytes(file_content.into())
.expect_failure()
.await;
response.assert_status(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_generate_signup_token_with_limits() {
use crate::persistence::sql::signup_code::{SignupCodeId, SignupCodeRepository};
use crate::shared::user_quota::QuotaOverride;
let context = AppContext::test().await;
let server = create_test_server(&context);
let body = serde_json::json!({
"storage_quota_mb": 1024,
"rate_read": "200mb/m"
});
let response = server
.post("/generate_signup_token")
.add_header("X-Admin-Password", "test")
.content_type("application/json")
.bytes(serde_json::to_vec(&body).unwrap().into())
.expect_success()
.await;
response.assert_status_ok();
let token_str = response.text();
let code_id = SignupCodeId::new(token_str).unwrap();
let code = SignupCodeRepository::get(&code_id, &mut context.sql_db.pool().into())
.await
.unwrap();
let limits = code.quota();
assert_eq!(limits.storage_quota_mb, QuotaOverride::Value(1024));
assert_eq!(limits.rate_read, QuotaOverride::Value(bw("200mb/m")));
assert_eq!(limits.rate_write, QuotaOverride::Default);
}
}