Skip to main content

byokey_proxy/
lib.rs

1//! HTTP proxy layer — axum router, route handlers, and error mapping.
2//!
3//! Exposes an OpenAI-compatible `/v1/chat/completions` endpoint, a `/v1/models`
4//! listing, and an Amp CLI compatibility layer under `/amp/*`.
5
6mod amp;
7mod amp_provider;
8mod chat;
9mod error;
10mod messages;
11mod models;
12pub mod usage;
13
14pub use error::ApiError;
15pub use usage::UsageStats;
16
17use arc_swap::ArcSwap;
18use axum::{
19    Json, Router,
20    extract::State,
21    routing::{any, get, post},
22};
23use byokey_auth::AuthManager;
24use byokey_config::Config;
25use std::sync::Arc;
26use tower_http::trace::TraceLayer;
27
28/// Shared application state passed to all route handlers.
29pub struct AppState {
30    /// Server configuration (providers, listen address, etc.).
31    /// Atomically swappable for hot-reloading.
32    pub config: Arc<ArcSwap<Config>>,
33    /// Token manager for OAuth-based providers.
34    pub auth: Arc<AuthManager>,
35    /// HTTP client for upstream requests.
36    pub http: rquest::Client,
37    /// In-memory usage statistics.
38    pub usage: Arc<UsageStats>,
39}
40
41impl AppState {
42    /// Creates a new shared application state wrapped in an `Arc`.
43    ///
44    /// If the config specifies a `proxy_url`, the HTTP client is built with that proxy.
45    pub fn new(config: Arc<ArcSwap<Config>>, auth: Arc<AuthManager>) -> Arc<Self> {
46        let snapshot = config.load();
47        let http = build_http_client(snapshot.proxy_url.as_deref());
48        Arc::new(Self {
49            config,
50            auth,
51            http,
52            usage: Arc::new(UsageStats::new()),
53        })
54    }
55}
56
57/// Build an HTTP client, optionally configured with a proxy URL.
58fn build_http_client(proxy_url: Option<&str>) -> rquest::Client {
59    if let Some(url) = proxy_url {
60        match rquest::Proxy::all(url) {
61            Ok(proxy) => {
62                return rquest::Client::builder()
63                    .proxy(proxy)
64                    .build()
65                    .unwrap_or_else(|_| rquest::Client::new());
66            }
67            Err(e) => {
68                tracing::warn!(url = url, error = %e, "invalid proxy_url, using direct connection");
69            }
70        }
71    }
72    rquest::Client::new()
73}
74
75/// Build the full axum router.
76///
77/// Routes:
78/// - POST /v1/chat/completions                          OpenAI-compatible
79/// - POST /v1/messages                                  Anthropic native passthrough
80/// - GET  /v1/models
81/// - GET  /amp/v1/login
82/// - ANY  /amp/v0/management/{*path}
83/// - POST /amp/v1/chat/completions
84///
85/// `AmpCode` provider routes:
86/// - POST /api/provider/anthropic/v1/messages           Anthropic native (`AmpCode`)
87/// - POST /api/provider/openai/v1/chat/completions      `OpenAI`-compatible (`AmpCode`)
88/// - POST /api/provider/openai/v1/responses             Codex Responses API (`AmpCode`)
89/// - POST /api/provider/google/v1beta/models/{action}   Gemini native (`AmpCode`)
90/// - ANY  /api/{*path}                                  `ampcode.com` management proxy
91pub fn make_router(state: Arc<AppState>) -> Router {
92    Router::new()
93        // Standard routes
94        .route("/v1/chat/completions", post(chat::chat_completions))
95        .route("/v1/messages", post(messages::anthropic_messages))
96        .route("/v1/models", get(models::list_models))
97        // Amp CLI routes
98        .route("/amp/auth/cli-login", get(amp::cli_login_redirect))
99        .route("/amp/v1/login", get(amp::login_redirect))
100        .route("/amp/v0/management/{*path}", any(amp::management_proxy))
101        .route("/amp/v1/chat/completions", post(chat::chat_completions))
102        // AmpCode provider-specific routes (must be registered before the catch-all)
103        .route(
104            "/api/provider/anthropic/v1/messages",
105            post(messages::anthropic_messages),
106        )
107        .route(
108            "/api/provider/openai/v1/chat/completions",
109            post(chat::chat_completions),
110        )
111        .route(
112            "/api/provider/openai/v1/responses",
113            post(amp_provider::codex_responses_passthrough),
114        )
115        .route(
116            "/api/provider/google/v1beta/models/{action}",
117            post(amp_provider::gemini_native_passthrough),
118        )
119        // Catch-all: forward remaining /api/* routes to ampcode.com
120        .route("/api/{*path}", any(amp_provider::amp_management_proxy))
121        // Management API
122        .route("/v0/management/usage", get(usage_handler))
123        .with_state(state)
124        .layer(TraceLayer::new_for_http())
125}
126
127async fn usage_handler(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
128    let snap = state.usage.snapshot();
129    Json(serde_json::to_value(snap).unwrap_or_default())
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use axum::{body::Body, http::Request};
136    use byokey_store::InMemoryTokenStore;
137    use http_body_util::BodyExt as _;
138    use serde_json::Value;
139    use tower::ServiceExt as _;
140
141    fn make_state() -> Arc<AppState> {
142        let store = Arc::new(InMemoryTokenStore::new());
143        let auth = Arc::new(AuthManager::new(store, rquest::Client::new()));
144        let config = Arc::new(ArcSwap::from_pointee(Config::default()));
145        AppState::new(config, auth)
146    }
147
148    async fn body_json(resp: axum::response::Response) -> Value {
149        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
150        serde_json::from_slice(&bytes).unwrap()
151    }
152
153    #[tokio::test]
154    async fn test_list_models_empty_config() {
155        let app = make_router(make_state());
156        let resp = app
157            .oneshot(
158                Request::builder()
159                    .uri("/v1/models")
160                    .body(Body::empty())
161                    .unwrap(),
162            )
163            .await
164            .unwrap();
165
166        assert_eq!(resp.status(), axum::http::StatusCode::OK);
167        let json = body_json(resp).await;
168        assert_eq!(json["object"], "list");
169        assert!(json["data"].is_array());
170        // All providers are enabled by default even without explicit config.
171        assert!(!json["data"].as_array().unwrap().is_empty());
172    }
173
174    #[tokio::test]
175    async fn test_amp_login_redirect() {
176        let app = make_router(make_state());
177        let resp = app
178            .oneshot(
179                Request::builder()
180                    .uri("/amp/v1/login")
181                    .body(Body::empty())
182                    .unwrap(),
183            )
184            .await
185            .unwrap();
186
187        assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
188        assert_eq!(
189            resp.headers().get("location").and_then(|v| v.to_str().ok()),
190            Some("https://ampcode.com/login")
191        );
192    }
193
194    #[tokio::test]
195    async fn test_amp_cli_login_redirect() {
196        let app = make_router(make_state());
197        let resp = app
198            .oneshot(
199                Request::builder()
200                    .uri("/amp/auth/cli-login?authToken=abc123&callbackPort=35789")
201                    .body(Body::empty())
202                    .unwrap(),
203            )
204            .await
205            .unwrap();
206
207        assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
208        assert_eq!(
209            resp.headers().get("location").and_then(|v| v.to_str().ok()),
210            Some("https://ampcode.com/auth/cli-login?authToken=abc123&callbackPort=35789")
211        );
212    }
213
214    #[tokio::test]
215    async fn test_chat_unknown_model_returns_400() {
216        use serde_json::json;
217
218        let app = make_router(make_state());
219        let body = json!({"model": "nonexistent-model-xyz", "messages": []});
220        let resp = app
221            .oneshot(
222                Request::builder()
223                    .method("POST")
224                    .uri("/v1/chat/completions")
225                    .header("content-type", "application/json")
226                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
227                    .unwrap(),
228            )
229            .await
230            .unwrap();
231
232        assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
233        let json = body_json(resp).await;
234        assert!(
235            json["error"]["message"]
236                .as_str()
237                .unwrap_or("")
238                .contains("nonexistent-model-xyz")
239        );
240    }
241
242    #[tokio::test]
243    async fn test_chat_missing_model_returns_422() {
244        use serde_json::json;
245
246        let app = make_router(make_state());
247        let body = json!({"messages": [{"role": "user", "content": "hi"}]});
248        let resp = app
249            .oneshot(
250                Request::builder()
251                    .method("POST")
252                    .uri("/v1/chat/completions")
253                    .header("content-type", "application/json")
254                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
255                    .unwrap(),
256            )
257            .await
258            .unwrap();
259
260        // Missing required `model` field → axum JSON rejection → 422
261        assert_eq!(resp.status(), axum::http::StatusCode::UNPROCESSABLE_ENTITY);
262    }
263
264    #[tokio::test]
265    async fn test_amp_chat_route_exists() {
266        use serde_json::json;
267
268        let app = make_router(make_state());
269        let body = json!({"model": "nonexistent", "messages": []});
270        let resp = app
271            .oneshot(
272                Request::builder()
273                    .method("POST")
274                    .uri("/amp/v1/chat/completions")
275                    .header("content-type", "application/json")
276                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
277                    .unwrap(),
278            )
279            .await
280            .unwrap();
281
282        // Route exists (not 404), even though model is invalid
283        assert_ne!(resp.status(), axum::http::StatusCode::NOT_FOUND);
284    }
285}