1use std::collections::HashMap;
2use std::time::Duration;
3
4use rand::RngCore;
5use serde::Deserialize;
6
7use crate::auth::OAuthCredential;
8use crate::error::{Error, Result};
9
10const CLIENT_ID: &str = "17e5f671-d194-4dfb-9706-5516cb48c098";
11const DEVICE_AUTH_URL: &str = "https://auth.kimi.com/api/oauth/device_authorization";
12const TOKEN_URL: &str = "https://auth.kimi.com/api/oauth/token";
13
14pub struct KimiCodeOAuth {
16 client_id: String,
17 token_url: String,
18 device_auth_url: String,
19}
20
21impl Default for KimiCodeOAuth {
22 fn default() -> Self {
23 Self {
24 client_id: CLIENT_ID.to_string(),
25 token_url: TOKEN_URL.to_string(),
26 device_auth_url: DEVICE_AUTH_URL.to_string(),
27 }
28 }
29}
30
31impl KimiCodeOAuth {
32 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn with_endpoints(device_auth_url: String, token_url: String) -> Self {
39 Self {
40 client_id: CLIENT_ID.to_string(),
41 token_url,
42 device_auth_url,
43 }
44 }
45
46 pub async fn request_device_authorization(&self) -> Result<DeviceAuthorization> {
48 let client = reqwest::Client::new();
49 let response = client
50 .post(&self.device_auth_url)
51 .form(&[("client_id", self.client_id.as_str())])
52 .headers(common_headers())
53 .send()
54 .await?;
55
56 let status = response.status();
57 let data: serde_json::Value = response.json().await?;
58
59 if !status.is_success() {
60 return Err(Error::Auth(format!(
61 "Device authorization failed ({status}): {data}"
62 )));
63 }
64
65 Ok(DeviceAuthorization {
66 user_code: data["user_code"].as_str().unwrap_or("").to_string(),
67 device_code: data["device_code"].as_str().unwrap_or("").to_string(),
68 verification_uri: data["verification_uri"].as_str().unwrap_or("").to_string(),
69 verification_uri_complete: data["verification_uri_complete"]
70 .as_str()
71 .unwrap_or("")
72 .to_string(),
73 expires_in: data["expires_in"].as_u64(),
74 interval: data["interval"].as_u64().unwrap_or(5).max(1),
75 })
76 }
77
78 pub async fn request_device_token(
80 &self,
81 device_code: &str,
82 ) -> Result<(u16, HashMap<String, serde_json::Value>)> {
83 let client = reqwest::Client::new();
84 let response = client
85 .post(&self.token_url)
86 .form(&[
87 ("client_id", self.client_id.as_str()),
88 ("device_code", device_code),
89 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
90 ])
91 .headers(common_headers())
92 .send()
93 .await?;
94
95 let status = response.status();
96 let data: HashMap<String, serde_json::Value> = response.json().await?;
97 Ok((status.as_u16(), data))
98 }
99
100 pub async fn exchange_code(&self, code: &str) -> Result<OAuthCredential> {
102 let client = reqwest::Client::new();
103 let response = client
104 .post(&self.token_url)
105 .form(&[
106 ("grant_type", "authorization_code"),
107 ("client_id", self.client_id.as_str()),
108 ("code", code),
109 ])
110 .headers(common_headers())
111 .send()
112 .await?;
113
114 if !response.status().is_success() {
115 let status = response.status();
116 let body = response.text().await.unwrap_or_default();
117 return Err(Error::Auth(format!(
118 "Token exchange failed ({status}): {body}"
119 )));
120 }
121
122 let token: TokenResponse = response.json().await?;
123 Ok(to_oauth_credential(token))
124 }
125
126 pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthCredential> {
128 let client = reqwest::Client::new();
129 let response = client
130 .post(&self.token_url)
131 .form(&[
132 ("grant_type", "refresh_token"),
133 ("client_id", self.client_id.as_str()),
134 ("refresh_token", refresh_token),
135 ])
136 .headers(common_headers())
137 .send()
138 .await?;
139
140 if !response.status().is_success() {
141 let status = response.status();
142 let body = response.text().await.unwrap_or_default();
143 return Err(Error::Auth(format!(
144 "Token refresh failed ({status}): {body}"
145 )));
146 }
147
148 let token: TokenResponse = response.json().await?;
149 Ok(to_oauth_credential(token))
150 }
151
152 pub async fn login<F, G>(&self, open_url: F, mut print_message: G) -> Result<OAuthCredential>
157 where
158 F: FnOnce(&str),
159 G: FnMut(&str),
160 {
161 let auth = self.request_device_authorization().await?;
162
163 print_message("Please visit the following URL to finish authorization:");
164 print_message(&format!(
165 "Verification URL: {}",
166 auth.verification_uri_complete
167 ));
168 open_url(&auth.verification_uri_complete);
169
170 let interval = Duration::from_secs(auth.interval);
171 let max_duration = auth
172 .expires_in
173 .map(Duration::from_secs)
174 .unwrap_or_else(|| Duration::from_secs(600));
175 let start = std::time::Instant::now();
176 let mut printed_wait = false;
177
178 while start.elapsed() < max_duration {
179 let (status, data) = self.request_device_token(&auth.device_code).await?;
180
181 if status == 200 && data.contains_key("access_token") {
182 let token: TokenResponse = serde_json::from_value(
183 serde_json::to_value(&data).map_err(|e| Error::Auth(e.to_string()))?,
184 )?;
185 return Ok(to_oauth_credential(token));
186 }
187
188 if let Some(error) = data.get("error").and_then(|v| v.as_str()) {
189 if error == "expired_token" {
190 return Err(Error::Auth(
191 "Device authorization expired. Please try again.".into(),
192 ));
193 }
194 if error == "authorization_pending" {
195 if !printed_wait {
196 print_message("Waiting for user authorization...");
197 printed_wait = true;
198 }
199 } else {
200 let desc = data
201 .get("error_description")
202 .and_then(|v| v.as_str())
203 .unwrap_or(error);
204 return Err(Error::Auth(format!("OAuth error: {desc}")));
205 }
206 }
207
208 tokio::time::sleep(interval).await;
209 }
210
211 Err(Error::Auth(
212 "Device authorization timed out. Please try again.".into(),
213 ))
214 }
215}
216
217#[derive(Debug, Clone)]
218pub struct DeviceAuthorization {
219 pub user_code: String,
220 pub device_code: String,
221 pub verification_uri: String,
222 pub verification_uri_complete: String,
223 pub expires_in: Option<u64>,
224 pub interval: u64,
225}
226
227#[derive(Debug, Deserialize)]
228struct TokenResponse {
229 access_token: String,
230 #[serde(default)]
231 refresh_token: Option<String>,
232 #[serde(default)]
233 expires_in: f64,
234 #[allow(dead_code)]
235 #[serde(default)]
236 scope: String,
237 #[allow(dead_code)]
238 #[serde(default)]
239 token_type: String,
240}
241
242fn to_oauth_credential(token: TokenResponse) -> OAuthCredential {
243 let expires_in = token.expires_in as u64;
244 let expires_at = crate::now() + expires_in.saturating_sub(300);
245 OAuthCredential {
246 access_token: token.access_token,
247 refresh_token: token.refresh_token.unwrap_or_default(),
248 expires_at,
249 }
250}
251
252pub fn common_headers() -> reqwest::header::HeaderMap {
257 let mut headers = reqwest::header::HeaderMap::new();
258 headers.insert(
259 reqwest::header::USER_AGENT,
260 reqwest::header::HeaderValue::from_static("KimiCLI/1.39.0"),
261 );
262 headers.insert(
263 "X-Msh-Platform",
264 reqwest::header::HeaderValue::from_static("kimi_cli"),
265 );
266 headers.insert(
267 "X-Msh-Version",
268 reqwest::header::HeaderValue::from_static("1.39.0"),
269 );
270 headers.insert(
271 "X-Msh-Device-Name",
272 reqwest::header::HeaderValue::from_str(&hostname())
273 .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("unknown")),
274 );
275 headers.insert(
276 "X-Msh-Device-Model",
277 reqwest::header::HeaderValue::from_str(&device_model())
278 .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("unknown")),
279 );
280 headers.insert(
281 "X-Msh-Os-Version",
282 reqwest::header::HeaderValue::from_str(&os_version())
283 .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("unknown")),
284 );
285 headers.insert(
286 "X-Msh-Device-Id",
287 reqwest::header::HeaderValue::from_str(&device_id())
288 .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("unknown")),
289 );
290 headers
291}
292
293fn hostname() -> String {
294 #[cfg(unix)]
295 {
296 std::process::Command::new("hostname")
297 .output()
298 .ok()
299 .and_then(|o| String::from_utf8(o.stdout).ok())
300 .map(|s| s.trim().to_string())
301 .unwrap_or_else(|| "unknown".into())
302 }
303 #[cfg(not(unix))]
304 {
305 "unknown".to_string()
306 }
307}
308
309fn device_model() -> String {
310 let arch = std::env::consts::ARCH;
311 let os = std::env::consts::OS;
312 format!("{} {}", os, arch)
313}
314
315fn os_version() -> String {
316 #[cfg(target_os = "macos")]
317 {
318 std::process::Command::new("sw_vers")
319 .arg("-productVersion")
320 .output()
321 .ok()
322 .and_then(|o| String::from_utf8(o.stdout).ok())
323 .map(|s| s.trim().to_string())
324 .unwrap_or_else(|| std::env::consts::OS.to_string())
325 }
326 #[cfg(not(target_os = "macos"))]
327 {
328 std::env::consts::OS.to_string()
329 }
330}
331
332fn device_id() -> String {
333 if let Some(ref p) = std::env::var_os("HOME")
337 .map(|h| std::path::PathBuf::from(h).join(".kimi").join("device_id"))
338 {
339 if let Ok(id) = std::fs::read_to_string(p) {
340 let trimmed = id.trim();
341 if !trimmed.is_empty() {
342 return trimmed.to_string();
343 }
344 }
345 }
346
347 if let Some(ref p) =
349 std::env::var_os("HOME").map(|h| std::path::PathBuf::from(h).join(".imp").join("device_id"))
350 {
351 if let Ok(id) = std::fs::read_to_string(p) {
352 let trimmed = id.trim();
353 if !trimmed.is_empty() {
354 return trimmed.to_string();
355 }
356 }
357 }
358
359 let mut bytes = [0u8; 16];
360 rand::thread_rng().fill_bytes(&mut bytes);
361 let id = bytes.iter().map(|b| format!("{b:02x}")).collect::<String>();
362
363 if let Some(ref p) =
365 std::env::var_os("HOME").map(|h| std::path::PathBuf::from(h).join(".imp").join("device_id"))
366 {
367 if let Some(parent) = p.parent() {
368 let _ = std::fs::create_dir_all(parent);
369 }
370 let _ = std::fs::write(p, &id);
371 }
372
373 id
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use tokio::io::{AsyncReadExt, AsyncWriteExt};
380 use tokio::net::TcpListener as TokioListener;
381
382 async fn start_mock_listener() -> (TokioListener, u16) {
383 let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
384 let port = listener.local_addr().unwrap().port();
385 (listener, port)
386 }
387
388 async fn serve_once(listener: TokioListener, status: u16, body: String) {
389 let (mut stream, _) = listener.accept().await.unwrap();
390 let mut buf = vec![0u8; 8192];
391 let _ = stream.read(&mut buf).await.unwrap();
392 let status_text = if status == 200 { "OK" } else { "Error" };
393 let response = format!(
394 "HTTP/1.1 {status} {status_text}\r\n\
395 Content-Type: application/json\r\n\
396 Content-Length: {}\r\n\
397 Connection: close\r\n\r\n\
398 {body}",
399 body.len()
400 );
401 stream.write_all(response.as_bytes()).await.unwrap();
402 stream.flush().await.unwrap();
403 }
404
405 #[tokio::test]
406 async fn test_request_device_authorization() {
407 let body = serde_json::json!({
408 "user_code": "ABCD-EFGH",
409 "device_code": "dev-123",
410 "verification_uri": "https://auth.kimi.com/verify",
411 "verification_uri_complete": "https://auth.kimi.com/verify?code=ABCD-EFGH",
412 "expires_in": 600,
413 "interval": 5
414 })
415 .to_string();
416
417 let (listener, port) = start_mock_listener().await;
418 tokio::spawn(serve_once(listener, 200, body));
419
420 let oauth = KimiCodeOAuth::with_endpoints(
421 format!("http://127.0.0.1:{port}/device"),
422 format!("http://127.0.0.1:{port}/token"),
423 );
424 let auth = oauth.request_device_authorization().await.unwrap();
425 assert_eq!(auth.user_code, "ABCD-EFGH");
426 assert_eq!(auth.device_code, "dev-123");
427 assert_eq!(auth.interval, 5);
428 }
429
430 #[tokio::test]
431 async fn test_refresh_token() {
432 let body = serde_json::json!({
433 "access_token": "new-access-token",
434 "refresh_token": "new-refresh-token",
435 "expires_in": 3600,
436 "scope": "kimi-code",
437 "token_type": "Bearer"
438 })
439 .to_string();
440
441 let (listener, port) = start_mock_listener().await;
442 tokio::spawn(serve_once(listener, 200, body));
443
444 let oauth = KimiCodeOAuth::with_endpoints(
445 format!("http://127.0.0.1:{port}/device"),
446 format!("http://127.0.0.1:{port}/token"),
447 );
448 let cred = oauth.refresh_token("old-refresh").await.unwrap();
449 assert_eq!(cred.access_token, "new-access-token");
450 assert_eq!(cred.refresh_token, "new-refresh-token");
451 }
452
453 #[tokio::test]
454 async fn test_token_response_with_float_expires_in() {
455 let body = serde_json::json!({
456 "access_token": "test-token",
457 "refresh_token": "test-refresh",
458 "expires_in": 900.0,
459 "scope": "kimi-code",
460 "token_type": "Bearer"
461 })
462 .to_string();
463
464 let (listener, port) = start_mock_listener().await;
465 tokio::spawn(serve_once(listener, 200, body));
466
467 let oauth = KimiCodeOAuth::with_endpoints(
468 format!("http://127.0.0.1:{port}/device"),
469 format!("http://127.0.0.1:{port}/token"),
470 );
471 let cred = oauth.refresh_token("old-refresh").await.unwrap();
472 assert_eq!(cred.access_token, "test-token");
473 assert_eq!(cred.refresh_token, "test-refresh");
474 let expected_min = crate::now() + 500;
476 let expected_max = crate::now() + 700;
477 assert!(
478 cred.expires_at >= expected_min && cred.expires_at <= expected_max,
479 "expires_at {} not in range [{}, {}]",
480 cred.expires_at,
481 expected_min,
482 expected_max
483 );
484 }
485}