Skip to main content

gvm_auth/
oauth2.rs

1// SPDX-FileCopyrightText: 2026 Greenbone AG
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5use crate::clock::{Clock, SystemClock};
6use oauth2::basic::BasicClient;
7use oauth2::reqwest;
8use oauth2::{
9    AuthUrl, ClientId, ClientSecret, EndpointNotSet, EndpointSet, Scope, TokenResponse, TokenUrl,
10};
11use thiserror::Error;
12use url;
13
14const DEFAULT_REFRESH_SKEW_SECONDS: u64 = 30;
15
16#[derive(Debug, Error)]
17pub enum OAuth2TokenProviderError {
18    #[error("invalid config: {0}")]
19    InvalidConfig(&'static str),
20
21    #[error("invalid token_url: {0}")]
22    InvalidTokenUrl(#[from] url::ParseError),
23
24    #[error("failed to build http client: {0}")]
25    HttpClientBuild(String),
26
27    #[error("token request failed: {0}")]
28    TokenRequest(String),
29
30    #[error("token response missing expires_in")]
31    MissingExpiresIn,
32}
33
34#[derive(Debug, Clone)]
35pub struct ClientCredentialsConfig {
36    pub token_url: String,
37    pub client_id: String,
38    pub client_secret: String,
39    pub scopes: Vec<String>,
40    pub refresh_skew_seconds: Option<u64>,
41}
42
43#[derive(Debug, Clone)]
44struct CachedToken {
45    access_token: String,
46    expired_at: u64,
47}
48
49type ConfiguredBasicClient =
50    BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
51
52#[derive(Debug)]
53pub struct OAuth2TokenProvider<C: Clock = SystemClock> {
54    config: ClientCredentialsConfig,
55    client: ConfiguredBasicClient,
56    cache: tokio::sync::RwLock<Option<CachedToken>>,
57    clock: C,
58}
59
60pub type Result<T> = std::result::Result<T, OAuth2TokenProviderError>;
61
62impl OAuth2TokenProvider<SystemClock> {
63    pub fn new(config: ClientCredentialsConfig) -> Result<Self> {
64        Self::with_clock(config, SystemClock)
65    }
66}
67
68impl<C: Clock> OAuth2TokenProvider<C> {
69    pub fn with_clock(config: ClientCredentialsConfig, clock: C) -> Result<Self> {
70        Self::validate_config(&config)?;
71
72        let token_url = TokenUrl::new(config.token_url.clone())?;
73        let auth_url = AuthUrl::new("https://invalid.local/authorize".to_string())
74            .expect("hardcoded url must be valid");
75
76        let client = BasicClient::new(ClientId::new(config.client_id.clone()))
77            .set_client_secret(ClientSecret::new(config.client_secret.clone()))
78            .set_auth_uri(auth_url)
79            .set_token_uri(token_url);
80
81        Ok(Self {
82            config,
83            client,
84            cache: tokio::sync::RwLock::new(None),
85            clock,
86        })
87    }
88
89    fn validate_config(config: &ClientCredentialsConfig) -> Result<()> {
90        if config.token_url.trim().is_empty() {
91            return Err(OAuth2TokenProviderError::InvalidConfig(
92                "token_url must not be empty",
93            ));
94        }
95        if config.client_id.trim().is_empty() {
96            return Err(OAuth2TokenProviderError::InvalidConfig(
97                "client_id must not be empty",
98            ));
99        }
100        if config.client_secret.trim().is_empty() {
101            return Err(OAuth2TokenProviderError::InvalidConfig(
102                "client_secret must not be empty",
103            ));
104        }
105        Ok(())
106    }
107
108    pub fn refresh_skew(&self) -> u64 {
109        match self.config.refresh_skew_seconds {
110            Some(0) => 0,
111            Some(seconds) => seconds,
112            None => DEFAULT_REFRESH_SKEW_SECONDS,
113        }
114    }
115
116    pub fn get_token(&self) -> Result<String> {
117        let guard = self.cache.blocking_read();
118        if let Some(token) = guard.as_ref() {
119            let now = self.clock.now();
120            let skew = self.refresh_skew();
121            if skew == 0 || token.expired_at > now.saturating_add(skew) {
122                return Ok(token.access_token.clone());
123            }
124        }
125        drop(guard); // read lock dropped
126
127        let http_client = reqwest::blocking::ClientBuilder::new()
128            .redirect(reqwest::redirect::Policy::none())
129            .build()
130            .map_err(|e| OAuth2TokenProviderError::HttpClientBuild(e.to_string()))?;
131
132        let mut req = self.client.exchange_client_credentials();
133        for s in &self.config.scopes {
134            let scope = s.trim();
135            if !scope.is_empty() {
136                req = req.add_scope(Scope::new(scope.to_string()));
137            }
138        }
139
140        let token = req
141            .request(&http_client)
142            .map_err(|e| OAuth2TokenProviderError::TokenRequest(e.to_string()))?;
143
144        let access_token = token.access_token().secret().to_string();
145        let expires_in = token
146            .expires_in()
147            .ok_or(OAuth2TokenProviderError::MissingExpiresIn)?
148            .as_secs();
149
150        let expired_at = self.clock.now().saturating_add(expires_in);
151
152        let mut guard = self.cache.blocking_write();
153        *guard = Some(CachedToken {
154            access_token: access_token.clone(),
155            expired_at,
156        });
157        drop(guard); // write lock dropped
158
159        Ok(access_token)
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::clock::ManualClock;
167    use httpmock::prelude::*;
168    use std::sync::Arc;
169
170    fn cfg(token_url: String) -> ClientCredentialsConfig {
171        ClientCredentialsConfig {
172            token_url,
173            client_id: "client-id".to_string(),
174            client_secret: "client-secret".to_string(),
175            scopes: vec!["scope-a".into(), "scope-b".into()],
176            refresh_skew_seconds: Some(30),
177        }
178    }
179
180    #[test]
181    fn new_rejects_empty_fields() {
182        let base = ClientCredentialsConfig {
183            token_url: "http://localhost/token".into(),
184            client_id: "id".into(),
185            client_secret: "secret".into(),
186            scopes: vec![],
187            refresh_skew_seconds: None,
188        };
189
190        let mut c = base.clone();
191        c.token_url = "   ".into();
192        assert!(matches!(
193            OAuth2TokenProvider::new(c).unwrap_err(),
194            OAuth2TokenProviderError::InvalidConfig(_)
195        ));
196
197        let mut c = base.clone();
198        c.client_id = "".into();
199        assert!(matches!(
200            OAuth2TokenProvider::new(c).unwrap_err(),
201            OAuth2TokenProviderError::InvalidConfig(_)
202        ));
203
204        let mut c = base.clone();
205        c.client_secret = " ".into();
206        assert!(matches!(
207            OAuth2TokenProvider::new(c).unwrap_err(),
208            OAuth2TokenProviderError::InvalidConfig(_)
209        ));
210    }
211
212    #[test]
213    fn refresh_skew_none_uses_default() {
214        let config = ClientCredentialsConfig {
215            token_url: "http://localhost/token".into(),
216            client_id: "id".into(),
217            client_secret: "secret".into(),
218            scopes: vec![],
219            refresh_skew_seconds: None,
220        };
221
222        let provider = OAuth2TokenProvider::new(config).unwrap();
223        assert_eq!(provider.refresh_skew(), DEFAULT_REFRESH_SKEW_SECONDS);
224    }
225
226    #[test]
227    fn get_token_fetches_and_caches_token() {
228        let server = MockServer::start();
229
230        let token_mock = server.mock(|when, then| {
231            when.method(POST)
232                .path("/token")
233                .header("content-type", "application/x-www-form-urlencoded")
234                .body_not("grant_type=client_credentials")
235                .body_not("scope=scope-a+scope-b");
236            then.status(200)
237                .header("content-type", "application/json")
238                .body(r#"{"access_token":"t1","token_type":"bearer","expires_in":3600}"#);
239        });
240
241        let clock = Arc::new(ManualClock::new(1000));
242        let provider =
243            OAuth2TokenProvider::with_clock(cfg(format!("{}/token", server.base_url())), clock)
244                .unwrap();
245
246        let t1 = provider.get_token().unwrap();
247        let t2 = provider.get_token().unwrap();
248
249        assert_eq!(t1, "t1");
250        assert_eq!(t2, "t1");
251        token_mock.assert_calls(1);
252    }
253
254    #[test]
255    fn get_token_refreshes_when_skew_window_reached() {
256        let server = MockServer::start();
257
258        let mock_server = server.mock(|when, then| {
259            when.method(POST).path("/token");
260            then.status(200)
261                .header("content-type", "application/json")
262                .body(r#"{"access_token":"t1","token_type":"bearer","expires_in":10}"#);
263        });
264
265        let clock = Arc::new(ManualClock::new(1000));
266
267        let config = ClientCredentialsConfig {
268            token_url: format!("{}/token", server.base_url()),
269            client_id: "client-id".into(),
270            client_secret: "client-secret".into(),
271            scopes: vec![],
272            refresh_skew_seconds: Some(30),
273        };
274
275        let provider = OAuth2TokenProvider::with_clock(config, clock.clone()).unwrap();
276
277        let token = provider.get_token().unwrap();
278        assert_eq!(token, "t1");
279
280        // Move time forward a bit
281        clock.advance(300);
282
283        provider.get_token().unwrap();
284
285        mock_server.assert_calls(2);
286    }
287
288    #[test]
289    fn get_token_does_not_refresh_when_skew_is_zero() {
290        let server = MockServer::start();
291
292        let mock_server = server.mock(|when, then| {
293            when.method(POST).path("/token");
294            then.status(200)
295                .header("content-type", "application/json")
296                .body(r#"{"access_token":"t1","token_type":"bearer","expires_in":1}"#);
297        });
298
299        let clock = Arc::new(ManualClock::new(1000));
300
301        let config = ClientCredentialsConfig {
302            token_url: format!("{}/token", server.base_url()),
303            client_id: "client-id".into(),
304            client_secret: "client-secret".into(),
305            scopes: vec![],
306            refresh_skew_seconds: Some(0),
307        };
308
309        let provider = OAuth2TokenProvider::with_clock(config, clock.clone()).unwrap();
310
311        let token = provider.get_token().unwrap();
312        // Move time forward a bit
313        clock.advance(1000);
314        provider.get_token().unwrap();
315
316        assert_eq!(token, "t1");
317        mock_server.assert_calls(1);
318    }
319
320    #[test]
321    fn missing_expires_in_returns_error() {
322        let server = MockServer::start();
323
324        server.mock(|when, then| {
325            when.method(POST).path("/token");
326            then.status(200)
327                .header("content-type", "application/json")
328                .body(r#"{"access_token":"t1","token_type":"bearer"}"#);
329        });
330
331        let clock = Arc::new(ManualClock::new(1000));
332        let provider =
333            OAuth2TokenProvider::with_clock(cfg(format!("{}/token", server.base_url())), clock)
334                .unwrap();
335
336        let err = provider.get_token().unwrap_err();
337        assert!(matches!(err, OAuth2TokenProviderError::MissingExpiresIn));
338    }
339
340    #[test]
341    fn token_request_error_is_mapped() {
342        let clock = Arc::new(ManualClock::new(1000));
343        let provider = OAuth2TokenProvider::with_clock(
344            ClientCredentialsConfig {
345                token_url: "http://127.0.0.1:9/token".into(),
346                client_id: "client-id".into(),
347                client_secret: "client-secret".into(),
348                scopes: vec![],
349                refresh_skew_seconds: None,
350            },
351            clock,
352        )
353        .unwrap();
354
355        let err = provider.get_token().unwrap_err();
356        assert!(matches!(err, OAuth2TokenProviderError::TokenRequest(_)));
357    }
358}