Skip to main content

cirrus_auth/
jwt.rs

1//! OAuth 2.0 JWT Bearer flow for Salesforce server-to-server auth.
2//!
3//! The caller pre-authorizes a Connected App by uploading a public X.509
4//! certificate; this auth implementation holds the corresponding RSA private
5//! key and mints fresh access tokens on demand by signing a short-lived JWT
6//! and exchanging it at the OAuth token endpoint.
7//!
8//! ## `instance_url`
9//!
10//! `instance_url` is required at builder time and verified against the
11//! value returned in the token response.
12//!
13//! ## Caching
14//!
15//! Each successful token exchange caches the access token for a configurable
16//! TTL (default 30 minutes). Salesforce does not return an explicit expiry
17//! in the token response — the connected app's session policy controls
18//! actual expiration — so the TTL is a conservative caller-controlled knob,
19//! not a claim about the token's true lifetime. After the TTL elapses, the
20//! next call mints a new token regardless of whether the previous one would
21//! still have worked.
22
23use crate::AuthSession;
24use crate::error::{AuthError, AuthResult};
25use crate::token_endpoint::{check_instance_url, exchange};
26use async_trait::async_trait;
27use camino::Utf8PathBuf;
28use jsonwebtoken::{Algorithm, EncodingKey, Header};
29use serde::Serialize;
30use std::borrow::Cow;
31use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
32use tokio::sync::RwLock;
33
34/// Salesforce production login URL — the default JWT audience and token
35/// exchange host.
36pub const PRODUCTION_LOGIN_URL: &str = "https://login.salesforce.com";
37
38/// Salesforce sandbox login URL.
39pub const SANDBOX_LOGIN_URL: &str = "https://test.salesforce.com";
40
41/// Default cache TTL for an access token after it's issued.
42const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
43
44/// JWT validity window. The Salesforce help docs document a 3-minute
45/// clock-skew buffer applied to the `exp` claim; this 3-minute lifetime
46/// fits comfortably inside any documented authorization-server bound and
47/// keeps the assertion short-lived if it leaks.
48const JWT_VALIDITY_SECS: i64 = 180;
49
50#[derive(Debug, Serialize)]
51struct JwtClaims {
52    iss: String,
53    sub: String,
54    aud: String,
55    exp: i64,
56}
57
58#[derive(Debug, Clone)]
59struct CachedToken {
60    access_token: String,
61    expires_at: Instant,
62}
63
64/// JWT Bearer flow auth session.
65///
66/// Construct via [`JwtAuth::builder`].
67pub struct JwtAuth {
68    consumer_key: String,
69    username: String,
70    encoding_key: EncodingKey,
71    login_url: String,
72    instance_url: String,
73    token_ttl: Duration,
74    http: reqwest::Client,
75    cached: RwLock<Option<CachedToken>>,
76}
77
78impl std::fmt::Debug for JwtAuth {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        // Deliberately omit consumer_key, username, and the encoding key —
81        // all carry secrets or PII.
82        f.debug_struct("JwtAuth")
83            .field("login_url", &self.login_url)
84            .field("instance_url", &self.instance_url)
85            .field("token_ttl", &self.token_ttl)
86            .finish_non_exhaustive()
87    }
88}
89
90impl JwtAuth {
91    /// Begins constructing a [`JwtAuth`].
92    ///
93    /// JWT bearer flow (RFC 7523): the SDK signs a JWT assertion with
94    /// your connected app's private key and exchanges it for an access
95    /// token at the configured login URL. Cached access tokens are
96    /// refreshed transparently on 401.
97    ///
98    /// # Example
99    ///
100    /// ```no_run
101    /// use cirrus_auth::JwtAuth;
102    /// use std::sync::Arc;
103    ///
104    /// # fn example() -> Result<(), cirrus_auth::AuthError> {
105    /// let auth = JwtAuth::builder()
106    ///     .consumer_key("3MVG9...")
107    ///     .username("integration-user@example.com")
108    ///     .login_url("https://login.salesforce.com")
109    ///     .instance_url("https://my-org.my.salesforce.com")
110    ///     .private_key_pem_file("./private.pem")?
111    ///     .build()?;
112    /// // Wrap as Arc<dyn AuthSession> and hand to a Cirrus client.
113    /// let _shared = Arc::new(auth);
114    /// # Ok(())
115    /// # }
116    /// ```
117    pub fn builder() -> JwtAuthBuilder {
118        JwtAuthBuilder::default()
119    }
120
121    async fn mint_token(&self) -> AuthResult<CachedToken> {
122        tracing::info!(
123            target: "cirrus::auth",
124            flow = "jwt-bearer",
125            login_url = %self.login_url,
126            "minting fresh access token",
127        );
128        let now_secs = SystemTime::now()
129            .duration_since(UNIX_EPOCH)
130            .map(|d| d.as_secs() as i64)
131            .map_err(|e| AuthError::Other(format!("system clock before UNIX epoch: {e}")))?;
132
133        let claims = JwtClaims {
134            iss: self.consumer_key.clone(),
135            sub: self.username.clone(),
136            aud: self.login_url.clone(),
137            exp: now_secs + JWT_VALIDITY_SECS,
138        };
139
140        let header = Header::new(Algorithm::RS256);
141        let assertion = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
142            .map_err(|e| AuthError::Other(format!("JWT signing failed: {e}")))?;
143
144        let body = [
145            ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
146            ("assertion", assertion.as_str()),
147        ];
148
149        let token = exchange(&self.http, &self.login_url, &body).await?;
150        check_instance_url(&self.instance_url, &token)?;
151
152        Ok(CachedToken {
153            access_token: token.access_token,
154            expires_at: Instant::now() + self.token_ttl,
155        })
156    }
157}
158
159#[async_trait]
160impl AuthSession for JwtAuth {
161    async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
162        // Fast path — read lock, return clone of cached token if still valid.
163        {
164            let guard = self.cached.read().await;
165            if let Some(cached) = guard.as_ref()
166                && cached.expires_at > Instant::now()
167            {
168                return Ok(Cow::Owned(cached.access_token.clone()));
169            }
170        }
171
172        // Slow path — write lock, double-check, mint.
173        let mut guard = self.cached.write().await;
174        if let Some(cached) = guard.as_ref()
175            && cached.expires_at > Instant::now()
176        {
177            return Ok(Cow::Owned(cached.access_token.clone()));
178        }
179        let new_token = self.mint_token().await?;
180        let token_str = new_token.access_token.clone();
181        *guard = Some(new_token);
182        Ok(Cow::Owned(token_str))
183    }
184
185    fn instance_url(&self) -> &str {
186        &self.instance_url
187    }
188
189    async fn invalidate(&self, stale_token: &str) {
190        // Compare-and-swap: only clear the cached token if it still
191        // matches what the failing request used. Avoids racing with a
192        // concurrent task that already refreshed.
193        let mut guard = self.cached.write().await;
194        if let Some(cached) = guard.as_ref()
195            && cached.access_token == stale_token
196        {
197            tracing::debug!(
198                target: "cirrus::auth",
199                flow = "jwt-bearer",
200                "invalidating cached token (CAS matched)",
201            );
202            *guard = None;
203        } else {
204            tracing::trace!(
205                target: "cirrus::auth",
206                flow = "jwt-bearer",
207                "invalidate called but cached token differs (concurrent refresh?); no-op",
208            );
209        }
210    }
211}
212
213/// Builder for [`JwtAuth`].
214#[derive(Default)]
215pub struct JwtAuthBuilder {
216    consumer_key: Option<String>,
217    username: Option<String>,
218    encoding_key: Option<EncodingKey>,
219    login_url: Option<String>,
220    instance_url: Option<String>,
221    token_ttl: Option<Duration>,
222    http_client: Option<reqwest::Client>,
223}
224
225impl std::fmt::Debug for JwtAuthBuilder {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        // Show which fields have been set without leaking secret-bearing values.
228        f.debug_struct("JwtAuthBuilder")
229            .field("consumer_key", &self.consumer_key.is_some())
230            .field("username", &self.username.is_some())
231            .field("private_key", &self.encoding_key.is_some())
232            .field("login_url", &self.login_url)
233            .field("instance_url", &self.instance_url)
234            .field("token_ttl", &self.token_ttl)
235            .finish_non_exhaustive()
236    }
237}
238
239impl JwtAuthBuilder {
240    /// Connected App's Consumer Key (a.k.a. Client ID) — used as the JWT
241    /// `iss` claim.
242    pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
243        self.consumer_key = Some(key.into());
244        self
245    }
246
247    /// Salesforce username to authenticate as — used as the JWT `sub` claim.
248    pub fn username(mut self, username: impl Into<String>) -> Self {
249        self.username = Some(username.into());
250        self
251    }
252
253    /// Loads the RSA private key from a PEM file at the given path.
254    pub fn private_key_pem_file(mut self, path: impl Into<Utf8PathBuf>) -> AuthResult<Self> {
255        let path = path.into();
256        let bytes = fs_err::read(path.as_std_path())
257            .map_err(|e| AuthError::Other(format!("failed to read private key: {e}")))?;
258        self.encoding_key = Some(
259            EncodingKey::from_rsa_pem(&bytes)
260                .map_err(|e| AuthError::Other(format!("invalid RSA PEM key: {e}")))?,
261        );
262        Ok(self)
263    }
264
265    /// Loads the RSA private key directly from PEM-encoded bytes. Useful
266    /// when the key is held in memory (e.g. fetched from a secret manager).
267    pub fn private_key_pem_bytes(mut self, bytes: &[u8]) -> AuthResult<Self> {
268        self.encoding_key = Some(
269            EncodingKey::from_rsa_pem(bytes)
270                .map_err(|e| AuthError::Other(format!("invalid RSA PEM key: {e}")))?,
271        );
272        Ok(self)
273    }
274
275    /// Login URL — the host that receives the JWT, also used as the JWT
276    /// `aud` claim. Defaults to [`PRODUCTION_LOGIN_URL`]. Use
277    /// [`SANDBOX_LOGIN_URL`] for sandboxes.
278    ///
279    /// Per Salesforce docs ("OAuth 2.0 JWT Bearer Flow ... aud"), valid
280    /// audience values are `https://login.salesforce.com`,
281    /// `https://test.salesforce.com`, or an Experience Cloud site URL —
282    /// **not** the org's My Domain. The `instance_url` is what points at
283    /// the org; this URL identifies the authorization server.
284    pub fn login_url(mut self, url: impl Into<String>) -> Self {
285        self.login_url = Some(url.into());
286        self
287    }
288
289    /// REST instance URL — the org's My Domain (e.g.
290    /// `https://my-org.my.salesforce.com`). Required. Must match the
291    /// `instance_url` that Salesforce returns from the token exchange.
292    pub fn instance_url(mut self, url: impl Into<String>) -> Self {
293        self.instance_url = Some(url.into());
294        self
295    }
296
297    /// How long to cache an access token before re-minting. Defaults to 30
298    /// minutes. Set lower to refresh more aggressively, or higher if your
299    /// connected app's session policy allows.
300    pub fn token_ttl(mut self, ttl: Duration) -> Self {
301        self.token_ttl = Some(ttl);
302        self
303    }
304
305    /// Supplies a pre-configured `reqwest::Client` for the token-exchange
306    /// requests. Useful for sharing a connection pool across multiple SDK
307    /// clients.
308    pub fn http_client(mut self, client: reqwest::Client) -> Self {
309        self.http_client = Some(client);
310        self
311    }
312
313    /// Finalizes the builder.
314    pub fn build(self) -> AuthResult<JwtAuth> {
315        let consumer_key = self
316            .consumer_key
317            .ok_or(AuthError::MissingField("consumer_key"))?;
318        let username = self.username.ok_or(AuthError::MissingField("username"))?;
319        let encoding_key = self
320            .encoding_key
321            .ok_or(AuthError::MissingField("private_key"))?;
322        let mut instance_url = self
323            .instance_url
324            .ok_or(AuthError::MissingField("instance_url"))?;
325        if instance_url.ends_with('/') {
326            instance_url.pop();
327        }
328        let mut login_url = self
329            .login_url
330            .unwrap_or_else(|| PRODUCTION_LOGIN_URL.to_string());
331        if login_url.ends_with('/') {
332            login_url.pop();
333        }
334        let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
335        let http = self.http_client.unwrap_or_default();
336
337        Ok(JwtAuth {
338            consumer_key,
339            username,
340            encoding_key,
341            login_url,
342            instance_url,
343            token_ttl,
344            http,
345            cached: RwLock::new(None),
346        })
347    }
348}
349
350#[cfg(test)]
351#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
352mod tests {
353    use super::*;
354    use std::sync::Arc;
355    use std::sync::atomic::{AtomicUsize, Ordering};
356    use wiremock::matchers::{body_string_contains, method, path};
357    use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
358
359    /// Throwaway test-only RSA private key. No security value.
360    /// See `tests/fixtures/test_rsa_key.pem`.
361    const TEST_PEM: &[u8] = include_bytes!("../tests/fixtures/test_rsa_key.pem");
362
363    fn builder_with_required_fields() -> JwtAuthBuilder {
364        JwtAuth::builder()
365            .consumer_key("consumer-key-123")
366            .username("integration@example.com")
367            .private_key_pem_bytes(TEST_PEM)
368            .unwrap()
369            .instance_url("https://my-org.my.salesforce.com")
370    }
371
372    #[test]
373    fn builder_requires_consumer_key() {
374        let err = JwtAuth::builder()
375            .username("u")
376            .private_key_pem_bytes(TEST_PEM)
377            .unwrap()
378            .instance_url("https://x")
379            .build()
380            .unwrap_err();
381        assert!(matches!(err, AuthError::MissingField("consumer_key")));
382    }
383
384    #[test]
385    fn builder_requires_username() {
386        let err = JwtAuth::builder()
387            .consumer_key("k")
388            .private_key_pem_bytes(TEST_PEM)
389            .unwrap()
390            .instance_url("https://x")
391            .build()
392            .unwrap_err();
393        assert!(matches!(err, AuthError::MissingField("username")));
394    }
395
396    #[test]
397    fn builder_requires_private_key() {
398        let err = JwtAuth::builder()
399            .consumer_key("k")
400            .username("u")
401            .instance_url("https://x")
402            .build()
403            .unwrap_err();
404        assert!(matches!(err, AuthError::MissingField("private_key")));
405    }
406
407    #[test]
408    fn builder_requires_instance_url() {
409        let err = JwtAuth::builder()
410            .consumer_key("k")
411            .username("u")
412            .private_key_pem_bytes(TEST_PEM)
413            .unwrap()
414            .build()
415            .unwrap_err();
416        assert!(matches!(err, AuthError::MissingField("instance_url")));
417    }
418
419    #[test]
420    fn invalid_pem_is_surfaced_as_auth_error() {
421        let err = JwtAuth::builder()
422            .private_key_pem_bytes(b"not a pem")
423            .unwrap_err();
424        assert!(matches!(err, AuthError::Other(_)));
425    }
426
427    #[test]
428    fn builder_strips_trailing_slashes_and_defaults_login_url() {
429        let auth = builder_with_required_fields()
430            .instance_url("https://my-org.my.salesforce.com/")
431            .build()
432            .unwrap();
433        assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
434        assert_eq!(auth.login_url, PRODUCTION_LOGIN_URL);
435    }
436
437    #[tokio::test]
438    async fn mint_token_succeeds_and_caches() {
439        let server = MockServer::start().await;
440        let hits = Arc::new(AtomicUsize::new(0));
441        let body = serde_json::json!({
442            "access_token": "00DXX!ACCESS",
443            "instance_url": "https://my-org.my.salesforce.com",
444            "token_type": "Bearer",
445            "scope": "api",
446            "id": "https://login.salesforce.com/id/00DXX/005XX",
447        });
448
449        Mock::given(method("POST"))
450            .and(path("/services/oauth2/token"))
451            .and(body_string_contains("grant_type=urn"))
452            .and(body_string_contains("assertion="))
453            .respond_with(CountingResponder {
454                hits: hits.clone(),
455                response: ResponseTemplate::new(200).set_body_json(body),
456            })
457            .mount(&server)
458            .await;
459
460        let auth = builder_with_required_fields()
461            .login_url(server.uri())
462            .build()
463            .unwrap();
464
465        let t1 = auth.access_token().await.unwrap();
466        assert_eq!(&*t1, "00DXX!ACCESS");
467        let t2 = auth.access_token().await.unwrap();
468        assert_eq!(&*t2, "00DXX!ACCESS");
469
470        // Second call must reuse the cached token, not call the endpoint again.
471        assert_eq!(hits.load(Ordering::SeqCst), 1);
472    }
473
474    #[tokio::test]
475    async fn expired_cache_remints_token() {
476        let server = MockServer::start().await;
477        let hits = Arc::new(AtomicUsize::new(0));
478
479        Mock::given(method("POST"))
480            .and(path("/services/oauth2/token"))
481            .respond_with(CountingResponder {
482                hits: hits.clone(),
483                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
484                    "access_token": "tok",
485                    "instance_url": "https://my-org.my.salesforce.com"
486                })),
487            })
488            .mount(&server)
489            .await;
490
491        let auth = builder_with_required_fields()
492            .login_url(server.uri())
493            .token_ttl(Duration::ZERO) // every call re-mints
494            .build()
495            .unwrap();
496
497        let _ = auth.access_token().await.unwrap();
498        let _ = auth.access_token().await.unwrap();
499        let _ = auth.access_token().await.unwrap();
500
501        assert_eq!(hits.load(Ordering::SeqCst), 3);
502    }
503
504    #[tokio::test]
505    async fn oauth_error_response_is_surfaced() {
506        let server = MockServer::start().await;
507        Mock::given(method("POST"))
508            .and(path("/services/oauth2/token"))
509            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
510                "error": "invalid_grant",
511                "error_description": "user hasn't approved this consumer"
512            })))
513            .mount(&server)
514            .await;
515
516        let auth = builder_with_required_fields()
517            .login_url(server.uri())
518            .build()
519            .unwrap();
520
521        let err = auth.access_token().await.unwrap_err();
522        match err {
523            AuthError::OAuth {
524                error,
525                error_description,
526            } => {
527                assert_eq!(error, "invalid_grant");
528                assert!(error_description.is_some());
529            }
530            other => panic!("expected OAuth error, got {other:?}"),
531        }
532    }
533
534    #[tokio::test]
535    async fn instance_url_mismatch_is_an_auth_error() {
536        let server = MockServer::start().await;
537        Mock::given(method("POST"))
538            .and(path("/services/oauth2/token"))
539            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
540                "access_token": "tok",
541                "instance_url": "https://different-org.my.salesforce.com"
542            })))
543            .mount(&server)
544            .await;
545
546        let auth = builder_with_required_fields()
547            .login_url(server.uri())
548            .build()
549            .unwrap();
550
551        let err = auth.access_token().await.unwrap_err();
552        assert!(matches!(err, AuthError::Other(_)));
553    }
554
555    /// `invalidate(stale_token)` is a compare-and-swap: it should
556    /// only clear the cached token when the cached value matches
557    /// `stale_token`. This is the contract for all three flows
558    /// (Jwt, Refresh, ClientCredentials); we test it here as the
559    /// canonical example since the impls are identical.
560    #[tokio::test]
561    async fn invalidate_clears_cache_only_when_stale_token_matches() {
562        let server = MockServer::start().await;
563        let hits = Arc::new(AtomicUsize::new(0));
564        let body = serde_json::json!({
565            "access_token": "T1",
566            "instance_url": "https://my-org.my.salesforce.com",
567            "token_type": "Bearer",
568        });
569
570        Mock::given(method("POST"))
571            .and(path("/services/oauth2/token"))
572            .respond_with(CountingResponder {
573                hits: hits.clone(),
574                response: ResponseTemplate::new(200).set_body_json(body),
575            })
576            .mount(&server)
577            .await;
578
579        let auth = builder_with_required_fields()
580            .login_url(server.uri())
581            .build()
582            .unwrap();
583
584        // First call mints T1; cache populated.
585        let t = auth.access_token().await.unwrap();
586        assert_eq!(&*t, "T1");
587        assert_eq!(hits.load(Ordering::SeqCst), 1);
588        drop(t);
589
590        // Invalidate with a *non-matching* stale_token — should be a
591        // no-op, cache stays populated.
592        auth.invalidate("not-the-cached-token").await;
593        let t = auth.access_token().await.unwrap();
594        assert_eq!(&*t, "T1");
595        // No re-mint — the cache wasn't cleared.
596        assert_eq!(hits.load(Ordering::SeqCst), 1);
597        drop(t);
598
599        // Invalidate with the *matching* stale_token — clears cache.
600        auth.invalidate("T1").await;
601        // Next access call must re-mint.
602        let t = auth.access_token().await.unwrap();
603        assert_eq!(&*t, "T1"); // mock still returns T1
604        assert_eq!(hits.load(Ordering::SeqCst), 2);
605    }
606
607    /// Wraps a [`ResponseTemplate`] and counts invocations. Wiremock's
608    /// `expect()` would also work, but counting lets us assert post-hoc.
609    struct CountingResponder {
610        hits: Arc<AtomicUsize>,
611        response: ResponseTemplate,
612    }
613
614    impl Respond for CountingResponder {
615        fn respond(&self, _: &Request) -> ResponseTemplate {
616            self.hits.fetch_add(1, Ordering::SeqCst);
617            self.response.clone()
618        }
619    }
620}