Skip to main content

imp_llm/oauth/
kimi_code.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use rand::RngCore;
5use serde::Deserialize;
6
7use crate::auth::OAuthCredential;
8use crate::error::{Error, Result};
9
10const CLIENT_ID: &str = "17e5f671-d194-4dfb-9706-5516cb48c098";
11const DEVICE_AUTH_URL: &str = "https://auth.kimi.com/api/oauth/device_authorization";
12const TOKEN_URL: &str = "https://auth.kimi.com/api/oauth/token";
13
14/// Kimi Code OAuth handler using OAuth 2.0 device flow.
15pub struct KimiCodeOAuth {
16    client_id: String,
17    token_url: String,
18    device_auth_url: String,
19}
20
21impl Default for KimiCodeOAuth {
22    fn default() -> Self {
23        Self {
24            client_id: CLIENT_ID.to_string(),
25            token_url: TOKEN_URL.to_string(),
26            device_auth_url: DEVICE_AUTH_URL.to_string(),
27        }
28    }
29}
30
31impl KimiCodeOAuth {
32    /// Create with production Kimi Code OAuth endpoints.
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Create with custom endpoints (for testing with a mock server).
38    pub fn with_endpoints(device_auth_url: String, token_url: String) -> Self {
39        Self {
40            client_id: CLIENT_ID.to_string(),
41            token_url,
42            device_auth_url,
43        }
44    }
45
46    /// Request a device authorization.
47    pub async fn request_device_authorization(&self) -> Result<DeviceAuthorization> {
48        let client = reqwest::Client::new();
49        let response = client
50            .post(&self.device_auth_url)
51            .form(&[("client_id", self.client_id.as_str())])
52            .headers(common_headers())
53            .send()
54            .await?;
55
56        let status = response.status();
57        let data: serde_json::Value = response.json().await?;
58
59        if !status.is_success() {
60            return Err(Error::Auth(format!(
61                "Device authorization failed ({status}): {data}"
62            )));
63        }
64
65        Ok(DeviceAuthorization {
66            user_code: data["user_code"].as_str().unwrap_or("").to_string(),
67            device_code: data["device_code"].as_str().unwrap_or("").to_string(),
68            verification_uri: data["verification_uri"].as_str().unwrap_or("").to_string(),
69            verification_uri_complete: data["verification_uri_complete"]
70                .as_str()
71                .unwrap_or("")
72                .to_string(),
73            expires_in: data["expires_in"].as_u64(),
74            interval: data["interval"].as_u64().unwrap_or(5).max(1),
75        })
76    }
77
78    /// Poll the token endpoint for a device code.
79    pub async fn request_device_token(
80        &self,
81        device_code: &str,
82    ) -> Result<(u16, HashMap<String, serde_json::Value>)> {
83        let client = reqwest::Client::new();
84        let response = client
85            .post(&self.token_url)
86            .form(&[
87                ("client_id", self.client_id.as_str()),
88                ("device_code", device_code),
89                ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
90            ])
91            .headers(common_headers())
92            .send()
93            .await?;
94
95        let status = response.status();
96        let data: HashMap<String, serde_json::Value> = response.json().await?;
97        Ok((status.as_u16(), data))
98    }
99
100    /// Exchange an authorization code for access + refresh tokens.
101    pub async fn exchange_code(&self, code: &str) -> Result<OAuthCredential> {
102        let client = reqwest::Client::new();
103        let response = client
104            .post(&self.token_url)
105            .form(&[
106                ("grant_type", "authorization_code"),
107                ("client_id", self.client_id.as_str()),
108                ("code", code),
109            ])
110            .headers(common_headers())
111            .send()
112            .await?;
113
114        if !response.status().is_success() {
115            let status = response.status();
116            let body = response.text().await.unwrap_or_default();
117            return Err(Error::Auth(format!(
118                "Token exchange failed ({status}): {body}"
119            )));
120        }
121
122        let token: TokenResponse = response.json().await?;
123        Ok(to_oauth_credential(token))
124    }
125
126    /// Refresh an expired OAuth token.
127    pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthCredential> {
128        let client = reqwest::Client::new();
129        let response = client
130            .post(&self.token_url)
131            .form(&[
132                ("grant_type", "refresh_token"),
133                ("client_id", self.client_id.as_str()),
134                ("refresh_token", refresh_token),
135            ])
136            .headers(common_headers())
137            .send()
138            .await?;
139
140        if !response.status().is_success() {
141            let status = response.status();
142            let body = response.text().await.unwrap_or_default();
143            return Err(Error::Auth(format!(
144                "Token refresh failed ({status}): {body}"
145            )));
146        }
147
148        let token: TokenResponse = response.json().await?;
149        Ok(to_oauth_credential(token))
150    }
151
152    /// Full device-flow login: get device code, open browser, poll for token.
153    ///
154    /// `open_url` is called with the verification URL to open in the browser.
155    /// `print_message` is called with status messages for the user.
156    pub async fn login<F, G>(&self, open_url: F, mut print_message: G) -> Result<OAuthCredential>
157    where
158        F: FnOnce(&str),
159        G: FnMut(&str),
160    {
161        let auth = self.request_device_authorization().await?;
162
163        print_message("Please visit the following URL to finish authorization:");
164        print_message(&format!(
165            "Verification URL: {}",
166            auth.verification_uri_complete
167        ));
168        open_url(&auth.verification_uri_complete);
169
170        let interval = Duration::from_secs(auth.interval);
171        let max_duration = auth
172            .expires_in
173            .map(Duration::from_secs)
174            .unwrap_or_else(|| Duration::from_secs(600));
175        let start = std::time::Instant::now();
176        let mut printed_wait = false;
177
178        while start.elapsed() < max_duration {
179            let (status, data) = self.request_device_token(&auth.device_code).await?;
180
181            if status == 200 && data.contains_key("access_token") {
182                let token: TokenResponse = serde_json::from_value(
183                    serde_json::to_value(&data).map_err(|e| Error::Auth(e.to_string()))?,
184                )?;
185                return Ok(to_oauth_credential(token));
186            }
187
188            if let Some(error) = data.get("error").and_then(|v| v.as_str()) {
189                if error == "expired_token" {
190                    return Err(Error::Auth(
191                        "Device authorization expired. Please try again.".into(),
192                    ));
193                }
194                if error == "authorization_pending" {
195                    if !printed_wait {
196                        print_message("Waiting for user authorization...");
197                        printed_wait = true;
198                    }
199                } else {
200                    let desc = data
201                        .get("error_description")
202                        .and_then(|v| v.as_str())
203                        .unwrap_or(error);
204                    return Err(Error::Auth(format!("OAuth error: {desc}")));
205                }
206            }
207
208            tokio::time::sleep(interval).await;
209        }
210
211        Err(Error::Auth(
212            "Device authorization timed out. Please try again.".into(),
213        ))
214    }
215}
216
217#[derive(Debug, Clone)]
218pub struct DeviceAuthorization {
219    pub user_code: String,
220    pub device_code: String,
221    pub verification_uri: String,
222    pub verification_uri_complete: String,
223    pub expires_in: Option<u64>,
224    pub interval: u64,
225}
226
227#[derive(Debug, Deserialize)]
228struct TokenResponse {
229    access_token: String,
230    #[serde(default)]
231    refresh_token: Option<String>,
232    #[serde(default)]
233    expires_in: f64,
234    #[allow(dead_code)]
235    #[serde(default)]
236    scope: String,
237    #[allow(dead_code)]
238    #[serde(default)]
239    token_type: String,
240}
241
242fn to_oauth_credential(token: TokenResponse) -> OAuthCredential {
243    let expires_in = token.expires_in as u64;
244    let expires_at = crate::now() + expires_in.saturating_sub(300);
245    OAuthCredential {
246        access_token: token.access_token,
247        refresh_token: token.refresh_token.unwrap_or_default(),
248        expires_at,
249    }
250}
251
252/// Build the standard Kimi Code request headers.
253///
254/// These headers match what kimi-cli sends so that the API accepts
255/// requests from imp as a recognized coding agent.
256pub fn common_headers() -> reqwest::header::HeaderMap {
257    let mut headers = reqwest::header::HeaderMap::new();
258    headers.insert(
259        reqwest::header::USER_AGENT,
260        reqwest::header::HeaderValue::from_static("KimiCLI/1.39.0"),
261    );
262    headers.insert(
263        "X-Msh-Platform",
264        reqwest::header::HeaderValue::from_static("kimi_cli"),
265    );
266    headers.insert(
267        "X-Msh-Version",
268        reqwest::header::HeaderValue::from_static("1.39.0"),
269    );
270    headers.insert(
271        "X-Msh-Device-Name",
272        reqwest::header::HeaderValue::from_str(&hostname())
273            .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("unknown")),
274    );
275    headers.insert(
276        "X-Msh-Device-Model",
277        reqwest::header::HeaderValue::from_str(&device_model())
278            .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("unknown")),
279    );
280    headers.insert(
281        "X-Msh-Os-Version",
282        reqwest::header::HeaderValue::from_str(&os_version())
283            .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("unknown")),
284    );
285    headers.insert(
286        "X-Msh-Device-Id",
287        reqwest::header::HeaderValue::from_str(&device_id())
288            .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("unknown")),
289    );
290    headers
291}
292
293fn hostname() -> String {
294    #[cfg(unix)]
295    {
296        std::process::Command::new("hostname")
297            .output()
298            .ok()
299            .and_then(|o| String::from_utf8(o.stdout).ok())
300            .map(|s| s.trim().to_string())
301            .unwrap_or_else(|| "unknown".into())
302    }
303    #[cfg(not(unix))]
304    {
305        "unknown".to_string()
306    }
307}
308
309fn device_model() -> String {
310    let arch = std::env::consts::ARCH;
311    let os = std::env::consts::OS;
312    format!("{} {}", os, arch)
313}
314
315fn os_version() -> String {
316    #[cfg(target_os = "macos")]
317    {
318        std::process::Command::new("sw_vers")
319            .arg("-productVersion")
320            .output()
321            .ok()
322            .and_then(|o| String::from_utf8(o.stdout).ok())
323            .map(|s| s.trim().to_string())
324            .unwrap_or_else(|| std::env::consts::OS.to_string())
325    }
326    #[cfg(not(target_os = "macos"))]
327    {
328        std::env::consts::OS.to_string()
329    }
330}
331
332fn device_id() -> String {
333    // If the user already has kimi-cli installed, reuse its device id so
334    // that imported tokens (whose JWT contains that device_id) match the
335    // headers we send to the API.
336    if let Some(ref p) = std::env::var_os("HOME")
337        .map(|h| std::path::PathBuf::from(h).join(".kimi").join("device_id"))
338    {
339        if let Ok(id) = std::fs::read_to_string(p) {
340            let trimmed = id.trim();
341            if !trimmed.is_empty() {
342                return trimmed.to_string();
343            }
344        }
345    }
346
347    // Fall back to imp's own device id.
348    if let Some(ref p) =
349        std::env::var_os("HOME").map(|h| std::path::PathBuf::from(h).join(".imp").join("device_id"))
350    {
351        if let Ok(id) = std::fs::read_to_string(p) {
352            let trimmed = id.trim();
353            if !trimmed.is_empty() {
354                return trimmed.to_string();
355            }
356        }
357    }
358
359    let mut bytes = [0u8; 16];
360    rand::thread_rng().fill_bytes(&mut bytes);
361    let id = bytes.iter().map(|b| format!("{b:02x}")).collect::<String>();
362
363    // Persist the newly generated id in imp's own directory.
364    if let Some(ref p) =
365        std::env::var_os("HOME").map(|h| std::path::PathBuf::from(h).join(".imp").join("device_id"))
366    {
367        if let Some(parent) = p.parent() {
368            let _ = std::fs::create_dir_all(parent);
369        }
370        let _ = std::fs::write(p, &id);
371    }
372
373    id
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use tokio::io::{AsyncReadExt, AsyncWriteExt};
380    use tokio::net::TcpListener as TokioListener;
381
382    async fn start_mock_listener() -> (TokioListener, u16) {
383        let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
384        let port = listener.local_addr().unwrap().port();
385        (listener, port)
386    }
387
388    async fn serve_once(listener: TokioListener, status: u16, body: String) {
389        let (mut stream, _) = listener.accept().await.unwrap();
390        let mut buf = vec![0u8; 8192];
391        let _ = stream.read(&mut buf).await.unwrap();
392        let status_text = if status == 200 { "OK" } else { "Error" };
393        let response = format!(
394            "HTTP/1.1 {status} {status_text}\r\n\
395             Content-Type: application/json\r\n\
396             Content-Length: {}\r\n\
397             Connection: close\r\n\r\n\
398             {body}",
399            body.len()
400        );
401        stream.write_all(response.as_bytes()).await.unwrap();
402        stream.flush().await.unwrap();
403    }
404
405    #[tokio::test]
406    async fn test_request_device_authorization() {
407        let body = serde_json::json!({
408            "user_code": "ABCD-EFGH",
409            "device_code": "dev-123",
410            "verification_uri": "https://auth.kimi.com/verify",
411            "verification_uri_complete": "https://auth.kimi.com/verify?code=ABCD-EFGH",
412            "expires_in": 600,
413            "interval": 5
414        })
415        .to_string();
416
417        let (listener, port) = start_mock_listener().await;
418        tokio::spawn(serve_once(listener, 200, body));
419
420        let oauth = KimiCodeOAuth::with_endpoints(
421            format!("http://127.0.0.1:{port}/device"),
422            format!("http://127.0.0.1:{port}/token"),
423        );
424        let auth = oauth.request_device_authorization().await.unwrap();
425        assert_eq!(auth.user_code, "ABCD-EFGH");
426        assert_eq!(auth.device_code, "dev-123");
427        assert_eq!(auth.interval, 5);
428    }
429
430    #[tokio::test]
431    async fn test_refresh_token() {
432        let body = serde_json::json!({
433            "access_token": "new-access-token",
434            "refresh_token": "new-refresh-token",
435            "expires_in": 3600,
436            "scope": "kimi-code",
437            "token_type": "Bearer"
438        })
439        .to_string();
440
441        let (listener, port) = start_mock_listener().await;
442        tokio::spawn(serve_once(listener, 200, body));
443
444        let oauth = KimiCodeOAuth::with_endpoints(
445            format!("http://127.0.0.1:{port}/device"),
446            format!("http://127.0.0.1:{port}/token"),
447        );
448        let cred = oauth.refresh_token("old-refresh").await.unwrap();
449        assert_eq!(cred.access_token, "new-access-token");
450        assert_eq!(cred.refresh_token, "new-refresh-token");
451    }
452
453    #[tokio::test]
454    async fn test_token_response_with_float_expires_in() {
455        let body = serde_json::json!({
456            "access_token": "test-token",
457            "refresh_token": "test-refresh",
458            "expires_in": 900.0,
459            "scope": "kimi-code",
460            "token_type": "Bearer"
461        })
462        .to_string();
463
464        let (listener, port) = start_mock_listener().await;
465        tokio::spawn(serve_once(listener, 200, body));
466
467        let oauth = KimiCodeOAuth::with_endpoints(
468            format!("http://127.0.0.1:{port}/device"),
469            format!("http://127.0.0.1:{port}/token"),
470        );
471        let cred = oauth.refresh_token("old-refresh").await.unwrap();
472        assert_eq!(cred.access_token, "test-token");
473        assert_eq!(cred.refresh_token, "test-refresh");
474        // expires_at should be roughly now + 900 - 300 = now + 600
475        let expected_min = crate::now() + 500;
476        let expected_max = crate::now() + 700;
477        assert!(
478            cred.expires_at >= expected_min && cred.expires_at <= expected_max,
479            "expires_at {} not in range [{}, {}]",
480            cred.expires_at,
481            expected_min,
482            expected_max
483        );
484    }
485}