Skip to main content

byokey_proxy/
lib.rs

1//! HTTP proxy layer — axum router, route handlers, and error mapping.
2//!
3//! Exposes an OpenAI-compatible `/v1/chat/completions` endpoint, a `/v1/models`
4//! listing, and an Amp CLI compatibility layer under `/amp/*`.
5
6mod amp;
7mod chat;
8mod error;
9mod models;
10
11pub use error::ApiError;
12
13use axum::{
14    Router,
15    routing::{any, get, post},
16};
17use byokey_auth::AuthManager;
18use byokey_config::Config;
19use std::sync::Arc;
20
21/// Shared application state passed to all route handlers.
22pub struct AppState {
23    /// Server configuration (providers, listen address, etc.).
24    pub config: Arc<Config>,
25    /// Token manager for OAuth-based providers.
26    pub auth: Arc<AuthManager>,
27    /// HTTP client for upstream requests.
28    pub http: rquest::Client,
29}
30
31impl AppState {
32    /// Creates a new shared application state wrapped in an `Arc`.
33    pub fn new(config: Config, auth: Arc<AuthManager>) -> Arc<Self> {
34        Arc::new(Self {
35            config: Arc::new(config),
36            auth,
37            http: rquest::Client::new(),
38        })
39    }
40}
41
42/// Build the full axum router.
43///
44/// Routes:
45/// - POST /v1/chat/completions
46/// - GET  /v1/models
47/// - GET  /amp/v1/login
48/// - ANY  /amp/v0/management/{*path}
49/// - POST /amp/v1/chat/completions
50pub fn make_router(state: Arc<AppState>) -> Router {
51    Router::new()
52        .route("/v1/chat/completions", post(chat::chat_completions))
53        .route("/v1/models", get(models::list_models))
54        .route("/amp/v1/login", get(amp::login_redirect))
55        .route("/amp/v0/management/{*path}", any(amp::management_proxy))
56        .route("/amp/v1/chat/completions", post(chat::chat_completions))
57        .with_state(state)
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use axum::{body::Body, http::Request};
64    use byokey_store::InMemoryTokenStore;
65    use http_body_util::BodyExt as _;
66    use serde_json::Value;
67    use tower::ServiceExt as _;
68
69    fn make_state() -> Arc<AppState> {
70        let store = Arc::new(InMemoryTokenStore::new());
71        let auth = Arc::new(AuthManager::new(store));
72        AppState::new(Config::default(), auth)
73    }
74
75    async fn body_json(resp: axum::response::Response) -> Value {
76        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
77        serde_json::from_slice(&bytes).unwrap()
78    }
79
80    #[tokio::test]
81    async fn test_list_models_empty_config() {
82        let app = make_router(make_state());
83        let resp = app
84            .oneshot(
85                Request::builder()
86                    .uri("/v1/models")
87                    .body(Body::empty())
88                    .unwrap(),
89            )
90            .await
91            .unwrap();
92
93        assert_eq!(resp.status(), axum::http::StatusCode::OK);
94        let json = body_json(resp).await;
95        assert_eq!(json["object"], "list");
96        assert!(json["data"].is_array());
97        assert_eq!(json["data"].as_array().unwrap().len(), 0);
98    }
99
100    #[tokio::test]
101    async fn test_amp_login_redirect() {
102        let app = make_router(make_state());
103        let resp = app
104            .oneshot(
105                Request::builder()
106                    .uri("/amp/v1/login")
107                    .body(Body::empty())
108                    .unwrap(),
109            )
110            .await
111            .unwrap();
112
113        assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
114        assert_eq!(
115            resp.headers().get("location").and_then(|v| v.to_str().ok()),
116            Some("https://ampcode.com/login")
117        );
118    }
119
120    #[tokio::test]
121    async fn test_chat_unknown_model_returns_400() {
122        use serde_json::json;
123
124        let app = make_router(make_state());
125        let body = json!({"model": "nonexistent-model-xyz", "messages": []});
126        let resp = app
127            .oneshot(
128                Request::builder()
129                    .method("POST")
130                    .uri("/v1/chat/completions")
131                    .header("content-type", "application/json")
132                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
133                    .unwrap(),
134            )
135            .await
136            .unwrap();
137
138        assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
139        let json = body_json(resp).await;
140        assert!(
141            json["error"]["message"]
142                .as_str()
143                .unwrap_or("")
144                .contains("nonexistent-model-xyz")
145        );
146    }
147
148    #[tokio::test]
149    async fn test_chat_missing_model_returns_400() {
150        use serde_json::json;
151
152        let app = make_router(make_state());
153        let body = json!({"messages": [{"role": "user", "content": "hi"}]});
154        let resp = app
155            .oneshot(
156                Request::builder()
157                    .method("POST")
158                    .uri("/v1/chat/completions")
159                    .header("content-type", "application/json")
160                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
161                    .unwrap(),
162            )
163            .await
164            .unwrap();
165
166        // Empty model string → UnsupportedModel → 400
167        assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
168    }
169
170    #[tokio::test]
171    async fn test_amp_chat_route_exists() {
172        use serde_json::json;
173
174        let app = make_router(make_state());
175        let body = json!({"model": "nonexistent", "messages": []});
176        let resp = app
177            .oneshot(
178                Request::builder()
179                    .method("POST")
180                    .uri("/amp/v1/chat/completions")
181                    .header("content-type", "application/json")
182                    .body(Body::from(serde_json::to_vec(&body).unwrap()))
183                    .unwrap(),
184            )
185            .await
186            .unwrap();
187
188        // Route exists (not 404), even though model is invalid
189        assert_ne!(resp.status(), axum::http::StatusCode::NOT_FOUND);
190    }
191}