1use std::path::{Path, PathBuf};
9
10use serde::{Deserialize, Serialize};
11
12use crate::cache::atomic_write;
13use crate::error::{AppError, Result};
14
15#[cfg(target_os = "macos")]
16use super::keychain;
17
18#[derive(Debug, Clone, Deserialize, Serialize)]
20pub struct CredentialsFile {
21 #[serde(rename = "claudeAiOauth")]
22 pub claude_ai_oauth: OauthCreds,
23}
24
25#[derive(Debug, Clone, Deserialize, Serialize)]
26pub struct OauthCreds {
27 #[serde(rename = "accessToken")]
28 pub access_token: String,
29 #[serde(rename = "refreshToken")]
30 pub refresh_token: String,
31 #[serde(rename = "expiresAt", deserialize_with = "de_ms_epoch")]
35 pub expires_at_ms: i64,
36 #[serde(rename = "subscriptionType", default)]
37 pub subscription_type: String,
38 #[serde(rename = "rateLimitTier", default)]
39 pub rate_limit_tier: String,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub scopes: Option<serde_json::Value>,
44}
45
46fn de_ms_epoch<'de, D>(d: D) -> std::result::Result<i64, D::Error>
47where
48 D: serde::Deserializer<'de>,
49{
50 let v = serde_json::Value::deserialize(d)?;
52 match v {
53 serde_json::Value::Number(n) => {
54 if let Some(i) = n.as_i64() {
55 Ok(i)
56 } else if let Some(f) = n.as_f64() {
57 Ok(f as i64)
58 } else {
59 Err(serde::de::Error::custom("expiresAt not numeric"))
60 }
61 }
62 _ => Err(serde::de::Error::custom("expiresAt must be a number")),
63 }
64}
65
66impl OauthCreds {
67 pub fn plan_label(&self) -> String {
70 let mut name = capitalize_first(&self.subscription_type);
71 if name.is_empty() {
72 name = "Unknown".into();
73 }
74 if self.rate_limit_tier.contains("5x") {
75 name.push_str(" 5x");
76 } else if self.rate_limit_tier.contains("20x") {
77 name.push_str(" 20x");
78 }
79 name
80 }
81
82 pub fn expires_at_secs(&self) -> i64 {
83 self.expires_at_ms / 1000
84 }
85}
86
87fn capitalize_first(s: &str) -> String {
88 let mut chars = s.chars();
89 match chars.next() {
90 Some(first) => {
91 let mut out = String::with_capacity(s.len());
92 for c in first.to_uppercase() {
93 out.push(c);
94 }
95 out.push_str(chars.as_str());
96 out
97 }
98 None => String::new(),
99 }
100}
101
102pub fn default_path() -> Result<PathBuf> {
108 Ok(crate::cache::home_dir()?
109 .join(".claude")
110 .join(".credentials.json"))
111}
112
113pub fn read_from(path: &Path) -> Result<CredentialsFile> {
114 match std::fs::read_to_string(path) {
115 Ok(raw) => parse(&raw, &path.display().to_string()),
116 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
117 #[cfg(target_os = "macos")]
119 if let Some(raw) = keychain::read_raw()? {
120 return parse(&raw, "macOS Keychain (Claude Code-credentials)");
121 }
122 Err(AppError::io_at(path, e))
123 }
124 Err(e) => Err(AppError::io_at(path, e)),
125 }
126}
127
128fn parse(raw: &str, source: &str) -> Result<CredentialsFile> {
130 serde_json::from_str(raw).map_err(|e| {
131 AppError::Credentials(format!(
132 "could not parse {source}: {e}. Run `claude` to re-authenticate."
133 ))
134 })
135}
136
137fn merge_oauth(existing: Option<&str>, new_oauth: &OauthCreds) -> Result<serde_json::Value> {
142 let mut doc: serde_json::Value = existing
143 .and_then(|s| serde_json::from_str(s).ok())
144 .unwrap_or_else(|| serde_json::json!({}));
145 if !doc.is_object() {
146 doc = serde_json::json!({});
147 }
148 doc.as_object_mut().expect("just ensured object").insert(
149 "claudeAiOauth".into(),
150 serde_json::to_value(new_oauth).map_err(AppError::Json)?,
151 );
152 Ok(doc)
153}
154
155pub fn write_back(path: &Path, new_oauth: &OauthCreds) -> Result<()> {
161 #[cfg(target_os = "macos")]
162 if !path.exists() {
163 if let Some(existing) = keychain::read_raw()? {
164 let doc = merge_oauth(Some(&existing), new_oauth)?;
165 let json = serde_json::to_string(&doc).map_err(AppError::Json)?;
166 return keychain::write_raw(&json);
167 }
168 }
169
170 let existing = std::fs::read_to_string(path).ok();
171 let doc = merge_oauth(existing.as_deref(), new_oauth)?;
172 let bytes = serde_json::to_vec_pretty(&doc).map_err(AppError::Json)?;
173 atomic_write(path, &bytes)
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use std::io::Write;
180 use tempfile::{NamedTempFile, TempDir};
181
182 fn write_creds(s: &str) -> NamedTempFile {
183 let mut f = NamedTempFile::new().unwrap();
184 f.write_all(s.as_bytes()).unwrap();
185 f.flush().unwrap();
186 f
187 }
188
189 fn write_creds_closed(s: &str) -> (TempDir, std::path::PathBuf) {
195 let dir = TempDir::new().unwrap();
196 let path = dir.path().join("credentials.json");
197 std::fs::write(&path, s).unwrap();
198 (dir, path)
199 }
200
201 #[test]
202 fn parses_canonical_shape() {
203 let f = write_creds(
204 r#"{"claudeAiOauth":{
205 "accessToken":"AT",
206 "refreshToken":"RT",
207 "expiresAt": 1735000000000,
208 "subscriptionType":"max",
209 "rateLimitTier":"default_claude_max_5x"
210 }}"#,
211 );
212 let creds = read_from(f.path()).unwrap();
213 assert_eq!(creds.claude_ai_oauth.access_token, "AT");
214 assert_eq!(creds.claude_ai_oauth.expires_at_ms, 1735000000000);
215 assert_eq!(creds.claude_ai_oauth.plan_label(), "Max 5x");
216 }
217
218 #[test]
219 fn accepts_float_expires_at() {
220 let f = write_creds(
222 r#"{"claudeAiOauth":{
223 "accessToken":"A","refreshToken":"R",
224 "expiresAt": 5000.0,
225 "subscriptionType":"pro","rateLimitTier":""
226 }}"#,
227 );
228 let creds = read_from(f.path()).unwrap();
229 assert_eq!(creds.claude_ai_oauth.expires_at_ms, 5000);
230 }
231
232 #[test]
233 fn plan_label_pro_no_tier() {
234 let f = write_creds(
235 r#"{"claudeAiOauth":{
236 "accessToken":"A","refreshToken":"R","expiresAt": 0,
237 "subscriptionType":"pro","rateLimitTier":""
238 }}"#,
239 );
240 let creds = read_from(f.path()).unwrap();
241 assert_eq!(creds.claude_ai_oauth.plan_label(), "Pro");
242 }
243
244 #[test]
245 fn plan_label_max_20x() {
246 let f = write_creds(
247 r#"{"claudeAiOauth":{
248 "accessToken":"A","refreshToken":"R","expiresAt": 0,
249 "subscriptionType":"max","rateLimitTier":"default_claude_max_20x"
250 }}"#,
251 );
252 let creds = read_from(f.path()).unwrap();
253 assert_eq!(creds.claude_ai_oauth.plan_label(), "Max 20x");
254 }
255
256 #[test]
257 fn plan_label_empty_subscription_falls_back() {
258 let f = write_creds(
259 r#"{"claudeAiOauth":{
260 "accessToken":"A","refreshToken":"R","expiresAt": 0,
261 "subscriptionType":"","rateLimitTier":""
262 }}"#,
263 );
264 let creds = read_from(f.path()).unwrap();
265 assert_eq!(creds.claude_ai_oauth.plan_label(), "Unknown");
266 }
267
268 #[test]
269 fn malformed_file_returns_credentials_error() {
270 let f = write_creds("not json");
271 let err = read_from(f.path()).unwrap_err();
272 assert!(matches!(err, AppError::Credentials(_)));
273 }
274
275 #[cfg(not(target_os = "macos"))]
279 #[test]
280 fn read_from_missing_file_is_io_error() {
281 let path = std::path::Path::new("/nonexistent/ai-usagebar/.credentials.json");
282 let err = read_from(path).unwrap_err();
283 assert!(matches!(err, AppError::Io { .. }));
284 }
285
286 #[test]
287 fn default_path_ends_with_claude_credentials() {
288 let p = default_path().unwrap();
289 assert!(p.ends_with(std::path::Path::new(".claude").join(".credentials.json")));
292 }
293
294 #[cfg(windows)]
297 #[test]
298 fn default_path_uses_userprofile_on_windows() {
299 let p = default_path().unwrap();
300 let userprofile = std::env::var("USERPROFILE").expect("USERPROFILE set on Windows");
301 let norm = |s: &str| s.to_lowercase().replace('/', "\\");
306 let p_norm = norm(&p.to_string_lossy());
307 let up_norm = norm(&userprofile);
308 assert!(
309 p_norm.starts_with(up_norm.as_str()),
310 "{} should live under {}",
311 p.display(),
312 userprofile
313 );
314 }
315
316 #[test]
317 fn merge_oauth_preserves_unknown_top_level_fields() {
318 let existing = r#"{"claudeAiOauth":{"accessToken":"OLD"},"mcpOAuth":{"x":1}}"#;
319 let new_oauth = OauthCreds {
320 access_token: "NEW".into(),
321 refresh_token: "RT".into(),
322 expires_at_ms: 99,
323 subscription_type: "max".into(),
324 rate_limit_tier: "".into(),
325 scopes: None,
326 };
327 let doc = merge_oauth(Some(existing), &new_oauth).unwrap();
328 assert_eq!(doc["mcpOAuth"]["x"], 1);
329 assert_eq!(doc["claudeAiOauth"]["accessToken"], "NEW");
330 assert_eq!(doc["claudeAiOauth"]["expiresAt"], 99);
331 }
332
333 #[test]
334 fn merge_oauth_handles_empty_and_non_object_input() {
335 let new_oauth = OauthCreds {
336 access_token: "A".into(),
337 refresh_token: "R".into(),
338 expires_at_ms: 0,
339 subscription_type: "pro".into(),
340 rate_limit_tier: "".into(),
341 scopes: None,
342 };
343 let doc = merge_oauth(None, &new_oauth).unwrap();
345 assert_eq!(doc["claudeAiOauth"]["accessToken"], "A");
346 let doc = merge_oauth(Some("not json"), &new_oauth).unwrap();
348 assert_eq!(doc["claudeAiOauth"]["accessToken"], "A");
349 let doc = merge_oauth(Some("[1,2,3]"), &new_oauth).unwrap();
350 assert_eq!(doc["claudeAiOauth"]["accessToken"], "A");
351 }
352
353 #[test]
354 fn write_back_round_trips_and_preserves_unknown_fields() {
355 let (_dir, path) = write_creds_closed(
356 r#"{"claudeAiOauth":{
357 "accessToken":"OLD","refreshToken":"OLD","expiresAt": 0,
358 "subscriptionType":"pro","rateLimitTier":""
359 },"someOtherField":"keep me"}"#,
360 );
361 let creds = read_from(&path).unwrap();
362 let new_oauth = OauthCreds {
363 access_token: "NEW".into(),
364 refresh_token: "NEW_RT".into(),
365 expires_at_ms: 1234,
366 subscription_type: "pro".into(),
367 rate_limit_tier: "".into(),
368 scopes: creds.claude_ai_oauth.scopes.clone(),
369 };
370 write_back(&path, &new_oauth).unwrap();
371 let raw = std::fs::read_to_string(&path).unwrap();
373 let v: serde_json::Value = serde_json::from_str(&raw).unwrap();
374 assert_eq!(v["someOtherField"], "keep me");
375 assert_eq!(v["claudeAiOauth"]["accessToken"], "NEW");
376 assert_eq!(v["claudeAiOauth"]["expiresAt"], 1234);
377 }
378}