coproxy 0.6.1

OpenAI-compatible API proxy backed by GitHub Copilot
Documentation
use crate::openai::error::ApiError;
use crate::server::routes::auth;
use crate::state::AppState;
use axum::Json;
use axum::body::Body;
use axum::extract::{Path, RawQuery, State};
use axum::http::{HeaderMap, StatusCode, header};
use axum::response::Response;

pub async fn create_response(
    State(state): State<AppState>,
    headers: HeaderMap,
    Json(payload): Json<serde_json::Value>,
) -> Result<Response, ApiError> {
    auth::authorize(&headers, state.api_key.as_deref())?;

    let stream_flag = payload
        .get("stream")
        .and_then(serde_json::Value::as_bool)
        .unwrap_or(false);
    let model = payload
        .get("model")
        .and_then(serde_json::Value::as_str)
        .map(str::to_owned);
    let span = tracing::Span::current();
    span.record("stream", stream_flag);
    if let Some(m) = model.as_deref() {
        span.record("model", m);
    }
    tracing::debug!(model = ?model, stream = stream_flag, "responses dispatch");

    let upstream = state
        .provider
        .create_response(payload, state.default_model.as_deref())
        .await
        .map_err(ApiError::from_provider_error)?;

    tracing::debug!(status = %upstream.status(), "responses upstream returned");

    Ok(proxy_upstream_response(upstream))
}

pub async fn get_response(
    State(state): State<AppState>,
    headers: HeaderMap,
    Path(response_id): Path<String>,
    RawQuery(raw_query): RawQuery,
) -> Result<Response, ApiError> {
    auth::authorize(&headers, state.api_key.as_deref())?;

    tracing::debug!(response_id = %response_id, "get response dispatch");

    let upstream = state
        .provider
        .get_response(&response_id, raw_query.as_deref())
        .await
        .map_err(ApiError::from_provider_error)?;

    tracing::debug!(status = %upstream.status(), "get response upstream returned");

    Ok(proxy_upstream_response(upstream))
}

fn proxy_upstream_response(upstream: reqwest::Response) -> Response {
    let status =
        StatusCode::from_u16(upstream.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
    let headers = upstream.headers().clone();
    let stream = upstream.bytes_stream();

    let mut response = Response::new(Body::from_stream(stream));
    *response.status_mut() = status;

    for (name, value) in &headers {
        if *name == header::CONTENT_LENGTH
            || *name == header::TRANSFER_ENCODING
            || *name == header::CONNECTION
        {
            continue;
        }
        response.headers_mut().insert(name, value.clone());
    }

    response
}