1use crate::AuthSession;
35use crate::error::{AuthError, AuthResult};
36use crate::token_endpoint::{check_instance_url, exchange};
37use async_trait::async_trait;
38use std::borrow::Cow;
39use std::time::{Duration, Instant};
40use tokio::sync::RwLock;
41
42const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
44
45#[derive(Clone)]
46struct CachedToken {
47 access_token: String,
48 expires_at: Instant,
49}
50
51impl std::fmt::Debug for CachedToken {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("CachedToken")
54 .field("access_token", &"[redacted]")
55 .field("expires_at", &self.expires_at)
56 .finish()
57 }
58}
59
60pub struct ClientCredentialsAuth {
64 consumer_key: String,
65 consumer_secret: String,
66 login_url: String,
67 instance_url: String,
68 token_ttl: Duration,
69 http: reqwest::Client,
70 cached: RwLock<Option<CachedToken>>,
71}
72
73impl std::fmt::Debug for ClientCredentialsAuth {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("ClientCredentialsAuth")
77 .field("login_url", &self.login_url)
78 .field("instance_url", &self.instance_url)
79 .field("token_ttl", &self.token_ttl)
80 .finish_non_exhaustive()
81 }
82}
83
84impl ClientCredentialsAuth {
85 pub fn builder() -> ClientCredentialsAuthBuilder {
111 ClientCredentialsAuthBuilder::default()
112 }
113
114 async fn mint_token(&self) -> AuthResult<CachedToken> {
115 tracing::info!(
116 target: "cirrus::auth",
117 flow = "client-credentials",
118 login_url = %self.login_url,
119 "minting fresh access token",
120 );
121 let body = [
122 ("grant_type", "client_credentials"),
123 ("client_id", self.consumer_key.as_str()),
124 ("client_secret", self.consumer_secret.as_str()),
125 ];
126
127 let token = exchange(&self.http, &self.login_url, &body).await?;
128 check_instance_url(&self.instance_url, &token)?;
129
130 Ok(CachedToken {
131 access_token: token.access_token,
132 expires_at: Instant::now() + self.token_ttl,
133 })
134 }
135}
136
137#[async_trait]
138impl AuthSession for ClientCredentialsAuth {
139 async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
140 {
142 let guard = self.cached.read().await;
143 if let Some(cached) = guard.as_ref()
144 && cached.expires_at > Instant::now()
145 {
146 return Ok(Cow::Owned(cached.access_token.clone()));
147 }
148 }
149
150 let mut guard = self.cached.write().await;
152 if let Some(cached) = guard.as_ref()
153 && cached.expires_at > Instant::now()
154 {
155 return Ok(Cow::Owned(cached.access_token.clone()));
156 }
157 let new_token = self.mint_token().await?;
158 let token_str = new_token.access_token.clone();
159 *guard = Some(new_token);
160 Ok(Cow::Owned(token_str))
161 }
162
163 fn instance_url(&self) -> &str {
164 &self.instance_url
165 }
166
167 async fn invalidate(&self, stale_token: &str) {
168 let mut guard = self.cached.write().await;
172 if let Some(cached) = guard.as_ref()
173 && cached.access_token == stale_token
174 {
175 tracing::debug!(
176 target: "cirrus::auth",
177 flow = "client-credentials",
178 "invalidating cached token (CAS matched)",
179 );
180 *guard = None;
181 } else {
182 tracing::trace!(
183 target: "cirrus::auth",
184 flow = "client-credentials",
185 "invalidate called but cached token differs (concurrent refresh?); no-op",
186 );
187 }
188 }
189}
190
191#[derive(Default)]
193pub struct ClientCredentialsAuthBuilder {
194 consumer_key: Option<String>,
195 consumer_secret: Option<String>,
196 login_url: Option<String>,
197 instance_url: Option<String>,
198 token_ttl: Option<Duration>,
199 http_client: Option<reqwest::Client>,
200}
201
202impl std::fmt::Debug for ClientCredentialsAuthBuilder {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("ClientCredentialsAuthBuilder")
205 .field("consumer_key", &self.consumer_key.is_some())
206 .field("consumer_secret", &self.consumer_secret.is_some())
207 .field("login_url", &self.login_url)
208 .field("instance_url", &self.instance_url)
209 .field("token_ttl", &self.token_ttl)
210 .finish_non_exhaustive()
211 }
212}
213
214impl ClientCredentialsAuthBuilder {
215 pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
217 self.consumer_key = Some(key.into());
218 self
219 }
220
221 pub fn consumer_secret(mut self, secret: impl Into<String>) -> Self {
224 self.consumer_secret = Some(secret.into());
225 self
226 }
227
228 pub fn login_url(mut self, url: impl Into<String>) -> Self {
234 self.login_url = Some(url.into());
235 self
236 }
237
238 pub fn instance_url(mut self, url: impl Into<String>) -> Self {
241 self.instance_url = Some(url.into());
242 self
243 }
244
245 pub fn token_ttl(mut self, ttl: Duration) -> Self {
248 self.token_ttl = Some(ttl);
249 self
250 }
251
252 pub fn http_client(mut self, client: reqwest::Client) -> Self {
255 self.http_client = Some(client);
256 self
257 }
258
259 pub fn build(self) -> AuthResult<ClientCredentialsAuth> {
261 let consumer_key = self
262 .consumer_key
263 .ok_or(AuthError::MissingField("consumer_key"))?;
264 let consumer_secret = self
265 .consumer_secret
266 .ok_or(AuthError::MissingField("consumer_secret"))?;
267 let mut instance_url = self
268 .instance_url
269 .ok_or(AuthError::MissingField("instance_url"))?;
270 if instance_url.ends_with('/') {
271 instance_url.pop();
272 }
273 let mut login_url = self.login_url.ok_or(AuthError::MissingField("login_url"))?;
274 if login_url.ends_with('/') {
275 login_url.pop();
276 }
277 let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
278 let http = self.http_client.unwrap_or_default();
279
280 Ok(ClientCredentialsAuth {
281 consumer_key,
282 consumer_secret,
283 login_url,
284 instance_url,
285 token_ttl,
286 http,
287 cached: RwLock::new(None),
288 })
289 }
290}
291
292#[cfg(test)]
293#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
294mod tests {
295 use super::*;
296 use std::sync::Arc;
297 use std::sync::atomic::{AtomicUsize, Ordering};
298 use wiremock::matchers::{body_string_contains, method, path};
299 use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
300
301 fn builder_with_required_fields() -> ClientCredentialsAuthBuilder {
302 ClientCredentialsAuth::builder()
303 .consumer_key("consumer-key-123")
304 .consumer_secret("top-secret")
305 .instance_url("https://my-org.my.salesforce.com")
306 .login_url("https://my-org.my.salesforce.com")
307 }
308
309 #[test]
310 fn builder_requires_consumer_key() {
311 let err = ClientCredentialsAuth::builder()
312 .consumer_secret("s")
313 .instance_url("https://x")
314 .build()
315 .unwrap_err();
316 assert!(matches!(err, AuthError::MissingField("consumer_key")));
317 }
318
319 #[test]
320 fn builder_requires_consumer_secret() {
321 let err = ClientCredentialsAuth::builder()
322 .consumer_key("k")
323 .instance_url("https://x")
324 .build()
325 .unwrap_err();
326 assert!(matches!(err, AuthError::MissingField("consumer_secret")));
327 }
328
329 #[test]
330 fn builder_requires_instance_url() {
331 let err = ClientCredentialsAuth::builder()
332 .consumer_key("k")
333 .consumer_secret("s")
334 .login_url("https://x")
335 .build()
336 .unwrap_err();
337 assert!(matches!(err, AuthError::MissingField("instance_url")));
338 }
339
340 #[test]
341 fn builder_requires_login_url() {
342 let err = ClientCredentialsAuth::builder()
346 .consumer_key("k")
347 .consumer_secret("s")
348 .instance_url("https://x")
349 .build()
350 .unwrap_err();
351 assert!(matches!(err, AuthError::MissingField("login_url")));
352 }
353
354 #[test]
355 fn builder_strips_trailing_slashes_on_login_and_instance_url() {
356 let auth = builder_with_required_fields()
357 .instance_url("https://my-org.my.salesforce.com/")
358 .login_url("https://my-org.my.salesforce.com/")
359 .build()
360 .unwrap();
361 assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
362 assert_eq!(auth.login_url, "https://my-org.my.salesforce.com");
363 }
364
365 #[tokio::test]
366 async fn mint_succeeds_and_caches() {
367 let server = MockServer::start().await;
368 let hits = Arc::new(AtomicUsize::new(0));
369
370 Mock::given(method("POST"))
371 .and(path("/services/oauth2/token"))
372 .and(body_string_contains("grant_type=client_credentials"))
373 .and(body_string_contains("client_id=consumer-key-123"))
374 .and(body_string_contains("client_secret=top-secret"))
375 .respond_with(CountingResponder {
376 hits: hits.clone(),
377 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
378 "access_token": "00DXX!ACCESS",
379 "instance_url": "https://my-org.my.salesforce.com",
380 "token_type": "Bearer",
381 "id": "https://login.salesforce.com/id/00DXX/005XX",
382 })),
383 })
384 .mount(&server)
385 .await;
386
387 let auth = builder_with_required_fields()
388 .login_url(server.uri())
389 .build()
390 .unwrap();
391
392 let t1 = auth.access_token().await.unwrap();
393 assert_eq!(&*t1, "00DXX!ACCESS");
394 let t2 = auth.access_token().await.unwrap();
395 assert_eq!(&*t2, "00DXX!ACCESS");
396 assert_eq!(hits.load(Ordering::SeqCst), 1);
397 }
398
399 #[tokio::test]
400 async fn expired_cache_remints_token() {
401 let server = MockServer::start().await;
402 let hits = Arc::new(AtomicUsize::new(0));
403
404 Mock::given(method("POST"))
405 .and(path("/services/oauth2/token"))
406 .respond_with(CountingResponder {
407 hits: hits.clone(),
408 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
409 "access_token": "tok",
410 "instance_url": "https://my-org.my.salesforce.com"
411 })),
412 })
413 .mount(&server)
414 .await;
415
416 let auth = builder_with_required_fields()
417 .login_url(server.uri())
418 .token_ttl(Duration::ZERO)
419 .build()
420 .unwrap();
421
422 let _ = auth.access_token().await.unwrap();
423 let _ = auth.access_token().await.unwrap();
424 let _ = auth.access_token().await.unwrap();
425 assert_eq!(hits.load(Ordering::SeqCst), 3);
426 }
427
428 #[tokio::test]
429 async fn invalid_client_surfaces_oauth_error() {
430 let server = MockServer::start().await;
431 Mock::given(method("POST"))
432 .and(path("/services/oauth2/token"))
433 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
434 "error": "invalid_client",
435 "error_description": "client identifier invalid"
436 })))
437 .mount(&server)
438 .await;
439
440 let auth = builder_with_required_fields()
441 .login_url(server.uri())
442 .build()
443 .unwrap();
444
445 let err = auth.access_token().await.unwrap_err();
446 match err {
447 AuthError::OAuth {
448 error,
449 error_description,
450 } => {
451 assert_eq!(error, "invalid_client");
452 assert!(error_description.is_some());
453 }
454 other => panic!("expected OAuth error, got {other:?}"),
455 }
456 }
457
458 #[tokio::test]
459 async fn instance_url_mismatch_is_an_auth_error() {
460 let server = MockServer::start().await;
461 Mock::given(method("POST"))
462 .and(path("/services/oauth2/token"))
463 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
464 "access_token": "tok",
465 "instance_url": "https://wrong-org.my.salesforce.com"
466 })))
467 .mount(&server)
468 .await;
469
470 let auth = builder_with_required_fields()
471 .login_url(server.uri())
472 .build()
473 .unwrap();
474
475 let err = auth.access_token().await.unwrap_err();
476 assert!(matches!(err, AuthError::Other(_)));
477 }
478
479 struct CountingResponder {
483 hits: Arc<AtomicUsize>,
484 response: ResponseTemplate,
485 }
486
487 impl Respond for CountingResponder {
488 fn respond(&self, _: &Request) -> ResponseTemplate {
489 self.hits.fetch_add(1, Ordering::SeqCst);
490 self.response.clone()
491 }
492 }
493}