1mod amp;
7mod amp_provider;
8mod chat;
9mod error;
10mod messages;
11mod models;
12pub mod usage;
13
14pub use error::ApiError;
15pub use usage::UsageStats;
16
17use arc_swap::ArcSwap;
18use axum::{
19 Json, Router,
20 extract::State,
21 routing::{any, get, post},
22};
23use byokey_auth::AuthManager;
24use byokey_config::Config;
25use std::sync::Arc;
26use tower_http::trace::TraceLayer;
27
28pub struct AppState {
30 pub config: Arc<ArcSwap<Config>>,
33 pub auth: Arc<AuthManager>,
35 pub http: rquest::Client,
37 pub usage: Arc<UsageStats>,
39}
40
41impl AppState {
42 pub fn new(config: Arc<ArcSwap<Config>>, auth: Arc<AuthManager>) -> Arc<Self> {
46 let snapshot = config.load();
47 let http = build_http_client(snapshot.proxy_url.as_deref());
48 Arc::new(Self {
49 config,
50 auth,
51 http,
52 usage: Arc::new(UsageStats::new()),
53 })
54 }
55}
56
57fn build_http_client(proxy_url: Option<&str>) -> rquest::Client {
59 if let Some(url) = proxy_url {
60 match rquest::Proxy::all(url) {
61 Ok(proxy) => {
62 return rquest::Client::builder()
63 .proxy(proxy)
64 .build()
65 .unwrap_or_else(|_| rquest::Client::new());
66 }
67 Err(e) => {
68 tracing::warn!(url = url, error = %e, "invalid proxy_url, using direct connection");
69 }
70 }
71 }
72 rquest::Client::new()
73}
74
75pub fn make_router(state: Arc<AppState>) -> Router {
92 Router::new()
93 .route("/v1/chat/completions", post(chat::chat_completions))
95 .route("/v1/messages", post(messages::anthropic_messages))
96 .route("/v1/models", get(models::list_models))
97 .route("/amp/auth/cli-login", get(amp::cli_login_redirect))
99 .route("/amp/v1/login", get(amp::login_redirect))
100 .route("/amp/v0/management/{*path}", any(amp::management_proxy))
101 .route("/amp/v1/chat/completions", post(chat::chat_completions))
102 .route(
104 "/api/provider/anthropic/v1/messages",
105 post(messages::anthropic_messages),
106 )
107 .route(
108 "/api/provider/openai/v1/chat/completions",
109 post(chat::chat_completions),
110 )
111 .route(
112 "/api/provider/openai/v1/responses",
113 post(amp_provider::codex_responses_passthrough),
114 )
115 .route(
116 "/api/provider/google/v1beta/models/{action}",
117 post(amp_provider::gemini_native_passthrough),
118 )
119 .route("/api/{*path}", any(amp_provider::amp_management_proxy))
121 .route("/v0/management/usage", get(usage_handler))
123 .with_state(state)
124 .layer(TraceLayer::new_for_http())
125}
126
127async fn usage_handler(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
128 let snap = state.usage.snapshot();
129 Json(serde_json::to_value(snap).unwrap_or_default())
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use axum::{body::Body, http::Request};
136 use byokey_store::InMemoryTokenStore;
137 use http_body_util::BodyExt as _;
138 use serde_json::Value;
139 use tower::ServiceExt as _;
140
141 fn make_state() -> Arc<AppState> {
142 let store = Arc::new(InMemoryTokenStore::new());
143 let auth = Arc::new(AuthManager::new(store, rquest::Client::new()));
144 let config = Arc::new(ArcSwap::from_pointee(Config::default()));
145 AppState::new(config, auth)
146 }
147
148 async fn body_json(resp: axum::response::Response) -> Value {
149 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
150 serde_json::from_slice(&bytes).unwrap()
151 }
152
153 #[tokio::test]
154 async fn test_list_models_empty_config() {
155 let app = make_router(make_state());
156 let resp = app
157 .oneshot(
158 Request::builder()
159 .uri("/v1/models")
160 .body(Body::empty())
161 .unwrap(),
162 )
163 .await
164 .unwrap();
165
166 assert_eq!(resp.status(), axum::http::StatusCode::OK);
167 let json = body_json(resp).await;
168 assert_eq!(json["object"], "list");
169 assert!(json["data"].is_array());
170 assert!(!json["data"].as_array().unwrap().is_empty());
172 }
173
174 #[tokio::test]
175 async fn test_amp_login_redirect() {
176 let app = make_router(make_state());
177 let resp = app
178 .oneshot(
179 Request::builder()
180 .uri("/amp/v1/login")
181 .body(Body::empty())
182 .unwrap(),
183 )
184 .await
185 .unwrap();
186
187 assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
188 assert_eq!(
189 resp.headers().get("location").and_then(|v| v.to_str().ok()),
190 Some("https://ampcode.com/login")
191 );
192 }
193
194 #[tokio::test]
195 async fn test_amp_cli_login_redirect() {
196 let app = make_router(make_state());
197 let resp = app
198 .oneshot(
199 Request::builder()
200 .uri("/amp/auth/cli-login?authToken=abc123&callbackPort=35789")
201 .body(Body::empty())
202 .unwrap(),
203 )
204 .await
205 .unwrap();
206
207 assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
208 assert_eq!(
209 resp.headers().get("location").and_then(|v| v.to_str().ok()),
210 Some("https://ampcode.com/auth/cli-login?authToken=abc123&callbackPort=35789")
211 );
212 }
213
214 #[tokio::test]
215 async fn test_chat_unknown_model_returns_400() {
216 use serde_json::json;
217
218 let app = make_router(make_state());
219 let body = json!({"model": "nonexistent-model-xyz", "messages": []});
220 let resp = app
221 .oneshot(
222 Request::builder()
223 .method("POST")
224 .uri("/v1/chat/completions")
225 .header("content-type", "application/json")
226 .body(Body::from(serde_json::to_vec(&body).unwrap()))
227 .unwrap(),
228 )
229 .await
230 .unwrap();
231
232 assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
233 let json = body_json(resp).await;
234 assert!(
235 json["error"]["message"]
236 .as_str()
237 .unwrap_or("")
238 .contains("nonexistent-model-xyz")
239 );
240 }
241
242 #[tokio::test]
243 async fn test_chat_missing_model_returns_422() {
244 use serde_json::json;
245
246 let app = make_router(make_state());
247 let body = json!({"messages": [{"role": "user", "content": "hi"}]});
248 let resp = app
249 .oneshot(
250 Request::builder()
251 .method("POST")
252 .uri("/v1/chat/completions")
253 .header("content-type", "application/json")
254 .body(Body::from(serde_json::to_vec(&body).unwrap()))
255 .unwrap(),
256 )
257 .await
258 .unwrap();
259
260 assert_eq!(resp.status(), axum::http::StatusCode::UNPROCESSABLE_ENTITY);
262 }
263
264 #[tokio::test]
265 async fn test_amp_chat_route_exists() {
266 use serde_json::json;
267
268 let app = make_router(make_state());
269 let body = json!({"model": "nonexistent", "messages": []});
270 let resp = app
271 .oneshot(
272 Request::builder()
273 .method("POST")
274 .uri("/amp/v1/chat/completions")
275 .header("content-type", "application/json")
276 .body(Body::from(serde_json::to_vec(&body).unwrap()))
277 .unwrap(),
278 )
279 .await
280 .unwrap();
281
282 assert_ne!(resp.status(), axum::http::StatusCode::NOT_FOUND);
284 }
285}