1use crate::AuthSession;
24use crate::error::{AuthError, AuthResult};
25use crate::token_endpoint::{check_instance_url, exchange};
26use async_trait::async_trait;
27use camino::Utf8PathBuf;
28use jsonwebtoken::{Algorithm, EncodingKey, Header};
29use serde::Serialize;
30use std::borrow::Cow;
31use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
32use tokio::sync::RwLock;
33
34pub const PRODUCTION_LOGIN_URL: &str = "https://login.salesforce.com";
37
38pub const SANDBOX_LOGIN_URL: &str = "https://test.salesforce.com";
40
41const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
43
44const JWT_VALIDITY_SECS: i64 = 180;
49
50#[derive(Debug, Serialize)]
51struct JwtClaims {
52 iss: String,
53 sub: String,
54 aud: String,
55 exp: i64,
56}
57
58#[derive(Debug, Clone)]
59struct CachedToken {
60 access_token: String,
61 expires_at: Instant,
62}
63
64pub struct JwtAuth {
68 consumer_key: String,
69 username: String,
70 encoding_key: EncodingKey,
71 login_url: String,
72 instance_url: String,
73 token_ttl: Duration,
74 http: reqwest::Client,
75 cached: RwLock<Option<CachedToken>>,
76}
77
78impl std::fmt::Debug for JwtAuth {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("JwtAuth")
83 .field("login_url", &self.login_url)
84 .field("instance_url", &self.instance_url)
85 .field("token_ttl", &self.token_ttl)
86 .finish_non_exhaustive()
87 }
88}
89
90impl JwtAuth {
91 pub fn builder() -> JwtAuthBuilder {
118 JwtAuthBuilder::default()
119 }
120
121 async fn mint_token(&self) -> AuthResult<CachedToken> {
122 tracing::info!(
123 target: "cirrus::auth",
124 flow = "jwt-bearer",
125 login_url = %self.login_url,
126 "minting fresh access token",
127 );
128 let now_secs = SystemTime::now()
129 .duration_since(UNIX_EPOCH)
130 .map(|d| d.as_secs() as i64)
131 .map_err(|e| AuthError::Other(format!("system clock before UNIX epoch: {e}")))?;
132
133 let claims = JwtClaims {
134 iss: self.consumer_key.clone(),
135 sub: self.username.clone(),
136 aud: self.login_url.clone(),
137 exp: now_secs + JWT_VALIDITY_SECS,
138 };
139
140 let header = Header::new(Algorithm::RS256);
141 let assertion = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
142 .map_err(|e| AuthError::Other(format!("JWT signing failed: {e}")))?;
143
144 let body = [
145 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
146 ("assertion", assertion.as_str()),
147 ];
148
149 let token = exchange(&self.http, &self.login_url, &body).await?;
150 check_instance_url(&self.instance_url, &token)?;
151
152 Ok(CachedToken {
153 access_token: token.access_token,
154 expires_at: Instant::now() + self.token_ttl,
155 })
156 }
157}
158
159#[async_trait]
160impl AuthSession for JwtAuth {
161 async fn access_token(&self) -> AuthResult<Cow<'_, str>> {
162 {
164 let guard = self.cached.read().await;
165 if let Some(cached) = guard.as_ref()
166 && cached.expires_at > Instant::now()
167 {
168 return Ok(Cow::Owned(cached.access_token.clone()));
169 }
170 }
171
172 let mut guard = self.cached.write().await;
174 if let Some(cached) = guard.as_ref()
175 && cached.expires_at > Instant::now()
176 {
177 return Ok(Cow::Owned(cached.access_token.clone()));
178 }
179 let new_token = self.mint_token().await?;
180 let token_str = new_token.access_token.clone();
181 *guard = Some(new_token);
182 Ok(Cow::Owned(token_str))
183 }
184
185 fn instance_url(&self) -> &str {
186 &self.instance_url
187 }
188
189 async fn invalidate(&self, stale_token: &str) {
190 let mut guard = self.cached.write().await;
194 if let Some(cached) = guard.as_ref()
195 && cached.access_token == stale_token
196 {
197 tracing::debug!(
198 target: "cirrus::auth",
199 flow = "jwt-bearer",
200 "invalidating cached token (CAS matched)",
201 );
202 *guard = None;
203 } else {
204 tracing::trace!(
205 target: "cirrus::auth",
206 flow = "jwt-bearer",
207 "invalidate called but cached token differs (concurrent refresh?); no-op",
208 );
209 }
210 }
211}
212
213#[derive(Default)]
215pub struct JwtAuthBuilder {
216 consumer_key: Option<String>,
217 username: Option<String>,
218 encoding_key: Option<EncodingKey>,
219 login_url: Option<String>,
220 instance_url: Option<String>,
221 token_ttl: Option<Duration>,
222 http_client: Option<reqwest::Client>,
223}
224
225impl std::fmt::Debug for JwtAuthBuilder {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 f.debug_struct("JwtAuthBuilder")
229 .field("consumer_key", &self.consumer_key.is_some())
230 .field("username", &self.username.is_some())
231 .field("private_key", &self.encoding_key.is_some())
232 .field("login_url", &self.login_url)
233 .field("instance_url", &self.instance_url)
234 .field("token_ttl", &self.token_ttl)
235 .finish_non_exhaustive()
236 }
237}
238
239impl JwtAuthBuilder {
240 pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
243 self.consumer_key = Some(key.into());
244 self
245 }
246
247 pub fn username(mut self, username: impl Into<String>) -> Self {
249 self.username = Some(username.into());
250 self
251 }
252
253 pub fn private_key_pem_file(mut self, path: impl Into<Utf8PathBuf>) -> AuthResult<Self> {
255 let path = path.into();
256 let bytes = fs_err::read(path.as_std_path())
257 .map_err(|e| AuthError::Other(format!("failed to read private key: {e}")))?;
258 self.encoding_key = Some(
259 EncodingKey::from_rsa_pem(&bytes)
260 .map_err(|e| AuthError::Other(format!("invalid RSA PEM key: {e}")))?,
261 );
262 Ok(self)
263 }
264
265 pub fn private_key_pem_bytes(mut self, bytes: &[u8]) -> AuthResult<Self> {
268 self.encoding_key = Some(
269 EncodingKey::from_rsa_pem(bytes)
270 .map_err(|e| AuthError::Other(format!("invalid RSA PEM key: {e}")))?,
271 );
272 Ok(self)
273 }
274
275 pub fn login_url(mut self, url: impl Into<String>) -> Self {
285 self.login_url = Some(url.into());
286 self
287 }
288
289 pub fn instance_url(mut self, url: impl Into<String>) -> Self {
293 self.instance_url = Some(url.into());
294 self
295 }
296
297 pub fn token_ttl(mut self, ttl: Duration) -> Self {
301 self.token_ttl = Some(ttl);
302 self
303 }
304
305 pub fn http_client(mut self, client: reqwest::Client) -> Self {
309 self.http_client = Some(client);
310 self
311 }
312
313 pub fn build(self) -> AuthResult<JwtAuth> {
315 let consumer_key = self
316 .consumer_key
317 .ok_or(AuthError::MissingField("consumer_key"))?;
318 let username = self.username.ok_or(AuthError::MissingField("username"))?;
319 let encoding_key = self
320 .encoding_key
321 .ok_or(AuthError::MissingField("private_key"))?;
322 let mut instance_url = self
323 .instance_url
324 .ok_or(AuthError::MissingField("instance_url"))?;
325 if instance_url.ends_with('/') {
326 instance_url.pop();
327 }
328 let mut login_url = self
329 .login_url
330 .unwrap_or_else(|| PRODUCTION_LOGIN_URL.to_string());
331 if login_url.ends_with('/') {
332 login_url.pop();
333 }
334 let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
335 let http = self.http_client.unwrap_or_default();
336
337 Ok(JwtAuth {
338 consumer_key,
339 username,
340 encoding_key,
341 login_url,
342 instance_url,
343 token_ttl,
344 http,
345 cached: RwLock::new(None),
346 })
347 }
348}
349
350#[cfg(test)]
351#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
352mod tests {
353 use super::*;
354 use std::sync::Arc;
355 use std::sync::atomic::{AtomicUsize, Ordering};
356 use wiremock::matchers::{body_string_contains, method, path};
357 use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
358
359 const TEST_PEM: &[u8] = include_bytes!("../tests/fixtures/test_rsa_key.pem");
362
363 fn builder_with_required_fields() -> JwtAuthBuilder {
364 JwtAuth::builder()
365 .consumer_key("consumer-key-123")
366 .username("integration@example.com")
367 .private_key_pem_bytes(TEST_PEM)
368 .unwrap()
369 .instance_url("https://my-org.my.salesforce.com")
370 }
371
372 #[test]
373 fn builder_requires_consumer_key() {
374 let err = JwtAuth::builder()
375 .username("u")
376 .private_key_pem_bytes(TEST_PEM)
377 .unwrap()
378 .instance_url("https://x")
379 .build()
380 .unwrap_err();
381 assert!(matches!(err, AuthError::MissingField("consumer_key")));
382 }
383
384 #[test]
385 fn builder_requires_username() {
386 let err = JwtAuth::builder()
387 .consumer_key("k")
388 .private_key_pem_bytes(TEST_PEM)
389 .unwrap()
390 .instance_url("https://x")
391 .build()
392 .unwrap_err();
393 assert!(matches!(err, AuthError::MissingField("username")));
394 }
395
396 #[test]
397 fn builder_requires_private_key() {
398 let err = JwtAuth::builder()
399 .consumer_key("k")
400 .username("u")
401 .instance_url("https://x")
402 .build()
403 .unwrap_err();
404 assert!(matches!(err, AuthError::MissingField("private_key")));
405 }
406
407 #[test]
408 fn builder_requires_instance_url() {
409 let err = JwtAuth::builder()
410 .consumer_key("k")
411 .username("u")
412 .private_key_pem_bytes(TEST_PEM)
413 .unwrap()
414 .build()
415 .unwrap_err();
416 assert!(matches!(err, AuthError::MissingField("instance_url")));
417 }
418
419 #[test]
420 fn invalid_pem_is_surfaced_as_auth_error() {
421 let err = JwtAuth::builder()
422 .private_key_pem_bytes(b"not a pem")
423 .unwrap_err();
424 assert!(matches!(err, AuthError::Other(_)));
425 }
426
427 #[test]
428 fn builder_strips_trailing_slashes_and_defaults_login_url() {
429 let auth = builder_with_required_fields()
430 .instance_url("https://my-org.my.salesforce.com/")
431 .build()
432 .unwrap();
433 assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
434 assert_eq!(auth.login_url, PRODUCTION_LOGIN_URL);
435 }
436
437 #[tokio::test]
438 async fn mint_token_succeeds_and_caches() {
439 let server = MockServer::start().await;
440 let hits = Arc::new(AtomicUsize::new(0));
441 let body = serde_json::json!({
442 "access_token": "00DXX!ACCESS",
443 "instance_url": "https://my-org.my.salesforce.com",
444 "token_type": "Bearer",
445 "scope": "api",
446 "id": "https://login.salesforce.com/id/00DXX/005XX",
447 });
448
449 Mock::given(method("POST"))
450 .and(path("/services/oauth2/token"))
451 .and(body_string_contains("grant_type=urn"))
452 .and(body_string_contains("assertion="))
453 .respond_with(CountingResponder {
454 hits: hits.clone(),
455 response: ResponseTemplate::new(200).set_body_json(body),
456 })
457 .mount(&server)
458 .await;
459
460 let auth = builder_with_required_fields()
461 .login_url(server.uri())
462 .build()
463 .unwrap();
464
465 let t1 = auth.access_token().await.unwrap();
466 assert_eq!(&*t1, "00DXX!ACCESS");
467 let t2 = auth.access_token().await.unwrap();
468 assert_eq!(&*t2, "00DXX!ACCESS");
469
470 assert_eq!(hits.load(Ordering::SeqCst), 1);
472 }
473
474 #[tokio::test]
475 async fn expired_cache_remints_token() {
476 let server = MockServer::start().await;
477 let hits = Arc::new(AtomicUsize::new(0));
478
479 Mock::given(method("POST"))
480 .and(path("/services/oauth2/token"))
481 .respond_with(CountingResponder {
482 hits: hits.clone(),
483 response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
484 "access_token": "tok",
485 "instance_url": "https://my-org.my.salesforce.com"
486 })),
487 })
488 .mount(&server)
489 .await;
490
491 let auth = builder_with_required_fields()
492 .login_url(server.uri())
493 .token_ttl(Duration::ZERO) .build()
495 .unwrap();
496
497 let _ = auth.access_token().await.unwrap();
498 let _ = auth.access_token().await.unwrap();
499 let _ = auth.access_token().await.unwrap();
500
501 assert_eq!(hits.load(Ordering::SeqCst), 3);
502 }
503
504 #[tokio::test]
505 async fn oauth_error_response_is_surfaced() {
506 let server = MockServer::start().await;
507 Mock::given(method("POST"))
508 .and(path("/services/oauth2/token"))
509 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
510 "error": "invalid_grant",
511 "error_description": "user hasn't approved this consumer"
512 })))
513 .mount(&server)
514 .await;
515
516 let auth = builder_with_required_fields()
517 .login_url(server.uri())
518 .build()
519 .unwrap();
520
521 let err = auth.access_token().await.unwrap_err();
522 match err {
523 AuthError::OAuth {
524 error,
525 error_description,
526 } => {
527 assert_eq!(error, "invalid_grant");
528 assert!(error_description.is_some());
529 }
530 other => panic!("expected OAuth error, got {other:?}"),
531 }
532 }
533
534 #[tokio::test]
535 async fn instance_url_mismatch_is_an_auth_error() {
536 let server = MockServer::start().await;
537 Mock::given(method("POST"))
538 .and(path("/services/oauth2/token"))
539 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
540 "access_token": "tok",
541 "instance_url": "https://different-org.my.salesforce.com"
542 })))
543 .mount(&server)
544 .await;
545
546 let auth = builder_with_required_fields()
547 .login_url(server.uri())
548 .build()
549 .unwrap();
550
551 let err = auth.access_token().await.unwrap_err();
552 assert!(matches!(err, AuthError::Other(_)));
553 }
554
555 #[tokio::test]
561 async fn invalidate_clears_cache_only_when_stale_token_matches() {
562 let server = MockServer::start().await;
563 let hits = Arc::new(AtomicUsize::new(0));
564 let body = serde_json::json!({
565 "access_token": "T1",
566 "instance_url": "https://my-org.my.salesforce.com",
567 "token_type": "Bearer",
568 });
569
570 Mock::given(method("POST"))
571 .and(path("/services/oauth2/token"))
572 .respond_with(CountingResponder {
573 hits: hits.clone(),
574 response: ResponseTemplate::new(200).set_body_json(body),
575 })
576 .mount(&server)
577 .await;
578
579 let auth = builder_with_required_fields()
580 .login_url(server.uri())
581 .build()
582 .unwrap();
583
584 let t = auth.access_token().await.unwrap();
586 assert_eq!(&*t, "T1");
587 assert_eq!(hits.load(Ordering::SeqCst), 1);
588 drop(t);
589
590 auth.invalidate("not-the-cached-token").await;
593 let t = auth.access_token().await.unwrap();
594 assert_eq!(&*t, "T1");
595 assert_eq!(hits.load(Ordering::SeqCst), 1);
597 drop(t);
598
599 auth.invalidate("T1").await;
601 let t = auth.access_token().await.unwrap();
603 assert_eq!(&*t, "T1"); assert_eq!(hits.load(Ordering::SeqCst), 2);
605 }
606
607 struct CountingResponder {
610 hits: Arc<AtomicUsize>,
611 response: ResponseTemplate,
612 }
613
614 impl Respond for CountingResponder {
615 fn respond(&self, _: &Request) -> ResponseTemplate {
616 self.hits.fetch_add(1, Ordering::SeqCst);
617 self.response.clone()
618 }
619 }
620}