patlite-beacon-serv 0.1.0

RESTful API server for controlling PATLITE USB beacons with comprehensive light patterns, sequences, and buzzer control
use axum::{
    extract::{Query, Request, State},
    http::StatusCode,
    middleware::Next,
    response::Response,
};
use crate::{api::ApiKeyQuery, AppState};

pub async fn auth_middleware(
    State(state): State<AppState>,
    Query(query): Query<ApiKeyQuery>,
    request: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    if let Some(ref expected_key) = state.api_key {
        if let Some(ref provided_key) = query.api_key {
            if provided_key == expected_key {
                return Ok(next.run(request).await);
            }
        }
        return Err(StatusCode::UNAUTHORIZED);
    }
    
    Ok(next.run(request).await)
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request, StatusCode},
        middleware,
        routing::get,
        Router,
    };
    use std::sync::{Arc, RwLock};
    use tower::ServiceExt;
    use crate::beacon::BeaconController;

    async fn dummy_handler() -> &'static str {
        "OK"
    }

    #[tokio::test]
    async fn test_auth_middleware_without_api_key() {
        let state = AppState {
            beacon: Arc::new(RwLock::new(BeaconController::mock())),
            api_key: None,
        };

        let app = Router::new()
            .route("/test", get(dummy_handler))
            .layer(middleware::from_fn_with_state(
                state.clone(),
                auth_middleware,
            ))
            .with_state(state);

        let response = app
            .oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);
    }

    #[tokio::test]
    async fn test_auth_middleware_with_correct_api_key() {
        let api_key = "test-key-123";
        let state = AppState {
            beacon: Arc::new(RwLock::new(BeaconController::mock())),
            api_key: Some(api_key.to_string()),
        };

        let app = Router::new()
            .route("/test", get(dummy_handler))
            .layer(middleware::from_fn_with_state(
                state.clone(),
                auth_middleware,
            ))
            .with_state(state);

        let response = app
            .oneshot(
                Request::builder()
                    .uri(format!("/test?apiKey={}", api_key))
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);
    }

    #[tokio::test]
    async fn test_auth_middleware_with_incorrect_api_key() {
        let state = AppState {
            beacon: Arc::new(RwLock::new(BeaconController::mock())),
            api_key: Some("correct-key".to_string()),
        };

        let app = Router::new()
            .route("/test", get(dummy_handler))
            .layer(middleware::from_fn_with_state(
                state.clone(),
                auth_middleware,
            ))
            .with_state(state);

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/test?apiKey=wrong-key")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn test_auth_middleware_missing_api_key() {
        let state = AppState {
            beacon: Arc::new(RwLock::new(BeaconController::mock())),
            api_key: Some("required-key".to_string()),
        };

        let app = Router::new()
            .route("/test", get(dummy_handler))
            .layer(middleware::from_fn_with_state(
                state.clone(),
                auth_middleware,
            ))
            .with_state(state);

        let response = app
            .oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }
}