gephyr 1.16.18

Gephyr is a headless local AI relay/proxy API handling OpenAI, Claude, and Gemini-compatible APIs
Documentation
use axum::http::{header, HeaderValue, Method};
use tower_http::cors::{AllowOrigin, Any, CorsLayer};

use crate::proxy::config::{CorsConfig, CorsMode};

pub fn cors_layer(config: &CorsConfig) -> CorsLayer {
    let base = CorsLayer::new()
        .allow_methods([
            Method::GET,
            Method::POST,
            Method::PUT,
            Method::DELETE,
            Method::HEAD,
            Method::OPTIONS,
            Method::PATCH,
        ])
        .allow_headers([
            header::AUTHORIZATION,
            header::CONTENT_TYPE,
            header::HeaderName::from_static("x-api-key"),
            header::HeaderName::from_static("x-goog-api-key"),
        ])
        .allow_credentials(false)
        .max_age(std::time::Duration::from_secs(3600));

    if matches!(config.mode, CorsMode::Permissive) {
        return base.allow_origin(Any).allow_headers(Any);
    }

    let allowed_origins: Vec<HeaderValue> = config
        .allowed_origins
        .iter()
        .filter_map(|origin| {
            let trimmed = origin.trim();
            match HeaderValue::from_str(trimmed) {
                Ok(value) => Some(value),
                Err(e) => {
                    tracing::warn!("Ignoring invalid CORS origin {:?}: {}", origin, e);
                    None
                }
            }
        })
        .collect();

    if allowed_origins.is_empty() {
        base
    } else {
        base.allow_origin(AllowOrigin::list(allowed_origins))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request, StatusCode},
        routing::get,
        Router,
    };
    use tower::ServiceExt;

    fn cors_test_router(config: CorsConfig) -> Router {
        Router::new()
            .route("/ping", get(|| async { "ok" }))
            .layer(cors_layer(&config))
    }

    #[test]
    fn test_cors_layer_creation() {
        let _layer = cors_layer(&CorsConfig::default());
    }

    #[tokio::test]
    async fn strict_mode_allows_configured_origin() {
        let app = cors_test_router(CorsConfig {
            mode: CorsMode::Strict,
            allowed_origins: vec!["http://localhost:3000".to_string()],
        });

        let response = app
            .oneshot(
                Request::builder()
                    .method(Method::OPTIONS)
                    .uri("/ping")
                    .header("origin", "http://localhost:3000")
                    .header("access-control-request-method", "GET")
                    .body(Body::empty())
                    .expect("request"),
            )
            .await
            .expect("response");

        assert_eq!(response.status(), StatusCode::OK);
        assert_eq!(
            response.headers().get("access-control-allow-origin"),
            Some(&HeaderValue::from_static("http://localhost:3000"))
        );
    }

    #[tokio::test]
    async fn strict_mode_blocks_unlisted_origin() {
        let app = cors_test_router(CorsConfig {
            mode: CorsMode::Strict,
            allowed_origins: vec!["http://localhost:3000".to_string()],
        });

        let response = app
            .oneshot(
                Request::builder()
                    .method(Method::OPTIONS)
                    .uri("/ping")
                    .header("origin", "http://evil.example")
                    .header("access-control-request-method", "GET")
                    .body(Body::empty())
                    .expect("request"),
            )
            .await
            .expect("response");

        assert_eq!(response.status(), StatusCode::OK);
        assert!(response
            .headers()
            .get("access-control-allow-origin")
            .is_none());
    }

    #[tokio::test]
    async fn strict_mode_with_empty_allowlist_blocks_all_cross_origin_requests() {
        let app = cors_test_router(CorsConfig {
            mode: CorsMode::Strict,
            allowed_origins: vec![],
        });

        let response = app
            .oneshot(
                Request::builder()
                    .method(Method::OPTIONS)
                    .uri("/ping")
                    .header("origin", "http://any.example")
                    .header("access-control-request-method", "GET")
                    .body(Body::empty())
                    .expect("request"),
            )
            .await
            .expect("response");

        assert_eq!(response.status(), StatusCode::OK);
        assert!(response
            .headers()
            .get("access-control-allow-origin")
            .is_none());
    }

    #[tokio::test]
    async fn permissive_mode_allows_any_origin() {
        let app = cors_test_router(CorsConfig {
            mode: CorsMode::Permissive,
            allowed_origins: vec![],
        });

        let response = app
            .oneshot(
                Request::builder()
                    .method(Method::OPTIONS)
                    .uri("/ping")
                    .header("origin", "http://any.example")
                    .header("access-control-request-method", "GET")
                    .body(Body::empty())
                    .expect("request"),
            )
            .await
            .expect("response");

        assert_eq!(response.status(), StatusCode::OK);
        assert_eq!(
            response.headers().get("access-control-allow-origin"),
            Some(&HeaderValue::from_static("*"))
        );
    }
}