Skip to main content

byokey_proxy/
lib.rs

1//! HTTP proxy layer — axum router, route handlers, and error mapping.
2//!
3//! Exposes an OpenAI-compatible `/v1/chat/completions` endpoint, a `/v1/models`
4//! listing, and an Amp CLI compatibility layer under `/amp/*`.
5
6mod amp;
7mod amp_provider;
8mod chat;
9mod error;
10mod messages;
11mod models;
12
13pub use error::ApiError;
14
15use arc_swap::ArcSwap;
16use axum::{
17    Router,
18    routing::{any, get, post},
19};
20use byokey_auth::AuthManager;
21use byokey_config::Config;
22use std::sync::Arc;
23use tower_http::trace::TraceLayer;
24
25/// Shared application state passed to all route handlers.
26pub struct AppState {
27    /// Server configuration (providers, listen address, etc.).
28    /// Atomically swappable for hot-reloading.
29    pub config: Arc<ArcSwap<Config>>,
30    /// Token manager for OAuth-based providers.
31    pub auth: Arc<AuthManager>,
32    /// HTTP client for upstream requests.
33    pub http: rquest::Client,
34}
35
36impl AppState {
37    /// Creates a new shared application state wrapped in an `Arc`.
38    pub fn new(config: Arc<ArcSwap<Config>>, auth: Arc<AuthManager>) -> Arc<Self> {
39        Arc::new(Self {
40            config,
41            auth,
42            http: rquest::Client::new(),
43        })
44    }
45}
46
47/// Build the full axum router.
48///
49/// Routes:
50/// - POST /v1/chat/completions                          OpenAI-compatible
51/// - POST /v1/messages                                  Anthropic native passthrough
52/// - GET  /v1/models
53/// - GET  /amp/v1/login
54/// - ANY  /amp/v0/management/{*path}
55/// - POST /amp/v1/chat/completions
56///
57/// `AmpCode` provider routes:
58/// - POST /api/provider/anthropic/v1/messages           Anthropic native (`AmpCode`)
59/// - POST /api/provider/openai/v1/chat/completions      `OpenAI`-compatible (`AmpCode`)
60/// - POST /api/provider/openai/v1/responses             Codex Responses API (`AmpCode`)
61/// - POST /api/provider/google/v1beta/models/{action}   Gemini native (`AmpCode`)
62/// - ANY  /api/{*path}                                  `ampcode.com` management proxy
63pub fn make_router(state: Arc<AppState>) -> Router {
64    Router::new()
65        // Standard routes
66        .route("/v1/chat/completions", post(chat::chat_completions))
67        .route("/v1/messages", post(messages::anthropic_messages))
68        .route("/v1/models", get(models::list_models))
69        // Amp CLI routes
70        .route("/amp/auth/cli-login", get(amp::cli_login_redirect))
71        .route("/amp/v1/login", get(amp::login_redirect))
72        .route("/amp/v0/management/{*path}", any(amp::management_proxy))
73        .route("/amp/v1/chat/completions", post(chat::chat_completions))
74        // AmpCode provider-specific routes (must be registered before the catch-all)
75        .route(
76            "/api/provider/anthropic/v1/messages",
77            post(messages::anthropic_messages),
78        )
79        .route(
80            "/api/provider/openai/v1/chat/completions",
81            post(chat::chat_completions),
82        )
83        .route(
84            "/api/provider/openai/v1/responses",
85            post(amp_provider::codex_responses_passthrough),
86        )
87        .route(
88            "/api/provider/google/v1beta/models/{action}",
89            post(amp_provider::gemini_native_passthrough),
90        )
91        // Catch-all: forward remaining /api/* routes to ampcode.com
92        .route("/api/{*path}", any(amp_provider::amp_management_proxy))
93        .with_state(state)
94        .layer(TraceLayer::new_for_http())
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use axum::{body::Body, http::Request};
101    use byokey_store::InMemoryTokenStore;
102    use http_body_util::BodyExt as _;
103    use serde_json::Value;
104    use tower::ServiceExt as _;
105
106    fn make_state() -> Arc<AppState> {
107        let store = Arc::new(InMemoryTokenStore::new());
108        let auth = Arc::new(AuthManager::new(store, rquest::Client::new()));
109        let config = Arc::new(ArcSwap::from_pointee(Config::default()));
110        AppState::new(config, auth)
111    }
112
113    async fn body_json(resp: axum::response::Response) -> Value {
114        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
115        serde_json::from_slice(&bytes).unwrap()
116    }
117
118    #[tokio::test]
119    async fn test_list_models_empty_config() {
120        let app = make_router(make_state());
121        let resp = app
122            .oneshot(
123                Request::builder()
124                    .uri("/v1/models")
125                    .body(Body::empty())
126                    .unwrap(),
127            )
128            .await
129            .unwrap();
130
131        assert_eq!(resp.status(), axum::http::StatusCode::OK);
132        let json = body_json(resp).await;
133        assert_eq!(json["object"], "list");
134        assert!(json["data"].is_array());
135        // All providers are enabled by default even without explicit config.
136        assert!(!json["data"].as_array().unwrap().is_empty());
137    }
138
139    #[tokio::test]
140    async fn test_amp_login_redirect() {
141        let app = make_router(make_state());
142        let resp = app
143            .oneshot(
144                Request::builder()
145                    .uri("/amp/v1/login")
146                    .body(Body::empty())
147                    .unwrap(),
148            )
149            .await
150            .unwrap();
151
152        assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
153        assert_eq!(
154            resp.headers().get("location").and_then(|v| v.to_str().ok()),
155            Some("https://ampcode.com/login")
156        );
157    }
158
159    #[tokio::test]
160    async fn test_amp_cli_login_redirect() {
161        let app = make_router(make_state());
162        let resp = app
163            .oneshot(
164                Request::builder()
165                    .uri("/amp/auth/cli-login?authToken=abc123&callbackPort=35789")
166                    .body(Body::empty())
167                    .unwrap(),
168            )
169            .await
170            .unwrap();
171
172        assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
173        assert_eq!(
174            resp.headers().get("location").and_then(|v| v.to_str().ok()),
175            Some("https://ampcode.com/auth/cli-login?authToken=abc123&callbackPort=35789")
176        );
177    }
178
179    #[tokio::test]
180    async fn test_chat_unknown_model_returns_400() {
181        use serde_json::json;
182
183        let app = make_router(make_state());
184        let body = json!({"model": "nonexistent-model-xyz", "messages": []});
185        let resp = app
186            .oneshot(
187                Request::builder()
188                    .method("POST")
189                    .uri("/v1/chat/completions")
190                    .header("content-type", "application/json")
191                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
192                    .unwrap(),
193            )
194            .await
195            .unwrap();
196
197        assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
198        let json = body_json(resp).await;
199        assert!(
200            json["error"]["message"]
201                .as_str()
202                .unwrap_or("")
203                .contains("nonexistent-model-xyz")
204        );
205    }
206
207    #[tokio::test]
208    async fn test_chat_missing_model_returns_422() {
209        use serde_json::json;
210
211        let app = make_router(make_state());
212        let body = json!({"messages": [{"role": "user", "content": "hi"}]});
213        let resp = app
214            .oneshot(
215                Request::builder()
216                    .method("POST")
217                    .uri("/v1/chat/completions")
218                    .header("content-type", "application/json")
219                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
220                    .unwrap(),
221            )
222            .await
223            .unwrap();
224
225        // Missing required `model` field → axum JSON rejection → 422
226        assert_eq!(resp.status(), axum::http::StatusCode::UNPROCESSABLE_ENTITY);
227    }
228
229    #[tokio::test]
230    async fn test_amp_chat_route_exists() {
231        use serde_json::json;
232
233        let app = make_router(make_state());
234        let body = json!({"model": "nonexistent", "messages": []});
235        let resp = app
236            .oneshot(
237                Request::builder()
238                    .method("POST")
239                    .uri("/amp/v1/chat/completions")
240                    .header("content-type", "application/json")
241                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
242                    .unwrap(),
243            )
244            .await
245            .unwrap();
246
247        // Route exists (not 404), even though model is invalid
248        assert_ne!(resp.status(), axum::http::StatusCode::NOT_FOUND);
249    }
250}