Skip to main content

rs_auth_core/oauth/
client.rs

1use oauth2::{
2    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointNotSet, EndpointSet,
3    PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
4    basic::BasicClient,
5};
6
7use crate::error::{AuthError, OAuthError};
8use crate::oauth::{OAuthProviderConfig, OAuthTokens};
9
10// Type alias for a fully configured OAuth client
11type ConfiguredClient = BasicClient<
12    EndpointSet,    // HasAuthUrl
13    EndpointNotSet, // HasDeviceAuthUrl
14    EndpointNotSet, // HasIntrospectionUrl
15    EndpointNotSet, // HasRevocationUrl
16    EndpointSet,    // HasTokenUrl
17>;
18
19/// Result of building an OAuth authorization URL.
20pub struct OAuthAuthorization {
21    pub authorize_url: String,
22    pub csrf_state: String,
23    pub pkce_verifier: String,
24}
25
26/// Build an OAuth authorization URL with PKCE.
27pub fn build_authorization(config: &OAuthProviderConfig) -> Result<OAuthAuthorization, AuthError> {
28    let client = build_client(config)?;
29    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
30
31    let mut auth_request = client.authorize_url(CsrfToken::new_random);
32    auth_request = auth_request.set_pkce_challenge(pkce_challenge);
33    for scope in &config.scopes {
34        auth_request = auth_request.add_scope(Scope::new(scope.clone()));
35    }
36    let (auth_url, csrf_token) = auth_request.url();
37
38    Ok(OAuthAuthorization {
39        authorize_url: auth_url.to_string(),
40        csrf_state: csrf_token.secret().clone(),
41        pkce_verifier: pkce_verifier.secret().clone(),
42    })
43}
44
45/// Exchange an authorization code for access and refresh tokens.
46pub async fn exchange_code(
47    config: &OAuthProviderConfig,
48    code: &str,
49    pkce_verifier: &str,
50) -> Result<OAuthTokens, AuthError> {
51    let client = build_client(config)?;
52
53    // Create a reqwest client that doesn't follow redirects (to prevent SSRF)
54    let http_client = reqwest::Client::builder()
55        .redirect(reqwest::redirect::Policy::none())
56        .build()
57        .map_err(|_| AuthError::OAuth(OAuthError::ExchangeFailed))?;
58
59    let token_response = client
60        .exchange_code(AuthorizationCode::new(code.to_string()))
61        .set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_string()))
62        .request_async(&http_client)
63        .await
64        .map_err(|_| AuthError::OAuth(OAuthError::ExchangeFailed))?;
65
66    let access_token = token_response.access_token().secret().clone();
67    let refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
68
69    // Extract expires_in and convert from std::time::Duration to time::Duration
70    let expires_in = token_response
71        .expires_in()
72        .and_then(|d| time::Duration::try_from(d).ok());
73
74    // Extract scopes and join them into a single string
75    let scope = token_response.scopes().map(|scopes| {
76        scopes
77            .iter()
78            .map(|s| s.as_str())
79            .collect::<Vec<_>>()
80            .join(" ")
81    });
82
83    Ok(OAuthTokens {
84        access_token: Some(access_token),
85        refresh_token,
86        expires_in,
87        scope,
88    })
89}
90
91/// Refresh an access token using a stored refresh token.
92pub async fn refresh_access_token(
93    config: &OAuthProviderConfig,
94    refresh_token_str: &str,
95) -> Result<OAuthTokens, AuthError> {
96    let client = build_client(config)?;
97
98    let http_client = reqwest::Client::builder()
99        .redirect(reqwest::redirect::Policy::none())
100        .build()
101        .map_err(|_| AuthError::OAuth(OAuthError::RefreshFailed))?;
102
103    let token_response = client
104        .exchange_refresh_token(&RefreshToken::new(refresh_token_str.to_string()))
105        .request_async(&http_client)
106        .await
107        .map_err(|_| AuthError::OAuth(OAuthError::RefreshFailed))?;
108
109    let access_token = token_response.access_token().secret().clone();
110    let refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
111
112    let expires_in = token_response
113        .expires_in()
114        .and_then(|d| time::Duration::try_from(d).ok());
115
116    let scope = token_response.scopes().map(|scopes| {
117        scopes
118            .iter()
119            .map(|s| s.as_str())
120            .collect::<Vec<_>>()
121            .join(" ")
122    });
123
124    Ok(OAuthTokens {
125        access_token: Some(access_token),
126        refresh_token,
127        expires_in,
128        scope,
129    })
130}
131
132fn build_client(config: &OAuthProviderConfig) -> Result<ConfiguredClient, AuthError> {
133    let client = BasicClient::new(ClientId::new(config.client_id.clone()))
134        .set_client_secret(ClientSecret::new(config.client_secret.clone()))
135        .set_auth_uri(AuthUrl::new(config.auth_url.clone()).map_err(|e| {
136            AuthError::OAuth(OAuthError::Misconfigured {
137                message: format!("invalid auth_url: {e}"),
138            })
139        })?)
140        .set_token_uri(TokenUrl::new(config.token_url.clone()).map_err(|e| {
141            AuthError::OAuth(OAuthError::Misconfigured {
142                message: format!("invalid token_url: {e}"),
143            })
144        })?)
145        .set_redirect_uri(RedirectUrl::new(config.redirect_url.clone()).map_err(|e| {
146            AuthError::OAuth(OAuthError::Misconfigured {
147                message: format!("invalid redirect_url: {e}"),
148            })
149        })?);
150    Ok(client)
151}