Skip to main content

shift_proxy/routes/
passthrough.rs

1//! Catch-all passthrough handler.
2//!
3//! Forwards requests to the upstream provider detected from the request
4//! path. Used for routes not explicitly matched by the provider-specific
5//! handlers (e.g., OpenAI batch endpoints, Anthropic beta paths, GET
6//! /v1/models, etc.).
7
8use crate::body::extract_body;
9use crate::forward::forward_request;
10use crate::ProxyState;
11use axum::body::Bytes;
12use axum::extract::State;
13use axum::http::{HeaderMap, Method, StatusCode, Uri};
14use axum::response::{IntoResponse, Response};
15
16/// Catch-all handler — detect provider from path and forward unchanged.
17/// Handles all HTTP methods (GET, POST, PUT, PATCH, DELETE).
18pub async fn passthrough_handler(
19    State(state): State<ProxyState>,
20    method: Method,
21    uri: Uri,
22    headers: HeaderMap,
23    body: Bytes,
24) -> Response {
25    let has_body = !matches!(method, Method::GET | Method::HEAD);
26
27    // Only decompress for methods that carry a body — avoids pointless
28    // work on empty GET/HEAD payloads.
29    let body = if has_body {
30        match extract_body(&headers, body) {
31            Ok(s) => s,
32            Err(e) => {
33                return (
34                    StatusCode::BAD_REQUEST,
35                    axum::Json(serde_json::json!({"error": e})),
36                )
37                    .into_response();
38            }
39        }
40    } else {
41        String::new()
42    };
43
44    let path = uri.path();
45    let provider = detect_provider_from_route(path);
46
47    let base_url = match provider {
48        Some("anthropic") => &state.config.providers.anthropic,
49        Some("openai") => &state.config.providers.openai,
50        Some("google") => &state.config.providers.google,
51        _ => {
52            return (
53                StatusCode::NOT_FOUND,
54                axum::Json(serde_json::json!({
55                    "error": "Unknown route — cannot determine upstream provider"
56                })),
57            )
58                .into_response();
59        }
60    };
61
62    let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
63    let target_url = format!("{}{}{}", base_url, path, query);
64
65    if state.config.verbose {
66        tracing::info!("Passthrough: {} {} → {}{}", method, path, base_url, path);
67    }
68
69    let body = if has_body { Some(body) } else { None };
70
71    forward_request(
72        &state.http_client,
73        method.as_str(),
74        &target_url,
75        &headers,
76        body,
77    )
78    .await
79}
80
81/// Detect which provider a route path belongs to.
82fn detect_provider_from_route(path: &str) -> Option<&'static str> {
83    if path.starts_with("/v1/messages") || path == "/messages" {
84        Some("anthropic")
85    } else if path.starts_with("/v1/chat/")
86        || path.starts_with("/v1/embeddings")
87        || path.starts_with("/v1/responses")
88        || path == "/responses"
89    {
90        Some("openai")
91    } else if path.starts_with("/v1beta/") || path.starts_with("/v1/models/gemini") {
92        Some("google")
93    } else if path.starts_with("/v1/") {
94        // Default to OpenAI for /v1/* paths (most common)
95        Some("openai")
96    } else {
97        None
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn detect_anthropic() {
107        assert_eq!(
108            detect_provider_from_route("/v1/messages"),
109            Some("anthropic")
110        );
111        assert_eq!(
112            detect_provider_from_route("/v1/messages/batches"),
113            Some("anthropic")
114        );
115        // Bare /messages (without /v1 prefix) should also route to Anthropic
116        assert_eq!(detect_provider_from_route("/messages"), Some("anthropic"));
117    }
118
119    #[test]
120    fn detect_openai() {
121        assert_eq!(
122            detect_provider_from_route("/v1/chat/completions"),
123            Some("openai")
124        );
125        assert_eq!(detect_provider_from_route("/v1/embeddings"), Some("openai"));
126        assert_eq!(detect_provider_from_route("/v1/responses"), Some("openai"));
127        // Bare /responses (without /v1 prefix) should also route to OpenAI
128        assert_eq!(detect_provider_from_route("/responses"), Some("openai"));
129    }
130
131    #[test]
132    fn detect_google() {
133        assert_eq!(
134            detect_provider_from_route("/v1beta/models/gemini-2.5-pro:generateContent"),
135            Some("google")
136        );
137    }
138
139    #[test]
140    fn detect_unknown() {
141        assert_eq!(detect_provider_from_route("/unknown"), None);
142    }
143}