edc_connector_client/auth/
oauth.rs

1use std::{
2    sync::Arc,
3    time::{Duration, Instant},
4};
5
6use bon::Builder;
7use oauth2::{
8    basic::{BasicClient, BasicTokenType},
9    AccessToken, AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, RefreshToken, Scope,
10    StandardTokenResponse, TokenResponse, TokenUrl,
11};
12use reqwest::Client;
13use tokio::sync::Mutex;
14
15use crate::{EdcResult, Error};
16
17#[derive(Clone)]
18pub struct OAuth2(Arc<OAuth2Internal>);
19
20type OAuthErrorResponse = oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>;
21pub type OAuthTokenResponse = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
22type OAuthTokenIntrospection =
23    oauth2::StandardTokenIntrospectionResponse<EmptyExtraTokenFields, BasicTokenType>;
24type OAuthRevocableToken = oauth2::StandardRevocableToken;
25type OAuthRevocationError = oauth2::StandardErrorResponse<oauth2::RevocationErrorResponseType>;
26type OAuthClient = oauth2::Client<
27    OAuthErrorResponse,
28    OAuthTokenResponse,
29    OAuthTokenIntrospection,
30    OAuthRevocableToken,
31    OAuthRevocationError,
32    oauth2::EndpointSet,
33    oauth2::EndpointNotSet,
34    oauth2::EndpointNotSet,
35    oauth2::EndpointNotSet,
36    oauth2::EndpointSet,
37>;
38
39pub struct OAuth2Internal {
40    oauth_client: OAuthClient,
41    session: Mutex<Option<OAuthTokenSession>>,
42    http_client: Client,
43    scopes: Vec<String>,
44}
45
46pub struct OAuthTokenSession {
47    access_token: AccessToken,
48    refresh_token: Option<RefreshToken>,
49    expires_at: std::time::Instant,
50}
51
52impl OAuthTokenSession {
53    pub fn new(
54        access_token: AccessToken,
55        refresh_token: Option<RefreshToken>,
56        expires_at: std::time::Instant,
57    ) -> Self {
58        Self {
59            access_token,
60            refresh_token,
61            expires_at,
62        }
63    }
64
65    pub fn access_token(&self) -> &AccessToken {
66        &self.access_token
67    }
68
69    pub fn refresh_token(&self) -> Option<&RefreshToken> {
70        self.refresh_token.as_ref()
71    }
72
73    fn is_expired(&self) -> bool {
74        Instant::now() >= self.expires_at - (Duration::from_secs(30))
75    }
76}
77
78#[derive(Builder)]
79pub struct OAuth2Config {
80    #[builder(into)]
81    client_id: String,
82    #[builder(into)]
83    client_secret: String,
84    #[builder(into)]
85    token_url: String,
86    #[builder(default = vec!["management-api:read".to_string(), "management-api:write".to_string()])]
87    scopes: Vec<String>,
88}
89
90impl OAuth2 {
91    pub fn init(cfg: OAuth2Config) -> EdcResult<OAuth2> {
92        let client = BasicClient::new(ClientId::new(cfg.client_id))
93            .set_client_secret(ClientSecret::new(cfg.client_secret))
94            .set_auth_uri(
95                AuthUrl::new("http://authorize".to_string())
96                    .map_err(|e| Error::Auth(Box::new(e)))?,
97            )
98            .set_token_uri(TokenUrl::new(cfg.token_url).map_err(|e| Error::Auth(Box::new(e)))?);
99
100        Ok(OAuth2(Arc::new(OAuth2Internal {
101            oauth_client: client,
102            session: Mutex::default(),
103            http_client: Client::new(),
104            scopes: cfg.scopes,
105        })))
106    }
107
108    pub async fn token(&self) -> EdcResult<String> {
109        self.0.token().await
110    }
111}
112
113impl OAuth2Internal {
114    pub async fn token(&self) -> EdcResult<String> {
115        let mut session = self.session.lock().await;
116
117        match session.as_ref() {
118            Some(t) if !t.is_expired() => Ok(t.access_token().secret().to_string()),
119            Some(t) => {
120                let new_session = self.refresh_session(t).await?;
121                let access_token = new_session.access_token().secret().to_string();
122                *session = Some(new_session);
123                Ok(access_token)
124            }
125            _ => {
126                let new_session = self.new_session().await?;
127                let access_token = new_session.access_token().secret().to_string();
128                *session = Some(new_session);
129                Ok(access_token)
130            }
131        }
132    }
133
134    async fn new_session(&self) -> EdcResult<OAuthTokenSession> {
135        let scopes = self
136            .scopes
137            .iter()
138            .cloned()
139            .map(Scope::new)
140            .collect::<Vec<_>>();
141        let token_result = self
142            .oauth_client
143            .exchange_client_credentials()
144            .add_scopes(scopes)
145            .request_async(&self.http_client)
146            .await
147            .map_err(|e| Error::Auth(Box::new(e)))?;
148
149        let expires_at = Instant::now()
150            + token_result
151                .expires_in()
152                .unwrap_or(Duration::from_secs(3600));
153
154        Ok(OAuthTokenSession::new(
155            token_result.access_token().clone(),
156            token_result.refresh_token().cloned(),
157            expires_at,
158        ))
159    }
160
161    async fn refresh_session(&self, session: &OAuthTokenSession) -> EdcResult<OAuthTokenSession> {
162        if let Some(refresh) = session.refresh_token() {
163            let token_result = self
164                .oauth_client
165                .exchange_refresh_token(refresh)
166                .request_async(&self.http_client)
167                .await
168                .map_err(|e| Error::Auth(Box::new(e)))?;
169
170            let expires_at = Instant::now()
171                + token_result
172                    .expires_in()
173                    .unwrap_or(Duration::from_secs(3600));
174
175            let refresh_token = token_result
176                .refresh_token()
177                .cloned()
178                .or_else(|| Some(refresh.clone()));
179
180            Ok(OAuthTokenSession::new(
181                token_result.access_token().clone(),
182                refresh_token,
183                expires_at,
184            ))
185        } else {
186            self.new_session().await
187        }
188    }
189}