venice-e2ee-proxy 0.1.1

OpenAI-compatible proxy for Venice.ai E2EE models
Documentation
use std::time::Duration;

use axum::{
    Json, Router,
    body::Body,
    http::{Request, StatusCode, header::AUTHORIZATION},
    response::{IntoResponse, Response},
    routing::get,
};
use serde::de::DeserializeOwned;
use serde_json::json;
use tokio::net::TcpListener;
use tower::ServiceExt;
use venice_e2ee_proxy::{
    config::ProxyConfig,
    http::{
        self, HEADER_PROXY_ATTESTATION_MODE, HEADER_PROXY_E2EE, HEADER_PROXY_ERROR_CODE,
        HEADER_PROXY_TOOL_MODE,
    },
    openai::{ErrorResponse, ModelListResponse},
    venice::VeniceClient,
};

const TEST_API_KEY: &str = "test-api-key";

#[tokio::test]
async fn get_models_returns_filtered_openai_model_list() {
    let base_url =
        spawn_mock_venice(Router::new().route("/api/v1/models", get(successful_models))).await;
    let app = proxy_app(base_url, Duration::from_secs(1));

    let response = request_models(app).await;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(
        response
            .headers()
            .get(HEADER_PROXY_ATTESTATION_MODE)
            .unwrap(),
        "independent"
    );
    assert_eq!(
        response.headers().get(HEADER_PROXY_TOOL_MODE).unwrap(),
        "emulated"
    );
    assert!(response.headers().get(HEADER_PROXY_E2EE).is_none());

    let body: ModelListResponse = response_json(response).await;
    assert_eq!(body.object, "list");
    assert_eq!(body.data.len(), 1);

    let model = &body.data[0];
    assert_eq!(model.id, "e2ee-qwen3-5-122b-a10b");
    assert_eq!(model.object, "model");
    assert_eq!(model.created, 1727966436);
    assert_eq!(model.owned_by, "venice.ai");
    assert_eq!(model.name, "e2ee-qwen3-5-122b-a10b");
    assert!(model.info.meta.capabilities.function_calling);
    assert!(model.info.meta.capabilities.builtin_tools);
    assert!(model.info.meta.capabilities.web_search);
    assert!(model.info.meta.capabilities.code_interpreter);
    assert!(!model.info.meta.capabilities.vision);
    assert!(model.info.meta.capabilities.reasoning);
    assert!(model.info.meta.capabilities.reasoning_effort);
    assert_eq!(model.venice.id, "e2ee-qwen3-5-122b-a10b");
    assert!(model.venice.supports_e2ee);
    assert!(model.venice.supports_tee_attestation);
    assert!(model.venice.supports_reasoning);
    assert!(model.venice.supports_reasoning_effort);
}

#[tokio::test]
async fn get_models_fails_closed_on_upstream_authentication_errors() {
    for upstream_status in [StatusCode::UNAUTHORIZED, StatusCode::FORBIDDEN] {
        let base_url = spawn_mock_venice(Router::new().route(
            "/api/v1/models",
            get(move || async move { upstream_status }),
        ))
        .await;
        let app = proxy_app(base_url, Duration::from_secs(1));

        let response = request_models(app).await;

        assert_proxy_error(
            response,
            "proxy_upstream_authentication_error",
            "upstream_authentication_failed",
        )
        .await;
    }
}

#[tokio::test]
async fn get_models_fails_closed_on_upstream_server_error() {
    let base_url = spawn_mock_venice(Router::new().route(
        "/api/v1/models",
        get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
    ))
    .await;
    let app = proxy_app(base_url, Duration::from_secs(1));

    let response = request_models(app).await;

    assert_proxy_error(response, "proxy_upstream_error", "upstream_status_error").await;
}

#[tokio::test]
async fn get_models_fails_closed_on_malformed_upstream_payload() {
    let base_url = spawn_mock_venice(Router::new().route(
        "/api/v1/models",
        get(|| async {
            Json(json!({
                "data": [
                    {
                        "id": "missing-required-attestation-flag",
                        "type": "text",
                        "model_spec": {
                            "capabilities": {
                                "supportsE2EE": true
                            }
                        }
                    }
                ]
            }))
        }),
    ))
    .await;
    let app = proxy_app(base_url, Duration::from_secs(1));

    let response = request_models(app).await;

    assert_proxy_error(
        response,
        "proxy_upstream_error",
        "upstream_malformed_response",
    )
    .await;
}

#[tokio::test]
async fn get_models_fails_closed_on_upstream_timeout() {
    let base_url = spawn_mock_venice(Router::new().route("/api/v1/models", get(slow_models))).await;
    let app = proxy_app(base_url, Duration::from_millis(20));

    let response = request_models(app).await;

    assert_proxy_error(response, "proxy_upstream_error", "upstream_timeout").await;
}

async fn successful_models(headers: axum::http::HeaderMap) -> Response {
    if headers
        .get(AUTHORIZATION)
        .and_then(|value| value.to_str().ok())
        != Some("Bearer test-api-key")
    {
        return StatusCode::UNAUTHORIZED.into_response();
    }

    Json(json!({
        "object": "list",
        "data": [
            {
                "id": "e2ee-qwen3-5-122b-a10b",
                "created": 1727966436,
                "owned_by": "venice.ai",
                "type": "text",
                "model_spec": {
                    "capabilities": {
                        "supportsE2EE": true,
                        "supportsTeeAttestation": true,
                        "supportsFunctionCalling": true,
                        "supportsBuiltinTools": true,
                        "supportsWebSearch": true,
                        "supportsCodeInterpreter": true,
                        "supportsVision": false,
                        "supportsReasoning": true,
                        "supportsReasoningEffort": true
                    }
                }
            },
            {
                "id": "non-e2ee-text",
                "type": "text",
                "model_spec": {
                    "capabilities": {
                        "supportsE2EE": false,
                        "supportsTeeAttestation": true
                    }
                }
            },
            {
                "id": "e2ee-without-attestation",
                "type": "text",
                "model_spec": {
                    "capabilities": {
                        "supportsE2EE": true,
                        "supportsTeeAttestation": false
                    }
                }
            },
            {
                "id": "e2ee-image",
                "type": "image",
                "model_spec": {
                    "capabilities": {
                        "supportsE2EE": true,
                        "supportsTeeAttestation": true
                    }
                }
            }
        ]
    }))
    .into_response()
}

async fn slow_models() -> impl IntoResponse {
    tokio::time::sleep(Duration::from_millis(200)).await;
    Json(json!({ "data": [] }))
}

fn proxy_app(base_url: String, timeout: Duration) -> Router {
    let client = VeniceClient::new(base_url, TEST_API_KEY, timeout)
        .expect("test Venice client should build");
    http::router_with_venice_client(ProxyConfig::default(), client)
}

async fn spawn_mock_venice(app: Router) -> String {
    let listener = TcpListener::bind(("127.0.0.1", 0))
        .await
        .expect("mock Venice listener should bind");
    let addr = listener
        .local_addr()
        .expect("mock Venice listener should have local address");

    tokio::spawn(async move {
        axum::serve(listener, app)
            .await
            .expect("mock Venice server should run");
    });

    format!("http://{addr}/api/v1")
}

async fn request_models(app: Router) -> Response {
    app.oneshot(
        Request::builder()
            .uri("/v1/models")
            .body(Body::empty())
            .expect("request should build"),
    )
    .await
    .expect("request should complete")
}

async fn assert_proxy_error(response: Response, expected_type: &str, expected_code: &str) {
    assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
    assert_eq!(
        response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
        expected_code
    );

    let body: ErrorResponse = response_json(response).await;
    assert_eq!(body.error.kind, expected_type);
    assert_eq!(body.error.code, expected_code);
}

async fn response_json<T: DeserializeOwned>(response: Response) -> T {
    let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
        .await
        .expect("response body should buffer");
    serde_json::from_slice(&bytes).expect("response should be JSON")
}