1use crate::AuthSession;
35use crate::error::{AuthError, AuthResult};
36use crate::token_endpoint::{check_instance_url, exchange};
37use async_trait::async_trait;
38use std::borrow::Cow;
39use std::time::{Duration, Instant};
40use tokio::sync::RwLock;
41
42const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
44
45#[derive(Debug, Clone)]
46struct CachedToken {
47 access_token: String,
48 expires_at: Instant,
49}
50
51pub struct ClientCredentialsAuth {
55 consumer_key: String,
56 consumer_secret: String,
57 login_url: String,
58 instance_url: String,
59 token_ttl: Duration,
60 http: reqwest::Client,
61 cached: RwLock<Option<CachedToken>>,
62}
63
64impl std::fmt::Debug for ClientCredentialsAuth {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("ClientCredentialsAuth")
68 .field("login_url", &self.login_url)
69 .field("instance_url", &self.instance_url)
70 .field("token_ttl", &self.token_ttl)
71 .finish_non_exhaustive()
72 }
73}
74
75impl ClientCredentialsAuth {
76 pub fn builder() -> ClientCredentialsAuthBuilder {
102 ClientCredentialsAuthBuilder::default()
103 }
104
105 async fn mint_token(&self) -> AuthResult<CachedToken> {
106 tracing::info!(
107 target: "cirrus::auth",
108 flow = "client-credentials",
109 login_url = %self.login_url,
110 "minting fresh access token",
111 );
112 let body = [
113 ("grant_type", "client_credentials"),
114 ("client_id", self.consumer_key.as_str()),
115 ("client_secret", self.consumer_secret.as_str()),
116 ];
117
118 let token = exchange(&self.http, &self.login_url, &body).await?;
119 check_instance_url(&self.instance_url, &token)?;
120
121 Ok(CachedToken {
122 access_token: token.access_token,
123 expires_at: Instant::now() + self.token_ttl,
124 })
125 }
126}
127
128#[async_trait]
129impl AuthSession for ClientCredentialsAuth {
130 async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
131 {
133 let guard = self.cached.read().await;
134 if let Some(cached) = guard.as_ref()
135 && cached.expires_at > Instant::now()
136 {
137 return Ok(Cow::Owned(cached.access_token.clone()));
138 }
139 }
140
141 let mut guard = self.cached.write().await;
143 if let Some(cached) = guard.as_ref()
144 && cached.expires_at > Instant::now()
145 {
146 return Ok(Cow::Owned(cached.access_token.clone()));
147 }
148 let new_token = self.mint_token().await?;
149 let token_str = new_token.access_token.clone();
150 *guard = Some(new_token);
151 Ok(Cow::Owned(token_str))
152 }
153
154 fn instance_url(&self) -> &str {
155 &self.instance_url
156 }
157
158 async fn invalidate(&self, stale_token: &str) {
159 let mut guard = self.cached.write().await;
163 if let Some(cached) = guard.as_ref()
164 && cached.access_token == stale_token
165 {
166 tracing::debug!(
167 target: "cirrus::auth",
168 flow = "client-credentials",
169 "invalidating cached token (CAS matched)",
170 );
171 *guard = None;
172 } else {
173 tracing::trace!(
174 target: "cirrus::auth",
175 flow = "client-credentials",
176 "invalidate called but cached token differs (concurrent refresh?); no-op",
177 );
178 }
179 }
180}
181
182#[derive(Default)]
184pub struct ClientCredentialsAuthBuilder {
185 consumer_key: Option<String>,
186 consumer_secret: Option<String>,
187 login_url: Option<String>,
188 instance_url: Option<String>,
189 token_ttl: Option<Duration>,
190 http_client: Option<reqwest::Client>,
191}
192
193impl std::fmt::Debug for ClientCredentialsAuthBuilder {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("ClientCredentialsAuthBuilder")
196 .field("consumer_key", &self.consumer_key.is_some())
197 .field("consumer_secret", &self.consumer_secret.is_some())
198 .field("login_url", &self.login_url)
199 .field("instance_url", &self.instance_url)
200 .field("token_ttl", &self.token_ttl)
201 .finish_non_exhaustive()
202 }
203}
204
205impl ClientCredentialsAuthBuilder {
206 pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
208 self.consumer_key = Some(key.into());
209 self
210 }
211
212 pub fn consumer_secret(mut self, secret: impl Into<String>) -> Self {
215 self.consumer_secret = Some(secret.into());
216 self
217 }
218
219 pub fn login_url(mut self, url: impl Into<String>) -> Self {
225 self.login_url = Some(url.into());
226 self
227 }
228
229 pub fn instance_url(mut self, url: impl Into<String>) -> Self {
232 self.instance_url = Some(url.into());
233 self
234 }
235
236 pub fn token_ttl(mut self, ttl: Duration) -> Self {
239 self.token_ttl = Some(ttl);
240 self
241 }
242
243 pub fn http_client(mut self, client: reqwest::Client) -> Self {
246 self.http_client = Some(client);
247 self
248 }
249
250 pub fn build(self) -> AuthResult<ClientCredentialsAuth> {
252 let consumer_key = self
253 .consumer_key
254 .ok_or(AuthError::MissingField("consumer_key"))?;
255 let consumer_secret = self
256 .consumer_secret
257 .ok_or(AuthError::MissingField("consumer_secret"))?;
258 let mut instance_url = self
259 .instance_url
260 .ok_or(AuthError::MissingField("instance_url"))?;
261 if instance_url.ends_with('/') {
262 instance_url.pop();
263 }
264 let mut login_url = self.login_url.ok_or(AuthError::MissingField("login_url"))?;
265 if login_url.ends_with('/') {
266 login_url.pop();
267 }
268 let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
269 let http = self.http_client.unwrap_or_default();
270
271 Ok(ClientCredentialsAuth {
272 consumer_key,
273 consumer_secret,
274 login_url,
275 instance_url,
276 token_ttl,
277 http,
278 cached: RwLock::new(None),
279 })
280 }
281}
282
283#[cfg(test)]
284#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
285mod tests {
286 use super::*;
287 use std::sync::Arc;
288 use std::sync::atomic::{AtomicUsize, Ordering};
289 use wiremock::matchers::{body_string_contains, method, path};
290 use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
291
292 fn builder_with_required_fields() -> ClientCredentialsAuthBuilder {
293 ClientCredentialsAuth::builder()
294 .consumer_key("consumer-key-123")
295 .consumer_secret("top-secret")
296 .instance_url("https://my-org.my.salesforce.com")
297 .login_url("https://my-org.my.salesforce.com")
298 }
299
300 #[test]
301 fn builder_requires_consumer_key() {
302 let err = ClientCredentialsAuth::builder()
303 .consumer_secret("s")
304 .instance_url("https://x")
305 .build()
306 .unwrap_err();
307 assert!(matches!(err, AuthError::MissingField("consumer_key")));
308 }
309
310 #[test]
311 fn builder_requires_consumer_secret() {
312 let err = ClientCredentialsAuth::builder()
313 .consumer_key("k")
314 .instance_url("https://x")
315 .build()
316 .unwrap_err();
317 assert!(matches!(err, AuthError::MissingField("consumer_secret")));
318 }
319
320 #[test]
321 fn builder_requires_instance_url() {
322 let err = ClientCredentialsAuth::builder()
323 .consumer_key("k")
324 .consumer_secret("s")
325 .login_url("https://x")
326 .build()
327 .unwrap_err();
328 assert!(matches!(err, AuthError::MissingField("instance_url")));
329 }
330
331 #[test]
332 fn builder_requires_login_url() {
333 let err = ClientCredentialsAuth::builder()
337 .consumer_key("k")
338 .consumer_secret("s")
339 .instance_url("https://x")
340 .build()
341 .unwrap_err();
342 assert!(matches!(err, AuthError::MissingField("login_url")));
343 }
344
345 #[test]
346 fn builder_strips_trailing_slashes_on_login_and_instance_url() {
347 let auth = builder_with_required_fields()
348 .instance_url("https://my-org.my.salesforce.com/")
349 .login_url("https://my-org.my.salesforce.com/")
350 .build()
351 .unwrap();
352 assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
353 assert_eq!(auth.login_url, "https://my-org.my.salesforce.com");
354 }
355
356 #[tokio::test]
357 async fn mint_succeeds_and_caches() {
358 let server = MockServer::start().await;
359 let hits = Arc::new(AtomicUsize::new(0));
360
361 Mock::given(method("POST"))
362 .and(path("/services/oauth2/token"))
363 .and(body_string_contains("grant_type=client_credentials"))
364 .and(body_string_contains("client_id=consumer-key-123"))
365 .and(body_string_contains("client_secret=top-secret"))
366 .respond_with(CountingResponder {
367 hits: hits.clone(),
368 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
369 "access_token": "00DXX!ACCESS",
370 "instance_url": "https://my-org.my.salesforce.com",
371 "token_type": "Bearer",
372 "id": "https://login.salesforce.com/id/00DXX/005XX",
373 })),
374 })
375 .mount(&server)
376 .await;
377
378 let auth = builder_with_required_fields()
379 .login_url(server.uri())
380 .build()
381 .unwrap();
382
383 let t1 = auth.access_token().await.unwrap();
384 assert_eq!(&*t1, "00DXX!ACCESS");
385 let t2 = auth.access_token().await.unwrap();
386 assert_eq!(&*t2, "00DXX!ACCESS");
387 assert_eq!(hits.load(Ordering::SeqCst), 1);
388 }
389
390 #[tokio::test]
391 async fn expired_cache_remints_token() {
392 let server = MockServer::start().await;
393 let hits = Arc::new(AtomicUsize::new(0));
394
395 Mock::given(method("POST"))
396 .and(path("/services/oauth2/token"))
397 .respond_with(CountingResponder {
398 hits: hits.clone(),
399 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
400 "access_token": "tok",
401 "instance_url": "https://my-org.my.salesforce.com"
402 })),
403 })
404 .mount(&server)
405 .await;
406
407 let auth = builder_with_required_fields()
408 .login_url(server.uri())
409 .token_ttl(Duration::ZERO)
410 .build()
411 .unwrap();
412
413 let _ = auth.access_token().await.unwrap();
414 let _ = auth.access_token().await.unwrap();
415 let _ = auth.access_token().await.unwrap();
416 assert_eq!(hits.load(Ordering::SeqCst), 3);
417 }
418
419 #[tokio::test]
420 async fn invalid_client_surfaces_oauth_error() {
421 let server = MockServer::start().await;
422 Mock::given(method("POST"))
423 .and(path("/services/oauth2/token"))
424 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
425 "error": "invalid_client",
426 "error_description": "client identifier invalid"
427 })))
428 .mount(&server)
429 .await;
430
431 let auth = builder_with_required_fields()
432 .login_url(server.uri())
433 .build()
434 .unwrap();
435
436 let err = auth.access_token().await.unwrap_err();
437 match err {
438 AuthError::OAuth {
439 error,
440 error_description,
441 } => {
442 assert_eq!(error, "invalid_client");
443 assert!(error_description.is_some());
444 }
445 other => panic!("expected OAuth error, got {other:?}"),
446 }
447 }
448
449 #[tokio::test]
450 async fn instance_url_mismatch_is_an_auth_error() {
451 let server = MockServer::start().await;
452 Mock::given(method("POST"))
453 .and(path("/services/oauth2/token"))
454 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
455 "access_token": "tok",
456 "instance_url": "https://wrong-org.my.salesforce.com"
457 })))
458 .mount(&server)
459 .await;
460
461 let auth = builder_with_required_fields()
462 .login_url(server.uri())
463 .build()
464 .unwrap();
465
466 let err = auth.access_token().await.unwrap_err();
467 assert!(matches!(err, AuthError::Other(_)));
468 }
469
470 struct CountingResponder {
474 hits: Arc<AtomicUsize>,
475 response: ResponseTemplate,
476 }
477
478 impl Respond for CountingResponder {
479 fn respond(&self, _: &Request) -> ResponseTemplate {
480 self.hits.fetch_add(1, Ordering::SeqCst);
481 self.response.clone()
482 }
483 }
484}