1mod amp;
7mod chat;
8mod error;
9mod models;
10
11pub use error::ApiError;
12
13use axum::{
14 Router,
15 routing::{any, get, post},
16};
17use byokey_auth::AuthManager;
18use byokey_config::Config;
19use std::sync::Arc;
20
21pub struct AppState {
23 pub config: Arc<Config>,
25 pub auth: Arc<AuthManager>,
27 pub http: rquest::Client,
29}
30
31impl AppState {
32 pub fn new(config: Config, auth: Arc<AuthManager>) -> Arc<Self> {
34 Arc::new(Self {
35 config: Arc::new(config),
36 auth,
37 http: rquest::Client::new(),
38 })
39 }
40}
41
42pub fn make_router(state: Arc<AppState>) -> Router {
51 Router::new()
52 .route("/v1/chat/completions", post(chat::chat_completions))
53 .route("/v1/models", get(models::list_models))
54 .route("/amp/v1/login", get(amp::login_redirect))
55 .route("/amp/v0/management/{*path}", any(amp::management_proxy))
56 .route("/amp/v1/chat/completions", post(chat::chat_completions))
57 .with_state(state)
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use axum::{body::Body, http::Request};
64 use byokey_store::InMemoryTokenStore;
65 use http_body_util::BodyExt as _;
66 use serde_json::Value;
67 use tower::ServiceExt as _;
68
69 fn make_state() -> Arc<AppState> {
70 let store = Arc::new(InMemoryTokenStore::new());
71 let auth = Arc::new(AuthManager::new(store));
72 AppState::new(Config::default(), auth)
73 }
74
75 async fn body_json(resp: axum::response::Response) -> Value {
76 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
77 serde_json::from_slice(&bytes).unwrap()
78 }
79
80 #[tokio::test]
81 async fn test_list_models_empty_config() {
82 let app = make_router(make_state());
83 let resp = app
84 .oneshot(
85 Request::builder()
86 .uri("/v1/models")
87 .body(Body::empty())
88 .unwrap(),
89 )
90 .await
91 .unwrap();
92
93 assert_eq!(resp.status(), axum::http::StatusCode::OK);
94 let json = body_json(resp).await;
95 assert_eq!(json["object"], "list");
96 assert!(json["data"].is_array());
97 assert_eq!(json["data"].as_array().unwrap().len(), 0);
98 }
99
100 #[tokio::test]
101 async fn test_amp_login_redirect() {
102 let app = make_router(make_state());
103 let resp = app
104 .oneshot(
105 Request::builder()
106 .uri("/amp/v1/login")
107 .body(Body::empty())
108 .unwrap(),
109 )
110 .await
111 .unwrap();
112
113 assert_eq!(resp.status(), axum::http::StatusCode::FOUND);
114 assert_eq!(
115 resp.headers().get("location").and_then(|v| v.to_str().ok()),
116 Some("https://ampcode.com/login")
117 );
118 }
119
120 #[tokio::test]
121 async fn test_chat_unknown_model_returns_400() {
122 use serde_json::json;
123
124 let app = make_router(make_state());
125 let body = json!({"model": "nonexistent-model-xyz", "messages": []});
126 let resp = app
127 .oneshot(
128 Request::builder()
129 .method("POST")
130 .uri("/v1/chat/completions")
131 .header("content-type", "application/json")
132 .body(Body::from(serde_json::to_vec(&body).unwrap()))
133 .unwrap(),
134 )
135 .await
136 .unwrap();
137
138 assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
139 let json = body_json(resp).await;
140 assert!(
141 json["error"]["message"]
142 .as_str()
143 .unwrap_or("")
144 .contains("nonexistent-model-xyz")
145 );
146 }
147
148 #[tokio::test]
149 async fn test_chat_missing_model_returns_400() {
150 use serde_json::json;
151
152 let app = make_router(make_state());
153 let body = json!({"messages": [{"role": "user", "content": "hi"}]});
154 let resp = app
155 .oneshot(
156 Request::builder()
157 .method("POST")
158 .uri("/v1/chat/completions")
159 .header("content-type", "application/json")
160 .body(Body::from(serde_json::to_vec(&body).unwrap()))
161 .unwrap(),
162 )
163 .await
164 .unwrap();
165
166 assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
168 }
169
170 #[tokio::test]
171 async fn test_amp_chat_route_exists() {
172 use serde_json::json;
173
174 let app = make_router(make_state());
175 let body = json!({"model": "nonexistent", "messages": []});
176 let resp = app
177 .oneshot(
178 Request::builder()
179 .method("POST")
180 .uri("/amp/v1/chat/completions")
181 .header("content-type", "application/json")
182 .body(Body::from(serde_json::to_vec(&body).unwrap()))
183 .unwrap(),
184 )
185 .await
186 .unwrap();
187
188 assert_ne!(resp.status(), axum::http::StatusCode::NOT_FOUND);
190 }
191}