1use anyhow::Context;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::io::Read;
14use std::path::PathBuf;
15use std::time::Duration;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum AuthCredential {
21 #[serde(rename = "api_key")]
22 ApiKey { key: String },
23 #[serde(rename = "oauth")]
24 Oauth {
25 access: String,
26 refresh: Option<String>,
27 expires: Option<i64>,
28 #[serde(rename = "enterpriseUrl")]
29 enterprise_url: Option<String>,
30 },
31}
32
33#[derive(Debug, Clone, Default, Deserialize)]
35pub struct AuthStorage(HashMap<String, AuthCredential>);
36
37impl AuthStorage {
38 pub fn load() -> anyhow::Result<Self> {
40 Self::load_from(Self::path()?)
41 }
42
43 pub fn load_from(path: std::path::PathBuf) -> anyhow::Result<Self> {
45 let content = read_json_file(&path)?;
46 match content {
47 Some(c) => serde_json::from_str(&c)
48 .with_context(|| format!("Failed to parse {}", path.display())),
49 None => Ok(Self::default()),
50 }
51 }
52
53 pub fn path() -> anyhow::Result<PathBuf> {
55 let dir = directories::BaseDirs::new().context("Could not determine home directory")?;
56 Ok(dir.home_dir().join(".rab").join("agent").join("auth.json"))
57 }
58
59 pub fn api_key(&self, provider: &str) -> Option<String> {
61 self.0.get(provider).and_then(|cred| match cred {
62 AuthCredential::ApiKey { key } => Some(key.clone()),
63 AuthCredential::Oauth { .. } => None,
64 })
65 }
66
67 pub fn oauth_token(&self, provider: &str) -> Option<String> {
70 self.0.get(provider).and_then(|cred| match cred {
71 AuthCredential::Oauth {
72 access, expires, ..
73 } => {
74 if is_expired(*expires) {
75 return None;
76 }
77 Some(access.clone())
78 }
79 AuthCredential::ApiKey { .. } => None,
80 })
81 }
82
83 pub fn oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
86 self.0.get(provider).cloned().and_then(|cred| match cred {
87 AuthCredential::Oauth { .. } => Some(cred),
88 AuthCredential::ApiKey { .. } => None,
89 })
90 }
91
92 pub fn all_credentials(&self) -> &HashMap<String, AuthCredential> {
94 &self.0
95 }
96}
97
98fn with_exclusive_lock<T>(path: &PathBuf, f: impl FnOnce() -> T) -> T {
104 use fs2::FileExt;
105
106 if let Some(parent) = path.parent() {
108 let _ = std::fs::create_dir_all(parent);
109 }
110
111 let file = std::fs::OpenOptions::new()
114 .create(true)
115 .truncate(false)
116 .write(true)
117 .read(true)
118 .open(path)
119 .expect("Failed to open auth file");
120
121 let mut attempts = 0;
123 loop {
124 match file.try_lock_exclusive() {
125 Ok(()) => break,
126 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
127 attempts += 1;
128 if attempts >= 200 {
130 break; }
132 if attempts > 5
133 && let Ok(metadata) = path.metadata()
134 && let Ok(modified) = metadata.modified()
135 && let Ok(elapsed) = modified.elapsed()
136 && elapsed > Duration::from_secs(10)
137 {
138 let _ = file.unlock();
140 continue;
141 }
142 std::thread::sleep(Duration::from_millis(50));
143 }
144 Err(e) => panic!("Failed to lock auth file: {}", e),
145 }
146 }
147
148 let result = f();
149 let _ = file.unlock();
150 result
151}
152
153fn read_json_file(path: &PathBuf) -> anyhow::Result<Option<String>> {
156 if !path.exists() {
157 return Ok(None);
158 }
159 let mut s = String::new();
160 let mut file =
161 std::fs::File::open(path).with_context(|| format!("Failed to open {}", path.display()))?;
162 file.read_to_string(&mut s)
163 .with_context(|| format!("Failed to read {}", path.display()))?;
164 Ok(Some(s))
165}
166
167fn modify_auth_file(
169 path: &PathBuf,
170 f: impl FnOnce(HashMap<String, AuthCredential>) -> (HashMap<String, AuthCredential>, bool),
171) -> anyhow::Result<()> {
172 with_exclusive_lock(path, || {
173 let auth: HashMap<String, AuthCredential> = match read_json_file(path) {
174 Ok(Some(c)) => serde_json::from_str(&c).unwrap_or_default(),
175 _ => HashMap::new(),
176 };
177
178 let (result, changed) = f(auth);
179 if changed {
180 if let Some(parent) = path.parent() {
181 let _ = std::fs::create_dir_all(parent);
182 }
183 if let Ok(content) = serde_json::to_string_pretty(&result) {
184 let _ = std::fs::write(path, &content);
185 }
186 }
187 });
188 Ok(())
189}
190
191fn is_expired(expires: Option<i64>) -> bool {
194 match expires {
195 Some(exp) => {
196 let now = std::time::SystemTime::now()
197 .duration_since(std::time::UNIX_EPOCH)
198 .unwrap_or_default()
199 .as_millis() as i64;
200 now >= exp
201 }
202 None => false, }
204}
205
206pub fn login(provider: &str, api_key: &str) -> anyhow::Result<()> {
210 let path = AuthStorage::path()?;
211 let p = provider.to_string();
212 let k = api_key.to_string();
213 modify_auth_file(&path, |mut auth| {
214 auth.insert(p, AuthCredential::ApiKey { key: k });
215 (auth, true)
216 })
217}
218
219pub fn login_oauth(provider: &str, cred: &AuthCredential) -> anyhow::Result<()> {
221 let path = AuthStorage::path()?;
222 let p = provider.to_string();
223 let c = cred.clone();
224 modify_auth_file(&path, |mut auth| {
225 auth.insert(p, c);
226 (auth, true)
227 })
228}
229
230pub fn logout(provider: Option<&str>) -> anyhow::Result<bool> {
234 let path = AuthStorage::path()?;
235 if !path.exists() {
236 return Ok(false);
237 }
238
239 let result = with_exclusive_lock(&path, || -> bool {
240 let auth: HashMap<String, AuthCredential> = match read_json_file(&path) {
241 Ok(Some(c)) => serde_json::from_str(&c).unwrap_or_default(),
242 _ => return false,
243 };
244
245 let (new_auth, removed) = match provider {
246 Some(prov) => {
247 let mut a = auth;
248 let removed = a.remove(prov).is_some();
249 (a, removed)
250 }
251 None => {
252 let removed = !auth.is_empty();
253 (HashMap::new(), removed)
254 }
255 };
256
257 if removed {
258 if let Some(parent) = path.parent() {
259 let _ = std::fs::create_dir_all(parent);
260 }
261 if let Ok(content) = serde_json::to_string_pretty(&new_auth) {
262 let _ = std::fs::write(&path, &content);
263 }
264 }
265 removed
266 });
267
268 Ok(result)
269}
270
271pub fn list_logged_in() -> anyhow::Result<Vec<String>> {
273 let path = AuthStorage::path()?;
274 let content = read_json_file(&path)?;
275 match content {
276 Some(c) => {
277 let auth: HashMap<String, AuthCredential> = serde_json::from_str(&c)
278 .with_context(|| format!("Failed to parse {}", path.display()))?;
279 Ok(auth.keys().cloned().collect())
280 }
281 None => Ok(Vec::new()),
282 }
283}
284
285pub fn read_credential(provider: &str) -> anyhow::Result<Option<AuthCredential>> {
289 let path = AuthStorage::path()?;
290 let content = read_json_file(&path)?;
291 match content {
292 Some(c) => {
293 let auth: HashMap<String, AuthCredential> = serde_json::from_str(&c)
294 .with_context(|| format!("Failed to parse {}", path.display()))?;
295 Ok(auth.get(provider).cloned())
296 }
297 None => Ok(None),
298 }
299}
300
301pub fn modify_credential(
305 provider: &str,
306 f: impl FnOnce(Option<AuthCredential>) -> Option<AuthCredential>,
307) -> anyhow::Result<()> {
308 let path = AuthStorage::path()?;
309 let p = provider.to_string();
310 modify_auth_file(&path, |auth| {
311 let current = auth.get(&p).cloned();
312 let next = f(current);
313 let mut updated = auth;
314 match next {
315 Some(cred) => {
316 updated.insert(p, cred);
317 }
318 None => {
319 updated.remove(&p);
320 }
321 }
322 (updated, true)
323 })
324}
325
326pub async fn refresh_oauth_token(provider: &str) -> Option<String> {
330 let credential = read_credential(provider).ok()??;
331 let oauth_cred = match &credential {
332 AuthCredential::Oauth { .. } => credential,
333 _ => return None,
334 };
335 let expires = match &oauth_cred {
336 AuthCredential::Oauth { expires, .. } => *expires,
337 _ => return None,
338 };
339
340 if !is_expired(Some(expires.unwrap_or(i64::MAX))) {
342 let buffer_ms = 300_000;
343 if let AuthCredential::Oauth { access, .. } = &oauth_cred {
344 let now = std::time::SystemTime::now()
345 .duration_since(std::time::UNIX_EPOCH)
346 .unwrap_or_default()
347 .as_millis() as i64;
348 if now < expires.unwrap_or(i64::MAX) - buffer_ms {
349 return Some(access.clone());
350 }
351 }
352 }
353
354 let oauth_provider = crate::provider::oauth::get(provider)?;
355
356 let oauth_creds = match &oauth_cred {
358 AuthCredential::Oauth {
359 access,
360 refresh,
361 expires,
362 enterprise_url,
363 ..
364 } => crate::provider::oauth::OAuthCredentials {
365 access: access.clone(),
366 refresh: refresh.clone().unwrap_or_default(),
367 expires: expires.unwrap_or(0),
368 enterprise_url: enterprise_url.clone(),
369 extra: std::collections::HashMap::new(),
370 },
371 _ => return None,
372 };
373
374 let new_creds = oauth_provider.refresh_token(&oauth_creds).await.ok()?;
375 let new_access = new_creds.access.clone();
376
377 let result = modify_credential(provider, |_| {
379 Some(AuthCredential::Oauth {
380 access: new_creds.access.clone(),
381 refresh: Some(new_creds.refresh),
382 expires: Some(new_creds.expires),
383 enterprise_url: new_creds.enterprise_url,
384 })
385 });
386
387 match result {
388 Ok(_) => Some(new_access),
389 Err(_) => None,
390 }
391}