volt-server 0.2.1

Server for crates.io/crate/volt-cli
use axum::{
    Router,
    body::Body,
    extract::{Path, State},
    http::{HeaderMap, Request, StatusCode},
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::{get, post},
};

use tokio::{
    fs::{self, File, create_dir_all},
    io::{AsyncWriteExt, BufWriter},
    net::TcpListener,
};

use anyhow::{Context, Result};
use futures::StreamExt;
use serde::Deserialize;
use std::{net::SocketAddr, path::PathBuf, process::ExitCode, sync::Arc};
use tokio_util::io::ReaderStream;
use tracing::{error, info, warn};

#[derive(Clone)]
struct AppState {
    config: ServerConfig,
}

#[derive(Clone, Deserialize)]
struct ServerConfig {
    auth_token: String,
    cache_dir: PathBuf,
    address: String,
}

async fn auth_middleware(State(state): State<Arc<AppState>>, request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
    let auth_header = request
        .headers()
        .get("Authorization")
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.strip_prefix("Bearer "))
        .ok_or_else(|| {
            warn!("Missing or malformed Authorization header");
            StatusCode::UNAUTHORIZED
        })?;

    if auth_header != state.config.auth_token {
        warn!("Invalid authentication token provided");
        return Err(StatusCode::FORBIDDEN);
    }

    Ok(next.run(request).await)
}

async fn logging_middleware(request: Request<Body>, next: Next) -> Response {
    let method = request.method().to_string();
    let uri = request.uri().to_string();
    let start = std::time::Instant::now();

    info!(%method, %uri, "Request started");
    let response = next.run(request).await;
    let status = response.status().as_u16();
    let duration = start.elapsed();

    info!(
        %method,
        %uri,
        %status,
        duration_ms = duration.as_millis(),
        "Request completed"
    );

    response
}

#[tokio::main]
async fn main() -> Result<ExitCode> {
    tracing_subscriber::fmt().with_max_level(tracing::Level::INFO).with_target(false).init();

    let config: ServerConfig = toml::from_str(&tokio::fs::read_to_string("config.toml").await?)?;
    let state = Arc::new(AppState { config: config.clone() });
    let addr = config.address.parse::<SocketAddr>().with_context(|| format!("Failed to parse address: {}", config.address))?;

    print_startup_message(&addr, &config);

    let app = Router::new()
        .route("/health/{volt_id}", get(health))
        .route("/push/{volt_id}", post(push))
        .route("/pull/{volt_id}", get(pull))
        .route("/check/{volt_id}", get(check_hash))
        .layer(middleware::from_fn(logging_middleware))
        .layer(middleware::from_fn_with_state(state.clone(), auth_middleware))
        .with_state(state);

    let listener = TcpListener::bind(addr).await?;
    axum::serve(listener, app).await?;

    Ok(ExitCode::SUCCESS)
}

fn print_startup_message(addr: &SocketAddr, config: &ServerConfig) {
    const BOX_WIDTH: usize = 60;

    fn pad_line(content: &str) -> String { format!("║ {:<BOX_WIDTH$} ║", content) }

    info!(
        r#"
╔══════════════════════════════════════════════════════════════╗
║ started volt server :3                                       ║
╠══════════════════════════════════════════════════════════════╣
║                                                              ║
{}
{}
{}
║                                                              ║
╚══════════════════════════════════════════════════════════════╝
        "#,
        pad_line(&format!("listening on:     {}", addr)),
        pad_line(&format!("cache directory:  {:?}", config.cache_dir)),
        pad_line("authentication:   always on"),
    );
}

async fn health(Path(volt_id): Path<String>) -> String { volt_id }

async fn check_hash(Path(volt_id): Path<String>, State(state): State<Arc<AppState>>, headers: HeaderMap) -> Result<impl IntoResponse, StatusCode> {
    uuid::Uuid::parse_str(&volt_id).map_err(|e| {
        warn!("Invalid UUID format: {}", e);
        StatusCode::BAD_REQUEST
    })?;

    let client_hash = headers.get("X-Volt-Hash").and_then(|h| h.to_str().ok());
    let server_hash_path = state.config.cache_dir.join(format!("{volt_id}.hash"));
    let server_hash = tokio::fs::read_to_string(&server_hash_path).await.ok();

    info!("Hash check: client={client_hash:?} server={server_hash:?}");

    match (client_hash, server_hash) {
        (Some(client_hash), Some(server_hash)) => {
            if client_hash == server_hash.trim() {
                Ok(StatusCode::NOT_MODIFIED.into_response())
            } else {
                Ok(StatusCode::OK.into_response())
            }
        }
        (_, None) => Ok(StatusCode::NOT_FOUND.into_response()),
        (None, _) => {
            warn!("Missing X-Volt-Hash header");
            Err(StatusCode::BAD_REQUEST)
        }
    }
}

async fn push(Path(volt_id): Path<String>, State(state): State<Arc<AppState>>, headers: HeaderMap, body: Body) -> Result<(), StatusCode> {
    uuid::Uuid::parse_str(&volt_id).map_err(|e| {
        warn!("Invalid UUID format: {}", e);
        StatusCode::BAD_REQUEST
    })?;

    create_dir_all(&state.config.cache_dir).await.map_err(|e| {
        error!("Failed to create cache directory: {}", e);
        StatusCode::INTERNAL_SERVER_ERROR
    })?;

    let file_path = state.config.cache_dir.join(format!("{}.zst", volt_id));
    let file = File::create(&file_path).await.map_err(|e| {
        error!("Failed to create file: {}", e);
        StatusCode::INTERNAL_SERVER_ERROR
    })?;

    let mut writer = BufWriter::new(file);
    let mut stream = body.into_data_stream();

    while let Some(chunk) = stream.next().await {
        let chunk = chunk.map_err(|e| {
            error!("Stream error: {}", e);
            StatusCode::BAD_REQUEST
        })?;

        writer.write_all(&chunk).await.map_err(|e| {
            error!("Write error: {}", e);
            StatusCode::INTERNAL_SERVER_ERROR
        })?;
    }

    writer.flush().await.map_err(|e| {
        error!("Flush error: {}", e);
        StatusCode::INTERNAL_SERVER_ERROR
    })?;

    let hash = headers.get("X-Volt-Hash").and_then(|h| h.to_str().ok()).unwrap_or_default();
    let hash_path = state.config.cache_dir.join(format!("{}.hash", volt_id));

    fs::write(hash_path, hash).await.map_err(|e| {
        error!("Failed to write hash file: {}", e);
        StatusCode::INTERNAL_SERVER_ERROR
    })?;

    Ok(())
}

async fn pull(Path(volt_id): Path<String>, State(state): State<Arc<AppState>>, headers: HeaderMap) -> Result<impl IntoResponse, StatusCode> {
    uuid::Uuid::parse_str(&volt_id).map_err(|e| {
        warn!("Invalid UUID format: {}", e);
        StatusCode::BAD_REQUEST
    })?;

    let client_hash = headers.get("X-Volt-Hash").and_then(|h| h.to_str().ok());
    let server_hash_path = state.config.cache_dir.join(format!("{}.hash", volt_id));
    let server_hash = tokio::fs::read_to_string(&server_hash_path).await.ok();

    info!("{client_hash:?} to {server_hash:?}");

    if let (Some(client_hash), Some(server_hash)) = (client_hash, server_hash) {
        if client_hash == server_hash.trim() {
            return Ok(StatusCode::NOT_MODIFIED.into_response());
        }
    }

    let file_path = state.config.cache_dir.join(format!("{}.zst", volt_id));
    let file = File::open(&file_path).await.map_err(|e| {
        if e.kind() == std::io::ErrorKind::NotFound {
            warn!("File not found: {}", volt_id);
            StatusCode::NOT_FOUND
        } else {
            error!("File open error: {}", e);
            StatusCode::INTERNAL_SERVER_ERROR
        }
    })?;

    let stream = ReaderStream::new(file);
    let mut headers = HeaderMap::new();
    headers.insert("Content-Encoding", "zstd".parse().unwrap());

    Ok((headers, Body::from_stream(stream)).into_response())
}