use serde::de::DeserializeOwned;
use std::time::Duration;
use crate::error::{Error, ErrorKind, Result};
#[derive(Debug)]
pub struct Response {
inner: reqwest::Response,
}
impl Response {
pub(crate) fn new(inner: reqwest::Response) -> Self {
Self { inner }
}
pub fn status(&self) -> u16 {
self.inner.status().as_u16()
}
pub fn is_success(&self) -> bool {
self.inner.status().is_success()
}
pub fn is_not_modified(&self) -> bool {
self.inner.status().as_u16() == 304
}
pub fn header(&self, name: &str) -> Option<&str> {
self.inner.headers().get(name)?.to_str().ok()
}
pub fn etag(&self) -> Option<&str> {
self.header("etag")
}
pub fn last_modified(&self) -> Option<&str> {
self.header("last-modified")
}
pub fn retry_after(&self) -> Option<Duration> {
let value = self.header("retry-after")?;
if let Ok(seconds) = value.parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
None
}
pub fn sforce_locator(&self) -> Option<&str> {
self.header("sforce-locator")
}
pub fn content_type(&self) -> Option<&str> {
self.header("content-type")
}
pub async fn text(self) -> Result<String> {
self.inner.text().await.map_err(Into::into)
}
pub async fn bytes(self) -> Result<bytes::Bytes> {
self.inner.bytes().await.map_err(Into::into)
}
pub async fn json<T: DeserializeOwned>(self) -> Result<T> {
self.inner.json().await.map_err(Into::into)
}
pub fn into_inner(self) -> reqwest::Response {
self.inner
}
pub fn api_usage(&self) -> Option<ApiUsage> {
let info = self.header("sforce-limit-info")?;
for part in info.split(',') {
let part = part.trim();
if part.starts_with("api-usage=") {
let usage = part.trim_start_matches("api-usage=");
let parts: Vec<&str> = usage.split('/').collect();
if parts.len() == 2 {
let used = parts[0].parse().ok()?;
let limit = parts[1].parse().ok()?;
return Some(ApiUsage { used, limit });
}
}
}
None
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ApiUsage {
pub used: u64,
pub limit: u64,
}
impl ApiUsage {
pub fn remaining(&self) -> u64 {
self.limit.saturating_sub(self.used)
}
pub fn percentage(&self) -> f64 {
if self.limit == 0 {
100.0
} else {
(self.used as f64 / self.limit as f64) * 100.0
}
}
pub fn is_above_threshold(&self, threshold_percent: f64) -> bool {
self.percentage() >= threshold_percent
}
}
pub trait ResponseExt {
fn check_salesforce_error(self) -> impl std::future::Future<Output = Result<Response>> + Send;
}
impl ResponseExt for Response {
async fn check_salesforce_error(self) -> Result<Response> {
let status = self.status();
if self.is_success() || self.is_not_modified() {
return Ok(self);
}
let body = self.text().await.unwrap_or_default();
if status == 429 {
return Err(Error::new(ErrorKind::RateLimited { retry_after: None }));
}
if let Ok(errors) = serde_json::from_str::<Vec<SalesforceErrorResponse>>(&body) {
if let Some(err) = errors.into_iter().next() {
return Err(Error::new(ErrorKind::SalesforceApi {
error_code: err.error_code,
message: sanitize_error_message(&err.message),
fields: err.fields.unwrap_or_default(),
}));
}
}
if let Ok(err) = serde_json::from_str::<SalesforceErrorResponse>(&body) {
return Err(Error::new(ErrorKind::SalesforceApi {
error_code: err.error_code,
message: sanitize_error_message(&err.message),
fields: err.fields.unwrap_or_default(),
}));
}
let sanitized = sanitize_error_message(&body);
let kind = match status {
401 => ErrorKind::Authentication(sanitized),
403 => ErrorKind::Authorization(sanitized),
404 => ErrorKind::NotFound(sanitized),
412 => ErrorKind::PreconditionFailed(sanitized),
_ => ErrorKind::Http {
status,
message: sanitized,
},
};
Err(Error::new(kind))
}
}
fn sanitize_error_message(message: &str) -> String {
const MAX_LENGTH: usize = 500;
let mut sanitized = message.to_string();
let token_pattern = regex_lite::Regex::new(r"00[A-Za-z0-9]{13,}[!][A-Za-z0-9_.]+").unwrap();
sanitized = token_pattern
.replace_all(&sanitized, "[REDACTED_TOKEN]")
.to_string();
let session_pattern = regex_lite::Regex::new(r"sid=[A-Za-z0-9]{20,}").unwrap();
sanitized = session_pattern
.replace_all(&sanitized, "sid=[REDACTED]")
.to_string();
if sanitized.len() > MAX_LENGTH {
sanitized.truncate(MAX_LENGTH);
sanitized.push_str("...[truncated]");
}
sanitized
}
#[derive(Debug, serde::Deserialize)]
struct SalesforceErrorResponse {
#[serde(alias = "errorCode")]
error_code: String,
message: String,
fields: Option<Vec<String>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_api_usage() {
let usage = ApiUsage {
used: 100,
limit: 1000,
};
assert_eq!(usage.remaining(), 900);
assert!((usage.percentage() - 10.0).abs() < 0.001);
assert!(!usage.is_above_threshold(50.0));
assert!(usage.is_above_threshold(5.0));
}
#[test]
fn test_api_usage_edge_cases() {
let usage = ApiUsage {
used: 1000,
limit: 1000,
};
assert_eq!(usage.remaining(), 0);
assert!((usage.percentage() - 100.0).abs() < 0.001);
let usage = ApiUsage { used: 0, limit: 0 };
assert_eq!(usage.remaining(), 0);
assert!((usage.percentage() - 100.0).abs() < 0.001);
}
#[test]
fn test_sanitize_redacts_access_tokens() {
let msg = "Session expired: 00Dxx0000001gEF!AQcAQH3k9s7LKbp_example_token_value.here";
let sanitized = sanitize_error_message(msg);
assert!(
sanitized.contains("[REDACTED_TOKEN]"),
"Should redact token: {sanitized}"
);
assert!(
!sanitized.contains("AQcAQH3k9s7LKbp"),
"Should not contain token value: {sanitized}"
);
}
#[test]
fn test_sanitize_redacts_session_ids() {
let msg = "Invalid session: sid=abc123def456ghi789jkl012";
let sanitized = sanitize_error_message(msg);
assert!(
sanitized.contains("sid=[REDACTED]"),
"Should redact session ID: {sanitized}"
);
assert!(
!sanitized.contains("abc123def456"),
"Should not contain session ID value: {sanitized}"
);
}
#[test]
fn test_sanitize_truncates_long_messages() {
let long_msg = "x".repeat(600);
let sanitized = sanitize_error_message(&long_msg);
assert!(
sanitized.len() < 600,
"Should be truncated: len={}",
sanitized.len()
);
assert!(
sanitized.ends_with("...[truncated]"),
"Should end with truncation marker: {sanitized}"
);
}
#[test]
fn test_sanitize_passes_through_clean_messages() {
let msg = "No such column 'foo' on entity 'Account'";
assert_eq!(sanitize_error_message(msg), msg);
}
#[test]
fn test_sanitize_redacts_multiple_tokens() {
let msg = "Token1: 00Dxx0000001gEF!token1_value and Token2: 00Dyy0000002gEF!token2_value";
let sanitized = sanitize_error_message(msg);
assert!(
!sanitized.contains("token1_value"),
"Should redact first token"
);
assert!(
!sanitized.contains("token2_value"),
"Should redact second token"
);
}
#[test]
fn test_salesforce_error_response_array_format() {
let json = r#"[{"errorCode":"INVALID_FIELD","message":"No such column","fields":["Foo"]}]"#;
let errors: Vec<SalesforceErrorResponse> = serde_json::from_str(json).unwrap();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].error_code, "INVALID_FIELD");
assert_eq!(errors[0].message, "No such column");
assert_eq!(errors[0].fields, Some(vec!["Foo".to_string()]));
}
#[test]
fn test_salesforce_error_response_single_object() {
let json = r#"{"errorCode":"NOT_FOUND","message":"The requested resource does not exist"}"#;
let err: SalesforceErrorResponse = serde_json::from_str(json).unwrap();
assert_eq!(err.error_code, "NOT_FOUND");
assert_eq!(err.message, "The requested resource does not exist");
assert!(err.fields.is_none());
}
#[test]
fn test_salesforce_error_response_with_error_code_alias() {
let json = r#"{"errorCode":"MALFORMED_QUERY","message":"unexpected token"}"#;
let err: SalesforceErrorResponse = serde_json::from_str(json).unwrap();
assert_eq!(err.error_code, "MALFORMED_QUERY");
}
#[test]
fn test_salesforce_error_response_empty_array() {
let json = "[]";
let errors: Vec<SalesforceErrorResponse> = serde_json::from_str(json).unwrap();
assert!(errors.is_empty());
}
#[test]
fn test_salesforce_error_response_multiple_errors() {
let json = r#"[
{"errorCode":"REQUIRED_FIELD_MISSING","message":"Required fields missing","fields":["Name","Email"]},
{"errorCode":"FIELD_CUSTOM_VALIDATION_EXCEPTION","message":"Must be positive"}
]"#;
let errors: Vec<SalesforceErrorResponse> = serde_json::from_str(json).unwrap();
assert_eq!(errors.len(), 2);
assert_eq!(
errors[0].fields,
Some(vec!["Name".to_string(), "Email".to_string()])
);
assert!(errors[1].fields.is_none());
}
}