Skip to main content

byokey_proxy/
lib.rs

1//! HTTP proxy layer — axum router, route handlers, and error mapping.
2//!
3//! Exposes an OpenAI-compatible `/v1/chat/completions` endpoint, a `/v1/models`
4//! listing, and an Amp CLI compatibility layer under `/amp/*`.
5
6mod amp;
7mod amp_provider;
8mod chat;
9mod error;
10mod messages;
11mod models;
12
13pub use error::ApiError;
14
15use axum::{
16    Router,
17    extract::Request,
18    middleware::{self, Next},
19    response::Response,
20    routing::{any, get, post},
21};
22use byokey_auth::AuthManager;
23use byokey_config::Config;
24use std::sync::Arc;
25
26/// Shared application state passed to all route handlers.
27pub struct AppState {
28    /// Server configuration (providers, listen address, etc.).
29    pub config: Arc<Config>,
30    /// Token manager for OAuth-based providers.
31    pub auth: Arc<AuthManager>,
32    /// HTTP client for upstream requests.
33    pub http: rquest::Client,
34}
35
36impl AppState {
37    /// Creates a new shared application state wrapped in an `Arc`.
38    pub fn new(config: Config, auth: Arc<AuthManager>) -> Arc<Self> {
39        Arc::new(Self {
40            config: Arc::new(config),
41            auth,
42            http: rquest::Client::new(),
43        })
44    }
45}
46
47async fn log_request(req: Request, next: Next) -> Response {
48    let method = req.method().clone();
49    let uri = req.uri().clone();
50    let resp = next.run(req).await;
51    eprintln!("{} {} -> {}", method, uri, resp.status());
52    resp
53}
54
55/// Build the full axum router.
56///
57/// Routes:
58/// - POST /v1/chat/completions                          OpenAI-compatible
59/// - POST /v1/messages                                  Anthropic native passthrough
60/// - GET  /v1/models
61/// - GET  /amp/v1/login
62/// - ANY  /amp/v0/management/{*path}
63/// - POST /amp/v1/chat/completions
64///
65/// `AmpCode` provider routes:
66/// - POST /api/provider/anthropic/v1/messages           Anthropic native (`AmpCode`)
67/// - POST /api/provider/openai/v1/chat/completions      `OpenAI`-compatible (`AmpCode`)
68/// - POST /api/provider/openai/v1/responses             Codex Responses API (`AmpCode`)
69/// - POST /api/provider/google/v1beta/models/{action}   Gemini native (`AmpCode`)
70/// - ANY  /api/{*path}                                  `ampcode.com` management proxy
71pub fn make_router(state: Arc<AppState>) -> Router {
72    Router::new()
73        // Standard routes
74        .route("/v1/chat/completions", post(chat::chat_completions))
75        .route("/v1/messages", post(messages::anthropic_messages))
76        .route("/v1/models", get(models::list_models))
77        // Amp CLI routes
78        .route("/amp/auth/cli-login", get(amp::cli_login_redirect))
79        .route("/amp/v1/login", get(amp::login_redirect))
80        .route("/amp/v0/management/{*path}", any(amp::management_proxy))
81        .route("/amp/v1/chat/completions", post(chat::chat_completions))
82        // AmpCode provider-specific routes (must be registered before the catch-all)
83        .route(
84            "/api/provider/anthropic/v1/messages",
85            post(messages::anthropic_messages),
86        )
87        .route(
88            "/api/provider/openai/v1/chat/completions",
89            post(chat::chat_completions),
90        )
91        .route(
92            "/api/provider/openai/v1/responses",
93            post(amp_provider::codex_responses_passthrough),
94        )
95        .route(
96            "/api/provider/google/v1beta/models/{action}",
97            post(amp_provider::gemini_native_passthrough),
98        )
99        // Catch-all: forward remaining /api/* routes to ampcode.com
100        .route("/api/{*path}", any(amp_provider::amp_management_proxy))
101        .with_state(state)
102        .layer(middleware::from_fn(log_request))
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use axum::{body::Body, http::Request};
109    use byokey_store::InMemoryTokenStore;
110    use http_body_util::BodyExt as _;
111    use serde_json::Value;
112    use tower::ServiceExt as _;
113
114    fn make_state() -> Arc<AppState> {
115        let store = Arc::new(InMemoryTokenStore::new());
116        let auth = Arc::new(AuthManager::new(store));
117        AppState::new(Config::default(), auth)
118    }
119
120    async fn body_json(resp: axum::response::Response) -> Value {
121        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
122        serde_json::from_slice(&bytes).unwrap()
123    }
124
125    #[tokio::test]
126    async fn test_list_models_empty_config() {
127        let app = make_router(make_state());
128        let resp = app
129            .oneshot(
130                Request::builder()
131                    .uri("/v1/models")
132                    .body(Body::empty())
133                    .unwrap(),
134            )
135            .await
136            .unwrap();
137
138        assert_eq!(resp.status(), axum::http::StatusCode::OK);
139        let json = body_json(resp).await;
140        assert_eq!(json["object"], "list");
141        assert!(json["data"].is_array());
142        // All providers are enabled by default even without explicit config.
143        assert!(!json["data"].as_array().unwrap().is_empty());
144    }
145
146    #[tokio::test]
147    async fn test_amp_login_redirect() {
148        let app = make_router(make_state());
149        let resp = app
150            .oneshot(
151                Request::builder()
152                    .uri("/amp/v1/login")
153                    .body(Body::empty())
154                    .unwrap(),
155            )
156            .await
157            .unwrap();
158
159        assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
160        assert_eq!(
161            resp.headers().get("location").and_then(|v| v.to_str().ok()),
162            Some("https://ampcode.com/login")
163        );
164    }
165
166    #[tokio::test]
167    async fn test_amp_cli_login_redirect() {
168        let app = make_router(make_state());
169        let resp = app
170            .oneshot(
171                Request::builder()
172                    .uri("/amp/auth/cli-login?authToken=abc123&callbackPort=35789")
173                    .body(Body::empty())
174                    .unwrap(),
175            )
176            .await
177            .unwrap();
178
179        assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
180        assert_eq!(
181            resp.headers().get("location").and_then(|v| v.to_str().ok()),
182            Some("https://ampcode.com/auth/cli-login?authToken=abc123&callbackPort=35789")
183        );
184    }
185
186    #[tokio::test]
187    async fn test_chat_unknown_model_returns_400() {
188        use serde_json::json;
189
190        let app = make_router(make_state());
191        let body = json!({"model": "nonexistent-model-xyz", "messages": []});
192        let resp = app
193            .oneshot(
194                Request::builder()
195                    .method("POST")
196                    .uri("/v1/chat/completions")
197                    .header("content-type", "application/json")
198                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
199                    .unwrap(),
200            )
201            .await
202            .unwrap();
203
204        assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
205        let json = body_json(resp).await;
206        assert!(
207            json["error"]["message"]
208                .as_str()
209                .unwrap_or("")
210                .contains("nonexistent-model-xyz")
211        );
212    }
213
214    #[tokio::test]
215    async fn test_chat_missing_model_returns_400() {
216        use serde_json::json;
217
218        let app = make_router(make_state());
219        let body = json!({"messages": [{"role": "user", "content": "hi"}]});
220        let resp = app
221            .oneshot(
222                Request::builder()
223                    .method("POST")
224                    .uri("/v1/chat/completions")
225                    .header("content-type", "application/json")
226                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
227                    .unwrap(),
228            )
229            .await
230            .unwrap();
231
232        // Empty model string → UnsupportedModel → 400
233        assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
234    }
235
236    #[tokio::test]
237    async fn test_amp_chat_route_exists() {
238        use serde_json::json;
239
240        let app = make_router(make_state());
241        let body = json!({"model": "nonexistent", "messages": []});
242        let resp = app
243            .oneshot(
244                Request::builder()
245                    .method("POST")
246                    .uri("/amp/v1/chat/completions")
247                    .header("content-type", "application/json")
248                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
249                    .unwrap(),
250            )
251            .await
252            .unwrap();
253
254        // Route exists (not 404), even though model is invalid
255        assert_ne!(resp.status(), axum::http::StatusCode::NOT_FOUND);
256    }
257}