Skip to main content

codineer_runtime/
oauth.rs

1use std::collections::BTreeMap;
2use std::fs;
3use std::io;
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7use serde_json::{Map, Value};
8use sha2::{Digest, Sha256};
9
10use crate::config::OAuthConfig;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct OAuthTokenSet {
14    pub access_token: String,
15    pub refresh_token: Option<String>,
16    pub expires_at: Option<u64>,
17    pub scopes: Vec<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct PkceCodePair {
22    pub verifier: String,
23    pub challenge: String,
24    pub challenge_method: PkceChallengeMethod,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum PkceChallengeMethod {
29    S256,
30}
31
32impl PkceChallengeMethod {
33    #[must_use]
34    pub const fn as_str(self) -> &'static str {
35        match self {
36            Self::S256 => "S256",
37        }
38    }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct OAuthAuthorizationRequest {
43    pub authorize_url: String,
44    pub client_id: String,
45    pub redirect_uri: String,
46    pub scopes: Vec<String>,
47    pub state: String,
48    pub code_challenge: String,
49    pub code_challenge_method: PkceChallengeMethod,
50    pub extra_params: BTreeMap<String, String>,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct OAuthTokenExchangeRequest {
55    pub grant_type: &'static str,
56    pub code: String,
57    pub redirect_uri: String,
58    pub client_id: String,
59    pub code_verifier: String,
60    pub state: String,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct OAuthRefreshRequest {
65    pub grant_type: &'static str,
66    pub refresh_token: String,
67    pub client_id: String,
68    pub scopes: Vec<String>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct OAuthCallbackParams {
73    pub code: Option<String>,
74    pub state: Option<String>,
75    pub error: Option<String>,
76    pub error_description: Option<String>,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80#[serde(rename_all = "camelCase")]
81struct StoredOAuthCredentials {
82    access_token: String,
83    #[serde(default)]
84    refresh_token: Option<String>,
85    #[serde(default)]
86    expires_at: Option<u64>,
87    #[serde(default)]
88    scopes: Vec<String>,
89}
90
91impl From<OAuthTokenSet> for StoredOAuthCredentials {
92    fn from(value: OAuthTokenSet) -> Self {
93        Self {
94            access_token: value.access_token,
95            refresh_token: value.refresh_token,
96            expires_at: value.expires_at,
97            scopes: value.scopes,
98        }
99    }
100}
101
102impl From<StoredOAuthCredentials> for OAuthTokenSet {
103    fn from(value: StoredOAuthCredentials) -> Self {
104        Self {
105            access_token: value.access_token,
106            refresh_token: value.refresh_token,
107            expires_at: value.expires_at,
108            scopes: value.scopes,
109        }
110    }
111}
112
113impl OAuthAuthorizationRequest {
114    #[must_use]
115    pub fn from_config(
116        config: &OAuthConfig,
117        redirect_uri: impl Into<String>,
118        state: impl Into<String>,
119        pkce: &PkceCodePair,
120    ) -> Self {
121        Self {
122            authorize_url: config.authorize_url.clone(),
123            client_id: config.client_id.clone(),
124            redirect_uri: redirect_uri.into(),
125            scopes: config.scopes.clone(),
126            state: state.into(),
127            code_challenge: pkce.challenge.clone(),
128            code_challenge_method: pkce.challenge_method,
129            extra_params: BTreeMap::new(),
130        }
131    }
132
133    #[must_use]
134    pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
135        self.extra_params.insert(key.into(), value.into());
136        self
137    }
138
139    #[must_use]
140    pub fn build_url(&self) -> String {
141        let mut params = vec![
142            ("response_type", "code".to_string()),
143            ("client_id", self.client_id.clone()),
144            ("redirect_uri", self.redirect_uri.clone()),
145            ("scope", self.scopes.join(" ")),
146            ("state", self.state.clone()),
147            ("code_challenge", self.code_challenge.clone()),
148            (
149                "code_challenge_method",
150                self.code_challenge_method.as_str().to_string(),
151            ),
152        ];
153        params.extend(
154            self.extra_params
155                .iter()
156                .map(|(key, value)| (key.as_str(), value.clone())),
157        );
158        let query = params
159            .into_iter()
160            .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
161            .collect::<Vec<_>>()
162            .join("&");
163        format!(
164            "{}{}{}",
165            self.authorize_url,
166            if self.authorize_url.contains('?') {
167                '&'
168            } else {
169                '?'
170            },
171            query
172        )
173    }
174}
175
176impl OAuthTokenExchangeRequest {
177    #[must_use]
178    pub fn from_config(
179        config: &OAuthConfig,
180        code: impl Into<String>,
181        state: impl Into<String>,
182        verifier: impl Into<String>,
183        redirect_uri: impl Into<String>,
184    ) -> Self {
185        Self {
186            grant_type: "authorization_code",
187            code: code.into(),
188            redirect_uri: redirect_uri.into(),
189            client_id: config.client_id.clone(),
190            code_verifier: verifier.into(),
191            state: state.into(),
192        }
193    }
194
195    #[must_use]
196    pub fn form_params(&self) -> BTreeMap<&str, String> {
197        BTreeMap::from([
198            ("grant_type", self.grant_type.to_string()),
199            ("code", self.code.clone()),
200            ("redirect_uri", self.redirect_uri.clone()),
201            ("client_id", self.client_id.clone()),
202            ("code_verifier", self.code_verifier.clone()),
203            ("state", self.state.clone()),
204        ])
205    }
206}
207
208impl OAuthRefreshRequest {
209    #[must_use]
210    pub fn from_config(
211        config: &OAuthConfig,
212        refresh_token: impl Into<String>,
213        scopes: Option<Vec<String>>,
214    ) -> Self {
215        Self {
216            grant_type: "refresh_token",
217            refresh_token: refresh_token.into(),
218            client_id: config.client_id.clone(),
219            scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
220        }
221    }
222
223    #[must_use]
224    pub fn form_params(&self) -> BTreeMap<&str, String> {
225        BTreeMap::from([
226            ("grant_type", self.grant_type.to_string()),
227            ("refresh_token", self.refresh_token.clone()),
228            ("client_id", self.client_id.clone()),
229            ("scope", self.scopes.join(" ")),
230        ])
231    }
232}
233
234pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
235    let verifier = generate_random_token(32)?;
236    Ok(PkceCodePair {
237        challenge: code_challenge_s256(&verifier),
238        verifier,
239        challenge_method: PkceChallengeMethod::S256,
240    })
241}
242
243pub fn generate_state() -> io::Result<String> {
244    generate_random_token(32)
245}
246
247#[must_use]
248pub fn code_challenge_s256(verifier: &str) -> String {
249    let digest = Sha256::digest(verifier.as_bytes());
250    base64url_encode(&digest)
251}
252
253#[must_use]
254pub fn loopback_redirect_uri(port: u16) -> String {
255    format!("http://localhost:{port}/callback")
256}
257
258pub fn credentials_path() -> io::Result<PathBuf> {
259    Ok(credentials_home_dir()?.join("credentials.json"))
260}
261
262const KEYRING_SERVICE: &str = "codineer";
263const KEYRING_USER: &str = "oauth";
264
265fn keyring_entry() -> Option<keyring::Entry> {
266    keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).ok()
267}
268
269fn load_from_keyring() -> Option<OAuthTokenSet> {
270    let entry = keyring_entry()?;
271    let json = entry.get_password().ok()?;
272    let stored: StoredOAuthCredentials = serde_json::from_str(&json).ok()?;
273    Some(stored.into())
274}
275
276fn save_to_keyring(token_set: &OAuthTokenSet) -> bool {
277    let Some(entry) = keyring_entry() else {
278        return false;
279    };
280    let stored = StoredOAuthCredentials::from(token_set.clone());
281    let Ok(json) = serde_json::to_string(&stored) else {
282        return false;
283    };
284    entry.set_password(&json).is_ok()
285}
286
287fn clear_from_keyring() {
288    if let Some(entry) = keyring_entry() {
289        let _ = entry.delete_credential();
290    }
291}
292
293pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
294    if let Some(token_set) = load_from_keyring() {
295        return Ok(Some(token_set));
296    }
297
298    let path = match credentials_path() {
299        Ok(path) => path,
300        Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
301        Err(error) => return Err(error),
302    };
303    let root = read_credentials_root(&path)?;
304    let Some(oauth) = root.get("oauth") else {
305        return Ok(None);
306    };
307    if oauth.is_null() {
308        return Ok(None);
309    }
310    let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
311        .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
312    let token_set: OAuthTokenSet = stored.into();
313
314    if save_to_keyring(&token_set) {
315        let mut migrated_root = root;
316        migrated_root.remove("oauth");
317        let _ = write_credentials_root(&path, &migrated_root);
318    }
319
320    Ok(Some(token_set))
321}
322
323pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
324    if save_to_keyring(token_set) {
325        return Ok(());
326    }
327
328    let path = credentials_path()?;
329    let mut root = read_credentials_root(&path)?;
330    root.insert(
331        "oauth".to_string(),
332        serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
333            .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
334    );
335    write_credentials_root(&path, &root)
336}
337
338pub fn clear_oauth_credentials() -> io::Result<()> {
339    clear_from_keyring();
340
341    let path = credentials_path()?;
342    let mut root = read_credentials_root(&path)?;
343    root.remove("oauth");
344    write_credentials_root(&path, &root)
345}
346
347pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
348    let (path, query) = target
349        .split_once('?')
350        .map_or((target, ""), |(path, query)| (path, query));
351    if path != "/callback" {
352        return Err(format!("unexpected callback path: {path}"));
353    }
354    parse_oauth_callback_query(query)
355}
356
357pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
358    let mut params = BTreeMap::new();
359    for pair in query.split('&').filter(|pair| !pair.is_empty()) {
360        let (key, value) = pair
361            .split_once('=')
362            .map_or((pair, ""), |(key, value)| (key, value));
363        params.insert(percent_decode(key)?, percent_decode(value)?);
364    }
365    Ok(OAuthCallbackParams {
366        code: params.get("code").cloned(),
367        state: params.get("state").cloned(),
368        error: params.get("error").cloned(),
369        error_description: params.get("error_description").cloned(),
370    })
371}
372
373fn generate_random_token(bytes: usize) -> io::Result<String> {
374    let mut buffer = vec![0_u8; bytes];
375    getrandom::getrandom(&mut buffer).map_err(|e| io::Error::other(e.to_string()))?;
376    Ok(base64url_encode(&buffer))
377}
378
379fn credentials_home_dir() -> io::Result<PathBuf> {
380    if let Some(path) = std::env::var_os("CODINEER_CONFIG_HOME") {
381        return Ok(PathBuf::from(path));
382    }
383    for key in ["HOME", "USERPROFILE"] {
384        if let Some(home) = std::env::var_os(key) {
385            return Ok(PathBuf::from(home).join(".codineer"));
386        }
387    }
388    Err(io::Error::new(
389        io::ErrorKind::NotFound,
390        "home directory not found (neither HOME nor USERPROFILE is set)",
391    ))
392}
393
394fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
395    match fs::read_to_string(path) {
396        Ok(contents) => {
397            if contents.trim().is_empty() {
398                return Ok(Map::new());
399            }
400            serde_json::from_str::<Value>(&contents)
401                .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
402                .as_object()
403                .cloned()
404                .ok_or_else(|| {
405                    io::Error::new(
406                        io::ErrorKind::InvalidData,
407                        "credentials file must contain a JSON object",
408                    )
409                })
410        }
411        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
412        Err(error) => Err(error),
413    }
414}
415
416fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
417    if let Some(parent) = path.parent() {
418        fs::create_dir_all(parent)?;
419    }
420    let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
421        .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
422    let temp_path = path.with_extension("json.tmp");
423    fs::write(&temp_path, format!("{rendered}\n"))?;
424    set_file_permissions_owner_only(&temp_path);
425    fs::rename(temp_path, path)
426}
427
428#[cfg(unix)]
429fn set_file_permissions_owner_only(path: &std::path::Path) {
430    use std::os::unix::fs::PermissionsExt;
431    let _ = fs::set_permissions(path, fs::Permissions::from_mode(0o600));
432}
433
434#[cfg(not(unix))]
435fn set_file_permissions_owner_only(_path: &std::path::Path) {}
436
437fn base64url_encode(bytes: &[u8]) -> String {
438    const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
439    let mut output = String::new();
440    let mut index = 0;
441    while index + 3 <= bytes.len() {
442        let block = (u32::from(bytes[index]) << 16)
443            | (u32::from(bytes[index + 1]) << 8)
444            | u32::from(bytes[index + 2]);
445        output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
446        output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
447        output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
448        output.push(TABLE[(block & 0x3F) as usize] as char);
449        index += 3;
450    }
451    match bytes.len().saturating_sub(index) {
452        1 => {
453            let block = u32::from(bytes[index]) << 16;
454            output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
455            output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
456        }
457        2 => {
458            let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
459            output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
460            output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
461            output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
462        }
463        _ => {}
464    }
465    output
466}
467
468fn percent_encode(value: &str) -> String {
469    let mut encoded = String::new();
470    for byte in value.bytes() {
471        match byte {
472            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
473                encoded.push(char::from(byte));
474            }
475            _ => {
476                use std::fmt::Write as _;
477                let _ = write!(&mut encoded, "%{byte:02X}");
478            }
479        }
480    }
481    encoded
482}
483
484fn percent_decode(value: &str) -> Result<String, String> {
485    let mut decoded = Vec::with_capacity(value.len());
486    let bytes = value.as_bytes();
487    let mut index = 0;
488    while index < bytes.len() {
489        match bytes[index] {
490            b'%' if index + 2 < bytes.len() => {
491                let hi = decode_hex(bytes[index + 1])?;
492                let lo = decode_hex(bytes[index + 2])?;
493                decoded.push((hi << 4) | lo);
494                index += 3;
495            }
496            b'+' => {
497                decoded.push(b' ');
498                index += 1;
499            }
500            byte => {
501                decoded.push(byte);
502                index += 1;
503            }
504        }
505    }
506    String::from_utf8(decoded).map_err(|error| error.to_string())
507}
508
509fn decode_hex(byte: u8) -> Result<u8, String> {
510    match byte {
511        b'0'..=b'9' => Ok(byte - b'0'),
512        b'a'..=b'f' => Ok(byte - b'a' + 10),
513        b'A'..=b'F' => Ok(byte - b'A' + 10),
514        _ => Err(format!("invalid percent-encoding byte: {byte}")),
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use std::time::{SystemTime, UNIX_EPOCH};
521
522    use super::{
523        clear_from_keyring, clear_oauth_credentials, code_challenge_s256, credentials_path,
524        generate_pkce_pair, generate_state, load_from_keyring, load_oauth_credentials,
525        loopback_redirect_uri, parse_oauth_callback_query, parse_oauth_callback_request_target,
526        save_oauth_credentials, OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest,
527        OAuthTokenExchangeRequest, OAuthTokenSet,
528    };
529
530    fn sample_config() -> OAuthConfig {
531        OAuthConfig {
532            client_id: "runtime-client".to_string(),
533            authorize_url: "https://console.test/oauth/authorize".to_string(),
534            token_url: "https://console.test/oauth/token".to_string(),
535            callback_port: Some(4545),
536            manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
537            scopes: vec!["org:read".to_string(), "user:write".to_string()],
538        }
539    }
540
541    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
542        crate::test_env_lock()
543    }
544
545    fn temp_config_home() -> std::path::PathBuf {
546        std::env::temp_dir().join(format!(
547            "runtime-oauth-test-{}-{}",
548            std::process::id(),
549            SystemTime::now()
550                .duration_since(UNIX_EPOCH)
551                .expect("time")
552                .as_nanos()
553        ))
554    }
555
556    #[test]
557    fn s256_challenge_matches_expected_vector() {
558        assert_eq!(
559            code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
560            "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
561        );
562    }
563
564    #[test]
565    fn generates_pkce_pair_and_state() {
566        let pair = generate_pkce_pair().expect("pkce pair");
567        let state = generate_state().expect("state");
568        assert!(!pair.verifier.is_empty());
569        assert!(!pair.challenge.is_empty());
570        assert!(!state.is_empty());
571    }
572
573    #[test]
574    fn builds_authorize_url_and_form_requests() {
575        let config = sample_config();
576        let pair = generate_pkce_pair().expect("pkce");
577        let url = OAuthAuthorizationRequest::from_config(
578            &config,
579            loopback_redirect_uri(4545),
580            "state-123",
581            &pair,
582        )
583        .with_extra_param("login_hint", "user@example.com")
584        .build_url();
585        assert!(url.starts_with("https://console.test/oauth/authorize?"));
586        assert!(url.contains("response_type=code"));
587        assert!(url.contains("client_id=runtime-client"));
588        assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
589        assert!(url.contains("login_hint=user%40example.com"));
590
591        let exchange = OAuthTokenExchangeRequest::from_config(
592            &config,
593            "auth-code",
594            "state-123",
595            pair.verifier,
596            loopback_redirect_uri(4545),
597        );
598        assert_eq!(
599            exchange.form_params().get("grant_type").map(String::as_str),
600            Some("authorization_code")
601        );
602
603        let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
604        assert_eq!(
605            refresh.form_params().get("scope").map(String::as_str),
606            Some("org:read user:write")
607        );
608    }
609
610    #[test]
611    fn oauth_credentials_round_trip_and_clear() {
612        let _guard = env_lock();
613        let config_home = temp_config_home();
614        std::env::set_var("CODINEER_CONFIG_HOME", &config_home);
615        let path = credentials_path().expect("credentials path");
616        std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
617        std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
618
619        let token_set = OAuthTokenSet {
620            access_token: "access-token".to_string(),
621            refresh_token: Some("refresh-token".to_string()),
622            expires_at: Some(123),
623            scopes: vec!["scope:a".to_string()],
624        };
625        save_oauth_credentials(&token_set).expect("save credentials");
626        assert_eq!(
627            load_oauth_credentials().expect("load credentials"),
628            Some(token_set)
629        );
630
631        let keyring_available = load_from_keyring().is_some();
632        let saved = std::fs::read_to_string(&path).expect("read saved file");
633        assert!(saved.contains("\"other\""));
634        if !keyring_available {
635            assert!(saved.contains("\"oauth\""));
636        }
637
638        clear_oauth_credentials().expect("clear credentials");
639        assert_eq!(load_oauth_credentials().expect("load cleared"), None);
640        let cleared = std::fs::read_to_string(&path).expect("read cleared file");
641        assert!(cleared.contains("\"other\""));
642        assert!(!cleared.contains("\"oauth\""));
643
644        clear_from_keyring();
645        std::env::remove_var("CODINEER_CONFIG_HOME");
646        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
647    }
648
649    #[test]
650    fn parses_callback_query_and_target() {
651        let params =
652            parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
653                .expect("parse query");
654        assert_eq!(params.code.as_deref(), Some("abc123"));
655        assert_eq!(params.state.as_deref(), Some("state-1"));
656        assert_eq!(params.error_description.as_deref(), Some("needs login"));
657
658        let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
659            .expect("parse callback target");
660        assert_eq!(params.code.as_deref(), Some("abc"));
661        assert_eq!(params.state.as_deref(), Some("xyz"));
662        assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
663    }
664}