1use crate::AuthSession;
24use crate::error::{AuthError, AuthResult};
25use crate::token_endpoint::{check_instance_url, exchange};
26use async_trait::async_trait;
27use camino::Utf8PathBuf;
28use jsonwebtoken::{Algorithm, EncodingKey, Header};
29use serde::Serialize;
30use std::borrow::Cow;
31use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
32use tokio::sync::RwLock;
33
34pub const PRODUCTION_LOGIN_URL: &str = "https://login.salesforce.com";
37
38pub const SANDBOX_LOGIN_URL: &str = "https://test.salesforce.com";
40
41const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
43
44const JWT_VALIDITY_SECS: i64 = 170;
52
53#[derive(Serialize)]
54struct JwtClaims {
55 iss: String,
56 sub: String,
57 aud: String,
58 exp: i64,
59}
60
61impl std::fmt::Debug for JwtClaims {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("JwtClaims")
67 .field("iss", &"[redacted]")
68 .field("sub", &"[redacted]")
69 .field("aud", &self.aud)
70 .field("exp", &self.exp)
71 .finish()
72 }
73}
74
75#[derive(Clone)]
76struct CachedToken {
77 access_token: String,
78 expires_at: Instant,
79}
80
81impl std::fmt::Debug for CachedToken {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("CachedToken")
84 .field("access_token", &"[redacted]")
85 .field("expires_at", &self.expires_at)
86 .finish()
87 }
88}
89
90pub struct JwtAuth {
94 consumer_key: String,
95 username: String,
96 encoding_key: EncodingKey,
97 login_url: String,
98 instance_url: String,
99 token_ttl: Duration,
100 http: reqwest::Client,
101 cached: RwLock<Option<CachedToken>>,
102}
103
104impl std::fmt::Debug for JwtAuth {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("JwtAuth")
109 .field("login_url", &self.login_url)
110 .field("instance_url", &self.instance_url)
111 .field("token_ttl", &self.token_ttl)
112 .finish_non_exhaustive()
113 }
114}
115
116impl JwtAuth {
117 pub fn builder() -> JwtAuthBuilder {
144 JwtAuthBuilder::default()
145 }
146
147 async fn mint_token(&self) -> AuthResult<CachedToken> {
148 tracing::info!(
149 target: "cirrus::auth",
150 flow = "jwt-bearer",
151 login_url = %self.login_url,
152 "minting fresh access token",
153 );
154 let now_secs = SystemTime::now()
155 .duration_since(UNIX_EPOCH)
156 .map(|d| d.as_secs() as i64)
157 .map_err(|e| AuthError::Other(format!("system clock before UNIX epoch: {e}")))?;
158
159 let claims = JwtClaims {
160 iss: self.consumer_key.clone(),
161 sub: self.username.clone(),
162 aud: self.login_url.clone(),
163 exp: now_secs + JWT_VALIDITY_SECS,
164 };
165
166 let header = Header::new(Algorithm::RS256);
167 let assertion = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
168 .map_err(|e| AuthError::Other(format!("JWT signing failed: {e}")))?;
169
170 let body = [
171 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
172 ("assertion", assertion.as_str()),
173 ];
174
175 let token = exchange(&self.http, &self.login_url, &body).await?;
176 check_instance_url(&self.instance_url, &token)?;
177
178 Ok(CachedToken {
179 access_token: token.access_token,
180 expires_at: Instant::now() + self.token_ttl,
181 })
182 }
183}
184
185#[async_trait]
186impl AuthSession for JwtAuth {
187 async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
188 {
190 let guard = self.cached.read().await;
191 if let Some(cached) = guard.as_ref()
192 && cached.expires_at > Instant::now()
193 {
194 return Ok(Cow::Owned(cached.access_token.clone()));
195 }
196 }
197
198 let mut guard = self.cached.write().await;
200 if let Some(cached) = guard.as_ref()
201 && cached.expires_at > Instant::now()
202 {
203 return Ok(Cow::Owned(cached.access_token.clone()));
204 }
205 let new_token = self.mint_token().await?;
206 let token_str = new_token.access_token.clone();
207 *guard = Some(new_token);
208 Ok(Cow::Owned(token_str))
209 }
210
211 fn instance_url(&self) -> &str {
212 &self.instance_url
213 }
214
215 async fn invalidate(&self, stale_token: &str) {
216 let mut guard = self.cached.write().await;
220 if let Some(cached) = guard.as_ref()
221 && cached.access_token == stale_token
222 {
223 tracing::debug!(
224 target: "cirrus::auth",
225 flow = "jwt-bearer",
226 "invalidating cached token (CAS matched)",
227 );
228 *guard = None;
229 } else {
230 tracing::trace!(
231 target: "cirrus::auth",
232 flow = "jwt-bearer",
233 "invalidate called but cached token differs (concurrent refresh?); no-op",
234 );
235 }
236 }
237}
238
239#[derive(Default)]
241pub struct JwtAuthBuilder {
242 consumer_key: Option<String>,
243 username: Option<String>,
244 encoding_key: Option<EncodingKey>,
245 login_url: Option<String>,
246 instance_url: Option<String>,
247 token_ttl: Option<Duration>,
248 http_client: Option<reqwest::Client>,
249}
250
251impl std::fmt::Debug for JwtAuthBuilder {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 f.debug_struct("JwtAuthBuilder")
255 .field("consumer_key", &self.consumer_key.is_some())
256 .field("username", &self.username.is_some())
257 .field("private_key", &self.encoding_key.is_some())
258 .field("login_url", &self.login_url)
259 .field("instance_url", &self.instance_url)
260 .field("token_ttl", &self.token_ttl)
261 .finish_non_exhaustive()
262 }
263}
264
265impl JwtAuthBuilder {
266 pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
269 self.consumer_key = Some(key.into());
270 self
271 }
272
273 pub fn username(mut self, username: impl Into<String>) -> Self {
275 self.username = Some(username.into());
276 self
277 }
278
279 pub fn private_key_pem_file(mut self, path: impl Into<Utf8PathBuf>) -> AuthResult<Self> {
281 let path = path.into();
282 let bytes = fs_err::read(path.as_std_path())
283 .map_err(|e| AuthError::Other(format!("failed to read private key: {e}")))?;
284 self.encoding_key = Some(
285 EncodingKey::from_rsa_pem(&bytes)
286 .map_err(|e| AuthError::Other(format!("invalid RSA PEM key: {e}")))?,
287 );
288 Ok(self)
289 }
290
291 pub fn private_key_pem_bytes(mut self, bytes: &[u8]) -> AuthResult<Self> {
294 self.encoding_key = Some(
295 EncodingKey::from_rsa_pem(bytes)
296 .map_err(|e| AuthError::Other(format!("invalid RSA PEM key: {e}")))?,
297 );
298 Ok(self)
299 }
300
301 pub fn login_url(mut self, url: impl Into<String>) -> Self {
311 self.login_url = Some(url.into());
312 self
313 }
314
315 pub fn instance_url(mut self, url: impl Into<String>) -> Self {
319 self.instance_url = Some(url.into());
320 self
321 }
322
323 pub fn token_ttl(mut self, ttl: Duration) -> Self {
327 self.token_ttl = Some(ttl);
328 self
329 }
330
331 pub fn http_client(mut self, client: reqwest::Client) -> Self {
335 self.http_client = Some(client);
336 self
337 }
338
339 pub fn build(self) -> AuthResult<JwtAuth> {
341 let consumer_key = self
342 .consumer_key
343 .ok_or(AuthError::MissingField("consumer_key"))?;
344 let username = self.username.ok_or(AuthError::MissingField("username"))?;
345 let encoding_key = self
346 .encoding_key
347 .ok_or(AuthError::MissingField("private_key"))?;
348 let mut instance_url = self
349 .instance_url
350 .ok_or(AuthError::MissingField("instance_url"))?;
351 if instance_url.ends_with('/') {
352 instance_url.pop();
353 }
354 let mut login_url = self
355 .login_url
356 .unwrap_or_else(|| PRODUCTION_LOGIN_URL.to_string());
357 if login_url.ends_with('/') {
358 login_url.pop();
359 }
360 let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
361 let http = self.http_client.unwrap_or_default();
362
363 Ok(JwtAuth {
364 consumer_key,
365 username,
366 encoding_key,
367 login_url,
368 instance_url,
369 token_ttl,
370 http,
371 cached: RwLock::new(None),
372 })
373 }
374}
375
376#[cfg(test)]
377#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
378mod tests {
379 use super::*;
380 use std::sync::Arc;
381 use std::sync::atomic::{AtomicUsize, Ordering};
382 use wiremock::matchers::{body_string_contains, method, path};
383 use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
384
385 const TEST_PEM: &[u8] = include_bytes!("../tests/fixtures/test_rsa_key.pem");
388
389 fn builder_with_required_fields() -> JwtAuthBuilder {
390 JwtAuth::builder()
391 .consumer_key("consumer-key-123")
392 .username("integration@example.com")
393 .private_key_pem_bytes(TEST_PEM)
394 .unwrap()
395 .instance_url("https://my-org.my.salesforce.com")
396 }
397
398 #[test]
399 fn builder_requires_consumer_key() {
400 let err = JwtAuth::builder()
401 .username("u")
402 .private_key_pem_bytes(TEST_PEM)
403 .unwrap()
404 .instance_url("https://x")
405 .build()
406 .unwrap_err();
407 assert!(matches!(err, AuthError::MissingField("consumer_key")));
408 }
409
410 #[test]
411 fn builder_requires_username() {
412 let err = JwtAuth::builder()
413 .consumer_key("k")
414 .private_key_pem_bytes(TEST_PEM)
415 .unwrap()
416 .instance_url("https://x")
417 .build()
418 .unwrap_err();
419 assert!(matches!(err, AuthError::MissingField("username")));
420 }
421
422 #[test]
423 fn builder_requires_private_key() {
424 let err = JwtAuth::builder()
425 .consumer_key("k")
426 .username("u")
427 .instance_url("https://x")
428 .build()
429 .unwrap_err();
430 assert!(matches!(err, AuthError::MissingField("private_key")));
431 }
432
433 #[test]
434 fn builder_requires_instance_url() {
435 let err = JwtAuth::builder()
436 .consumer_key("k")
437 .username("u")
438 .private_key_pem_bytes(TEST_PEM)
439 .unwrap()
440 .build()
441 .unwrap_err();
442 assert!(matches!(err, AuthError::MissingField("instance_url")));
443 }
444
445 #[test]
446 fn invalid_pem_is_surfaced_as_auth_error() {
447 let err = JwtAuth::builder()
448 .private_key_pem_bytes(b"not a pem")
449 .unwrap_err();
450 assert!(matches!(err, AuthError::Other(_)));
451 }
452
453 #[test]
454 fn builder_strips_trailing_slashes_and_defaults_login_url() {
455 let auth = builder_with_required_fields()
456 .instance_url("https://my-org.my.salesforce.com/")
457 .build()
458 .unwrap();
459 assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
460 assert_eq!(auth.login_url, PRODUCTION_LOGIN_URL);
461 }
462
463 #[tokio::test]
464 async fn mint_token_succeeds_and_caches() {
465 let server = MockServer::start().await;
466 let hits = Arc::new(AtomicUsize::new(0));
467 let body = serde_json::json!({
468 "access_token": "00DXX!ACCESS",
469 "instance_url": "https://my-org.my.salesforce.com",
470 "token_type": "Bearer",
471 "scope": "api",
472 "id": "https://login.salesforce.com/id/00DXX/005XX",
473 });
474
475 Mock::given(method("POST"))
476 .and(path("/services/oauth2/token"))
477 .and(body_string_contains("grant_type=urn"))
478 .and(body_string_contains("assertion="))
479 .respond_with(CountingResponder {
480 hits: hits.clone(),
481 response: ResponseTemplate::new(200).set_body_json(body),
482 })
483 .mount(&server)
484 .await;
485
486 let auth = builder_with_required_fields()
487 .login_url(server.uri())
488 .build()
489 .unwrap();
490
491 let t1 = auth.access_token().await.unwrap();
492 assert_eq!(&*t1, "00DXX!ACCESS");
493 let t2 = auth.access_token().await.unwrap();
494 assert_eq!(&*t2, "00DXX!ACCESS");
495
496 assert_eq!(hits.load(Ordering::SeqCst), 1);
498 }
499
500 #[tokio::test]
501 async fn expired_cache_remints_token() {
502 let server = MockServer::start().await;
503 let hits = Arc::new(AtomicUsize::new(0));
504
505 Mock::given(method("POST"))
506 .and(path("/services/oauth2/token"))
507 .respond_with(CountingResponder {
508 hits: hits.clone(),
509 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
510 "access_token": "tok",
511 "instance_url": "https://my-org.my.salesforce.com"
512 })),
513 })
514 .mount(&server)
515 .await;
516
517 let auth = builder_with_required_fields()
518 .login_url(server.uri())
519 .token_ttl(Duration::ZERO) .build()
521 .unwrap();
522
523 let _ = auth.access_token().await.unwrap();
524 let _ = auth.access_token().await.unwrap();
525 let _ = auth.access_token().await.unwrap();
526
527 assert_eq!(hits.load(Ordering::SeqCst), 3);
528 }
529
530 #[tokio::test]
531 async fn oauth_error_response_is_surfaced() {
532 let server = MockServer::start().await;
533 Mock::given(method("POST"))
534 .and(path("/services/oauth2/token"))
535 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
536 "error": "invalid_grant",
537 "error_description": "user hasn't approved this consumer"
538 })))
539 .mount(&server)
540 .await;
541
542 let auth = builder_with_required_fields()
543 .login_url(server.uri())
544 .build()
545 .unwrap();
546
547 let err = auth.access_token().await.unwrap_err();
548 match err {
549 AuthError::OAuth {
550 error,
551 error_description,
552 } => {
553 assert_eq!(error, "invalid_grant");
554 assert!(error_description.is_some());
555 }
556 other => panic!("expected OAuth error, got {other:?}"),
557 }
558 }
559
560 #[tokio::test]
561 async fn instance_url_mismatch_is_an_auth_error() {
562 let server = MockServer::start().await;
563 Mock::given(method("POST"))
564 .and(path("/services/oauth2/token"))
565 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
566 "access_token": "tok",
567 "instance_url": "https://different-org.my.salesforce.com"
568 })))
569 .mount(&server)
570 .await;
571
572 let auth = builder_with_required_fields()
573 .login_url(server.uri())
574 .build()
575 .unwrap();
576
577 let err = auth.access_token().await.unwrap_err();
578 assert!(matches!(err, AuthError::Other(_)));
579 }
580
581 #[tokio::test]
587 async fn invalidate_clears_cache_only_when_stale_token_matches() {
588 let server = MockServer::start().await;
589 let hits = Arc::new(AtomicUsize::new(0));
590 let body = serde_json::json!({
591 "access_token": "T1",
592 "instance_url": "https://my-org.my.salesforce.com",
593 "token_type": "Bearer",
594 });
595
596 Mock::given(method("POST"))
597 .and(path("/services/oauth2/token"))
598 .respond_with(CountingResponder {
599 hits: hits.clone(),
600 response: ResponseTemplate::new(200).set_body_json(body),
601 })
602 .mount(&server)
603 .await;
604
605 let auth = builder_with_required_fields()
606 .login_url(server.uri())
607 .build()
608 .unwrap();
609
610 let t = auth.access_token().await.unwrap();
612 assert_eq!(&*t, "T1");
613 assert_eq!(hits.load(Ordering::SeqCst), 1);
614 drop(t);
615
616 auth.invalidate("not-the-cached-token").await;
619 let t = auth.access_token().await.unwrap();
620 assert_eq!(&*t, "T1");
621 assert_eq!(hits.load(Ordering::SeqCst), 1);
623 drop(t);
624
625 auth.invalidate("T1").await;
627 let t = auth.access_token().await.unwrap();
629 assert_eq!(&*t, "T1"); assert_eq!(hits.load(Ordering::SeqCst), 2);
631 }
632
633 struct CountingResponder {
636 hits: Arc<AtomicUsize>,
637 response: ResponseTemplate,
638 }
639
640 impl Respond for CountingResponder {
641 fn respond(&self, _: &Request) -> ResponseTemplate {
642 self.hits.fetch_add(1, Ordering::SeqCst);
643 self.response.clone()
644 }
645 }
646}