akas 2.4.18

AKAS: API Key Authorization Server
use crate::{
    models::{ApiKey, KeyFormat, LoadForm, StatusResponse},
    state::AppState,
    utils::{check_admin, parse_headers},
};
use actix_multipart::form::MultipartForm;
use actix_web::{get, post, web, HttpRequest, HttpResponse, Responder, Result};
use actix_web_httpauth::extractors::bearer::BearerAuth;
use chrono::Utc;
use sha2::{Digest, Sha256};
use std::io::{BufRead, BufReader, Read};
use tracing::{error, info, warn};

/// Authenticates a user based on a bearer token and checks key length and prefix.
///
/// This function extracts the bearer token from the request, calculates its SHA256 hash,
/// and checks if the hash exists in the shared `hash_key_set`.
/// It also validates the key length and prefix against the configured values in `AppState`.
///
/// # Parameters
///
/// * `req`: The HTTP request. Used to extract the `x-original-host` and `x-original-uri` headers.
/// * `bearer_auth`: The bearer token extractor.
/// * `data`: Shared application state containing the set of valid key hashes, key length, and prefix.
///
/// # Returns
///
/// * `HttpResponse::Ok()` if the key is valid.
/// * `HttpResponse::Unauthorized()` if the key is invalid, has an incorrect length, or an incorrect prefix.
///
#[get("/auth")]
pub async fn auth(
    req: HttpRequest,
    bearer_auth: BearerAuth,
    data: web::Data<AppState>,
) -> impl Responder {
    let headers = parse_headers(req.headers(), data.original_length, data.metadata_length);

    let key: &str = bearer_auth.token();

    let api_key = match ApiKey::new(key, data.key_length, &data.key_prefix) {
        Some(k) => k,
        None => {
            warn!(
                forwarded_for = headers.forwarded_for,
                original_host = headers.original_host,
                original_uri = headers.original_uri,
                metadata = headers.metadata,
                access = "unauthorized",
                "invalid key - length or prefix mismatch",
            );
            data.auth_counter
                .with_label_values(&[
                    "/auth",
                    "GET",
                    "401",
                    &headers.original_host,
                    &headers.metadata,
                ])
                .inc();
            return HttpResponse::Unauthorized();
        }
    };

    // Hash the key using SHA256
    let sha256_key = api_key.hash();

    // Read access to hash_key_set
    let hash_key_set = data.hash_key_set.read().unwrap();
    if hash_key_set.contains(&sha256_key) {
        // Only 12 first chars
        let sha256_key_short = &sha256_key[..12];
        info!(
            key_hash = sha256_key_short,
            forwarded_for = headers.forwarded_for,
            original_host = headers.original_host,
            original_uri = headers.original_uri,
            metadata = headers.metadata,
            access = "authorized",
            "access authorized",
        );
        data.auth_counter
            .with_label_values(&[
                "/auth",
                "GET",
                "200",
                &headers.original_host,
                &headers.metadata,
            ])
            .inc();
        HttpResponse::Ok()
    } else {
        warn!(
            key_hash = sha256_key,
            forwarded_for = headers.forwarded_for,
            original_host = headers.original_host,
            original_uri = headers.original_uri,
            metadata = headers.metadata,
            access = "unauthorized",
            "access unauthorized",
        );
        data.auth_counter
            .with_label_values(&[
                "/auth",
                "GET",
                "401",
                &headers.original_host,
                &headers.metadata,
            ])
            .inc();
        HttpResponse::Unauthorized()
    }
}

/// Handles unauthorized authentication requests.
///
/// This endpoint always returns an HTTP Unauthorized (401) status code.
/// It's primarily used for testing and demonstration purposes.
///
#[get("/auth-unauthorized")]
pub async fn auth_unauthorized() -> impl Responder {
    info!("access: auth-unauthorized");
    HttpResponse::Unauthorized()
}

/// Handles health requests.
///
/// This endpoint always returns an HTTP OK (200) status code.
///
#[get("/health")]
pub async fn health() -> impl Responder {
    info!("health check");
    HttpResponse::Ok()
}

/// Handles file uploads, extracts keys, and adds them to a shared hash set.
///
/// # Parameters
///
/// * `MultipartForm(form)`: A multipart form containing an uploaded file.
/// * `bearer_auth`: An extractor that provides the bearer key from the request.
/// * `data`: Shared application state with a `HashSet` to store keys.
///
/// # Returns
///
/// * `HttpResponse::Ok()` on success.
/// * `HttpResponse::BadRequest` or `HttpResponse::InternalServerError` on error.
///
#[post("/load")]
pub async fn load(
    MultipartForm(form): MultipartForm<LoadForm>,
    bearer_auth: BearerAuth,
    data: web::Data<AppState>,
) -> Result<HttpResponse, actix_web::Error> {
    // Check Auth bearer if admin key provided
    if !check_admin(bearer_auth.token(), &data) {
        error!("Invalid admin key");
        return Err(actix_web::error::ErrorUnauthorized("Invalid admin key"));
    }

    if form.file.size == 0 {
        error!("Empty file");
        return Err(actix_web::error::ErrorBadRequest("Empty file"));
    }

    let mut key_format = KeyFormat::Sha256;
    let mut hash_input_file = None;

    // Check if json data is present
    if let Some(json) = form.json {
        if let Some(fmt) = json.format.as_deref() {
            key_format = KeyFormat::from_str(fmt).ok_or_else(|| {
                error!("Format not supported: {}", fmt);
                actix_web::error::ErrorBadRequest("Format not supported")
            })?;
        }
        // If hash_input_file is present
        if let Some(hash) = json.hash_input_file.as_deref() {
            if hash.len() == 64 {
                hash_input_file = Some(hash.to_string());
            } else {
                warn!("{} is not a valid sha-256 hash", hash);
                return Err(actix_web::error::ErrorBadRequest(
                    "Input file hash is not valid",
                ));
            }
        }
    }
    // List of keys to be loaded
    let mut content_temp_file = String::new();
    form.file
        .file
        .as_file()
        .read_to_string(&mut content_temp_file)
        .map_err(|err| {
            error!("Error on reading temp file");
            actix_web::error::ErrorInternalServerError(err)
        })?;
    let content = content_temp_file.as_bytes();

    // Calculate hash of the file
    let mut hasher = Sha256::new();
    hasher.update(content);
    let hash_bytes = hasher.finalize();
    let hash_file = hex::encode(hash_bytes);

    // If hash is present in json, compare it with the calculated one
    if let Some(hash_input) = hash_input_file {
        if hash_file != hash_input {
            error!(
                "Hash mismatch - input hash: {}, calculate hash: {}",
                hash_input, hash_file
            );
            return Err(actix_web::error::ErrorBadRequest("Hash mismatch"));
        }
    }
    // Insert keys into tempory Vec
    let mut temp_hash_keys: Vec<String> = Vec::new(); // Stores the read hash keys
    let mut count_hash: u32 = 0; // Counter for valid hash keys
    let mut count_line: u32 = 0; // Counter for total lines processed
    let reader = BufReader::new(content);

    for line in reader.lines() {
        match line {
            Ok(content) => {
                // Not empty line
                if !content.is_empty() {
                    count_line += 1;
                    let mut words = content.split_whitespace(); // Split the line by whitespace
                    if let Some(key) = words.next() {
                        match key_format {
                            KeyFormat::Plain => {
                                // Check length of key if length arg is defined
                                if let Some(api_key) =
                                    ApiKey::new(key, data.key_length, &data.key_prefix)
                                {
                                    // Ok - hash the key
                                    temp_hash_keys.push(api_key.hash()); // Add hash key to the set
                                    count_hash += 1;
                                } else {
                                    // Optionally log why the key was skipped if you need more verbosity
                                    warn!(
                                        "Skipping plain key due to invalid key length or prefix: {}",
                                        key
                                    );
                                }
                            }
                            KeyFormat::Sha256 => {
                                // Format is sha256
                                if key.len() == 64 {
                                    temp_hash_keys.push(key.to_string()); // Add hash key to the set
                                    count_hash += 1;
                                }
                            }
                        }
                    }
                }
            }
            Err(e) => {
                error!("Error when reading line {}: {}", count_line, e);
                return Err(actix_web::error::ErrorInternalServerError(
                    "Error when reading file",
                )); // Return error on line reading error
            }
        }
    }
    info!("Lines count: {} - Keys count: {}", count_line, count_hash);
    if count_hash == 0 {
        return Err(actix_web::error::ErrorBadRequest(
            "No valid keys found in the file",
        ));
    }
    // Acquire a mutable lock on the HashSet
    let mut data_hash_key_set = data.hash_key_set.write().map_err(|err| {
        error!("Error acquiring hash_key_set lock: {}", err);
        actix_web::error::ErrorInternalServerError("Error acquiring hash_key_set lock")
    })?;
    // Copy the contents of the temporary Vec to the HashSet
    data_hash_key_set.clear();
    data_hash_key_set.extend(temp_hash_keys.drain(..));

    info!(
        "File loaded, format: {}, hash: {}",
        key_format.as_str(),
        hash_file
    );

    // Update metadata of uploaded file
    if let Ok(mut data_file_hash) = data.file_hash.write() {
        info!("File metadata - hash: {}", hash_file);
        *data_file_hash = hash_file;
    } else {
        error!("Error acquiring file_hash lock");
    }
    let now = Utc::now();
    let iso_now = now.to_rfc3339();
    if let Ok(mut data_file_date) = data.file_date.write() {
        info!("File metadata - date: {}", iso_now);
        *data_file_date = iso_now;
    } else {
        error!("Error acquiring file_date lock");
    }
    if let Ok(mut data_file_key_count) = data.file_key_count.write() {
        info!("File metadata - keys count: {}", count_hash);
        *data_file_key_count = count_hash;
    } else {
        error!("Error acquiring file_key_count lock");
    }

    Ok(HttpResponse::Ok().body("File loaded"))
}

/// Handles the GET `/status` route.
///
/// This function provides information about the application's current status.
/// It **requires administrator authentication** via a Bearer token.
///
/// # Arguments
///
/// * `bearer_auth` - A [`BearerAuth`] containing the authentication token sent in the headers.
///                   Used to verify administrator privileges.
/// * `data` - A [`web::Data<AppState>`] providing access to the application's shared state,
///            including information such as log level, file hashes, and other metadata.
///
#[get("/status")]
pub async fn status(
    bearer_auth: BearerAuth,
    data: web::Data<AppState>,
) -> Result<actix_web::web::Json<StatusResponse>, actix_web::Error> {
    // Check Auth bearer if admin key provided
    if !check_admin(bearer_auth.token(), &data) {
        error!("Invalid admin key");
        return Err(actix_web::error::ErrorUnauthorized("Invalid admin key"));
    }
    let file_hash = data.file_hash.read().unwrap().clone();
    let file_date = data.file_date.read().unwrap().clone();
    let file_key_count = *data.file_key_count.read().unwrap();
    let response = StatusResponse {
        log_level: data.log_level.to_string(),
        no_admin_key: data.no_admin_key,
        local: data.local,
        enable_metrics: data.enable_metrics,
        original_length: data.original_length,
        metadata_length: data.metadata_length,
        file_hash,
        file_date,
        file_key_count,
        key_length: data.key_length,
        key_prefix: data.key_prefix.clone(),
    };
    info!("access: status");
    Ok(web::Json(response))
}

#[cfg(test)]
mod actix_unit_tests {
    use super::*;
    use actix_web::{http::StatusCode, test, App};

    #[actix_web::test]
    async fn test_health() {
        let resp = test::call_service(
            &test::init_service(App::new().service(health)).await,
            test::TestRequest::get().uri("/health").to_request(),
        )
        .await;

        assert_eq!(resp.status(), StatusCode::OK);
    }

    #[actix_web::test]
    async fn test_auth_unauthorized() {
        let resp = test::call_service(
            &test::init_service(App::new().service(auth_unauthorized)).await,
            test::TestRequest::get()
                .uri("/auth-unauthorized")
                .to_request(),
        )
        .await;

        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
    }
}