Skip to main content

roder_ext_kimi_code/
auth.rs

1use std::fs;
2use std::path::PathBuf;
3use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
4
5use reqwest::{Client, Url};
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9const DEFAULT_OAUTH_HOST: &str = "https://auth.kimi.com";
10pub const DEFAULT_MANAGED_BASE_URL: &str = "https://api.kimi.com/coding/v1";
11pub const DEFAULT_OPEN_PLATFORM_BASE_URL: &str = "https://api.moonshot.ai/v1";
12const CLIENT_ID: &str = "17e5f671-d194-4dfb-9706-5516cb48c098";
13const KIMI_CODE_PLATFORM: &str = "kimi_code_cli";
14const KIMI_CODE_USER_AGENT_PRODUCT: &str = "kimi-code-cli";
15const REFRESH_EXPIRY_SKEW_MILLIS: i64 = 3 * 60 * 1000;
16const RODER_VERSION: &str = env!("CARGO_PKG_VERSION");
17
18fn auth_client() -> Client {
19    Client::builder()
20        .user_agent(format!("Roder/{RODER_VERSION} (+https://roder.sh)"))
21        .build()
22        .expect("failed to construct reqwest client for Kimi Code auth")
23}
24
25#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq)]
26pub struct Tokens {
27    #[serde(rename = "type", default = "default_token_type")]
28    pub token_type: String,
29    #[serde(default)]
30    pub refresh: String,
31    #[serde(default)]
32    pub access: String,
33    #[serde(default)]
34    pub expires: i64,
35    #[serde(default)]
36    pub scope: String,
37}
38
39#[derive(Debug, Default, Deserialize)]
40struct TokenResponse {
41    #[serde(default)]
42    access_token: String,
43    #[serde(default)]
44    refresh_token: String,
45    #[serde(default)]
46    expires_in: i64,
47    #[serde(default)]
48    token_type: String,
49    #[serde(default)]
50    scope: String,
51}
52
53#[derive(Debug, Default, Deserialize)]
54struct DeviceAuthorizationResponse {
55    device_code: String,
56    user_code: String,
57    verification_uri: String,
58    verification_uri_complete: Option<String>,
59    expires_in: u64,
60    interval: u64,
61}
62
63#[derive(Debug, Default, Deserialize)]
64struct DeviceTokenErrorResponse {
65    error: String,
66    error_description: Option<String>,
67}
68
69pub struct Store {
70    data_dir: PathBuf,
71}
72
73impl Default for Store {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl Store {
80    pub fn new() -> Self {
81        Self {
82            data_dir: roder_data_dir(),
83        }
84    }
85
86    pub fn load(&self) -> anyhow::Result<Tokens> {
87        load_tokens_from(&self.path())
88    }
89
90    pub fn save(&self, mut tokens: Tokens) -> anyhow::Result<()> {
91        normalize(&mut tokens);
92        let path = self.path();
93        if let Some(parent) = path.parent() {
94            fs::create_dir_all(parent)?;
95        }
96        let data = serde_json::to_vec_pretty(&tokens)?;
97        fs::write(path, [data, b"\n".to_vec()].concat())?;
98        Ok(())
99    }
100
101    pub fn delete(&self) -> anyhow::Result<()> {
102        let path = self.path();
103        match fs::remove_file(path) {
104            Ok(()) => Ok(()),
105            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
106            Err(err) => Err(err.into()),
107        }
108    }
109
110    fn path(&self) -> PathBuf {
111        self.data_dir.join("auth").join("kimi-code.json")
112    }
113
114    fn device_id_path(&self) -> PathBuf {
115        self.data_dir.join("auth").join("kimi-code-device-id")
116    }
117}
118
119pub fn managed_base_url() -> String {
120    std::env::var("KIMI_CODE_BASE_URL")
121        .ok()
122        .or_else(|| std::env::var("RODER_KIMI_CODE_BASE_URL").ok())
123        .filter(|value| !value.trim().is_empty())
124        .unwrap_or_else(|| DEFAULT_MANAGED_BASE_URL.to_string())
125}
126
127pub fn inference_headers() -> anyhow::Result<Vec<(String, String)>> {
128    let mut headers = device_headers()?
129        .into_iter()
130        .map(|(name, value)| (name.to_string(), value))
131        .collect::<Vec<_>>();
132    headers.push(("User-Agent".to_string(), kimi_code_user_agent()));
133    Ok(headers)
134}
135
136fn kimi_code_user_agent() -> String {
137    format!("{KIMI_CODE_USER_AGENT_PRODUCT}/{RODER_VERSION} (roder)")
138}
139
140pub fn oauth_host() -> String {
141    std::env::var("KIMI_CODE_OAUTH_HOST")
142        .ok()
143        .or_else(|| std::env::var("KIMI_OAUTH_HOST").ok())
144        .filter(|value| !value.trim().is_empty())
145        .unwrap_or_else(|| DEFAULT_OAUTH_HOST.to_string())
146}
147
148pub fn has_stored_tokens() -> bool {
149    Store::new()
150        .load()
151        .ok()
152        .is_some_and(|tokens| {
153            !tokens.refresh.trim().is_empty() || !tokens.access.trim().is_empty()
154        })
155}
156
157pub async fn access_token() -> anyhow::Result<Option<String>> {
158    let store = Store::new();
159    let tokens = store.load()?;
160    if tokens.refresh.trim().is_empty() && tokens.access.trim().is_empty() {
161        return Ok(None);
162    }
163    let now = now_millis();
164    if !tokens.access.trim().is_empty() && tokens.expires > now + REFRESH_EXPIRY_SKEW_MILLIS {
165        return Ok(Some(tokens.access));
166    }
167    if tokens.refresh.trim().is_empty() {
168        return Ok(None);
169    }
170    let mut refreshed = refresh(&tokens.refresh).await?;
171    if refreshed.refresh.is_empty() {
172        refreshed.refresh = tokens.refresh;
173    }
174    if refreshed.scope.is_empty() {
175        refreshed.scope = tokens.scope;
176    }
177    let access = refreshed.access.clone();
178    store.save(refreshed)?;
179    Ok(Some(access))
180}
181
182pub async fn device_flow() -> anyhow::Result<Tokens> {
183    let oauth_host = oauth_host();
184    let device_endpoint = device_authorization_url(&oauth_host)?;
185    let token_endpoint = token_url(&oauth_host)?;
186
187    let client = auth_client();
188    let response = client
189        .post(device_endpoint)
190        .headers(device_header_map()?)
191        .form(&[("client_id", CLIENT_ID)])
192        .send()
193        .await?;
194    let status = response.status();
195    let text = response.text().await?;
196    if !status.is_success() {
197        anyhow::bail!(
198            "kimi-code device authorization request failed: {status} {}",
199            text.trim()
200        );
201    }
202    let device_auth: DeviceAuthorizationResponse = serde_json::from_str(&text).map_err(|e| {
203        anyhow::anyhow!(
204            "kimi-code device authorization response was not valid JSON: {e}\n{text}"
205        )
206    })?;
207
208    let verification_uri = device_auth
209        .verification_uri_complete
210        .as_ref()
211        .unwrap_or(&device_auth.verification_uri);
212    eprintln!("Kimi Code device sign-in");
213    eprintln!("User code: {}", device_auth.user_code);
214    eprintln!("Open: {verification_uri}");
215    open_browser(verification_uri).or_else(|err| {
216        eprintln!("Could not open browser automatically: {err}");
217        Ok::<(), anyhow::Error>(())
218    })?;
219
220    let token = poll_device_token(
221        &token_endpoint,
222        &device_auth.device_code,
223        device_auth.interval,
224        device_auth.expires_in,
225    )
226    .await?;
227    Store::new().save(token.clone())?;
228    Ok(token)
229}
230
231async fn poll_device_token(
232    token_endpoint: &str,
233    device_code: &str,
234    mut interval: u64,
235    expires_in: u64,
236) -> anyhow::Result<Tokens> {
237    let max_interval = 60;
238    if interval == 0 {
239        interval = 5;
240    }
241    let expires_at = Instant::now() + Duration::from_secs(expires_in);
242    let client = auth_client();
243    loop {
244        if Instant::now() >= expires_at {
245            anyhow::bail!("kimi-code device sign-in expired");
246        }
247        let response = client
248            .post(token_endpoint)
249            .headers(device_header_map()?)
250            .form(&[
251                (
252                    "grant_type",
253                    "urn:ietf:params:oauth:grant-type:device_code",
254                ),
255                ("device_code", device_code),
256                ("client_id", CLIENT_ID),
257            ])
258            .send()
259            .await?;
260        let status = response.status();
261        let text = response.text().await?;
262        if status.is_success() {
263            let token_response = parse_token_response(&text)?;
264            return tokens_from_response(token_response);
265        }
266        let error: DeviceTokenErrorResponse = match serde_json::from_str(&text) {
267            Ok(error) => error,
268            Err(_) => {
269                anyhow::bail!("kimi-code device token request failed: {status} {}", text.trim());
270            }
271        };
272        match error.error.as_str() {
273            "authorization_pending" => {
274                tokio::time::sleep(Duration::from_secs(interval)).await;
275            }
276            "slow_down" => {
277                interval = (interval + 5).min(max_interval);
278                tokio::time::sleep(Duration::from_secs(interval)).await;
279            }
280            "expired_token" => {
281                anyhow::bail!("kimi-code device sign-in expired (expired_token)");
282            }
283            "access_denied" => {
284                let desc = error
285                    .error_description
286                    .as_deref()
287                    .unwrap_or("access denied");
288                anyhow::bail!("kimi-code device sign-in denied: {desc}");
289            }
290            other => {
291                let desc = error.error_description.as_deref().unwrap_or(other);
292                anyhow::bail!("kimi-code device sign-in error: {desc}");
293            }
294        }
295    }
296}
297
298pub async fn status() -> anyhow::Result<Option<Tokens>> {
299    let tokens = Store::new().load()?;
300    Ok((!tokens.refresh.trim().is_empty()).then_some(tokens))
301}
302
303pub fn logout() -> anyhow::Result<()> {
304    Store::new().delete()
305}
306
307async fn refresh(refresh_token: &str) -> anyhow::Result<Tokens> {
308    let token_endpoint = token_url(&oauth_host())?;
309    let params = [
310        ("grant_type", "refresh_token"),
311        ("refresh_token", refresh_token),
312        ("client_id", CLIENT_ID),
313    ];
314    token_request(&token_endpoint, &params).await
315}
316
317async fn token_request(token_endpoint: &str, params: &[(&str, &str)]) -> anyhow::Result<Tokens> {
318    validate_kimi_https_endpoint(token_endpoint)?;
319    let response = auth_client()
320        .post(token_endpoint)
321        .headers(device_header_map()?)
322        .form(params)
323        .send()
324        .await?;
325    let status = response.status();
326    let text = response.text().await?;
327    if !status.is_success() {
328        anyhow::bail!(
329            "kimi-code token request failed: {status} {}",
330            redacted_body_excerpt(&text)
331        );
332    }
333    let token_response = parse_token_response(&text)?;
334    tokens_from_response(token_response)
335}
336
337fn device_authorization_url(oauth_host: &str) -> anyhow::Result<String> {
338    let base = normalize_oauth_host(oauth_host)?;
339    validate_kimi_https_endpoint(&base)?;
340    Ok(format!("{base}/api/oauth/device_authorization"))
341}
342
343fn token_url(oauth_host: &str) -> anyhow::Result<String> {
344    let base = normalize_oauth_host(oauth_host)?;
345    validate_kimi_https_endpoint(&base)?;
346    Ok(format!("{base}/api/oauth/token"))
347}
348
349fn normalize_oauth_host(oauth_host: &str) -> anyhow::Result<String> {
350    Ok(oauth_host.trim_end_matches('/').to_string())
351}
352
353fn device_header_map() -> anyhow::Result<reqwest::header::HeaderMap> {
354    let mut headers = reqwest::header::HeaderMap::new();
355    for (name, value) in device_headers()? {
356        headers.insert(
357            reqwest::header::HeaderName::from_bytes(name.as_bytes())
358                .map_err(|err| anyhow::anyhow!("invalid kimi-code header name: {err}"))?,
359            reqwest::header::HeaderValue::from_str(&value)
360                .map_err(|err| anyhow::anyhow!("invalid kimi-code header value for {name}: {err}"))?,
361        );
362    }
363    Ok(headers)
364}
365
366fn device_headers() -> anyhow::Result<Vec<(&'static str, String)>> {
367    let store = Store::new();
368    Ok(vec![
369        ("X-Msh-Platform", KIMI_CODE_PLATFORM.to_string()),
370        ("X-Msh-Version", RODER_VERSION.to_string()),
371        ("X-Msh-Device-Name", ascii_header(hostname(), "unknown")),
372        ("X-Msh-Device-Model", ascii_header(device_model(), "unknown")),
373        ("X-Msh-Os-Version", ascii_header(os_version(), "unknown")),
374        ("X-Msh-Device-Id", read_or_create_device_id(&store)?),
375    ])
376}
377
378fn read_or_create_device_id(store: &Store) -> anyhow::Result<String> {
379    let path = store.device_id_path();
380    if let Ok(contents) = fs::read_to_string(&path) {
381        let id = contents.trim();
382        if !id.is_empty() {
383            return Ok(id.to_string());
384        }
385    }
386    let id = Uuid::new_v4().to_string();
387    if let Some(parent) = path.parent() {
388        fs::create_dir_all(parent)?;
389    }
390    fs::write(&path, format!("{id}\n"))?;
391    Ok(id)
392}
393
394fn hostname() -> String {
395    std::env::var("HOSTNAME")
396        .or_else(|_| std::env::var("COMPUTERNAME"))
397        .unwrap_or_else(|_| "unknown".to_string())
398}
399
400fn os_version() -> String {
401    #[cfg(target_os = "macos")]
402    {
403        if let Ok(version) = std::process::Command::new("/usr/bin/sw_vers")
404            .arg("-productVersion")
405            .output()
406        {
407            let text = String::from_utf8_lossy(&version.stdout).trim().to_string();
408            if !text.is_empty() {
409                return text;
410            }
411        }
412    }
413    std::env::consts::OS.to_string()
414}
415
416fn device_model() -> String {
417    let os = std::env::consts::OS;
418    let arch = std::env::consts::ARCH;
419    if os == "macos" {
420        format!("macOS {} {arch}", os_version())
421    } else if os == "windows" {
422        format!("Windows {arch}")
423    } else {
424        format!("{os} {arch}")
425    }
426}
427
428fn ascii_header(value: String, fallback: &str) -> String {
429    let cleaned: String = value
430        .chars()
431        .filter(|ch| (' '..='~').contains(ch))
432        .collect::<String>()
433        .trim()
434        .to_string();
435    if cleaned.is_empty() {
436        fallback.to_string()
437    } else {
438        cleaned
439    }
440}
441
442fn validate_kimi_https_endpoint(endpoint: &str) -> anyhow::Result<()> {
443    let url = Url::parse(endpoint)?;
444    if url.scheme() != "https" {
445        anyhow::bail!("kimi-code oauth endpoint must use https");
446    }
447    let host = url.host_str().unwrap_or_default();
448    if host == "kimi.com" || host.ends_with(".kimi.com") {
449        return Ok(());
450    }
451    anyhow::bail!("kimi-code oauth endpoint must be hosted on kimi.com");
452}
453
454fn load_tokens_from(path: &PathBuf) -> anyhow::Result<Tokens> {
455    match fs::read_to_string(path) {
456        Ok(contents) if contents.trim().is_empty() => Ok(Tokens::default()),
457        Ok(contents) => {
458            let mut tokens: Tokens = serde_json::from_str(&contents)?;
459            normalize(&mut tokens);
460            Ok(tokens)
461        }
462        Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(Tokens::default()),
463        Err(err) => Err(err.into()),
464    }
465}
466
467fn parse_token_response(text: &str) -> anyhow::Result<TokenResponse> {
468    serde_json::from_str(text).map_err(|err| {
469        anyhow::anyhow!(
470            "kimi-code token response was not valid JSON: {err}; body: {}",
471            redacted_body_excerpt(text)
472        )
473    })
474}
475
476fn tokens_from_response(response: TokenResponse) -> anyhow::Result<Tokens> {
477    if response.access_token.trim().is_empty() {
478        anyhow::bail!("kimi-code token response missing access_token");
479    }
480    if response.refresh_token.trim().is_empty() {
481        anyhow::bail!("kimi-code token response missing refresh_token");
482    }
483    let expires_in = if response.expires_in > 0 {
484        response.expires_in
485    } else {
486        3600
487    };
488    let mut tokens = Tokens {
489        token_type: if response.token_type.is_empty() {
490            default_token_type()
491        } else {
492            response.token_type
493        },
494        refresh: response.refresh_token,
495        access: response.access_token,
496        expires: now_millis() + expires_in * 1000,
497        scope: response.scope,
498    };
499    normalize(&mut tokens);
500    Ok(tokens)
501}
502
503fn open_browser(url: &str) -> anyhow::Result<()> {
504    let mut command = browser_command(url);
505    let status = command.status()?;
506    if !status.success() {
507        anyhow::bail!("failed to open browser");
508    }
509    Ok(())
510}
511
512fn browser_command(url: &str) -> std::process::Command {
513    #[cfg(target_os = "macos")]
514    let mut command = std::process::Command::new("open");
515    #[cfg(target_os = "linux")]
516    let mut command = std::process::Command::new("xdg-open");
517    #[cfg(target_os = "windows")]
518    let mut command = {
519        let mut command = std::process::Command::new("rundll32");
520        command.arg("url.dll,FileProtocolHandler");
521        command
522    };
523    command.arg(url);
524    command
525}
526
527fn redacted_body_excerpt(body: &str) -> String {
528    const MAX_ERROR_BODY_CHARS: usize = 1_000;
529    let mut excerpt = body.chars().take(MAX_ERROR_BODY_CHARS).collect::<String>();
530    if body.chars().count() > MAX_ERROR_BODY_CHARS {
531        excerpt.push_str(" ...");
532    }
533    for field in ["access_token", "refresh_token", "access", "refresh"] {
534        redact_json_string_field(&mut excerpt, field);
535    }
536    excerpt
537}
538
539fn redact_json_string_field(body: &mut String, field: &str) {
540    let pattern = format!("\"{field}\"");
541    let mut search_from = 0;
542    while let Some(relative_key_start) = body[search_from..].find(&pattern) {
543        let key_start = search_from + relative_key_start;
544        let Some(relative_colon) = body[key_start + pattern.len()..].find(':') else {
545            return;
546        };
547        let value_scan_start = key_start + pattern.len() + relative_colon + 1;
548        let Some(relative_quote) = body[value_scan_start..].find('"') else {
549            search_from = value_scan_start;
550            continue;
551        };
552        let value_start = value_scan_start + relative_quote;
553        let mut escaped = false;
554        let mut value_end = None;
555        for (offset, ch) in body[value_start + 1..].char_indices() {
556            if escaped {
557                escaped = false;
558            } else if ch == '\\' {
559                escaped = true;
560            } else if ch == '"' {
561                value_end = Some(value_start + 1 + offset);
562                break;
563            }
564        }
565        let Some(value_end) = value_end else {
566            return;
567        };
568        body.replace_range(value_start + 1..value_end, "[redacted]");
569        search_from = value_start + "\"[redacted]\"".len();
570    }
571}
572
573fn roder_data_dir() -> PathBuf {
574    std::env::var_os("RODER_DATA_DIR")
575        .or_else(|| std::env::var_os("RODER_CONFIG_DIR"))
576        .map(PathBuf::from)
577        .unwrap_or_else(|| {
578            dirs::home_dir()
579                .unwrap_or_else(|| PathBuf::from("."))
580                .join(".roder")
581        })
582}
583
584fn normalize(tokens: &mut Tokens) {
585    if tokens.token_type.is_empty() {
586        tokens.token_type = default_token_type();
587    }
588    tokens.refresh = tokens.refresh.trim().to_string();
589    tokens.access = tokens.access.trim().to_string();
590    tokens.scope = tokens.scope.trim().to_string();
591}
592
593fn default_token_type() -> String {
594    "Bearer".to_string()
595}
596
597fn now_millis() -> i64 {
598    SystemTime::now()
599        .duration_since(UNIX_EPOCH)
600        .unwrap_or_else(|_| Duration::from_secs(0))
601        .as_millis() as i64
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    #[test]
609    fn oauth_host_defaults_to_auth_kimi_com() {
610        assert_eq!(oauth_host(), DEFAULT_OAUTH_HOST);
611    }
612
613    #[test]
614    fn device_authorization_and_token_urls_use_kimi_host() {
615        let device = device_authorization_url(DEFAULT_OAUTH_HOST).unwrap();
616        let token = token_url(DEFAULT_OAUTH_HOST).unwrap();
617        assert_eq!(
618            device,
619            "https://auth.kimi.com/api/oauth/device_authorization"
620        );
621        assert_eq!(token, "https://auth.kimi.com/api/oauth/token");
622    }
623
624    #[test]
625    fn endpoint_validation_rejects_non_kimi_hosts() {
626        let err = validate_kimi_https_endpoint("https://example.com/oauth/token")
627            .unwrap_err()
628            .to_string();
629        assert!(err.contains("kimi.com"));
630
631        let insecure = validate_kimi_https_endpoint("http://auth.kimi.com/oauth/token")
632            .unwrap_err()
633            .to_string();
634        assert!(insecure.contains("https"));
635    }
636
637    #[test]
638    fn token_response_parses_required_fields() {
639        let tokens = tokens_from_response(TokenResponse {
640            access_token: "access".to_string(),
641            refresh_token: "refresh".to_string(),
642            expires_in: 120,
643            token_type: "Bearer".to_string(),
644            scope: "kimi".to_string(),
645        })
646        .unwrap();
647
648        assert_eq!(tokens.access, "access");
649        assert_eq!(tokens.refresh, "refresh");
650        assert_eq!(tokens.scope, "kimi");
651        assert!(tokens.expires > now_millis());
652    }
653
654    #[test]
655    fn token_response_parse_error_redacts_secret_material() {
656        let raw = r#"{"access_token":"secret-access","refresh_token":"secret-refresh"}{"extra":true}"#;
657        let err = parse_token_response(raw).unwrap_err().to_string();
658
659        assert!(err.contains("kimi-code token response was not valid JSON"));
660        assert!(err.contains("[redacted]"));
661        assert!(!err.contains("secret-access"));
662        assert!(!err.contains("secret-refresh"));
663    }
664
665    #[test]
666    fn device_headers_include_platform_and_version() {
667        let headers = device_headers().unwrap();
668        let platform = headers
669            .iter()
670            .find(|(name, _)| *name == "X-Msh-Platform")
671            .map(|(_, value)| value.as_str());
672        let version = headers
673            .iter()
674            .find(|(name, _)| *name == "X-Msh-Version")
675            .map(|(_, value)| value.as_str());
676        assert_eq!(platform, Some(KIMI_CODE_PLATFORM));
677        assert_eq!(version, Some(RODER_VERSION));
678    }
679
680    #[test]
681    fn inference_headers_include_kimi_code_cli_user_agent() {
682        let headers = inference_headers().unwrap();
683        let user_agent = headers
684            .iter()
685            .find(|(name, _)| name == "User-Agent")
686            .map(|(_, value)| value.as_str());
687        assert_eq!(user_agent, Some(kimi_code_user_agent().as_str()));
688        assert!(headers.iter().any(|(name, _)| name == "X-Msh-Device-Id"));
689    }
690
691    #[test]
692    fn managed_base_url_defaults_to_kimi_coding_api() {
693        assert_eq!(managed_base_url(), DEFAULT_MANAGED_BASE_URL);
694    }
695
696    #[test]
697    #[cfg(windows)]
698    fn windows_browser_command_does_not_shell_split_oauth_url() {
699        let url = "https://auth.kimi.com/device?client_id=app";
700        let command = browser_command(url);
701        let args = command
702            .get_args()
703            .map(|arg| arg.to_string_lossy().into_owned())
704            .collect::<Vec<_>>();
705
706        assert_eq!(command.get_program().to_string_lossy(), "rundll32");
707        assert_eq!(args, vec!["url.dll,FileProtocolHandler", url]);
708    }
709}