use crate::auth::AuthSession;
use crate::auth::token_endpoint::{check_instance_url, exchange};
use crate::error::{CirrusError, CirrusResult};
use async_trait::async_trait;
use std::borrow::Cow;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
pub const PRODUCTION_LOGIN_URL: &str = "https://login.salesforce.com";
pub const SANDBOX_LOGIN_URL: &str = "https://test.salesforce.com";
const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(30 * 60);
#[derive(Debug, Clone)]
struct CachedToken {
access_token: String,
expires_at: Instant,
}
pub struct RefreshTokenAuth {
consumer_key: String,
consumer_secret: Option<String>,
refresh_token: String,
login_url: String,
instance_url: String,
token_ttl: Duration,
http: reqwest::Client,
cached: RwLock<Option<CachedToken>>,
}
impl std::fmt::Debug for RefreshTokenAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RefreshTokenAuth")
.field("login_url", &self.login_url)
.field("instance_url", &self.instance_url)
.field("token_ttl", &self.token_ttl)
.field("confidential", &self.consumer_secret.is_some())
.finish_non_exhaustive()
}
}
impl RefreshTokenAuth {
pub fn builder() -> RefreshTokenAuthBuilder {
RefreshTokenAuthBuilder::default()
}
async fn mint_token(&self) -> CirrusResult<CachedToken> {
tracing::info!(
target: "cirrus::auth",
flow = "refresh-token",
login_url = %self.login_url,
"minting fresh access token",
);
let mut body: Vec<(&str, &str)> = vec![
("grant_type", "refresh_token"),
("client_id", self.consumer_key.as_str()),
("refresh_token", self.refresh_token.as_str()),
];
if let Some(secret) = self.consumer_secret.as_deref() {
body.push(("client_secret", secret));
}
let token = exchange(&self.http, &self.login_url, &body).await?;
check_instance_url(&self.instance_url, &token)?;
Ok(CachedToken {
access_token: token.access_token,
expires_at: Instant::now() + self.token_ttl,
})
}
}
#[async_trait]
impl AuthSession for RefreshTokenAuth {
async fn access_token(&self) -> CirrusResult<Cow<'_, str>> {
{
let guard = self.cached.read().await;
if let Some(cached) = guard.as_ref()
&& cached.expires_at > Instant::now()
{
return Ok(Cow::Owned(cached.access_token.clone()));
}
}
let mut guard = self.cached.write().await;
if let Some(cached) = guard.as_ref()
&& cached.expires_at > Instant::now()
{
return Ok(Cow::Owned(cached.access_token.clone()));
}
let new_token = self.mint_token().await?;
let token_str = new_token.access_token.clone();
*guard = Some(new_token);
Ok(Cow::Owned(token_str))
}
fn instance_url(&self) -> &str {
&self.instance_url
}
async fn invalidate(&self, stale_token: &str) {
let mut guard = self.cached.write().await;
if let Some(cached) = guard.as_ref()
&& cached.access_token == stale_token
{
tracing::debug!(
target: "cirrus::auth",
flow = "refresh-token",
"invalidating cached token (CAS matched)",
);
*guard = None;
} else {
tracing::trace!(
target: "cirrus::auth",
flow = "refresh-token",
"invalidate called but cached token differs (concurrent refresh?); no-op",
);
}
}
}
#[derive(Default)]
pub struct RefreshTokenAuthBuilder {
consumer_key: Option<String>,
consumer_secret: Option<String>,
refresh_token: Option<String>,
login_url: Option<String>,
instance_url: Option<String>,
token_ttl: Option<Duration>,
http_client: Option<reqwest::Client>,
}
impl std::fmt::Debug for RefreshTokenAuthBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RefreshTokenAuthBuilder")
.field("consumer_key", &self.consumer_key.is_some())
.field("consumer_secret", &self.consumer_secret.is_some())
.field("refresh_token", &self.refresh_token.is_some())
.field("login_url", &self.login_url)
.field("instance_url", &self.instance_url)
.field("token_ttl", &self.token_ttl)
.finish_non_exhaustive()
}
}
impl RefreshTokenAuthBuilder {
pub fn consumer_key(mut self, key: impl Into<String>) -> Self {
self.consumer_key = Some(key.into());
self
}
pub fn consumer_secret(mut self, secret: impl Into<String>) -> Self {
self.consumer_secret = Some(secret.into());
self
}
pub fn refresh_token(mut self, token: impl Into<String>) -> Self {
self.refresh_token = Some(token.into());
self
}
pub fn login_url(mut self, url: impl Into<String>) -> Self {
self.login_url = Some(url.into());
self
}
pub fn instance_url(mut self, url: impl Into<String>) -> Self {
self.instance_url = Some(url.into());
self
}
pub fn token_ttl(mut self, ttl: Duration) -> Self {
self.token_ttl = Some(ttl);
self
}
pub fn http_client(mut self, client: reqwest::Client) -> Self {
self.http_client = Some(client);
self
}
pub fn build(self) -> CirrusResult<RefreshTokenAuth> {
let consumer_key = self
.consumer_key
.ok_or(CirrusError::MissingField("consumer_key"))?;
let refresh_token = self
.refresh_token
.ok_or(CirrusError::MissingField("refresh_token"))?;
let mut instance_url = self
.instance_url
.ok_or(CirrusError::MissingField("instance_url"))?;
if instance_url.ends_with('/') {
instance_url.pop();
}
let mut login_url = self
.login_url
.unwrap_or_else(|| PRODUCTION_LOGIN_URL.to_string());
if login_url.ends_with('/') {
login_url.pop();
}
let token_ttl = self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL);
let http = self.http_client.unwrap_or_default();
Ok(RefreshTokenAuth {
consumer_key,
consumer_secret: self.consumer_secret,
refresh_token,
login_url,
instance_url,
token_ttl,
http,
cached: RwLock::new(None),
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use wiremock::matchers::{body_string_contains, method, path};
use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
fn builder_with_required_fields() -> RefreshTokenAuthBuilder {
RefreshTokenAuth::builder()
.consumer_key("consumer-key-123")
.refresh_token("5Aep861KIwKdekr...refresh")
.instance_url("https://my-org.my.salesforce.com")
}
#[test]
fn builder_requires_consumer_key() {
let err = RefreshTokenAuth::builder()
.refresh_token("r")
.instance_url("https://x")
.build()
.unwrap_err();
assert!(matches!(err, CirrusError::MissingField("consumer_key")));
}
#[test]
fn builder_requires_refresh_token() {
let err = RefreshTokenAuth::builder()
.consumer_key("k")
.instance_url("https://x")
.build()
.unwrap_err();
assert!(matches!(err, CirrusError::MissingField("refresh_token")));
}
#[test]
fn builder_requires_instance_url() {
let err = RefreshTokenAuth::builder()
.consumer_key("k")
.refresh_token("r")
.build()
.unwrap_err();
assert!(matches!(err, CirrusError::MissingField("instance_url")));
}
#[test]
fn builder_strips_trailing_slashes_and_defaults_login_url() {
let auth = builder_with_required_fields()
.instance_url("https://my-org.my.salesforce.com/")
.build()
.unwrap();
assert_eq!(auth.instance_url(), "https://my-org.my.salesforce.com");
assert_eq!(auth.login_url, PRODUCTION_LOGIN_URL);
}
#[tokio::test]
async fn refresh_succeeds_and_caches() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.and(body_string_contains("grant_type=refresh_token"))
.and(body_string_contains("client_id=consumer-key-123"))
.and(body_string_contains("refresh_token=5Aep861KIwKdekr"))
.respond_with(CountingResponder {
hits: hits.clone(),
response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "00DXX!ACCESS",
"instance_url": "https://my-org.my.salesforce.com",
"token_type": "Bearer",
"id": "https://login.salesforce.com/id/00DXX/005XX",
})),
})
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.build()
.unwrap();
let t1 = auth.access_token().await.unwrap();
assert_eq!(&*t1, "00DXX!ACCESS");
let t2 = auth.access_token().await.unwrap();
assert_eq!(&*t2, "00DXX!ACCESS");
assert_eq!(hits.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn confidential_client_includes_consumer_secret() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.and(body_string_contains("client_secret=top-secret"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "tok",
"instance_url": "https://my-org.my.salesforce.com"
})))
.mount(&server)
.await;
let auth = builder_with_required_fields()
.consumer_secret("top-secret")
.login_url(server.uri())
.build()
.unwrap();
auth.access_token().await.unwrap();
}
#[tokio::test]
async fn public_client_omits_consumer_secret() {
let server = MockServer::start().await;
let received_body = Arc::new(tokio::sync::Mutex::new(String::new()));
let captured = received_body.clone();
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(BodyCapturingResponder {
captured,
response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "tok",
"instance_url": "https://my-org.my.salesforce.com"
})),
})
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.build()
.unwrap();
auth.access_token().await.unwrap();
let body = received_body.lock().await;
assert!(
!body.contains("client_secret"),
"public client should not send client_secret, got: {body}"
);
}
#[tokio::test]
async fn expired_cache_remints_token() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(CountingResponder {
hits: hits.clone(),
response: ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "tok",
"instance_url": "https://my-org.my.salesforce.com"
})),
})
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.token_ttl(Duration::ZERO)
.build()
.unwrap();
let _ = auth.access_token().await.unwrap();
let _ = auth.access_token().await.unwrap();
let _ = auth.access_token().await.unwrap();
assert_eq!(hits.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn revoked_refresh_token_surfaces_oauth_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
"error": "invalid_grant",
"error_description": "expired authorization code"
})))
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.build()
.unwrap();
let err = auth.access_token().await.unwrap_err();
match err {
CirrusError::OAuth {
error,
error_description,
} => {
assert_eq!(error, "invalid_grant");
assert!(error_description.is_some());
}
other => panic!("expected OAuth error, got {other:?}"),
}
}
#[tokio::test]
async fn instance_url_mismatch_is_an_auth_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/services/oauth2/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "tok",
"instance_url": "https://wrong-org.my.salesforce.com"
})))
.mount(&server)
.await;
let auth = builder_with_required_fields()
.login_url(server.uri())
.build()
.unwrap();
let err = auth.access_token().await.unwrap_err();
assert!(matches!(err, CirrusError::Auth(_)));
}
struct CountingResponder {
hits: Arc<AtomicUsize>,
response: ResponseTemplate,
}
impl Respond for CountingResponder {
fn respond(&self, _: &Request) -> ResponseTemplate {
self.hits.fetch_add(1, Ordering::SeqCst);
self.response.clone()
}
}
struct BodyCapturingResponder {
captured: Arc<tokio::sync::Mutex<String>>,
response: ResponseTemplate,
}
impl Respond for BodyCapturingResponder {
fn respond(&self, request: &Request) -> ResponseTemplate {
let body = String::from_utf8_lossy(&request.body).into_owned();
if let Ok(mut guard) = self.captured.try_lock() {
*guard = body;
}
self.response.clone()
}
}
}