Skip to main content

cirrus_auth/
client_credentials.rs

1//! OAuth 2.0 Client Credentials grant for server-to-server integrations.
2//!
3//! The client app trades its `consumer_key`/`consumer_secret` for an access
4//! token tied to a pre-configured integration user on the External Client
5//! App / Connected App. Per RFC 6749 §4.4 this grant is for confidential
6//! clients only — there is no public-client variant — so `consumer_secret`
7//! is mandatory.
8//!
9//! ## Salesforce-specific configuration
10//!
11//! Beyond the standard OAuth wire shape, Salesforce requires the connected
12//! app's admin to designate a "Run As" user. That happens entirely on the
13//! org side; the SDK has nothing to configure for it. If the connected app
14//! is not set up with a run-as user, the token endpoint returns
15//! `invalid_client` or `invalid_grant`, which surface as
16//! [`AuthError::OAuth`].
17//!
18//! ## My Domain URL is mandatory
19//!
20//! Per the Salesforce help docs ("OAuth 2.0 Client Credentials Flow for
21//! Server-to-Server Integration"): *"For this flow, requests to
22//! `https://login.salesforce.com` and `https://test.salesforce.com` aren't
23//! supported. Use your My Domain URL instead."* The builder therefore has
24//! no `PRODUCTION_LOGIN_URL`/`SANDBOX_LOGIN_URL` defaults — `login_url` is
25//! required and must be the org's My Domain (e.g.
26//! `https://my-org.my.salesforce.com`).
27//!
28//! ## No refresh token
29//!
30//! Per RFC 6749 §4.4.3, the Client Credentials grant does not issue a
31//! refresh token. Token rotation is handled by re-running the grant when
32//! the local TTL elapses; semantics match [`crate::jwt::JwtAuth`].
33
34use crate::AuthSession;
35use crate::error::{AuthError, AuthResult};
36use crate::token_endpoint::{check_instance_url, exchange};
37use async_trait::async_trait;
38use std::borrow::Cow;
39use std::time::{Duration, Instant};
40use tokio::sync::RwLock;
41
42/// Default cache TTL for an access token after it's issued.
43const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
44
45#[derive(Clone)]
46struct CachedToken {
47    access_token: String,
48    expires_at: Instant,
49}
50
51impl std::fmt::Debug for CachedToken {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("CachedToken")
54            .field("access_token", &"[redacted]")
55            .field("expires_at", &self.expires_at)
56            .finish()
57    }
58}
59
60/// Client-credentials-grant auth session.
61///
62/// Construct via [`ClientCredentialsAuth::builder`].
63pub struct ClientCredentialsAuth {
64    consumer_key: String,
65    consumer_secret: String,
66    login_url: String,
67    instance_url: String,
68    token_ttl: Duration,
69    http: reqwest::Client,
70    cached: RwLock<Option<CachedToken>>,
71}
72
73impl std::fmt::Debug for ClientCredentialsAuth {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        // Omit consumer_key and consumer_secret — both are credentials.
76        f.debug_struct("ClientCredentialsAuth")
77            .field("login_url", &self.login_url)
78            .field("instance_url", &self.instance_url)
79            .field("token_ttl", &self.token_ttl)
80            .finish_non_exhaustive()
81    }
82}
83
84impl ClientCredentialsAuth {
85    /// Begins constructing a [`ClientCredentialsAuth`].
86    ///
87    /// Client-credentials grant (RFC 6749 §4.4): server-to-server flow
88    /// where the connected app's consumer key + secret are exchanged
89    /// directly for an access token, no user context. The connected
90    /// app's "Run As" user determines record-level visibility.
91    ///
92    /// # Example
93    ///
94    /// ```no_run
95    /// use cirrus_auth::ClientCredentialsAuth;
96    /// use std::sync::Arc;
97    ///
98    /// # fn example() -> Result<(), cirrus_auth::AuthError> {
99    /// let auth = ClientCredentialsAuth::builder()
100    ///     .consumer_key("3MVG9...")
101    ///     .consumer_secret("28A2...")
102    ///     .login_url("https://my-org.my.salesforce.com")
103    ///     .instance_url("https://my-org.my.salesforce.com")
104    ///     .build()?;
105    /// // Wrap as Arc<dyn AuthSession> and hand to a Cirrus client.
106    /// let _shared = Arc::new(auth);
107    /// # Ok(())
108    /// # }
109    /// ```
110    pub fn builder() -> ClientCredentialsAuthBuilder {
111        ClientCredentialsAuthBuilder::default()
112    }
113
114    async fn mint_token(&self) -> AuthResult<CachedToken> {
115        tracing::info!(
116            target: "cirrus::auth",
117            flow = "client-credentials",
118            login_url = %self.login_url,
119            "minting fresh access token",
120        );
121        let body = [
122            ("grant_type", "client_credentials"),
123            ("client_id", self.consumer_key.as_str()),
124            ("client_secret", self.consumer_secret.as_str()),
125        ];
126
127        let token = exchange(&self.http, &self.login_url, &body).await?;
128        check_instance_url(&self.instance_url, &token)?;
129
130        Ok(CachedToken {
131            access_token: token.access_token,
132            expires_at: Instant::now() + self.token_ttl,
133        })
134    }
135}
136
137#[async_trait]
138impl AuthSession for ClientCredentialsAuth {
139    async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
140        // Fast path — read lock, return clone of cached token if still valid.
141        {
142            let guard = self.cached.read().await;
143            if let Some(cached) = guard.as_ref()
144                && cached.expires_at > Instant::now()
145            {
146                return Ok(Cow::Owned(cached.access_token.clone()));
147            }
148        }
149
150        // Slow path — write lock, double-check, mint.
151        let mut guard = self.cached.write().await;
152        if let Some(cached) = guard.as_ref()
153            && cached.expires_at > Instant::now()
154        {
155            return Ok(Cow::Owned(cached.access_token.clone()));
156        }
157        let new_token = self.mint_token().await?;
158        let token_str = new_token.access_token.clone();
159        *guard = Some(new_token);
160        Ok(Cow::Owned(token_str))
161    }
162
163    fn instance_url(&self) -> &str {
164        &self.instance_url
165    }
166
167    async fn invalidate(&self, stale_token: &str) {
168        // Compare-and-swap: only clear the cached token if it still
169        // matches what the failing request used. Avoids racing with a
170        // concurrent task that already refreshed.
171        let mut guard = self.cached.write().await;
172        if let Some(cached) = guard.as_ref()
173            && cached.access_token == stale_token
174        {
175            tracing::debug!(
176                target: "cirrus::auth",
177                flow = "client-credentials",
178                "invalidating cached token (CAS matched)",
179            );
180            *guard = None;
181        } else {
182            tracing::trace!(
183                target: "cirrus::auth",
184                flow = "client-credentials",
185                "invalidate called but cached token differs (concurrent refresh?); no-op",
186            );
187        }
188    }
189}
190
191/// Builder for [`ClientCredentialsAuth`].
192#[derive(Default)]
193pub struct ClientCredentialsAuthBuilder {
194    consumer_key: Option<String>,
195    consumer_secret: Option<String>,
196    login_url: Option<String>,
197    instance_url: Option<String>,
198    token_ttl: Option<Duration>,
199    http_client: Option<reqwest::Client>,
200}
201
202impl std::fmt::Debug for ClientCredentialsAuthBuilder {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.debug_struct("ClientCredentialsAuthBuilder")
205            .field("consumer_key", &self.consumer_key.is_some())
206            .field("consumer_secret", &self.consumer_secret.is_some())
207            .field("login_url", &self.login_url)
208            .field("instance_url", &self.instance_url)
209            .field("token_ttl", &self.token_ttl)
210            .finish_non_exhaustive()
211    }
212}
213
214impl ClientCredentialsAuthBuilder {
215    /// Connected App's Consumer Key (Client ID). Required.
216    pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
217        self.consumer_key = Some(key.into());
218        self
219    }
220
221    /// Connected App's Consumer Secret (Client Secret). Required —
222    /// Client Credentials is a confidential-client-only grant.
223    pub fn consumer_secret(mut self, secret: impl Into<String>) -> Self {
224        self.consumer_secret = Some(secret.into());
225        self
226    }
227
228    /// Login URL — the host serving `/services/oauth2/token`. Required;
229    /// must be the org's My Domain URL (e.g.
230    /// `https://my-org.my.salesforce.com`). Salesforce explicitly rejects
231    /// this flow at `https://login.salesforce.com` and
232    /// `https://test.salesforce.com`.
233    pub fn login_url(mut self, url: impl Into<String>) -> Self {
234        self.login_url = Some(url.into());
235        self
236    }
237
238    /// REST instance URL — the org's My Domain. Required. Must match the
239    /// `instance_url` returned by the token-exchange response.
240    pub fn instance_url(mut self, url: impl Into<String>) -> Self {
241        self.instance_url = Some(url.into());
242        self
243    }
244
245    /// How long to cache an access token before re-minting. Defaults to 30
246    /// minutes.
247    pub fn token_ttl(mut self, ttl: Duration) -> Self {
248        self.token_ttl = Some(ttl);
249        self
250    }
251
252    /// Supplies a pre-configured `reqwest::Client`. Useful for sharing a
253    /// connection pool.
254    pub fn http_client(mut self, client: reqwest::Client) -> Self {
255        self.http_client = Some(client);
256        self
257    }
258
259    /// Finalizes the builder.
260    pub fn build(self) -> AuthResult<ClientCredentialsAuth> {
261        let consumer_key = self
262            .consumer_key
263            .ok_or(AuthError::MissingField("consumer_key"))?;
264        let consumer_secret = self
265            .consumer_secret
266            .ok_or(AuthError::MissingField("consumer_secret"))?;
267        let mut instance_url = self
268            .instance_url
269            .ok_or(AuthError::MissingField("instance_url"))?;
270        if instance_url.ends_with('/') {
271            instance_url.pop();
272        }
273        let mut login_url = self.login_url.ok_or(AuthError::MissingField("login_url"))?;
274        if login_url.ends_with('/') {
275            login_url.pop();
276        }
277        let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
278        let http = self.http_client.unwrap_or_default();
279
280        Ok(ClientCredentialsAuth {
281            consumer_key,
282            consumer_secret,
283            login_url,
284            instance_url,
285            token_ttl,
286            http,
287            cached: RwLock::new(None),
288        })
289    }
290}
291
292#[cfg(test)]
293#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
294mod tests {
295    use super::*;
296    use std::sync::Arc;
297    use std::sync::atomic::{AtomicUsize, Ordering};
298    use wiremock::matchers::{body_string_contains, method, path};
299    use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
300
301    fn builder_with_required_fields() -> ClientCredentialsAuthBuilder {
302        ClientCredentialsAuth::builder()
303            .consumer_key("consumer-key-123")
304            .consumer_secret("top-secret")
305            .instance_url("https://my-org.my.salesforce.com")
306            .login_url("https://my-org.my.salesforce.com")
307    }
308
309    #[test]
310    fn builder_requires_consumer_key() {
311        let err = ClientCredentialsAuth::builder()
312            .consumer_secret("s")
313            .instance_url("https://x")
314            .build()
315            .unwrap_err();
316        assert!(matches!(err, AuthError::MissingField("consumer_key")));
317    }
318
319    #[test]
320    fn builder_requires_consumer_secret() {
321        let err = ClientCredentialsAuth::builder()
322            .consumer_key("k")
323            .instance_url("https://x")
324            .build()
325            .unwrap_err();
326        assert!(matches!(err, AuthError::MissingField("consumer_secret")));
327    }
328
329    #[test]
330    fn builder_requires_instance_url() {
331        let err = ClientCredentialsAuth::builder()
332            .consumer_key("k")
333            .consumer_secret("s")
334            .login_url("https://x")
335            .build()
336            .unwrap_err();
337        assert!(matches!(err, AuthError::MissingField("instance_url")));
338    }
339
340    #[test]
341    fn builder_requires_login_url() {
342        // Salesforce rejects Client Credentials at login.salesforce.com /
343        // test.salesforce.com — there's no safe default, so the builder
344        // must demand a My Domain URL up front.
345        let err = ClientCredentialsAuth::builder()
346            .consumer_key("k")
347            .consumer_secret("s")
348            .instance_url("https://x")
349            .build()
350            .unwrap_err();
351        assert!(matches!(err, AuthError::MissingField("login_url")));
352    }
353
354    #[test]
355    fn builder_strips_trailing_slashes_on_login_and_instance_url() {
356        let auth = builder_with_required_fields()
357            .instance_url("https://my-org.my.salesforce.com/")
358            .login_url("https://my-org.my.salesforce.com/")
359            .build()
360            .unwrap();
361        assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
362        assert_eq!(auth.login_url, "https://my-org.my.salesforce.com");
363    }
364
365    #[tokio::test]
366    async fn mint_succeeds_and_caches() {
367        let server = MockServer::start().await;
368        let hits = Arc::new(AtomicUsize::new(0));
369
370        Mock::given(method("POST"))
371            .and(path("/services/oauth2/token"))
372            .and(body_string_contains("grant_type=client_credentials"))
373            .and(body_string_contains("client_id=consumer-key-123"))
374            .and(body_string_contains("client_secret=top-secret"))
375            .respond_with(CountingResponder {
376                hits: hits.clone(),
377                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
378                    "access_token": "00DXX!ACCESS",
379                    "instance_url": "https://my-org.my.salesforce.com",
380                    "token_type": "Bearer",
381                    "id": "https://login.salesforce.com/id/00DXX/005XX",
382                })),
383            })
384            .mount(&server)
385            .await;
386
387        let auth = builder_with_required_fields()
388            .login_url(server.uri())
389            .build()
390            .unwrap();
391
392        let t1 = auth.access_token().await.unwrap();
393        assert_eq!(&*t1, "00DXX!ACCESS");
394        let t2 = auth.access_token().await.unwrap();
395        assert_eq!(&*t2, "00DXX!ACCESS");
396        assert_eq!(hits.load(Ordering::SeqCst), 1);
397    }
398
399    #[tokio::test]
400    async fn expired_cache_remints_token() {
401        let server = MockServer::start().await;
402        let hits = Arc::new(AtomicUsize::new(0));
403
404        Mock::given(method("POST"))
405            .and(path("/services/oauth2/token"))
406            .respond_with(CountingResponder {
407                hits: hits.clone(),
408                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
409                    "access_token": "tok",
410                    "instance_url": "https://my-org.my.salesforce.com"
411                })),
412            })
413            .mount(&server)
414            .await;
415
416        let auth = builder_with_required_fields()
417            .login_url(server.uri())
418            .token_ttl(Duration::ZERO)
419            .build()
420            .unwrap();
421
422        let _ = auth.access_token().await.unwrap();
423        let _ = auth.access_token().await.unwrap();
424        let _ = auth.access_token().await.unwrap();
425        assert_eq!(hits.load(Ordering::SeqCst), 3);
426    }
427
428    #[tokio::test]
429    async fn invalid_client_surfaces_oauth_error() {
430        let server = MockServer::start().await;
431        Mock::given(method("POST"))
432            .and(path("/services/oauth2/token"))
433            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
434                "error": "invalid_client",
435                "error_description": "client identifier invalid"
436            })))
437            .mount(&server)
438            .await;
439
440        let auth = builder_with_required_fields()
441            .login_url(server.uri())
442            .build()
443            .unwrap();
444
445        let err = auth.access_token().await.unwrap_err();
446        match err {
447            AuthError::OAuth {
448                error,
449                error_description,
450            } => {
451                assert_eq!(error, "invalid_client");
452                assert!(error_description.is_some());
453            }
454            other => panic!("expected OAuth error, got {other:?}"),
455        }
456    }
457
458    #[tokio::test]
459    async fn instance_url_mismatch_is_an_auth_error() {
460        let server = MockServer::start().await;
461        Mock::given(method("POST"))
462            .and(path("/services/oauth2/token"))
463            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
464                "access_token": "tok",
465                "instance_url": "https://wrong-org.my.salesforce.com"
466            })))
467            .mount(&server)
468            .await;
469
470        let auth = builder_with_required_fields()
471            .login_url(server.uri())
472            .build()
473            .unwrap();
474
475        let err = auth.access_token().await.unwrap_err();
476        assert!(matches!(err, AuthError::Other(_)));
477    }
478
479    /// Counts invocations and returns a fixed response. Same shape as the
480    /// JWT/Refresh tests' helpers; duplicated to keep test modules
481    /// self-contained.
482    struct CountingResponder {
483        hits: Arc<AtomicUsize>,
484        response: ResponseTemplate,
485    }
486
487    impl Respond for CountingResponder {
488        fn respond(&self, _: &Request) -> ResponseTemplate {
489            self.hits.fetch_add(1, Ordering::SeqCst);
490            self.response.clone()
491        }
492    }
493}