Skip to main content

rab/
auth.rs

1//! Auth storage — read/write `~/.rab/agent/auth.json`.
2//!
3//! Pi-compatible credential store with file locking and OAuth auto-refresh.
4//!
5//! Format (pi-compatible):
6//! ```json
7//! { "opencode-go": { "type": "api_key", "key": "sk-..." } }
8//! ```
9
10use anyhow::Context;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::io::Read;
14use std::path::PathBuf;
15use std::time::Duration;
16
17/// Credential for a provider (mirrors pi's auth.json schema).
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum AuthCredential {
21    #[serde(rename = "api_key")]
22    ApiKey { key: String },
23    #[serde(rename = "oauth")]
24    Oauth {
25        access: String,
26        refresh: Option<String>,
27        expires: Option<i64>,
28        #[serde(rename = "enterpriseUrl")]
29        enterprise_url: Option<String>,
30    },
31}
32
33/// Auth storage loaded from ~/.rab/auth.json.
34#[derive(Debug, Clone, Default, Deserialize)]
35pub struct AuthStorage(HashMap<String, AuthCredential>);
36
37impl AuthStorage {
38    /// Load auth from `~/.rab/agent/auth.json`. Returns empty if file doesn't exist.
39    pub fn load() -> anyhow::Result<Self> {
40        Self::load_from(Self::path()?)
41    }
42
43    /// Load auth from an explicit path (for testing).
44    pub fn load_from(path: std::path::PathBuf) -> anyhow::Result<Self> {
45        let content = read_json_file(&path)?;
46        match content {
47            Some(c) => serde_json::from_str(&c)
48                .with_context(|| format!("Failed to parse {}", path.display())),
49            None => Ok(Self::default()),
50        }
51    }
52
53    /// Get the path to the auth file.
54    pub fn path() -> anyhow::Result<PathBuf> {
55        let dir = directories::BaseDirs::new().context("Could not determine home directory")?;
56        Ok(dir.home_dir().join(".rab").join("agent").join("auth.json"))
57    }
58
59    /// Get the API key for a provider. Returns None if not configured or if OAuth.
60    pub fn api_key(&self, provider: &str) -> Option<String> {
61        self.0.get(provider).and_then(|cred| match cred {
62            AuthCredential::ApiKey { key } => Some(key.clone()),
63            AuthCredential::Oauth { .. } => None,
64        })
65    }
66
67    /// Get the OAuth access token for a provider.
68    /// Returns None if not configured, if API key, or if the token is expired.
69    pub fn oauth_token(&self, provider: &str) -> Option<String> {
70        self.0.get(provider).and_then(|cred| match cred {
71            AuthCredential::Oauth {
72                access, expires, ..
73            } => {
74                if is_expired(*expires) {
75                    return None;
76                }
77                Some(access.clone())
78            }
79            AuthCredential::ApiKey { .. } => None,
80        })
81    }
82
83    /// Get the stored credential for a provider, if it's an OAuth credential.
84    /// Returns None for API key credentials or missing entries.
85    pub fn oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
86        self.0.get(provider).cloned().and_then(|cred| match cred {
87            AuthCredential::Oauth { .. } => Some(cred),
88            AuthCredential::ApiKey { .. } => None,
89        })
90    }
91
92    /// Get all stored credentials.
93    pub fn all_credentials(&self) -> &HashMap<String, AuthCredential> {
94        &self.0
95    }
96}
97
98// ── File locking helpers ────────────────────────────────────────
99
100/// Acquire an exclusive file lock (blocking with retry) and run the closure.
101/// Uses `fs2::FileExt::try_lock_exclusive` for cross-process safety,
102/// matching pi's `proper-lockfile` pattern.
103fn with_exclusive_lock<T>(path: &PathBuf, f: impl FnOnce() -> T) -> T {
104    use fs2::FileExt;
105
106    // Ensure parent dir exists
107    if let Some(parent) = path.parent() {
108        let _ = std::fs::create_dir_all(parent);
109    }
110
111    // Open or create the auth file itself (no truncate — we're just getting
112    // a fd for locking; actual reads/writes are done separately).
113    let file = std::fs::OpenOptions::new()
114        .create(true)
115        .truncate(false)
116        .write(true)
117        .read(true)
118        .open(path)
119        .expect("Failed to open auth file");
120
121    // Retry loop for lock acquisition (pi-compatible)
122    let mut attempts = 0;
123    loop {
124        match file.try_lock_exclusive() {
125            Ok(()) => break,
126            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
127                attempts += 1;
128                // Stale lock detection: if the lock is held for >10s, it's stale
129                if attempts >= 200 {
130                    break; // Give up and proceed anyway
131                }
132                if attempts > 5
133                    && let Ok(metadata) = path.metadata()
134                    && let Ok(modified) = metadata.modified()
135                    && let Ok(elapsed) = modified.elapsed()
136                    && elapsed > Duration::from_secs(10)
137                {
138                    // Stale lock — break it by unlocking and retrying
139                    let _ = file.unlock();
140                    continue;
141                }
142                std::thread::sleep(Duration::from_millis(50));
143            }
144            Err(e) => panic!("Failed to lock auth file: {}", e),
145        }
146    }
147
148    let result = f();
149    let _ = file.unlock();
150    result
151}
152
153/// Read JSON from a file (no locking — caller should use exclusive lock for writes).
154/// Returns None if the file doesn't exist.
155fn read_json_file(path: &PathBuf) -> anyhow::Result<Option<String>> {
156    if !path.exists() {
157        return Ok(None);
158    }
159    let mut s = String::new();
160    let mut file =
161        std::fs::File::open(path).with_context(|| format!("Failed to open {}", path.display()))?;
162    file.read_to_string(&mut s)
163        .with_context(|| format!("Failed to read {}", path.display()))?;
164    Ok(Some(s))
165}
166
167/// Read-modify-write the auth file under an exclusive lock.
168fn modify_auth_file(
169    path: &PathBuf,
170    f: impl FnOnce(HashMap<String, AuthCredential>) -> (HashMap<String, AuthCredential>, bool),
171) -> anyhow::Result<()> {
172    with_exclusive_lock(path, || {
173        let auth: HashMap<String, AuthCredential> = match read_json_file(path) {
174            Ok(Some(c)) => serde_json::from_str(&c).unwrap_or_default(),
175            _ => HashMap::new(),
176        };
177
178        let (result, changed) = f(auth);
179        if changed {
180            if let Some(parent) = path.parent() {
181                let _ = std::fs::create_dir_all(parent);
182            }
183            if let Ok(content) = serde_json::to_string_pretty(&result) {
184                let _ = std::fs::write(path, &content);
185            }
186        }
187    });
188    Ok(())
189}
190
191// ── Helper ────────────────────────────────────────────────────
192
193fn is_expired(expires: Option<i64>) -> bool {
194    match expires {
195        Some(exp) => {
196            let now = std::time::SystemTime::now()
197                .duration_since(std::time::UNIX_EPOCH)
198                .unwrap_or_default()
199                .as_millis() as i64;
200            now >= exp
201        }
202        None => false, // No expiry = treat as not expired
203    }
204}
205
206// ── Write operations ─────────────────────────────────────────────
207
208/// Login a provider by storing its API key in auth.json.
209pub fn login(provider: &str, api_key: &str) -> anyhow::Result<()> {
210    let path = AuthStorage::path()?;
211    let p = provider.to_string();
212    let k = api_key.to_string();
213    modify_auth_file(&path, |mut auth| {
214        auth.insert(p, AuthCredential::ApiKey { key: k });
215        (auth, true)
216    })
217}
218
219/// Login a provider by storing its OAuth credentials in auth.json.
220pub fn login_oauth(provider: &str, cred: &AuthCredential) -> anyhow::Result<()> {
221    let path = AuthStorage::path()?;
222    let p = provider.to_string();
223    let c = cred.clone();
224    modify_auth_file(&path, |mut auth| {
225        auth.insert(p, c);
226        (auth, true)
227    })
228}
229
230/// Logout a provider by removing its credential from auth.json.
231/// If `provider` is `None`, clears all credentials.
232/// Returns true if something was actually removed.
233pub fn logout(provider: Option<&str>) -> anyhow::Result<bool> {
234    let path = AuthStorage::path()?;
235    if !path.exists() {
236        return Ok(false);
237    }
238
239    let result = with_exclusive_lock(&path, || -> bool {
240        let auth: HashMap<String, AuthCredential> = match read_json_file(&path) {
241            Ok(Some(c)) => serde_json::from_str(&c).unwrap_or_default(),
242            _ => return false,
243        };
244
245        let (new_auth, removed) = match provider {
246            Some(prov) => {
247                let mut a = auth;
248                let removed = a.remove(prov).is_some();
249                (a, removed)
250            }
251            None => {
252                let removed = !auth.is_empty();
253                (HashMap::new(), removed)
254            }
255        };
256
257        if removed {
258            if let Some(parent) = path.parent() {
259                let _ = std::fs::create_dir_all(parent);
260            }
261            if let Ok(content) = serde_json::to_string_pretty(&new_auth) {
262                let _ = std::fs::write(&path, &content);
263            }
264        }
265        removed
266    });
267
268    Ok(result)
269}
270
271/// List all providers that have credentials stored.
272pub fn list_logged_in() -> anyhow::Result<Vec<String>> {
273    let path = AuthStorage::path()?;
274    let content = read_json_file(&path)?;
275    match content {
276        Some(c) => {
277            let auth: HashMap<String, AuthCredential> = serde_json::from_str(&c)
278                .with_context(|| format!("Failed to parse {}", path.display()))?;
279            Ok(auth.keys().cloned().collect())
280        }
281        None => Ok(Vec::new()),
282    }
283}
284
285// ── Enhanced credential read ──────────────────────────────────────
286
287/// Read a credential from auth.json. Returns None if the provider has no stored credential.
288pub fn read_credential(provider: &str) -> anyhow::Result<Option<AuthCredential>> {
289    let path = AuthStorage::path()?;
290    let content = read_json_file(&path)?;
291    match content {
292        Some(c) => {
293            let auth: HashMap<String, AuthCredential> = serde_json::from_str(&c)
294                .with_context(|| format!("Failed to parse {}", path.display()))?;
295            Ok(auth.get(provider).cloned())
296        }
297        None => Ok(None),
298    }
299}
300
301/// Atomically modify a single provider's credential (pi-compatible `CredentialStore.modify()`).
302/// `f` receives the current credential (None if missing), returns the new
303/// credential, or None to delete the entry.
304pub fn modify_credential(
305    provider: &str,
306    f: impl FnOnce(Option<AuthCredential>) -> Option<AuthCredential>,
307) -> anyhow::Result<()> {
308    let path = AuthStorage::path()?;
309    let p = provider.to_string();
310    modify_auth_file(&path, |auth| {
311        let current = auth.get(&p).cloned();
312        let next = f(current);
313        let mut updated = auth;
314        match next {
315            Some(cred) => {
316                updated.insert(p, cred);
317            }
318            None => {
319                updated.remove(&p);
320            }
321        }
322        (updated, true)
323    })
324}
325
326/// Refresh an expired OAuth token using the registered OAuth provider.
327/// Returns the new access token string, or None if refresh fails.
328/// Matching pi's `AuthStorage.refreshOAuthTokenWithLock()` pattern.
329pub async fn refresh_oauth_token(provider: &str) -> Option<String> {
330    let credential = read_credential(provider).ok()??;
331    let oauth_cred = match &credential {
332        AuthCredential::Oauth { .. } => credential,
333        _ => return None,
334    };
335    let expires = match &oauth_cred {
336        AuthCredential::Oauth { expires, .. } => *expires,
337        _ => return None,
338    };
339
340    // If token is still valid for more than 5 minutes, return current access token
341    if !is_expired(Some(expires.unwrap_or(i64::MAX))) {
342        let buffer_ms = 300_000;
343        if let AuthCredential::Oauth { access, .. } = &oauth_cred {
344            let now = std::time::SystemTime::now()
345                .duration_since(std::time::UNIX_EPOCH)
346                .unwrap_or_default()
347                .as_millis() as i64;
348            if now < expires.unwrap_or(i64::MAX) - buffer_ms {
349                return Some(access.clone());
350            }
351        }
352    }
353
354    let oauth_provider = crate::provider::oauth::get(provider)?;
355
356    // Build OAuthCredentials for the refresh call
357    let oauth_creds = match &oauth_cred {
358        AuthCredential::Oauth {
359            access,
360            refresh,
361            expires,
362            enterprise_url,
363            ..
364        } => crate::provider::oauth::OAuthCredentials {
365            access: access.clone(),
366            refresh: refresh.clone().unwrap_or_default(),
367            expires: expires.unwrap_or(0),
368            enterprise_url: enterprise_url.clone(),
369            extra: std::collections::HashMap::new(),
370        },
371        _ => return None,
372    };
373
374    let new_creds = oauth_provider.refresh_token(&oauth_creds).await.ok()?;
375    let new_access = new_creds.access.clone();
376
377    // Store updated credentials under file lock
378    let result = modify_credential(provider, |_| {
379        Some(AuthCredential::Oauth {
380            access: new_creds.access.clone(),
381            refresh: Some(new_creds.refresh),
382            expires: Some(new_creds.expires),
383            enterprise_url: new_creds.enterprise_url,
384        })
385    });
386
387    match result {
388        Ok(_) => Some(new_access),
389        Err(_) => None,
390    }
391}