Skip to main content

via/auth/
oauth.rs

1use std::time::{SystemTime, UNIX_EPOCH};
2
3use reqwest::blocking::Client;
4use reqwest::header::CONTENT_TYPE;
5use ring::digest::{Context, SHA256};
6use serde::Deserialize;
7use serde_json::Value;
8
9use crate::error::ViaError;
10use crate::redaction::Redactor;
11use crate::secrets::SecretValue;
12
13const CACHE_EXPIRY_SKEW_SECONDS: i64 = 60;
14const SERVICE_OAUTH_TYPE: &str = "service_oauth";
15
16pub fn access_token(credential: &SecretValue, redactor: &mut Redactor) -> Result<String, ViaError> {
17    access_token_with_mode(credential, redactor, crate::daemon::OAuthTokenMode::Cached)
18}
19
20pub fn refresh_access_token(
21    credential: &SecretValue,
22    redactor: &mut Redactor,
23) -> Result<String, ViaError> {
24    access_token_with_mode(credential, redactor, crate::daemon::OAuthTokenMode::Refresh)
25}
26
27fn access_token_with_mode(
28    credential: &SecretValue,
29    redactor: &mut Redactor,
30    mode: crate::daemon::OAuthTokenMode,
31) -> Result<String, ViaError> {
32    redactor.add(credential.expose());
33    let bundle = CredentialBundle::parse(credential.expose())?;
34    register_bundle_secrets(&bundle, redactor);
35
36    let token = crate::daemon::oauth_access_token(credential.expose(), mode)?;
37    redactor.add(token.expose());
38    Ok(token.expose().to_owned())
39}
40
41pub fn validate_credential_bundle(raw: &str) -> Result<(), ViaError> {
42    CredentialBundle::parse(raw).map(|_| ())
43}
44
45pub(crate) fn exchange_access_token(
46    client: &Client,
47    bundle: &CredentialBundle,
48    cached: Option<&CachedOAuthToken>,
49    redactor: &mut Redactor,
50) -> Result<OAuthAccessToken, ViaError> {
51    match &bundle.grant {
52        OAuthGrant::RefreshToken { refresh_token } => {
53            let cached_refresh_token = cached.and_then(|cached| cached.refresh_token.as_deref());
54            let refresh_token_for_request = cached_refresh_token.unwrap_or(refresh_token);
55            match exchange_refresh_token(client, bundle, refresh_token_for_request, redactor) {
56                Ok(token) => Ok(token),
57                Err(_error)
58                    if cached_refresh_token.is_some_and(|cached| cached != refresh_token) =>
59                {
60                    crate::timing::event(
61                        "oauth refresh token fallback",
62                        "cached_refresh_token_failed",
63                    );
64                    exchange_refresh_token(client, bundle, refresh_token, redactor)
65                }
66                Err(error) => Err(error),
67            }
68        }
69        OAuthGrant::ClientCredentials { .. } => {
70            exchange_client_credentials(client, bundle, redactor)
71        }
72    }
73}
74
75fn exchange_refresh_token(
76    client: &Client,
77    bundle: &CredentialBundle,
78    refresh_token: &str,
79    redactor: &mut Redactor,
80) -> Result<OAuthAccessToken, ViaError> {
81    redactor.add(refresh_token);
82    let mut form = vec![
83        ("grant_type", "refresh_token"),
84        ("refresh_token", refresh_token),
85        ("client_id", bundle.client_id.as_str()),
86    ];
87    if let Some(client_secret) = bundle.client_secret.as_deref() {
88        form.push(("client_secret", client_secret));
89    }
90
91    exchange_token_form(
92        client,
93        bundle,
94        &form,
95        TokenResponseRefreshMode::PreserveRefreshToken(refresh_token),
96        redactor,
97    )
98}
99
100fn exchange_client_credentials(
101    client: &Client,
102    bundle: &CredentialBundle,
103    redactor: &mut Redactor,
104) -> Result<OAuthAccessToken, ViaError> {
105    let OAuthGrant::ClientCredentials { scope } = &bundle.grant else {
106        unreachable!("caller only passes client_credentials grants");
107    };
108    let client_secret = bundle.client_secret.as_deref().ok_or_else(|| {
109        ViaError::InvalidConfig(
110            "oauth client_credentials credential bundle must include `client_secret`".to_owned(),
111        )
112    })?;
113
114    let form = vec![
115        ("grant_type", "client_credentials"),
116        ("scope", scope.as_str()),
117        ("client_id", bundle.client_id.as_str()),
118        ("client_secret", client_secret),
119    ];
120
121    exchange_token_form(
122        client,
123        bundle,
124        &form,
125        TokenResponseRefreshMode::NoRefreshToken,
126        redactor,
127    )
128}
129
130fn exchange_token_form(
131    client: &Client,
132    bundle: &CredentialBundle,
133    form: &[(&str, &str)],
134    refresh_mode: TokenResponseRefreshMode<'_>,
135    redactor: &mut Redactor,
136) -> Result<OAuthAccessToken, ViaError> {
137    let body = form_encode(form);
138    let exchange_span = crate::timing::span("oauth token exchange");
139    let response = match client
140        .post(&bundle.token_url)
141        .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
142        .body(body)
143        .send()
144    {
145        Ok(response) => {
146            let status = response.status();
147            exchange_span.finish(format!("status={status}"));
148            response
149        }
150        Err(error) => {
151            exchange_span.finish("failed");
152            return Err(error.into());
153        }
154    };
155    let status = response.status();
156    let body_span = crate::timing::span("oauth token body");
157    let body = match response.text() {
158        Ok(body) => {
159            body_span.finish(format!("bytes={}", body.len()));
160            body
161        }
162        Err(error) => {
163            body_span.finish("failed");
164            return Err(error.into());
165        }
166    };
167
168    if !status.is_success() {
169        let body = redactor.redact(&body);
170        return Err(ViaError::InvalidArgument(format!(
171            "OAuth token exchange failed with status {status}: {body}"
172        )));
173    }
174
175    parse_token_response(&body, refresh_mode, redactor)
176}
177
178fn parse_token_response(
179    body: &str,
180    refresh_mode: TokenResponseRefreshMode<'_>,
181    redactor: &mut Redactor,
182) -> Result<OAuthAccessToken, ViaError> {
183    let response: TokenResponse = serde_json::from_str(body)?;
184    if let Some(token_type) = &response.token_type {
185        if !token_type.eq_ignore_ascii_case("bearer") {
186            return Err(ViaError::InvalidArgument(format!(
187                "OAuth token response had unsupported token_type `{token_type}`"
188            )));
189        }
190    }
191
192    let refresh_token = match refresh_mode {
193        TokenResponseRefreshMode::PreserveRefreshToken(refresh_token) => Some(
194            response
195                .refresh_token
196                .unwrap_or_else(|| refresh_token.to_owned()),
197        ),
198        TokenResponseRefreshMode::NoRefreshToken => response.refresh_token,
199    };
200    let expires_at = expires_at(response.expires_in)?;
201
202    redactor.add(&response.access_token);
203    if let Some(refresh_token) = &refresh_token {
204        redactor.add(refresh_token);
205    }
206
207    Ok(OAuthAccessToken {
208        access_token: response.access_token,
209        refresh_token,
210        expires_at,
211    })
212}
213
214fn expires_at(expires_in: u64) -> Result<i64, ViaError> {
215    let now = unix_timestamp()?;
216    let expires_in = i64::try_from(expires_in).map_err(|_| {
217        ViaError::InvalidArgument("OAuth token response expires_in is too large".to_owned())
218    })?;
219    now.checked_add(expires_in).ok_or_else(|| {
220        ViaError::InvalidArgument("OAuth token response expires_at is too large".to_owned())
221    })
222}
223
224pub(crate) fn register_bundle_secrets(bundle: &CredentialBundle, redactor: &mut Redactor) {
225    if let Some(client_secret) = &bundle.client_secret {
226        redactor.add(client_secret);
227    }
228    match &bundle.grant {
229        OAuthGrant::RefreshToken { refresh_token } => redactor.add(refresh_token),
230        OAuthGrant::ClientCredentials { .. } => {}
231    }
232}
233
234pub(crate) fn register_cached_secrets(cached: Option<&CachedOAuthToken>, redactor: &mut Redactor) {
235    if let Some(cached) = cached {
236        redactor.add(&cached.access_token);
237        if let Some(refresh_token) = &cached.refresh_token {
238            redactor.add(refresh_token);
239        }
240    }
241}
242
243#[derive(Debug, PartialEq, Eq)]
244pub(crate) struct CredentialBundle {
245    credential_type: String,
246    pub(crate) token_url: String,
247    pub(crate) client_id: String,
248    pub(crate) client_secret: Option<String>,
249    grant: OAuthGrant,
250}
251
252impl CredentialBundle {
253    pub(crate) fn parse(raw: &str) -> Result<Self, ViaError> {
254        let value: Value = serde_json::from_str(raw).map_err(credential_json_error)?;
255        let credential_type = required_string(&value, "type")?;
256        validate_credential_type(&credential_type)?;
257        let token_url = required_string(&value, "token_url")?;
258        let client_id = required_string(&value, "client_id")?;
259        let client_secret = optional_string(&value, "client_secret")?;
260        let configured_grant_type = optional_string(&value, "grant_type")?;
261        let configured_refresh_token = optional_string(&value, "refresh_token")?;
262        let grant = match configured_grant_type.as_deref() {
263            Some("refresh_token") => OAuthGrant::RefreshToken {
264                refresh_token: configured_refresh_token.ok_or_else(|| {
265                    ViaError::InvalidConfig(
266                        "oauth refresh_token credential bundle must include `refresh_token`"
267                            .to_owned(),
268                    )
269                })?,
270            },
271            Some("client_credentials") => OAuthGrant::ClientCredentials {
272                scope: required_string(&value, "scope")?,
273            },
274            Some(grant_type) => {
275                return Err(ViaError::InvalidConfig(format!(
276                    "unsupported oauth grant_type `{grant_type}`"
277                )));
278            }
279            None => match configured_refresh_token {
280                Some(refresh_token) => OAuthGrant::RefreshToken { refresh_token },
281                None => {
282                    return Err(ViaError::InvalidConfig(
283                        "oauth credential bundle must include `grant_type`".to_owned(),
284                    ));
285                }
286            },
287        };
288
289        Ok(Self {
290            credential_type,
291            token_url,
292            client_id,
293            client_secret,
294            grant,
295        })
296    }
297}
298
299fn validate_credential_type(value: &str) -> Result<(), ViaError> {
300    if value == SERVICE_OAUTH_TYPE {
301        return Ok(());
302    }
303
304    Err(ViaError::InvalidConfig(format!(
305        "unsupported oauth credential type `{value}`; expected `{SERVICE_OAUTH_TYPE}`"
306    )))
307}
308
309#[derive(Debug, PartialEq, Eq)]
310enum OAuthGrant {
311    RefreshToken { refresh_token: String },
312    ClientCredentials { scope: String },
313}
314
315#[derive(Clone, Copy)]
316enum TokenResponseRefreshMode<'a> {
317    PreserveRefreshToken(&'a str),
318    NoRefreshToken,
319}
320
321#[derive(Debug, Deserialize)]
322struct TokenResponse {
323    access_token: String,
324    #[serde(default)]
325    token_type: Option<String>,
326    expires_in: u64,
327    #[serde(default)]
328    refresh_token: Option<String>,
329}
330
331#[derive(Debug)]
332pub(crate) struct OAuthAccessToken {
333    pub(crate) access_token: String,
334    pub(crate) refresh_token: Option<String>,
335    pub(crate) expires_at: i64,
336}
337
338#[derive(Clone, Debug, Deserialize)]
339pub(crate) struct CachedOAuthToken {
340    pub(crate) access_token: String,
341    pub(crate) expires_at: i64,
342    #[serde(default)]
343    pub(crate) refresh_token: Option<String>,
344}
345
346pub(crate) fn cache_key(bundle: &CredentialBundle) -> String {
347    let mut context = Context::new(&SHA256);
348    context.update(bundle.credential_type.as_bytes());
349    context.update(b"\0");
350    context.update(bundle.token_url.as_bytes());
351    context.update(b"\0");
352    context.update(bundle.client_id.as_bytes());
353    context.update(b"\0");
354    match &bundle.grant {
355        OAuthGrant::RefreshToken { refresh_token } => {
356            context.update(b"refresh_token\0");
357            context.update(refresh_token.as_bytes());
358        }
359        OAuthGrant::ClientCredentials { scope } => {
360            context.update(b"client_credentials\0");
361            context.update(scope.as_bytes());
362        }
363    }
364    hex_encode(context.finish().as_ref())
365}
366
367pub(crate) fn cached_access_token(cached: Option<&CachedOAuthToken>, now: i64) -> Option<String> {
368    let cached = cached?;
369    if cached.expires_at <= now + CACHE_EXPIRY_SKEW_SECONDS {
370        return None;
371    }
372    Some(cached.access_token.clone())
373}
374
375pub(crate) fn unix_timestamp() -> Result<i64, ViaError> {
376    let duration = SystemTime::now()
377        .duration_since(UNIX_EPOCH)
378        .map_err(|_| ViaError::InvalidConfig("system clock is before UNIX epoch".to_owned()))?;
379    i64::try_from(duration.as_secs())
380        .map_err(|_| ViaError::InvalidConfig("system clock timestamp is too large".to_owned()))
381}
382
383fn hex_encode(bytes: &[u8]) -> String {
384    const HEX: &[u8; 16] = b"0123456789abcdef";
385    let mut encoded = String::with_capacity(bytes.len() * 2);
386    for byte in bytes {
387        encoded.push(HEX[(byte >> 4) as usize] as char);
388        encoded.push(HEX[(byte & 0x0f) as usize] as char);
389    }
390    encoded
391}
392
393fn form_encode(fields: &[(&str, &str)]) -> String {
394    fields
395        .iter()
396        .map(|(name, value)| {
397            format!(
398                "{}={}",
399                form_percent_encode(name),
400                form_percent_encode(value)
401            )
402        })
403        .collect::<Vec<_>>()
404        .join("&")
405}
406
407fn form_percent_encode(value: &str) -> String {
408    let mut encoded = String::new();
409    for byte in value.bytes() {
410        match byte {
411            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
412                encoded.push(byte as char)
413            }
414            b' ' => encoded.push('+'),
415            _ => encoded.push_str(&format!("%{byte:02X}")),
416        }
417    }
418    encoded
419}
420
421fn credential_json_error(error: serde_json::Error) -> ViaError {
422    ViaError::InvalidConfig(format!(
423        "oauth credential bundle must be valid JSON: {error}"
424    ))
425}
426
427fn required_string(value: &Value, field: &str) -> Result<String, ViaError> {
428    value
429        .get(field)
430        .and_then(Value::as_str)
431        .filter(|value| !value.trim().is_empty())
432        .map(str::to_owned)
433        .ok_or_else(|| {
434            ViaError::InvalidConfig(format!(
435                "oauth credential bundle must include non-empty `{field}`"
436            ))
437        })
438}
439
440fn optional_string(value: &Value, field: &str) -> Result<Option<String>, ViaError> {
441    match value.get(field) {
442        Some(Value::String(value)) if !value.trim().is_empty() => Ok(Some(value.to_owned())),
443        Some(Value::String(_)) | None => Ok(None),
444        Some(_) => Err(ViaError::InvalidConfig(format!(
445            "oauth credential bundle `{field}` must be a string"
446        ))),
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use std::io::{Read, Write};
454    use std::net::TcpListener;
455    use std::thread;
456
457    const LINEAR_TOKEN_URL: &str = "https://api.linear.app/oauth/token";
458
459    #[test]
460    fn parses_service_refresh_token_bundle() {
461        let bundle = CredentialBundle::parse(
462            &serde_json::json!({
463                "type": "service_oauth",
464                "token_url": LINEAR_TOKEN_URL,
465                "grant_type": "refresh_token",
466                "client_id": "client-id",
467                "client_secret": "client-secret",
468                "refresh_token": "refresh-token",
469            })
470            .to_string(),
471        )
472        .unwrap();
473
474        assert_eq!(bundle.credential_type, SERVICE_OAUTH_TYPE);
475        assert_eq!(bundle.token_url, LINEAR_TOKEN_URL);
476        assert_eq!(bundle.client_id, "client-id");
477        assert_eq!(bundle.client_secret.as_deref(), Some("client-secret"));
478        assert_eq!(
479            bundle.grant,
480            OAuthGrant::RefreshToken {
481                refresh_token: "refresh-token".to_owned()
482            }
483        );
484    }
485
486    #[test]
487    fn parses_service_client_credentials_bundle() {
488        let bundle = CredentialBundle::parse(
489            &serde_json::json!({
490                "type": "service_oauth",
491                "token_url": LINEAR_TOKEN_URL,
492                "grant_type": "client_credentials",
493                "client_id": "client-id",
494                "client_secret": "client-secret",
495                "scope": "read,issues:create",
496            })
497            .to_string(),
498        )
499        .unwrap();
500
501        assert_eq!(
502            bundle.grant,
503            OAuthGrant::ClientCredentials {
504                scope: "read,issues:create".to_owned()
505            }
506        );
507    }
508
509    #[test]
510    fn rejects_unsupported_oauth_credential_type() {
511        let error = CredentialBundle::parse(
512            &serde_json::json!({
513                "type": "example_oauth",
514                "token_url": LINEAR_TOKEN_URL,
515                "grant_type": "refresh_token",
516                "client_id": "client-id",
517                "refresh_token": "refresh-token",
518            })
519            .to_string(),
520        )
521        .unwrap_err();
522
523        assert!(matches!(
524            error,
525            ViaError::InvalidConfig(message) if message.contains("unsupported oauth credential type")
526        ));
527    }
528
529    #[test]
530    fn validates_credential_bundle() {
531        validate_credential_bundle(
532            &serde_json::json!({
533                "type": "service_oauth",
534                "token_url": LINEAR_TOKEN_URL,
535                "client_id": "client-id",
536                "refresh_token": "refresh-token",
537            })
538            .to_string(),
539        )
540        .unwrap();
541    }
542
543    #[test]
544    fn returns_unexpired_cached_oauth_token() {
545        let cached = CachedOAuthToken {
546            access_token: "cached-access-token".to_owned(),
547            expires_at: unix_timestamp().unwrap() + 3_600,
548            refresh_token: Some("cached-refresh-token".to_owned()),
549        };
550
551        let token = cached_access_token(Some(&cached), unix_timestamp().unwrap()).unwrap();
552
553        assert_eq!(token, "cached-access-token");
554    }
555
556    #[test]
557    fn refreshes_and_returns_rotated_refresh_token() {
558        let response_body = serde_json::json!({
559            "access_token": "fresh-access-token",
560            "token_type": "Bearer",
561            "expires_in": 3600,
562            "refresh_token": "rotated-refresh-token",
563            "scope": "read write",
564        })
565        .to_string();
566        let (token_url, server) = token_server(response_body);
567        let bundle = test_refresh_bundle(&token_url);
568
569        let client = Client::new();
570        let mut redactor = Redactor::new();
571        let token = exchange_access_token(&client, &bundle, None, &mut redactor).unwrap();
572        let request = server.join().unwrap();
573
574        assert_eq!(token.access_token, "fresh-access-token");
575        assert!(request.starts_with("POST /oauth/token "));
576        assert!(request.contains("content-type: application/x-www-form-urlencoded"));
577        assert!(request.contains("grant_type=refresh_token"));
578        assert!(request.contains("refresh_token=configured-refresh-token"));
579        assert_eq!(
580            token.refresh_token.as_deref(),
581            Some("rotated-refresh-token")
582        );
583        assert_eq!(
584            redactor.redact("fresh-access-token rotated-refresh-token configured-refresh-token"),
585            "[REDACTED] [REDACTED] [REDACTED]"
586        );
587    }
588
589    #[test]
590    fn refreshes_and_preserves_current_refresh_token_when_response_omits_rotation() {
591        let response_body = serde_json::json!({
592            "access_token": "fresh-access-token",
593            "token_type": "Bearer",
594            "expires_in": 3600,
595        })
596        .to_string();
597        let (token_url, server) = token_server(response_body);
598        let bundle = test_refresh_bundle(&token_url);
599
600        let client = Client::new();
601        let mut redactor = Redactor::new();
602        let token = exchange_access_token(&client, &bundle, None, &mut redactor).unwrap();
603        let request = server.join().unwrap();
604
605        assert_eq!(token.access_token, "fresh-access-token");
606        assert!(request.contains("grant_type=refresh_token"));
607        assert_eq!(
608            token.refresh_token.as_deref(),
609            Some("configured-refresh-token")
610        );
611    }
612
613    #[test]
614    fn exchanges_client_credentials_and_returns_access_token() {
615        let response_body = serde_json::json!({
616            "access_token": "client-access-token",
617            "token_type": "Bearer",
618            "expires_in": 3600,
619            "scope": "read issues:create",
620        })
621        .to_string();
622        let (token_url, server) = token_server(response_body);
623        let bundle = CredentialBundle {
624            credential_type: SERVICE_OAUTH_TYPE.to_owned(),
625            token_url,
626            client_id: "client-id".to_owned(),
627            client_secret: Some("client-secret".to_owned()),
628            grant: OAuthGrant::ClientCredentials {
629                scope: "read,issues:create".to_owned(),
630            },
631        };
632
633        let client = Client::new();
634        let mut redactor = Redactor::new();
635        let token = exchange_access_token(&client, &bundle, None, &mut redactor).unwrap();
636        let request = server.join().unwrap();
637
638        assert_eq!(token.access_token, "client-access-token");
639        assert!(request.contains("grant_type=client_credentials"));
640        assert!(request.contains("scope=read%2Cissues%3Acreate"));
641        assert!(request.contains("client_secret=client-secret"));
642    }
643
644    #[test]
645    fn rejects_non_bearer_token_response() {
646        let mut redactor = Redactor::new();
647        let error = parse_token_response(
648            &serde_json::json!({
649                "access_token": "access-token",
650                "token_type": "mac",
651                "expires_in": 3600,
652                "refresh_token": "refresh-token",
653            })
654            .to_string(),
655            TokenResponseRefreshMode::PreserveRefreshToken("refresh-token"),
656            &mut redactor,
657        )
658        .unwrap_err();
659
660        assert!(
661            matches!(error, ViaError::InvalidArgument(message) if message.contains("token_type"))
662        );
663    }
664
665    fn test_refresh_bundle(token_url: &str) -> CredentialBundle {
666        CredentialBundle {
667            credential_type: SERVICE_OAUTH_TYPE.to_owned(),
668            token_url: token_url.to_owned(),
669            client_id: "client-id".to_owned(),
670            client_secret: Some("client-secret".to_owned()),
671            grant: OAuthGrant::RefreshToken {
672                refresh_token: "configured-refresh-token".to_owned(),
673            },
674        }
675    }
676
677    fn token_server(response_body: String) -> (String, thread::JoinHandle<String>) {
678        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
679        let address = listener.local_addr().unwrap();
680        let handle = thread::spawn(move || {
681            let (mut stream, _) = listener.accept().unwrap();
682            let mut buffer = [0_u8; 8192];
683            let read = stream.read(&mut buffer).unwrap();
684            let request = String::from_utf8_lossy(&buffer[..read]).to_string();
685            let response = format!(
686                "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
687                response_body.len(),
688                response_body
689            );
690            stream.write_all(response.as_bytes()).unwrap();
691            request
692        });
693
694        (format!("http://{address}/oauth/token"), handle)
695    }
696}