Skip to main content

cirrus_auth/
jwt.rs

1//! OAuth 2.0 JWT Bearer flow for Salesforce server-to-server auth.
2//!
3//! The caller pre-authorizes a Connected App by uploading a public X.509
4//! certificate; this auth implementation holds the corresponding RSA private
5//! key and mints fresh access tokens on demand by signing a short-lived JWT
6//! and exchanging it at the OAuth token endpoint.
7//!
8//! ## `instance_url`
9//!
10//! `instance_url` is required at builder time and verified against the
11//! value returned in the token response.
12//!
13//! ## Caching
14//!
15//! Each successful token exchange caches the access token for a configurable
16//! TTL (default 30 minutes). Salesforce does not return an explicit expiry
17//! in the token response — the connected app's session policy controls
18//! actual expiration — so the TTL is a conservative caller-controlled knob,
19//! not a claim about the token's true lifetime. After the TTL elapses, the
20//! next call mints a new token regardless of whether the previous one would
21//! still have worked.
22
23use crate::AuthSession;
24use crate::error::{AuthError, AuthResult};
25use crate::token_endpoint::{check_instance_url, exchange};
26use async_trait::async_trait;
27use camino::Utf8PathBuf;
28use jsonwebtoken::{Algorithm, EncodingKey, Header};
29use serde::Serialize;
30use std::borrow::Cow;
31use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
32use tokio::sync::RwLock;
33
34/// Salesforce production login URL — the default JWT audience and token
35/// exchange host.
36pub const PRODUCTION_LOGIN_URL: &str = "https://login.salesforce.com";
37
38/// Salesforce sandbox login URL.
39pub const SANDBOX_LOGIN_URL: &str = "https://test.salesforce.com";
40
41/// Default cache TTL for an access token after it's issued.
42const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
43
44/// JWT validity window. Salesforce rejects assertions whose `exp` is
45/// more than 3 minutes ahead of *its* clock. Setting `exp = now + 180`
46/// leaves zero slack: any clock skew where the local host runs even
47/// slightly ahead of Salesforce will push `exp` over the ceiling and
48/// produce `invalid_grant`. The 10-second buffer below keeps the
49/// assertion short-lived while tolerating the kind of skew typical of
50/// NTP-synced machines.
51const JWT_VALIDITY_SECS: i64 = 170;
52
53#[derive(Serialize)]
54struct JwtClaims {
55    iss: String,
56    sub: String,
57    aud: String,
58    exp: i64,
59}
60
61// `iss` is the Connected App consumer key (a credential identifier) and
62// `sub` is the Salesforce username (PII). Redact both so a stray
63// `{:?}` in error-handling code never leaks them.
64impl std::fmt::Debug for JwtClaims {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("JwtClaims")
67            .field("iss", &"[redacted]")
68            .field("sub", &"[redacted]")
69            .field("aud", &self.aud)
70            .field("exp", &self.exp)
71            .finish()
72    }
73}
74
75#[derive(Clone)]
76struct CachedToken {
77    access_token: String,
78    expires_at: Instant,
79}
80
81impl std::fmt::Debug for CachedToken {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("CachedToken")
84            .field("access_token", &"[redacted]")
85            .field("expires_at", &self.expires_at)
86            .finish()
87    }
88}
89
90/// JWT Bearer flow auth session.
91///
92/// Construct via [`JwtAuth::builder`].
93pub struct JwtAuth {
94    consumer_key: String,
95    username: String,
96    encoding_key: EncodingKey,
97    login_url: String,
98    instance_url: String,
99    token_ttl: Duration,
100    http: reqwest::Client,
101    cached: RwLock<Option<CachedToken>>,
102}
103
104impl std::fmt::Debug for JwtAuth {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        // Deliberately omit consumer_key, username, and the encoding key —
107        // all carry secrets or PII.
108        f.debug_struct("JwtAuth")
109            .field("login_url", &self.login_url)
110            .field("instance_url", &self.instance_url)
111            .field("token_ttl", &self.token_ttl)
112            .finish_non_exhaustive()
113    }
114}
115
116impl JwtAuth {
117    /// Begins constructing a [`JwtAuth`].
118    ///
119    /// JWT bearer flow (RFC 7523): the SDK signs a JWT assertion with
120    /// your connected app's private key and exchanges it for an access
121    /// token at the configured login URL. Cached access tokens are
122    /// refreshed transparently on 401.
123    ///
124    /// # Example
125    ///
126    /// ```no_run
127    /// use cirrus_auth::JwtAuth;
128    /// use std::sync::Arc;
129    ///
130    /// # fn example() -> Result<(), cirrus_auth::AuthError> {
131    /// let auth = JwtAuth::builder()
132    ///     .consumer_key("3MVG9...")
133    ///     .username("integration-user@example.com")
134    ///     .login_url("https://login.salesforce.com")
135    ///     .instance_url("https://my-org.my.salesforce.com")
136    ///     .private_key_pem_file("./private.pem")?
137    ///     .build()?;
138    /// // Wrap as Arc<dyn AuthSession> and hand to a Cirrus client.
139    /// let _shared = Arc::new(auth);
140    /// # Ok(())
141    /// # }
142    /// ```
143    pub fn builder() -> JwtAuthBuilder {
144        JwtAuthBuilder::default()
145    }
146
147    async fn mint_token(&self) -> AuthResult<CachedToken> {
148        tracing::info!(
149            target: "cirrus::auth",
150            flow = "jwt-bearer",
151            login_url = %self.login_url,
152            "minting fresh access token",
153        );
154        let now_secs = SystemTime::now()
155            .duration_since(UNIX_EPOCH)
156            .map(|d| d.as_secs() as i64)
157            .map_err(|e| AuthError::Other(format!("system clock before UNIX epoch: {e}")))?;
158
159        let claims = JwtClaims {
160            iss: self.consumer_key.clone(),
161            sub: self.username.clone(),
162            aud: self.login_url.clone(),
163            exp: now_secs + JWT_VALIDITY_SECS,
164        };
165
166        let header = Header::new(Algorithm::RS256);
167        let assertion = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
168            .map_err(|e| AuthError::Other(format!("JWT signing failed: {e}")))?;
169
170        let body = [
171            ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
172            ("assertion", assertion.as_str()),
173        ];
174
175        let token = exchange(&self.http, &self.login_url, &body).await?;
176        check_instance_url(&self.instance_url, &token)?;
177
178        Ok(CachedToken {
179            access_token: token.access_token,
180            expires_at: Instant::now() + self.token_ttl,
181        })
182    }
183}
184
185#[async_trait]
186impl AuthSession for JwtAuth {
187    async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
188        // Fast path — read lock, return clone of cached token if still valid.
189        {
190            let guard = self.cached.read().await;
191            if let Some(cached) = guard.as_ref()
192                && cached.expires_at > Instant::now()
193            {
194                return Ok(Cow::Owned(cached.access_token.clone()));
195            }
196        }
197
198        // Slow path — write lock, double-check, mint.
199        let mut guard = self.cached.write().await;
200        if let Some(cached) = guard.as_ref()
201            && cached.expires_at > Instant::now()
202        {
203            return Ok(Cow::Owned(cached.access_token.clone()));
204        }
205        let new_token = self.mint_token().await?;
206        let token_str = new_token.access_token.clone();
207        *guard = Some(new_token);
208        Ok(Cow::Owned(token_str))
209    }
210
211    fn instance_url(&self) -> &str {
212        &self.instance_url
213    }
214
215    async fn invalidate(&self, stale_token: &str) {
216        // Compare-and-swap: only clear the cached token if it still
217        // matches what the failing request used. Avoids racing with a
218        // concurrent task that already refreshed.
219        let mut guard = self.cached.write().await;
220        if let Some(cached) = guard.as_ref()
221            && cached.access_token == stale_token
222        {
223            tracing::debug!(
224                target: "cirrus::auth",
225                flow = "jwt-bearer",
226                "invalidating cached token (CAS matched)",
227            );
228            *guard = None;
229        } else {
230            tracing::trace!(
231                target: "cirrus::auth",
232                flow = "jwt-bearer",
233                "invalidate called but cached token differs (concurrent refresh?); no-op",
234            );
235        }
236    }
237}
238
239/// Builder for [`JwtAuth`].
240#[derive(Default)]
241pub struct JwtAuthBuilder {
242    consumer_key: Option<String>,
243    username: Option<String>,
244    encoding_key: Option<EncodingKey>,
245    login_url: Option<String>,
246    instance_url: Option<String>,
247    token_ttl: Option<Duration>,
248    http_client: Option<reqwest::Client>,
249}
250
251impl std::fmt::Debug for JwtAuthBuilder {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        // Show which fields have been set without leaking secret-bearing values.
254        f.debug_struct("JwtAuthBuilder")
255            .field("consumer_key", &self.consumer_key.is_some())
256            .field("username", &self.username.is_some())
257            .field("private_key", &self.encoding_key.is_some())
258            .field("login_url", &self.login_url)
259            .field("instance_url", &self.instance_url)
260            .field("token_ttl", &self.token_ttl)
261            .finish_non_exhaustive()
262    }
263}
264
265impl JwtAuthBuilder {
266    /// Connected App's Consumer Key (a.k.a. Client ID) — used as the JWT
267    /// `iss` claim.
268    pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
269        self.consumer_key = Some(key.into());
270        self
271    }
272
273    /// Salesforce username to authenticate as — used as the JWT `sub` claim.
274    pub fn username(mut self, username: impl Into<String>) -> Self {
275        self.username = Some(username.into());
276        self
277    }
278
279    /// Loads the RSA private key from a PEM file at the given path.
280    pub fn private_key_pem_file(mut self, path: impl Into<Utf8PathBuf>) -> AuthResult<Self> {
281        let path = path.into();
282        let bytes = fs_err::read(path.as_std_path())
283            .map_err(|e| AuthError::Other(format!("failed to read private key: {e}")))?;
284        self.encoding_key = Some(
285            EncodingKey::from_rsa_pem(&bytes)
286                .map_err(|e| AuthError::Other(format!("invalid RSA PEM key: {e}")))?,
287        );
288        Ok(self)
289    }
290
291    /// Loads the RSA private key directly from PEM-encoded bytes. Useful
292    /// when the key is held in memory (e.g. fetched from a secret manager).
293    pub fn private_key_pem_bytes(mut self, bytes: &[u8]) -> AuthResult<Self> {
294        self.encoding_key = Some(
295            EncodingKey::from_rsa_pem(bytes)
296                .map_err(|e| AuthError::Other(format!("invalid RSA PEM key: {e}")))?,
297        );
298        Ok(self)
299    }
300
301    /// Login URL — the host that receives the JWT, also used as the JWT
302    /// `aud` claim. Defaults to [`PRODUCTION_LOGIN_URL`]. Use
303    /// [`SANDBOX_LOGIN_URL`] for sandboxes.
304    ///
305    /// Per Salesforce docs ("OAuth 2.0 JWT Bearer Flow ... aud"), valid
306    /// audience values are `https://login.salesforce.com`,
307    /// `https://test.salesforce.com`, or an Experience Cloud site URL —
308    /// **not** the org's My Domain. The `instance_url` is what points at
309    /// the org; this URL identifies the authorization server.
310    pub fn login_url(mut self, url: impl Into<String>) -> Self {
311        self.login_url = Some(url.into());
312        self
313    }
314
315    /// REST instance URL — the org's My Domain (e.g.
316    /// `https://my-org.my.salesforce.com`). Required. Must match the
317    /// `instance_url` that Salesforce returns from the token exchange.
318    pub fn instance_url(mut self, url: impl Into<String>) -> Self {
319        self.instance_url = Some(url.into());
320        self
321    }
322
323    /// How long to cache an access token before re-minting. Defaults to 30
324    /// minutes. Set lower to refresh more aggressively, or higher if your
325    /// connected app's session policy allows.
326    pub fn token_ttl(mut self, ttl: Duration) -> Self {
327        self.token_ttl = Some(ttl);
328        self
329    }
330
331    /// Supplies a pre-configured `reqwest::Client` for the token-exchange
332    /// requests. Useful for sharing a connection pool across multiple SDK
333    /// clients.
334    pub fn http_client(mut self, client: reqwest::Client) -> Self {
335        self.http_client = Some(client);
336        self
337    }
338
339    /// Finalizes the builder.
340    pub fn build(self) -> AuthResult<JwtAuth> {
341        let consumer_key = self
342            .consumer_key
343            .ok_or(AuthError::MissingField("consumer_key"))?;
344        let username = self.username.ok_or(AuthError::MissingField("username"))?;
345        let encoding_key = self
346            .encoding_key
347            .ok_or(AuthError::MissingField("private_key"))?;
348        let mut instance_url = self
349            .instance_url
350            .ok_or(AuthError::MissingField("instance_url"))?;
351        if instance_url.ends_with('/') {
352            instance_url.pop();
353        }
354        let mut login_url = self
355            .login_url
356            .unwrap_or_else(|| PRODUCTION_LOGIN_URL.to_string());
357        if login_url.ends_with('/') {
358            login_url.pop();
359        }
360        let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
361        let http = self.http_client.unwrap_or_default();
362
363        Ok(JwtAuth {
364            consumer_key,
365            username,
366            encoding_key,
367            login_url,
368            instance_url,
369            token_ttl,
370            http,
371            cached: RwLock::new(None),
372        })
373    }
374}
375
376#[cfg(test)]
377#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
378mod tests {
379    use super::*;
380    use std::sync::Arc;
381    use std::sync::atomic::{AtomicUsize, Ordering};
382    use wiremock::matchers::{body_string_contains, method, path};
383    use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
384
385    /// Throwaway test-only RSA private key. No security value.
386    /// See `tests/fixtures/test_rsa_key.pem`.
387    const TEST_PEM: &[u8] = include_bytes!("../tests/fixtures/test_rsa_key.pem");
388
389    fn builder_with_required_fields() -> JwtAuthBuilder {
390        JwtAuth::builder()
391            .consumer_key("consumer-key-123")
392            .username("integration@example.com")
393            .private_key_pem_bytes(TEST_PEM)
394            .unwrap()
395            .instance_url("https://my-org.my.salesforce.com")
396    }
397
398    #[test]
399    fn builder_requires_consumer_key() {
400        let err = JwtAuth::builder()
401            .username("u")
402            .private_key_pem_bytes(TEST_PEM)
403            .unwrap()
404            .instance_url("https://x")
405            .build()
406            .unwrap_err();
407        assert!(matches!(err, AuthError::MissingField("consumer_key")));
408    }
409
410    #[test]
411    fn builder_requires_username() {
412        let err = JwtAuth::builder()
413            .consumer_key("k")
414            .private_key_pem_bytes(TEST_PEM)
415            .unwrap()
416            .instance_url("https://x")
417            .build()
418            .unwrap_err();
419        assert!(matches!(err, AuthError::MissingField("username")));
420    }
421
422    #[test]
423    fn builder_requires_private_key() {
424        let err = JwtAuth::builder()
425            .consumer_key("k")
426            .username("u")
427            .instance_url("https://x")
428            .build()
429            .unwrap_err();
430        assert!(matches!(err, AuthError::MissingField("private_key")));
431    }
432
433    #[test]
434    fn builder_requires_instance_url() {
435        let err = JwtAuth::builder()
436            .consumer_key("k")
437            .username("u")
438            .private_key_pem_bytes(TEST_PEM)
439            .unwrap()
440            .build()
441            .unwrap_err();
442        assert!(matches!(err, AuthError::MissingField("instance_url")));
443    }
444
445    #[test]
446    fn invalid_pem_is_surfaced_as_auth_error() {
447        let err = JwtAuth::builder()
448            .private_key_pem_bytes(b"not a pem")
449            .unwrap_err();
450        assert!(matches!(err, AuthError::Other(_)));
451    }
452
453    #[test]
454    fn builder_strips_trailing_slashes_and_defaults_login_url() {
455        let auth = builder_with_required_fields()
456            .instance_url("https://my-org.my.salesforce.com/")
457            .build()
458            .unwrap();
459        assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
460        assert_eq!(auth.login_url, PRODUCTION_LOGIN_URL);
461    }
462
463    #[tokio::test]
464    async fn mint_token_succeeds_and_caches() {
465        let server = MockServer::start().await;
466        let hits = Arc::new(AtomicUsize::new(0));
467        let body = serde_json::json!({
468            "access_token": "00DXX!ACCESS",
469            "instance_url": "https://my-org.my.salesforce.com",
470            "token_type": "Bearer",
471            "scope": "api",
472            "id": "https://login.salesforce.com/id/00DXX/005XX",
473        });
474
475        Mock::given(method("POST"))
476            .and(path("/services/oauth2/token"))
477            .and(body_string_contains("grant_type=urn"))
478            .and(body_string_contains("assertion="))
479            .respond_with(CountingResponder {
480                hits: hits.clone(),
481                response: ResponseTemplate::new(200).set_body_json(body),
482            })
483            .mount(&server)
484            .await;
485
486        let auth = builder_with_required_fields()
487            .login_url(server.uri())
488            .build()
489            .unwrap();
490
491        let t1 = auth.access_token().await.unwrap();
492        assert_eq!(&*t1, "00DXX!ACCESS");
493        let t2 = auth.access_token().await.unwrap();
494        assert_eq!(&*t2, "00DXX!ACCESS");
495
496        // Second call must reuse the cached token, not call the endpoint again.
497        assert_eq!(hits.load(Ordering::SeqCst), 1);
498    }
499
500    #[tokio::test]
501    async fn expired_cache_remints_token() {
502        let server = MockServer::start().await;
503        let hits = Arc::new(AtomicUsize::new(0));
504
505        Mock::given(method("POST"))
506            .and(path("/services/oauth2/token"))
507            .respond_with(CountingResponder {
508                hits: hits.clone(),
509                response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
510                    "access_token": "tok",
511                    "instance_url": "https://my-org.my.salesforce.com"
512                })),
513            })
514            .mount(&server)
515            .await;
516
517        let auth = builder_with_required_fields()
518            .login_url(server.uri())
519            .token_ttl(Duration::ZERO) // every call re-mints
520            .build()
521            .unwrap();
522
523        let _ = auth.access_token().await.unwrap();
524        let _ = auth.access_token().await.unwrap();
525        let _ = auth.access_token().await.unwrap();
526
527        assert_eq!(hits.load(Ordering::SeqCst), 3);
528    }
529
530    #[tokio::test]
531    async fn oauth_error_response_is_surfaced() {
532        let server = MockServer::start().await;
533        Mock::given(method("POST"))
534            .and(path("/services/oauth2/token"))
535            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
536                "error": "invalid_grant",
537                "error_description": "user hasn't approved this consumer"
538            })))
539            .mount(&server)
540            .await;
541
542        let auth = builder_with_required_fields()
543            .login_url(server.uri())
544            .build()
545            .unwrap();
546
547        let err = auth.access_token().await.unwrap_err();
548        match err {
549            AuthError::OAuth {
550                error,
551                error_description,
552            } => {
553                assert_eq!(error, "invalid_grant");
554                assert!(error_description.is_some());
555            }
556            other => panic!("expected OAuth error, got {other:?}"),
557        }
558    }
559
560    #[tokio::test]
561    async fn instance_url_mismatch_is_an_auth_error() {
562        let server = MockServer::start().await;
563        Mock::given(method("POST"))
564            .and(path("/services/oauth2/token"))
565            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
566                "access_token": "tok",
567                "instance_url": "https://different-org.my.salesforce.com"
568            })))
569            .mount(&server)
570            .await;
571
572        let auth = builder_with_required_fields()
573            .login_url(server.uri())
574            .build()
575            .unwrap();
576
577        let err = auth.access_token().await.unwrap_err();
578        assert!(matches!(err, AuthError::Other(_)));
579    }
580
581    /// `invalidate(stale_token)` is a compare-and-swap: it should
582    /// only clear the cached token when the cached value matches
583    /// `stale_token`. This is the contract for all three flows
584    /// (Jwt, Refresh, ClientCredentials); we test it here as the
585    /// canonical example since the impls are identical.
586    #[tokio::test]
587    async fn invalidate_clears_cache_only_when_stale_token_matches() {
588        let server = MockServer::start().await;
589        let hits = Arc::new(AtomicUsize::new(0));
590        let body = serde_json::json!({
591            "access_token": "T1",
592            "instance_url": "https://my-org.my.salesforce.com",
593            "token_type": "Bearer",
594        });
595
596        Mock::given(method("POST"))
597            .and(path("/services/oauth2/token"))
598            .respond_with(CountingResponder {
599                hits: hits.clone(),
600                response: ResponseTemplate::new(200).set_body_json(body),
601            })
602            .mount(&server)
603            .await;
604
605        let auth = builder_with_required_fields()
606            .login_url(server.uri())
607            .build()
608            .unwrap();
609
610        // First call mints T1; cache populated.
611        let t = auth.access_token().await.unwrap();
612        assert_eq!(&*t, "T1");
613        assert_eq!(hits.load(Ordering::SeqCst), 1);
614        drop(t);
615
616        // Invalidate with a *non-matching* stale_token — should be a
617        // no-op, cache stays populated.
618        auth.invalidate("not-the-cached-token").await;
619        let t = auth.access_token().await.unwrap();
620        assert_eq!(&*t, "T1");
621        // No re-mint — the cache wasn't cleared.
622        assert_eq!(hits.load(Ordering::SeqCst), 1);
623        drop(t);
624
625        // Invalidate with the *matching* stale_token — clears cache.
626        auth.invalidate("T1").await;
627        // Next access call must re-mint.
628        let t = auth.access_token().await.unwrap();
629        assert_eq!(&*t, "T1"); // mock still returns T1
630        assert_eq!(hits.load(Ordering::SeqCst), 2);
631    }
632
633    /// Wraps a [`ResponseTemplate`] and counts invocations. Wiremock's
634    /// `expect()` would also work, but counting lets us assert post-hoc.
635    struct CountingResponder {
636        hits: Arc<AtomicUsize>,
637        response: ResponseTemplate,
638    }
639
640    impl Respond for CountingResponder {
641        fn respond(&self, _: &Request) -> ResponseTemplate {
642            self.hits.fetch_add(1, Ordering::SeqCst);
643            self.response.clone()
644        }
645    }
646}