clawspec_core/client/oauth2/
config.rs1use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use url::Url;
8
9use super::error::OAuth2Error;
10use super::token::{OAuth2Token, TokenCache};
11use crate::client::SecureString;
12
13const DEFAULT_REFRESH_THRESHOLD: Duration = Duration::from_secs(60);
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum OAuth2GrantType {
19 ClientCredentials,
21 PreAcquired,
23}
24
25#[derive(Clone)]
29pub struct OAuth2Config {
30 pub(crate) client_id: String,
32 pub(crate) client_secret: Option<SecureString>,
34 pub(crate) token_url: Url,
36 pub(crate) auth_url: Option<Url>,
38 pub(crate) scopes: Vec<String>,
40 pub(crate) grant_type: OAuth2GrantType,
42 pub(crate) auto_refresh: bool,
44 pub(crate) refresh_threshold: Duration,
46 pub(crate) token_cache: TokenCache,
48}
49
50impl OAuth2Config {
51 pub fn client_credentials(
53 client_id: impl Into<String>,
54 client_secret: impl Into<SecureString>,
55 token_url: impl AsRef<str>,
56 ) -> Result<OAuth2ConfigBuilder, OAuth2Error> {
57 Ok(OAuth2ConfigBuilder::new(client_id, token_url)?
58 .with_client_secret(client_secret)
59 .with_grant_type(OAuth2GrantType::ClientCredentials))
60 }
61
62 pub fn pre_acquired(
64 client_id: impl Into<String>,
65 token_url: impl AsRef<str>,
66 access_token: impl Into<String>,
67 ) -> Result<OAuth2ConfigBuilder, OAuth2Error> {
68 let token = OAuth2Token::new(access_token);
69 Ok(OAuth2ConfigBuilder::new(client_id, token_url)?
70 .with_pre_acquired_token(token)
71 .with_grant_type(OAuth2GrantType::PreAcquired))
72 }
73
74 pub async fn needs_token(&self) -> bool {
76 self.token_cache
77 .should_refresh(self.refresh_threshold)
78 .await
79 }
80
81 pub async fn get_token(&self) -> Option<OAuth2Token> {
83 self.token_cache.get().await
84 }
85
86 pub async fn set_token(&self, token: OAuth2Token) {
88 self.token_cache.set(token).await;
89 }
90}
91
92impl fmt::Debug for OAuth2Config {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 f.debug_struct("OAuth2Config")
95 .field("client_id", &self.client_id)
96 .field(
97 "client_secret",
98 &self.client_secret.as_ref().map(|_| "[REDACTED]"),
99 )
100 .field("token_url", &self.token_url)
101 .field("auth_url", &self.auth_url)
102 .field("scopes", &self.scopes)
103 .field("grant_type", &self.grant_type)
104 .field("auto_refresh", &self.auto_refresh)
105 .field("refresh_threshold", &self.refresh_threshold)
106 .finish()
107 }
108}
109
110#[derive(Clone)]
112pub struct OAuth2ConfigBuilder {
113 client_id: String,
114 client_secret: Option<SecureString>,
115 token_url: Url,
116 auth_url: Option<Url>,
117 scopes: Vec<String>,
118 grant_type: OAuth2GrantType,
119 auto_refresh: bool,
120 refresh_threshold: Duration,
121 pre_acquired_token: Option<OAuth2Token>,
122}
123
124impl OAuth2ConfigBuilder {
125 pub fn new(
127 client_id: impl Into<String>,
128 token_url: impl AsRef<str>,
129 ) -> Result<Self, OAuth2Error> {
130 let token_url =
131 Url::parse(token_url.as_ref()).map_err(|e| OAuth2Error::InvalidTokenEndpoint {
132 url: token_url.as_ref().to_string(),
133 reason: e.to_string(),
134 })?;
135
136 Ok(Self {
137 client_id: client_id.into(),
138 client_secret: None,
139 token_url,
140 auth_url: None,
141 scopes: Vec::new(),
142 grant_type: OAuth2GrantType::ClientCredentials,
143 auto_refresh: true,
144 refresh_threshold: DEFAULT_REFRESH_THRESHOLD,
145 pre_acquired_token: None,
146 })
147 }
148
149 #[must_use]
151 pub fn with_client_secret(mut self, secret: impl Into<SecureString>) -> Self {
152 self.client_secret = Some(secret.into());
153 self
154 }
155
156 pub fn with_auth_url(mut self, auth_url: impl AsRef<str>) -> Result<Self, OAuth2Error> {
158 let url = Url::parse(auth_url.as_ref()).map_err(|e| OAuth2Error::ConfigurationError {
159 reason: format!("Invalid authorization URL: {e}"),
160 })?;
161 self.auth_url = Some(url);
162 Ok(self)
163 }
164
165 #[must_use]
167 pub fn add_scope(mut self, scope: impl Into<String>) -> Self {
168 self.scopes.push(scope.into());
169 self
170 }
171
172 #[must_use]
174 pub fn add_scopes(mut self, scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
175 self.scopes.extend(scopes.into_iter().map(Into::into));
176 self
177 }
178
179 #[must_use]
181 fn with_grant_type(mut self, grant_type: OAuth2GrantType) -> Self {
182 self.grant_type = grant_type;
183 self
184 }
185
186 #[must_use]
188 pub fn with_auto_refresh(mut self, auto_refresh: bool) -> Self {
189 self.auto_refresh = auto_refresh;
190 self
191 }
192
193 #[must_use]
195 pub fn with_refresh_threshold(mut self, threshold: Duration) -> Self {
196 self.refresh_threshold = threshold;
197 self
198 }
199
200 #[must_use]
202 fn with_pre_acquired_token(mut self, token: OAuth2Token) -> Self {
203 self.pre_acquired_token = Some(token);
204 self
205 }
206
207 pub fn build(self) -> Result<OAuth2Config, OAuth2Error> {
209 if self.grant_type == OAuth2GrantType::ClientCredentials && self.client_secret.is_none() {
211 return Err(OAuth2Error::ConfigurationError {
212 reason: "Client credentials flow requires a client secret".to_string(),
213 });
214 }
215
216 let token_cache = if let Some(token) = self.pre_acquired_token {
217 TokenCache::with_token(token)
218 } else {
219 TokenCache::new()
220 };
221
222 Ok(OAuth2Config {
223 client_id: self.client_id,
224 client_secret: self.client_secret,
225 token_url: self.token_url,
226 auth_url: self.auth_url,
227 scopes: self.scopes,
228 grant_type: self.grant_type,
229 auto_refresh: self.auto_refresh,
230 refresh_threshold: self.refresh_threshold,
231 token_cache,
232 })
233 }
234}
235
236impl fmt::Debug for OAuth2ConfigBuilder {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 f.debug_struct("OAuth2ConfigBuilder")
239 .field("client_id", &self.client_id)
240 .field(
241 "client_secret",
242 &self.client_secret.as_ref().map(|_| "[REDACTED]"),
243 )
244 .field("token_url", &self.token_url)
245 .field("scopes", &self.scopes)
246 .field("grant_type", &self.grant_type)
247 .finish()
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct SharedOAuth2Config(pub(crate) Arc<OAuth2Config>);
254
255impl SharedOAuth2Config {
256 pub fn new(config: OAuth2Config) -> Self {
258 Self(Arc::new(config))
259 }
260
261 pub fn inner(&self) -> &OAuth2Config {
263 &self.0
264 }
265}
266
267impl From<OAuth2Config> for SharedOAuth2Config {
268 fn from(config: OAuth2Config) -> Self {
269 Self::new(config)
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn should_create_client_credentials_config() {
279 let config = OAuth2Config::client_credentials(
280 "client-id",
281 "client-secret",
282 "https://auth.example.com/token",
283 )
284 .expect("Should create builder")
285 .build()
286 .expect("Should build config");
287
288 assert_eq!(config.client_id, "client-id");
289 assert!(config.client_secret.is_some());
290 assert_eq!(config.token_url.as_str(), "https://auth.example.com/token");
291 assert_eq!(config.grant_type, OAuth2GrantType::ClientCredentials);
292 }
293
294 #[test]
295 fn should_create_pre_acquired_config() {
296 let config = OAuth2Config::pre_acquired(
297 "client-id",
298 "https://auth.example.com/token",
299 "pre-acquired-token",
300 )
301 .expect("Should create builder")
302 .build()
303 .expect("Should build config");
304
305 assert_eq!(config.grant_type, OAuth2GrantType::PreAcquired);
306 }
307
308 #[test]
309 fn should_reject_invalid_token_url() {
310 let result = OAuth2ConfigBuilder::new("client-id", "not-a-url");
311 assert!(result.is_err());
312
313 let err = result.expect_err("Should fail");
314 match err {
315 OAuth2Error::InvalidTokenEndpoint { url, .. } => {
316 assert_eq!(url, "not-a-url");
317 }
318 _ => panic!("Expected InvalidTokenEndpoint error"),
319 }
320 }
321
322 #[test]
323 fn should_require_client_secret_for_client_credentials() {
324 let result = OAuth2ConfigBuilder::new("client-id", "https://auth.example.com/token")
325 .expect("Should create builder")
326 .with_grant_type(OAuth2GrantType::ClientCredentials)
327 .build();
328
329 assert!(result.is_err());
330 match result.expect_err("Should fail") {
331 OAuth2Error::ConfigurationError { reason } => {
332 assert!(reason.contains("client secret"));
333 }
334 _ => panic!("Expected ConfigurationError"),
335 }
336 }
337
338 #[test]
339 fn should_add_scopes() {
340 let config = OAuth2Config::client_credentials(
341 "client-id",
342 "secret",
343 "https://auth.example.com/token",
344 )
345 .expect("Should create builder")
346 .add_scope("read:users")
347 .add_scope("write:users")
348 .build()
349 .expect("Should build config");
350
351 assert_eq!(config.scopes, vec!["read:users", "write:users"]);
352 }
353
354 #[test]
355 fn should_add_multiple_scopes() {
356 let config = OAuth2Config::client_credentials(
357 "client-id",
358 "secret",
359 "https://auth.example.com/token",
360 )
361 .expect("Should create builder")
362 .add_scopes(["scope1", "scope2", "scope3"])
363 .build()
364 .expect("Should build config");
365
366 assert_eq!(config.scopes, vec!["scope1", "scope2", "scope3"]);
367 }
368
369 #[test]
370 fn should_set_refresh_threshold() {
371 let config = OAuth2Config::client_credentials(
372 "client-id",
373 "secret",
374 "https://auth.example.com/token",
375 )
376 .expect("Should create builder")
377 .with_refresh_threshold(Duration::from_secs(120))
378 .build()
379 .expect("Should build config");
380
381 assert_eq!(config.refresh_threshold, Duration::from_secs(120));
382 }
383
384 #[test]
385 fn should_redact_debug_output() {
386 let config = OAuth2Config::client_credentials(
387 "client-id",
388 "super-secret",
389 "https://auth.example.com/token",
390 )
391 .expect("Should create builder")
392 .build()
393 .expect("Should build config");
394
395 let debug_str = format!("{config:?}");
396 assert!(debug_str.contains("[REDACTED]"));
397 assert!(!debug_str.contains("super-secret"));
398 }
399
400 #[tokio::test]
401 async fn should_cache_pre_acquired_token() {
402 let config =
403 OAuth2Config::pre_acquired("client-id", "https://auth.example.com/token", "my-token")
404 .expect("Should create builder")
405 .build()
406 .expect("Should build config");
407
408 let token = config.get_token().await.expect("Should have cached token");
409 assert_eq!(token.access_token(), "my-token");
410 }
411
412 #[tokio::test]
413 async fn should_need_token_when_cache_empty() {
414 let config = OAuth2Config::client_credentials(
415 "client-id",
416 "secret",
417 "https://auth.example.com/token",
418 )
419 .expect("Should create builder")
420 .build()
421 .expect("Should build config");
422
423 assert!(config.needs_token().await);
424 }
425}