use crate::error::{HttpError, Result};
use chrono::{DateTime, Duration, Utc};
use reqwest::header::HeaderValue;
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub instance_url: String,
#[serde(default = "default_token_type")]
pub token_type: String,
#[serde(default = "default_issued_at")]
pub issued_at: String,
#[serde(default)]
pub signature: String,
#[serde(default)]
pub expires_in: Option<u64>,
#[serde(default)]
pub refresh_token: Option<String>,
}
pub fn default_token_type() -> String {
"Bearer".to_string()
}
fn default_issued_at() -> String {
Utc::now().timestamp_millis().to_string()
}
#[derive(Debug, Clone)]
pub struct AccessToken {
token: SecretString,
issued_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
instance_url: String,
token_type: String,
auth_header: Option<HeaderValue>,
}
impl AccessToken {
#[must_use]
pub fn from_response(response: TokenResponse) -> Self {
let issued_at = parse_issued_at(&response.issued_at).unwrap_or_else(|_| Utc::now());
let expires_at = calculate_expiration(issued_at, response.expires_in);
let auth_header = create_auth_header(&response.token_type, &response.access_token);
Self {
token: SecretString::new(response.access_token.into()),
issued_at,
expires_at,
instance_url: response.instance_url,
token_type: response.token_type,
auth_header,
}
}
#[cfg(test)]
pub fn new(token: String, instance_url: String, expires_at: Option<DateTime<Utc>>) -> Self {
let auth_header = create_auth_header("Bearer", &token);
Self {
token: SecretString::new(token.into()),
issued_at: Utc::now(),
expires_at,
instance_url,
token_type: "Bearer".to_string(),
auth_header,
}
}
#[must_use]
pub fn as_str(&self) -> &str {
self.token.expose_secret()
}
#[must_use]
pub fn instance_url(&self) -> &str {
&self.instance_url
}
#[must_use]
pub fn token_type(&self) -> &str {
&self.token_type
}
#[must_use]
pub fn is_expired(&self) -> bool {
self.is_soft_expired()
}
#[must_use]
pub fn is_hard_expired(&self) -> bool {
self.is_expired_with_buffer(Duration::zero())
}
#[must_use]
pub fn is_soft_expired(&self) -> bool {
self.is_expired_with_buffer(Duration::seconds(60))
}
#[must_use]
pub fn is_expired_with_buffer(&self, buffer: Duration) -> bool {
self.expires_at
.is_some_and(|expires_at| Utc::now() + buffer >= expires_at)
}
#[must_use]
pub const fn issued_at(&self) -> DateTime<Utc> {
self.issued_at
}
#[must_use]
pub const fn expires_at(&self) -> Option<DateTime<Utc>> {
self.expires_at
}
pub fn auth_header(&self) -> std::result::Result<&HeaderValue, HttpError> {
self.auth_header
.as_ref()
.ok_or_else(|| HttpError::InvalidUrl("invalid authorization header".to_string()))
}
}
fn calculate_expiration(
issued_at: DateTime<Utc>,
expires_in: Option<u64>,
) -> Option<DateTime<Utc>> {
expires_in.and_then(|seconds| {
if seconds > 3_000_000_000 {
return None;
}
let Ok(seconds_i64) = i64::try_from(seconds) else {
return None;
};
let duration = Duration::seconds(seconds_i64);
issued_at.checked_add_signed(duration)
})
}
fn create_auth_header(token_type: &str, access_token: &str) -> Option<HeaderValue> {
let mut header = HeaderValue::from_str(&format!("{} {}", token_type, access_token)).ok();
if let Some(h) = &mut header {
h.set_sensitive(true);
}
header
}
fn parse_issued_at(issued_at: &str) -> Result<DateTime<Utc>> {
let timestamp_ms = issued_at.parse::<i64>().map_err(|_| {
crate::error::ForceError::Serialization(crate::error::SerializationError::InvalidFormat(
format!("invalid issued_at timestamp: {issued_at}"),
))
})?;
DateTime::from_timestamp_millis(timestamp_ms).ok_or_else(|| {
crate::error::ForceError::Serialization(crate::error::SerializationError::InvalidFormat(
format!("timestamp out of range: {timestamp_ms}"),
))
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::Must;
#[test]
fn test_token_response_deserialization() {
let json = r#"{
"access_token": "00D123456789!token",
"instance_url": "https://example.my.salesforce.com",
"token_type": "Bearer",
"issued_at": "1704067200000",
"signature": "signature_value"
}"#;
let response: TokenResponse = serde_json::from_str(json).must();
assert_eq!(response.access_token, "00D123456789!token");
assert_eq!(response.instance_url, "https://example.my.salesforce.com");
assert_eq!(response.token_type, "Bearer");
}
#[test]
fn test_token_response_with_expires_in() {
let json = r#"{
"access_token": "token123",
"instance_url": "https://test.salesforce.com",
"issued_at": "1704067200000",
"expires_in": 7200
}"#;
let response: TokenResponse = serde_json::from_str(json).must();
assert_eq!(response.expires_in, Some(7200));
assert_eq!(response.token_type, "Bearer"); }
#[test]
fn test_access_token_from_response() {
let response = TokenResponse {
access_token: "test_token".to_string(),
instance_url: "https://example.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: "1704067200000".to_string(),
signature: String::new(),
expires_in: Some(3600),
refresh_token: None,
};
let token = AccessToken::from_response(response);
assert_eq!(token.as_str(), "test_token");
assert_eq!(token.instance_url(), "https://example.salesforce.com");
assert_eq!(token.token_type(), "Bearer");
assert!(token.expires_at().is_some());
}
#[test]
fn test_access_token_is_expired() {
let expires_at = Utc::now() - Duration::hours(1);
let token = AccessToken::new(
"expired_token".to_string(),
"https://test.salesforce.com".to_string(),
Some(expires_at),
);
assert!(token.is_expired());
}
#[test]
fn test_access_token_not_expired() {
let expires_at = Utc::now() + Duration::hours(2);
let token = AccessToken::new(
"valid_token".to_string(),
"https://test.salesforce.com".to_string(),
Some(expires_at),
);
assert!(!token.is_expired());
}
#[test]
fn test_access_token_expiring_soon() {
let expires_at = Utc::now() + Duration::seconds(30);
let token = AccessToken::new(
"expiring_token".to_string(),
"https://test.salesforce.com".to_string(),
Some(expires_at),
);
assert!(token.is_expired()); }
#[test]
fn test_access_token_no_expiration() {
let token = AccessToken::new(
"no_expiry_token".to_string(),
"https://test.salesforce.com".to_string(),
None,
);
assert!(!token.is_expired());
}
#[test]
fn test_access_token_custom_buffer() {
let expires_at = Utc::now() + Duration::minutes(5);
let token = AccessToken::new(
"token".to_string(),
"https://test.salesforce.com".to_string(),
Some(expires_at),
);
assert!(!token.is_expired_with_buffer(Duration::minutes(1)));
assert!(token.is_expired_with_buffer(Duration::minutes(10)));
}
#[test]
fn test_access_token_hard_vs_soft_expiry() {
let expires_at = Utc::now() + Duration::seconds(30);
let token = AccessToken::new(
"token".to_string(),
"https://test.salesforce.com".to_string(),
Some(expires_at),
);
assert!(token.is_soft_expired());
assert!(!token.is_hard_expired());
let past_expiration = Utc::now() - Duration::seconds(1);
let expired_token = AccessToken::new(
"expired".to_string(),
"https://test.salesforce.com".to_string(),
Some(past_expiration),
);
assert!(expired_token.is_soft_expired());
assert!(expired_token.is_hard_expired());
}
#[test]
fn test_parse_issued_at_valid() {
let timestamp = "1704067200000"; let result = parse_issued_at(timestamp);
assert!(result.is_ok());
}
#[test]
fn test_parse_issued_at_invalid() {
let timestamp = "not_a_number";
let result = parse_issued_at(timestamp);
let Err(err) = result else {
panic!("Expected an error");
};
assert!(err.to_string().contains(""));
}
#[test]
fn test_parse_issued_at_negative() {
let timestamp = "-1000"; let result = parse_issued_at(timestamp);
assert!(result.is_ok());
let dt = result.must();
assert_eq!(dt.timestamp(), -1);
}
#[test]
fn test_parse_issued_at_empty() {
let timestamp = "";
let result = parse_issued_at(timestamp);
assert!(matches!(
result,
Err(crate::error::ForceError::Serialization(_))
));
}
#[test]
fn test_access_token_expires_in_overflow() {
let response = TokenResponse {
access_token: "test_token".to_string(),
instance_url: "https://example.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: "1704067200000".to_string(),
signature: String::new(),
expires_in: Some(u64::MAX), refresh_token: None,
};
let token = AccessToken::from_response(response);
assert!(token.expires_at.is_none());
}
#[test]
fn test_access_token_expires_in_cap() {
let large_seconds = 4_000_000_000_u64; let response = TokenResponse {
access_token: "test_token".to_string(),
instance_url: "https://example.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: "1704067200000".to_string(),
signature: String::new(),
expires_in: Some(large_seconds),
refresh_token: None,
};
let token = AccessToken::from_response(response);
assert!(token.expires_at.is_none());
}
#[test]
fn test_access_token_expires_in_cap_boundary() {
let boundary_seconds = 3_000_000_000_u64;
let response = TokenResponse {
access_token: "test_token".to_string(),
instance_url: "https://example.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: "1704067200000".to_string(),
signature: String::new(),
expires_in: Some(boundary_seconds),
refresh_token: None,
};
let token = AccessToken::from_response(response);
assert!(token.expires_at.is_some());
}
#[test]
fn test_parse_issued_at_milliseconds_precision() {
let timestamp = "1704067200500";
let result = parse_issued_at(timestamp).must();
assert_eq!(result.timestamp_subsec_millis(), 500);
}
#[test]
fn test_access_token_from_response_invalid_issued_at() {
let response = TokenResponse {
access_token: "test_token".to_string(),
instance_url: "https://example.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: "garbage".to_string(),
signature: String::new(),
expires_in: Some(3600),
refresh_token: None,
};
let token = AccessToken::from_response(response);
let now = Utc::now();
let diff = (now - token.issued_at).num_seconds().abs();
assert!(
diff < 5,
"Should fallback to current time when issued_at is invalid"
);
}
#[test]
fn test_access_token_invalid_header_chars() {
let response = TokenResponse {
access_token: "token\nwith\nnewlines".to_string(),
instance_url: "https://example.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: "1704067200000".to_string(),
signature: String::new(),
expires_in: Some(3600),
refresh_token: None,
};
let token = AccessToken::from_response(response);
assert!(token.auth_header().is_err());
}
}