use crate::auth::AuthSession;
use crate::auth::token_endpoint::{check_instance_url, exchange};
use crate::error::{CirrusError, CirrusResult};
use async_trait::async_trait;
use camino::Utf8PathBuf;
use jsonwebtoken::{Algorithm, EncodingKey, Header};
use serde::Serialize;
use std::borrow::Cow;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
pub const PRODUCTION_LOGIN_URL: &str = "https://login.salesforce.com";
pub const SANDBOX_LOGIN_URL: &str = "https://test.salesforce.com";
const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
const JWT_VALIDITY_SECS: i64 = 180;
#[derive(Debug, Serialize)]
struct JwtClaims {
iss: String,
sub: String,
aud: String,
exp: i64,
}
#[derive(Debug, Clone)]
struct CachedToken {
access_token: String,
expires_at: Instant,
}
pub struct JwtAuth {
consumer_key: String,
username: String,
encoding_key: EncodingKey,
login_url: String,
instance_url: String,
token_ttl: Duration,
http: reqwest::Client,
cached: RwLock<Option<CachedToken>>,
}
impl std::fmt::Debug for JwtAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtAuth")
.field("login_url", &self.login_url)
.field("instance_url", &self.instance_url)
.field("token_ttl", &self.token_ttl)
.finish_non_exhaustive()
}
}
impl JwtAuth {
pub fn builder() -> JwtAuthBuilder {
JwtAuthBuilder::default()
}
async fn mint_token(&self) -> CirrusResult<CachedToken> {
tracing::info!(
target: "cirrus::auth",
flow = "jwt-bearer",
login_url = %self.login_url,
"minting fresh access token",
);
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.map_err(|e| CirrusError::Auth(format!("system clock before UNIX epoch: {e}")))?;
let claims = JwtClaims {
iss: self.consumer_key.clone(),
sub: self.username.clone(),
aud: self.login_url.clone(),
exp: now_secs + JWT_VALIDITY_SECS,
};
let header = Header::new(Algorithm::RS256);
let assertion = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
.map_err(|e| CirrusError::Auth(format!("JWT signing failed: {e}")))?;
let body = [
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", assertion.as_str()),
];
let token = exchange(&self.http, &self.login_url, &body).await?;
check_instance_url(&self.instance_url, &token)?;
Ok(CachedToken {
access_token: token.access_token,
expires_at: Instant::now() + self.token_ttl,
})
}
}
#[async_trait]
impl AuthSession for JwtAuth {
async fn access_token(&self) -> CirrusResult<Cow<'_, str>> {
{
let guard = self.cached.read().await;
if let Some(cached) = guard.as_ref()
&& cached.expires_at > Instant::now()
{
return Ok(Cow::Owned(cached.access_token.clone()));
}
}
let mut guard = self.cached.write().await;
if let Some(cached) = guard.as_ref()
&& cached.expires_at > Instant::now()
{
return Ok(Cow::Owned(cached.access_token.clone()));
}
let new_token = self.mint_token().await?;
let token_str = new_token.access_token.clone();
*guard = Some(new_token);
Ok(Cow::Owned(token_str))
}
fn instance_url(&self) -> &str {
&self.instance_url
}
async fn invalidate(&self, stale_token: &str) {
let mut guard = self.cached.write().await;
if let Some(cached) = guard.as_ref()
&& cached.access_token == stale_token
{
tracing::debug!(
target: "cirrus::auth",
flow = "jwt-bearer",
"invalidating cached token (CAS matched)",
);
*guard = None;
} else {
tracing::trace!(
target: "cirrus::auth",
flow = "jwt-bearer",
"invalidate called but cached token differs (concurrent refresh?); no-op",
);
}
}
}
#[derive(Default)]
pub struct JwtAuthBuilder {
consumer_key: Option<String>,
username: Option<String>,
encoding_key: Option<EncodingKey>,
login_url: Option<String>,
instance_url: Option<String>,
token_ttl: Option<Duration>,
http_client: Option<reqwest::Client>,
}
impl std::fmt::Debug for JwtAuthBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtAuthBuilder")
.field("consumer_key", &self.consumer_key.is_some())
.field("username", &self.username.is_some())
.field("private_key", &self.encoding_key.is_some())
.field("login_url", &self.login_url)
.field("instance_url", &self.instance_url)
.field("token_ttl", &self.token_ttl)
.finish_non_exhaustive()
}
}
impl JwtAuthBuilder {
pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
self.consumer_key = Some(key.into());
self
}
pub fn username(mut self, username: impl Into<String>) -> Self {
self.username = Some(username.into());
self
}
pub fn private_key_pem_file(mut self, path: impl Into<Utf8PathBuf>) -> CirrusResult<Self> {
let path = path.into();
let bytes = fs_err::read(path.as_std_path())
.map_err(|e| CirrusError::Auth(format!("failed to read private key: {e}")))?;
self.encoding_key = Some(
EncodingKey::from_rsa_pem(&bytes)
.map_err(|e| CirrusError::Auth(format!("invalid RSA PEM key: {e}")))?,
);
Ok(self)
}
pub fn private_key_pem_bytes(mut self, bytes: &[u8]) -> CirrusResult<Self> {
self.encoding_key = Some(
EncodingKey::from_rsa_pem(bytes)
.map_err(|e| CirrusError::Auth(format!("invalid RSA PEM key: {e}")))?,
);
Ok(self)
}
pub fn login_url(mut self, url: impl Into<String>) -> Self {
self.login_url = Some(url.into());
self
}
pub fn instance_url(mut self, url: impl Into<String>) -> Self {
self.instance_url = Some(url.into());
self
}
pub fn token_ttl(mut self, ttl: Duration) -> Self {
self.token_ttl = Some(ttl);
self
}
pub fn http_client(mut self, client: reqwest::Client) -> Self {
self.http_client = Some(client);
self
}
pub fn build(self) -> CirrusResult<JwtAuth> {
let consumer_key = self
.consumer_key
.ok_or(CirrusError::MissingField("consumer_key"))?;
let username = self.username.ok_or(CirrusError::MissingField("username"))?;
let encoding_key = self
.encoding_key
.ok_or(CirrusError::MissingField("private_key"))?;
let mut instance_url = self
.instance_url
.ok_or(CirrusError::MissingField("instance_url"))?;
if instance_url.ends_with('/') {
instance_url.pop();
}
let mut login_url = self
.login_url
.unwrap_or_else(|| PRODUCTION_LOGIN_URL.to_string());
if login_url.ends_with('/') {
login_url.pop();
}
let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
let http = self.http_client.unwrap_or_default();
Ok(JwtAuth {
consumer_key,
username,
encoding_key,
login_url,
instance_url,
token_ttl,
http,
cached: RwLock::new(None),
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use wiremock::matchers::{body_string_contains, method, path};
use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
const TEST_PEM: &[u8] = include_bytes!("../tests/fixtures/test_rsa_key.pem");
fn builder_with_required_fields() -> JwtAuthBuilder {
JwtAuth::builder()
.consumer_key("consumer-key-123")
.username("integration@example.com")
.private_key_pem_bytes(TEST_PEM)
.unwrap()
.instance_url("https://my-org.my.salesforce.com")
}
#[test]
fn builder_requires_consumer_key() {
let err = JwtAuth::builder()
.username("u")
.private_key_pem_bytes(TEST_PEM)
.unwrap()
.instance_url("https://x")
.build()
.unwrap_err();
assert!(matches!(err, CirrusError::MissingField("consumer_key")));
}
#[test]
fn builder_requires_username() {
let err = JwtAuth::builder()
.consumer_key("k")
.private_key_pem_bytes(TEST_PEM)
.unwrap()
.instance_url("https://x")
.build()
.unwrap_err();
assert!(matches!(err, CirrusError::MissingField("username")));
}
#[test]
fn builder_requires_private_key() {
let err = JwtAuth::builder()
.consumer_key("k")
.username("u")
.instance_url("https://x")
.build()
.unwrap_err();
assert!(matches!(err, CirrusError::MissingField("private_key")));
}
#[test]
fn builder_requires_instance_url() {
let err = JwtAuth::builder()
.consumer_key("k")
.username("u")
.private_key_pem_bytes(TEST_PEM)
.unwrap()
.build()
.unwrap_err();
assert!(matches!(err, CirrusError::MissingField("instance_url")));
}
#[test]
fn invalid_pem_is_surfaced_as_auth_error() {
let err = JwtAuth::builder()
.private_key_pem_bytes(b"not a pem")
.unwrap_err();
assert!(matches!(err, CirrusError::Auth(_)));
}
#[test]
fn builder_strips_trailing_slashes_and_defaults_login_url() {
let auth = builder_with_required_fields()
.instance_url("https://my-org.my.salesforce.com/")
.build()
.unwrap();
assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
assert_eq!(auth.login_url, PRODUCTION_LOGIN_URL);
}
#[tokio::test]
async fn mint_token_succeeds_and_caches() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
let body = serde_json::json!({
"access_token": "00DXX!ACCESS",
"instance_url": "https://my-org.my.salesforce.com",
"token_type": "Bearer",
"scope": "api",
"id": "https://login.salesforce.com/id/00DXX/005XX",
});
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.and(body_string_contains("grant_type=urn"))
.and(body_string_contains("assertion="))
.respond_with(CountingResponder {
hits: hits.clone(),
response: ResponseTemplate::new(200).set_body_json(body),
})
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.build()
.unwrap();
let t1 = auth.access_token().await.unwrap();
assert_eq!(&*t1, "00DXX!ACCESS");
let t2 = auth.access_token().await.unwrap();
assert_eq!(&*t2, "00DXX!ACCESS");
assert_eq!(hits.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn expired_cache_remints_token() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(CountingResponder {
hits: hits.clone(),
response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "tok",
"instance_url": "https://my-org.my.salesforce.com"
})),
})
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.token_ttl(Duration::ZERO) .build()
.unwrap();
let _ = auth.access_token().await.unwrap();
let _ = auth.access_token().await.unwrap();
let _ = auth.access_token().await.unwrap();
assert_eq!(hits.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn oauth_error_response_is_surfaced() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
"error": "invalid_grant",
"error_description": "user hasn't approved this consumer"
})))
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.build()
.unwrap();
let err = auth.access_token().await.unwrap_err();
match err {
CirrusError::OAuth {
error,
error_description,
} => {
assert_eq!(error, "invalid_grant");
assert!(error_description.is_some());
}
other => panic!("expected OAuth error, got {other:?}"),
}
}
#[tokio::test]
async fn instance_url_mismatch_is_an_auth_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "tok",
"instance_url": "https://different-org.my.salesforce.com"
})))
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.build()
.unwrap();
let err = auth.access_token().await.unwrap_err();
assert!(matches!(err, CirrusError::Auth(_)));
}
#[tokio::test]
async fn invalidate_clears_cache_only_when_stale_token_matches() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
let body = serde_json::json!({
"access_token": "T1",
"instance_url": "https://my-org.my.salesforce.com",
"token_type": "Bearer",
});
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(CountingResponder {
hits: hits.clone(),
response: ResponseTemplate::new(200).set_body_json(body),
})
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.build()
.unwrap();
let t = auth.access_token().await.unwrap();
assert_eq!(&*t, "T1");
assert_eq!(hits.load(Ordering::SeqCst), 1);
drop(t);
auth.invalidate("not-the-cached-token").await;
let t = auth.access_token().await.unwrap();
assert_eq!(&*t, "T1");
assert_eq!(hits.load(Ordering::SeqCst), 1);
drop(t);
auth.invalidate("T1").await;
let t = auth.access_token().await.unwrap();
assert_eq!(&*t, "T1"); assert_eq!(hits.load(Ordering::SeqCst), 2);
}
struct CountingResponder {
hits: Arc<AtomicUsize>,
response: ResponseTemplate,
}
impl Respond for CountingResponder {
fn respond(&self, _: &Request) -> ResponseTemplate {
self.hits.fetch_add(1, Ordering::SeqCst);
self.response.clone()
}
}
}