Skip to main content

cirrus_auth/
client_credentials.rs

1//! OAuth 2.0 Client Credentials grant for server-to-server integrations.
2//!
3//! The client app trades its `consumer_key`/`consumer_secret` for an access
4//! token tied to a pre-configured integration user on the External Client
5//! App / Connected App. Per RFC 6749 §4.4 this grant is for confidential
6//! clients only — there is no public-client variant — so `consumer_secret`
7//! is mandatory.
8//!
9//! ## Salesforce-specific configuration
10//!
11//! Beyond the standard OAuth wire shape, Salesforce requires the connected
12//! app's admin to designate a "Run As" user. That happens entirely on the
13//! org side; the SDK has nothing to configure for it. If the connected app
14//! is not set up with a run-as user, the token endpoint returns
15//! `invalid_client` or `invalid_grant`, which surface as
16//! [`AuthError::OAuth`].
17//!
18//! ## My Domain URL is mandatory
19//!
20//! Per the Salesforce help docs ("OAuth 2.0 Client Credentials Flow for
21//! Server-to-Server Integration"): *"For this flow, requests to
22//! `https://login.salesforce.com` and `https://test.salesforce.com` aren't
23//! supported. Use your My Domain URL instead."* The builder therefore has
24//! no `PRODUCTION_LOGIN_URL`/`SANDBOX_LOGIN_URL` defaults — `login_url` is
25//! required and must be the org's My Domain (e.g.
26//! `https://my-org.my.salesforce.com`).
27//!
28//! ## No refresh token
29//!
30//! Per RFC 6749 §4.4.3, the Client Credentials grant does not issue a
31//! refresh token. Token rotation is handled by re-running the grant when
32//! the local TTL elapses; semantics match [`crate::jwt::JwtAuth`].
33
34use crate::AuthSession;
35use crate::error::{AuthError, AuthResult};
36use crate::token_endpoint::{check_instance_url, exchange};
37use async_trait::async_trait;
38use std::borrow::Cow;
39use std::time::{Duration, Instant};
40use tokio::sync::RwLock;
41
42/// Default cache TTL for an access token after it's issued.
43const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
44
45#[derive(Debug, Clone)]
46struct CachedToken {
47    access_token: String,
48    expires_at: Instant,
49}
50
51/// Client-credentials-grant auth session.
52///
53/// Construct via [`ClientCredentialsAuth::builder`].
54pub struct ClientCredentialsAuth {
55    consumer_key: String,
56    consumer_secret: String,
57    login_url: String,
58    instance_url: String,
59    token_ttl: Duration,
60    http: reqwest::Client,
61    cached: RwLock<Option<CachedToken>>,
62}
63
64impl std::fmt::Debug for ClientCredentialsAuth {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        // Omit consumer_key and consumer_secret — both are credentials.
67        f.debug_struct("ClientCredentialsAuth")
68            .field("login_url", &self.login_url)
69            .field("instance_url", &self.instance_url)
70            .field("token_ttl", &self.token_ttl)
71            .finish_non_exhaustive()
72    }
73}
74
75impl ClientCredentialsAuth {
76    /// Begins constructing a [`ClientCredentialsAuth`].
77    ///
78    /// Client-credentials grant (RFC 6749 §4.4): server-to-server flow
79    /// where the connected app's consumer key + secret are exchanged
80    /// directly for an access token, no user context. The connected
81    /// app's "Run As" user determines record-level visibility.
82    ///
83    /// # Example
84    ///
85    /// ```no_run
86    /// use cirrus_auth::ClientCredentialsAuth;
87    /// use std::sync::Arc;
88    ///
89    /// # fn example() -> Result<(), cirrus_auth::AuthError> {
90    /// let auth = ClientCredentialsAuth::builder()
91    ///     .consumer_key("3MVG9...")
92    ///     .consumer_secret("28A2...")
93    ///     .login_url("https://my-org.my.salesforce.com")
94    ///     .instance_url("https://my-org.my.salesforce.com")
95    ///     .build()?;
96    /// // Wrap as Arc<dyn AuthSession> and hand to a Cirrus client.
97    /// let _shared = Arc::new(auth);
98    /// # Ok(())
99    /// # }
100    /// ```
101    pub fn builder() -> ClientCredentialsAuthBuilder {
102        ClientCredentialsAuthBuilder::default()
103    }
104
105    async fn mint_token(&self) -> AuthResult<CachedToken> {
106        tracing::info!(
107            target: "cirrus::auth",
108            flow = "client-credentials",
109            login_url = %self.login_url,
110            "minting fresh access token",
111        );
112        let body = [
113            ("grant_type", "client_credentials"),
114            ("client_id", self.consumer_key.as_str()),
115            ("client_secret", self.consumer_secret.as_str()),
116        ];
117
118        let token = exchange(&self.http, &self.login_url, &body).await?;
119        check_instance_url(&self.instance_url, &token)?;
120
121        Ok(CachedToken {
122            access_token: token.access_token,
123            expires_at: Instant::now() + self.token_ttl,
124        })
125    }
126}
127
128#[async_trait]
129impl AuthSession for ClientCredentialsAuth {
130    async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
131        // Fast path — read lock, return clone of cached token if still valid.
132        {
133            let guard = self.cached.read().await;
134            if let Some(cached) = guard.as_ref()
135                && cached.expires_at > Instant::now()
136            {
137                return Ok(Cow::Owned(cached.access_token.clone()));
138            }
139        }
140
141        // Slow path — write lock, double-check, mint.
142        let mut guard = self.cached.write().await;
143        if let Some(cached) = guard.as_ref()
144            && cached.expires_at > Instant::now()
145        {
146            return Ok(Cow::Owned(cached.access_token.clone()));
147        }
148        let new_token = self.mint_token().await?;
149        let token_str = new_token.access_token.clone();
150        *guard = Some(new_token);
151        Ok(Cow::Owned(token_str))
152    }
153
154    fn instance_url(&self) -> &str {
155        &self.instance_url
156    }
157
158    async fn invalidate(&self, stale_token: &str) {
159        // Compare-and-swap: only clear the cached token if it still
160        // matches what the failing request used. Avoids racing with a
161        // concurrent task that already refreshed.
162        let mut guard = self.cached.write().await;
163        if let Some(cached) = guard.as_ref()
164            && cached.access_token == stale_token
165        {
166            tracing::debug!(
167                target: "cirrus::auth",
168                flow = "client-credentials",
169                "invalidating cached token (CAS matched)",
170            );
171            *guard = None;
172        } else {
173            tracing::trace!(
174                target: "cirrus::auth",
175                flow = "client-credentials",
176                "invalidate called but cached token differs (concurrent refresh?); no-op",
177            );
178        }
179    }
180}
181
182/// Builder for [`ClientCredentialsAuth`].
183#[derive(Default)]
184pub struct ClientCredentialsAuthBuilder {
185    consumer_key: Option<String>,
186    consumer_secret: Option<String>,
187    login_url: Option<String>,
188    instance_url: Option<String>,
189    token_ttl: Option<Duration>,
190    http_client: Option<reqwest::Client>,
191}
192
193impl std::fmt::Debug for ClientCredentialsAuthBuilder {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        f.debug_struct("ClientCredentialsAuthBuilder")
196            .field("consumer_key", &self.consumer_key.is_some())
197            .field("consumer_secret", &self.consumer_secret.is_some())
198            .field("login_url", &self.login_url)
199            .field("instance_url", &self.instance_url)
200            .field("token_ttl", &self.token_ttl)
201            .finish_non_exhaustive()
202    }
203}
204
205impl ClientCredentialsAuthBuilder {
206    /// Connected App's Consumer Key (Client ID). Required.
207    pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
208        self.consumer_key = Some(key.into());
209        self
210    }
211
212    /// Connected App's Consumer Secret (Client Secret). Required —
213    /// Client Credentials is a confidential-client-only grant.
214    pub fn consumer_secret(mut self, secret: impl Into<String>) -> Self {
215        self.consumer_secret = Some(secret.into());
216        self
217    }
218
219    /// Login URL — the host serving `/services/oauth2/token`. Required;
220    /// must be the org's My Domain URL (e.g.
221    /// `https://my-org.my.salesforce.com`). Salesforce explicitly rejects
222    /// this flow at `https://login.salesforce.com` and
223    /// `https://test.salesforce.com`.
224    pub fn login_url(mut self, url: impl Into<String>) -> Self {
225        self.login_url = Some(url.into());
226        self
227    }
228
229    /// REST instance URL — the org's My Domain. Required. Must match the
230    /// `instance_url` returned by the token-exchange response.
231    pub fn instance_url(mut self, url: impl Into<String>) -> Self {
232        self.instance_url = Some(url.into());
233        self
234    }
235
236    /// How long to cache an access token before re-minting. Defaults to 30
237    /// minutes.
238    pub fn token_ttl(mut self, ttl: Duration) -> Self {
239        self.token_ttl = Some(ttl);
240        self
241    }
242
243    /// Supplies a pre-configured `reqwest::Client`. Useful for sharing a
244    /// connection pool.
245    pub fn http_client(mut self, client: reqwest::Client) -> Self {
246        self.http_client = Some(client);
247        self
248    }
249
250    /// Finalizes the builder.
251    pub fn build(self) -> AuthResult<ClientCredentialsAuth> {
252        let consumer_key = self
253            .consumer_key
254            .ok_or(AuthError::MissingField("consumer_key"))?;
255        let consumer_secret = self
256            .consumer_secret
257            .ok_or(AuthError::MissingField("consumer_secret"))?;
258        let mut instance_url = self
259            .instance_url
260            .ok_or(AuthError::MissingField("instance_url"))?;
261        if instance_url.ends_with('/') {
262            instance_url.pop();
263        }
264        let mut login_url = self.login_url.ok_or(AuthError::MissingField("login_url"))?;
265        if login_url.ends_with('/') {
266            login_url.pop();
267        }
268        let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
269        let http = self.http_client.unwrap_or_default();
270
271        Ok(ClientCredentialsAuth {
272            consumer_key,
273            consumer_secret,
274            login_url,
275            instance_url,
276            token_ttl,
277            http,
278            cached: RwLock::new(None),
279        })
280    }
281}
282
283#[cfg(test)]
284#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
285mod tests {
286    use super::*;
287    use std::sync::Arc;
288    use std::sync::atomic::{AtomicUsize, Ordering};
289    use wiremock::matchers::{body_string_contains, method, path};
290    use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
291
292    fn builder_with_required_fields() -> ClientCredentialsAuthBuilder {
293        ClientCredentialsAuth::builder()
294            .consumer_key("consumer-key-123")
295            .consumer_secret("top-secret")
296            .instance_url("https://my-org.my.salesforce.com")
297            .login_url("https://my-org.my.salesforce.com")
298    }
299
300    #[test]
301    fn builder_requires_consumer_key() {
302        let err = ClientCredentialsAuth::builder()
303            .consumer_secret("s")
304            .instance_url("https://x")
305            .build()
306            .unwrap_err();
307        assert!(matches!(err, AuthError::MissingField("consumer_key")));
308    }
309
310    #[test]
311    fn builder_requires_consumer_secret() {
312        let err = ClientCredentialsAuth::builder()
313            .consumer_key("k")
314            .instance_url("https://x")
315            .build()
316            .unwrap_err();
317        assert!(matches!(err, AuthError::MissingField("consumer_secret")));
318    }
319
320    #[test]
321    fn builder_requires_instance_url() {
322        let err = ClientCredentialsAuth::builder()
323            .consumer_key("k")
324            .consumer_secret("s")
325            .login_url("https://x")
326            .build()
327            .unwrap_err();
328        assert!(matches!(err, AuthError::MissingField("instance_url")));
329    }
330
331    #[test]
332    fn builder_requires_login_url() {
333        // Salesforce rejects Client Credentials at login.salesforce.com /
334        // test.salesforce.com — there's no safe default, so the builder
335        // must demand a My Domain URL up front.
336        let err = ClientCredentialsAuth::builder()
337            .consumer_key("k")
338            .consumer_secret("s")
339            .instance_url("https://x")
340            .build()
341            .unwrap_err();
342        assert!(matches!(err, AuthError::MissingField("login_url")));
343    }
344
345    #[test]
346    fn builder_strips_trailing_slashes_on_login_and_instance_url() {
347        let auth = builder_with_required_fields()
348            .instance_url("https://my-org.my.salesforce.com/")
349            .login_url("https://my-org.my.salesforce.com/")
350            .build()
351            .unwrap();
352        assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
353        assert_eq!(auth.login_url, "https://my-org.my.salesforce.com");
354    }
355
356    #[tokio::test]
357    async fn mint_succeeds_and_caches() {
358        let server = MockServer::start().await;
359        let hits = Arc::new(AtomicUsize::new(0));
360
361        Mock::given(method("POST"))
362            .and(path("/services/oauth2/token"))
363            .and(body_string_contains("grant_type=client_credentials"))
364            .and(body_string_contains("client_id=consumer-key-123"))
365            .and(body_string_contains("client_secret=top-secret"))
366            .respond_with(CountingResponder {
367                hits: hits.clone(),
368                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
369                    "access_token": "00DXX!ACCESS",
370                    "instance_url": "https://my-org.my.salesforce.com",
371                    "token_type": "Bearer",
372                    "id": "https://login.salesforce.com/id/00DXX/005XX",
373                })),
374            })
375            .mount(&server)
376            .await;
377
378        let auth = builder_with_required_fields()
379            .login_url(server.uri())
380            .build()
381            .unwrap();
382
383        let t1 = auth.access_token().await.unwrap();
384        assert_eq!(&*t1, "00DXX!ACCESS");
385        let t2 = auth.access_token().await.unwrap();
386        assert_eq!(&*t2, "00DXX!ACCESS");
387        assert_eq!(hits.load(Ordering::SeqCst), 1);
388    }
389
390    #[tokio::test]
391    async fn expired_cache_remints_token() {
392        let server = MockServer::start().await;
393        let hits = Arc::new(AtomicUsize::new(0));
394
395        Mock::given(method("POST"))
396            .and(path("/services/oauth2/token"))
397            .respond_with(CountingResponder {
398                hits: hits.clone(),
399                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
400                    "access_token": "tok",
401                    "instance_url": "https://my-org.my.salesforce.com"
402                })),
403            })
404            .mount(&server)
405            .await;
406
407        let auth = builder_with_required_fields()
408            .login_url(server.uri())
409            .token_ttl(Duration::ZERO)
410            .build()
411            .unwrap();
412
413        let _ = auth.access_token().await.unwrap();
414        let _ = auth.access_token().await.unwrap();
415        let _ = auth.access_token().await.unwrap();
416        assert_eq!(hits.load(Ordering::SeqCst), 3);
417    }
418
419    #[tokio::test]
420    async fn invalid_client_surfaces_oauth_error() {
421        let server = MockServer::start().await;
422        Mock::given(method("POST"))
423            .and(path("/services/oauth2/token"))
424            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
425                "error": "invalid_client",
426                "error_description": "client identifier invalid"
427            })))
428            .mount(&server)
429            .await;
430
431        let auth = builder_with_required_fields()
432            .login_url(server.uri())
433            .build()
434            .unwrap();
435
436        let err = auth.access_token().await.unwrap_err();
437        match err {
438            AuthError::OAuth {
439                error,
440                error_description,
441            } => {
442                assert_eq!(error, "invalid_client");
443                assert!(error_description.is_some());
444            }
445            other => panic!("expected OAuth error, got {other:?}"),
446        }
447    }
448
449    #[tokio::test]
450    async fn instance_url_mismatch_is_an_auth_error() {
451        let server = MockServer::start().await;
452        Mock::given(method("POST"))
453            .and(path("/services/oauth2/token"))
454            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
455                "access_token": "tok",
456                "instance_url": "https://wrong-org.my.salesforce.com"
457            })))
458            .mount(&server)
459            .await;
460
461        let auth = builder_with_required_fields()
462            .login_url(server.uri())
463            .build()
464            .unwrap();
465
466        let err = auth.access_token().await.unwrap_err();
467        assert!(matches!(err, AuthError::Other(_)));
468    }
469
470    /// Counts invocations and returns a fixed response. Same shape as the
471    /// JWT/Refresh tests' helpers; duplicated to keep test modules
472    /// self-contained.
473    struct CountingResponder {
474        hits: Arc<AtomicUsize>,
475        response: ResponseTemplate,
476    }
477
478    impl Respond for CountingResponder {
479        fn respond(&self, _: &Request) -> ResponseTemplate {
480            self.hits.fetch_add(1, Ordering::SeqCst);
481            self.response.clone()
482        }
483    }
484}