ares-server 0.7.5

A.R.E.S - Agentic Retrieval Enhanced Server: A production-grade agentic chatbot server with multi-provider LLM support, tool calling, RAG, and MCP integration
Documentation
use crate::db::tenants::TenantDb;
use axum::{
    extract::Request,
    http::StatusCode,
    middleware::Next,
    response::{IntoResponse, Response},
    Json,
};
use std::sync::Arc;

pub async fn api_key_auth_middleware(req: Request, next: Next) -> Response {
    let auth_header = match req.headers().get("authorization") {
        Some(h) => h,
        None => {
            return error_response(StatusCode::UNAUTHORIZED, "Missing Authorization header");
        }
    };

    let auth_str = match auth_header.to_str() {
        Ok(s) => s,
        Err(_) => {
            return error_response(StatusCode::UNAUTHORIZED, "Invalid Authorization header");
        }
    };

    let api_key = match auth_str.strip_prefix("Bearer ") {
        Some(k) => k,
        None => {
            return error_response(
                StatusCode::UNAUTHORIZED,
                "Invalid Authorization format. Expected: Bearer ares_...",
            );
        }
    };

    if !api_key.starts_with("ares_") {
        return error_response(
            StatusCode::UNAUTHORIZED,
            "Invalid API key format. Must start with ares_",
        );
    }

    let extensions = req.extensions();
    let tenant_db: Arc<TenantDb> = match extensions.get::<Arc<TenantDb>>() {
        Some(db) => db.clone(),
        None => {
            return error_response(
                StatusCode::INTERNAL_SERVER_ERROR,
                "Tenant database not configured",
            );
        }
    };

    let tenant_ctx = match tenant_db.verify_api_key(api_key).await {
        Ok(Some(ctx)) => ctx,
        Ok(None) => {
            return error_response(StatusCode::UNAUTHORIZED, "Invalid API key");
        }
        Err(e) => {
            tracing::error!("API key verification error: {}", e);
            return error_response(
                StatusCode::INTERNAL_SERVER_ERROR,
                "Failed to verify API key",
            );
        }
    };

    let monthly_usage = match tenant_db.get_monthly_requests(&tenant_ctx.tenant_id).await {
        Ok(m) => m,
        Err(_) => {
            return error_response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to check usage");
        }
    };

    let daily_usage = match tenant_db.get_daily_requests(&tenant_ctx.tenant_id).await {
        Ok(d) => d,
        Err(_) => {
            return error_response(
                StatusCode::INTERNAL_SERVER_ERROR,
                "Failed to check rate limit",
            );
        }
    };

    if !tenant_ctx.can_make_request(monthly_usage, daily_usage) {
        if monthly_usage >= tenant_ctx.quota.requests_per_month {
            return error_response(
                StatusCode::TOO_MANY_REQUESTS,
                "Monthly request quota exceeded",
            );
        }
        if daily_usage >= tenant_ctx.quota.requests_per_day {
            return error_response(StatusCode::TOO_MANY_REQUESTS, "Daily rate limit exceeded");
        }
    }

    let mut req = req;
    req.extensions_mut().insert(tenant_ctx);

    next.run(req).await
}

fn error_response(status: StatusCode, message: &str) -> Response {
    let body = Json(serde_json::json!({
        "error": message
    }));
    (status, body).into_response()
}

pub use crate::auth::middleware::AuthUser;

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request, StatusCode},
        routing::get,
        Router,
    };
    use tower::ServiceExt;

    async fn protected_handler() -> &'static str {
        "protected content"
    }

    #[tokio::test]
    async fn test_middleware_no_auth_header() {
        let app = Router::new()
            .route("/protected", get(protected_handler))
            .layer(axum::middleware::from_fn(api_key_auth_middleware));

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/protected")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn test_middleware_invalid_format() {
        let app = Router::new()
            .route("/protected", get(protected_handler))
            .layer(axum::middleware::from_fn(api_key_auth_middleware));

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/protected")
                    .header("Authorization", "Basic abc123")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn test_middleware_missing_prefix() {
        let app = Router::new()
            .route("/protected", get(protected_handler))
            .layer(axum::middleware::from_fn(api_key_auth_middleware));

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/protected")
                    .header("Authorization", "Bearer abc123")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }
}