Skip to main content

ai_usagebar/anthropic/
creds.rs

1//! Read and write `~/.claude/.credentials.json` — the OAuth state the Claude
2//! CLI maintains. Mirrors claudebar:330-333 (read) and claudebar:447-452 (write).
3
4use std::path::{Path, PathBuf};
5
6use serde::{Deserialize, Serialize};
7
8use crate::cache::atomic_write;
9use crate::error::{AppError, Result};
10
11/// Disk shape (matches claudebar's jq paths).
12#[derive(Debug, Clone, Deserialize, Serialize)]
13pub struct CredentialsFile {
14    #[serde(rename = "claudeAiOauth")]
15    pub claude_ai_oauth: OauthCreds,
16}
17
18#[derive(Debug, Clone, Deserialize, Serialize)]
19pub struct OauthCreds {
20    #[serde(rename = "accessToken")]
21    pub access_token: String,
22    #[serde(rename = "refreshToken")]
23    pub refresh_token: String,
24    /// Unix epoch in **milliseconds** (claudebar:445 multiplies seconds × 1000).
25    /// May arrive as a float in the wild — claudebar truncates with `%%.*`,
26    /// so we accept both.
27    #[serde(rename = "expiresAt", deserialize_with = "de_ms_epoch")]
28    pub expires_at_ms: i64,
29    #[serde(rename = "subscriptionType", default)]
30    pub subscription_type: String,
31    #[serde(rename = "rateLimitTier", default)]
32    pub rate_limit_tier: String,
33    /// Optional `scopes` array — preserved through round-trips so we don't
34    /// drop information when we write back after a refresh.
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub scopes: Option<serde_json::Value>,
37}
38
39fn de_ms_epoch<'de, D>(d: D) -> std::result::Result<i64, D::Error>
40where
41    D: serde::Deserializer<'de>,
42{
43    // Accept int or float — float values like 5000.0 are truncated.
44    let v = serde_json::Value::deserialize(d)?;
45    match v {
46        serde_json::Value::Number(n) => {
47            if let Some(i) = n.as_i64() {
48                Ok(i)
49            } else if let Some(f) = n.as_f64() {
50                Ok(f as i64)
51            } else {
52                Err(serde::de::Error::custom("expiresAt not numeric"))
53            }
54        }
55        _ => Err(serde::de::Error::custom("expiresAt must be a number")),
56    }
57}
58
59impl OauthCreds {
60    /// Plan label rendered the way claudebar does (claudebar:547-550):
61    ///   "${sub_type^} [5x|20x]" (first letter capitalized, optional tier suffix).
62    pub fn plan_label(&self) -> String {
63        let mut name = capitalize_first(&self.subscription_type);
64        if name.is_empty() {
65            name = "Unknown".into();
66        }
67        if self.rate_limit_tier.contains("5x") {
68            name.push_str(" 5x");
69        } else if self.rate_limit_tier.contains("20x") {
70            name.push_str(" 20x");
71        }
72        name
73    }
74
75    pub fn expires_at_secs(&self) -> i64 {
76        self.expires_at_ms / 1000
77    }
78}
79
80fn capitalize_first(s: &str) -> String {
81    let mut chars = s.chars();
82    match chars.next() {
83        Some(first) => {
84            let mut out = String::with_capacity(s.len());
85            for c in first.to_uppercase() {
86                out.push(c);
87            }
88            out.push_str(chars.as_str());
89            out
90        }
91        None => String::new(),
92    }
93}
94
95/// Default location: `~/.claude/.credentials.json`.
96pub fn default_path() -> Result<PathBuf> {
97    let home = std::env::var_os("HOME").ok_or_else(|| AppError::Other("HOME not set".into()))?;
98    Ok(PathBuf::from(home).join(".claude/.credentials.json"))
99}
100
101pub fn read_from(path: &Path) -> Result<CredentialsFile> {
102    let raw = std::fs::read_to_string(path).map_err(|e| AppError::io_at(path, e))?;
103    serde_json::from_str(&raw).map_err(|e| {
104        AppError::Credentials(format!(
105            "could not parse {}: {e}. Run `claude` to re-authenticate.",
106            path.display()
107        ))
108    })
109}
110
111/// Persist updated credentials, preserving any unknown top-level fields the
112/// Claude CLI might have added. Reads the existing file, merges our updates
113/// into the `claudeAiOauth` object, and atomically writes it back.
114pub fn write_back(path: &Path, new_oauth: &OauthCreds) -> Result<()> {
115    let mut doc: serde_json::Value = std::fs::read_to_string(path)
116        .map_err(|e| AppError::io_at(path, e))
117        .and_then(|s| serde_json::from_str(&s).map_err(AppError::Json))
118        .unwrap_or_else(|_| serde_json::json!({}));
119
120    let obj = match doc.as_object_mut() {
121        Some(o) => o,
122        None => {
123            doc = serde_json::json!({});
124            doc.as_object_mut().expect("just constructed object")
125        }
126    };
127    obj.insert(
128        "claudeAiOauth".into(),
129        serde_json::to_value(new_oauth).map_err(AppError::Json)?,
130    );
131
132    let bytes = serde_json::to_vec_pretty(&doc).map_err(AppError::Json)?;
133    atomic_write(path, &bytes)
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use std::io::Write;
140    use tempfile::NamedTempFile;
141
142    fn write_creds(s: &str) -> NamedTempFile {
143        let mut f = NamedTempFile::new().unwrap();
144        f.write_all(s.as_bytes()).unwrap();
145        f.flush().unwrap();
146        f
147    }
148
149    #[test]
150    fn parses_canonical_shape() {
151        let f = write_creds(
152            r#"{"claudeAiOauth":{
153                "accessToken":"AT",
154                "refreshToken":"RT",
155                "expiresAt": 1735000000000,
156                "subscriptionType":"max",
157                "rateLimitTier":"default_claude_max_5x"
158            }}"#,
159        );
160        let creds = read_from(f.path()).unwrap();
161        assert_eq!(creds.claude_ai_oauth.access_token, "AT");
162        assert_eq!(creds.claude_ai_oauth.expires_at_ms, 1735000000000);
163        assert_eq!(creds.claude_ai_oauth.plan_label(), "Max 5x");
164    }
165
166    #[test]
167    fn accepts_float_expires_at() {
168        // claudebar truncates `5000.0 → 5000`; we do the same.
169        let f = write_creds(
170            r#"{"claudeAiOauth":{
171                "accessToken":"A","refreshToken":"R",
172                "expiresAt": 5000.0,
173                "subscriptionType":"pro","rateLimitTier":""
174            }}"#,
175        );
176        let creds = read_from(f.path()).unwrap();
177        assert_eq!(creds.claude_ai_oauth.expires_at_ms, 5000);
178    }
179
180    #[test]
181    fn plan_label_pro_no_tier() {
182        let f = write_creds(
183            r#"{"claudeAiOauth":{
184                "accessToken":"A","refreshToken":"R","expiresAt": 0,
185                "subscriptionType":"pro","rateLimitTier":""
186            }}"#,
187        );
188        let creds = read_from(f.path()).unwrap();
189        assert_eq!(creds.claude_ai_oauth.plan_label(), "Pro");
190    }
191
192    #[test]
193    fn plan_label_max_20x() {
194        let f = write_creds(
195            r#"{"claudeAiOauth":{
196                "accessToken":"A","refreshToken":"R","expiresAt": 0,
197                "subscriptionType":"max","rateLimitTier":"default_claude_max_20x"
198            }}"#,
199        );
200        let creds = read_from(f.path()).unwrap();
201        assert_eq!(creds.claude_ai_oauth.plan_label(), "Max 20x");
202    }
203
204    #[test]
205    fn plan_label_empty_subscription_falls_back() {
206        let f = write_creds(
207            r#"{"claudeAiOauth":{
208                "accessToken":"A","refreshToken":"R","expiresAt": 0,
209                "subscriptionType":"","rateLimitTier":""
210            }}"#,
211        );
212        let creds = read_from(f.path()).unwrap();
213        assert_eq!(creds.claude_ai_oauth.plan_label(), "Unknown");
214    }
215
216    #[test]
217    fn malformed_file_returns_credentials_error() {
218        let f = write_creds("not json");
219        let err = read_from(f.path()).unwrap_err();
220        assert!(matches!(err, AppError::Credentials(_)));
221    }
222
223    #[test]
224    fn write_back_round_trips_and_preserves_unknown_fields() {
225        let f = write_creds(
226            r#"{"claudeAiOauth":{
227                "accessToken":"OLD","refreshToken":"OLD","expiresAt": 0,
228                "subscriptionType":"pro","rateLimitTier":""
229            },"someOtherField":"keep me"}"#,
230        );
231        let creds = read_from(f.path()).unwrap();
232        let new_oauth = OauthCreds {
233            access_token: "NEW".into(),
234            refresh_token: "NEW_RT".into(),
235            expires_at_ms: 1234,
236            subscription_type: "pro".into(),
237            rate_limit_tier: "".into(),
238            scopes: creds.claude_ai_oauth.scopes.clone(),
239        };
240        write_back(f.path(), &new_oauth).unwrap();
241        // Re-read & verify the unknown field survived.
242        let raw = std::fs::read_to_string(f.path()).unwrap();
243        let v: serde_json::Value = serde_json::from_str(&raw).unwrap();
244        assert_eq!(v["someOtherField"], "keep me");
245        assert_eq!(v["claudeAiOauth"]["accessToken"], "NEW");
246        assert_eq!(v["claudeAiOauth"]["expiresAt"], 1234);
247    }
248}