use crate::error::HttpError;
use futures::StreamExt;
use reqwest::Response;
#[derive(serde::Deserialize)]
struct SalesforceError {
#[serde(rename = "errorCode")]
error_code: Option<String>,
message: String,
#[serde(default)]
fields: Vec<String>,
}
pub fn parse_api_error(status_code: u16, body: &str) -> HttpError {
if let Ok(errors) = serde_json::from_str::<Vec<SalesforceError>>(body) {
if let Some(first_error) = errors.first() {
let code = first_error.error_code.as_deref().unwrap_or("UNKNOWN");
let mut cap = code.len() + first_error.message.len() + 4; if !first_error.fields.is_empty() {
cap += 11 + first_error.fields.iter().map(|f| f.len()).sum::<usize>(); if first_error.fields.len() > 1 {
cap += (first_error.fields.len() - 1) * 2; }
}
let mut message = String::with_capacity(cap);
message.push('[');
message.push_str(code);
message.push_str("] ");
message.push_str(&first_error.message);
if !first_error.fields.is_empty() {
message.push_str(" (fields: ");
let mut first = true;
for field in &first_error.fields {
if !first {
message.push_str(", ");
}
first = false;
message.push_str(field);
}
message.push(')');
}
return HttpError::StatusError {
status_code,
message,
};
}
}
HttpError::StatusError {
status_code,
message: body.to_string(),
}
}
pub async fn read_capped_body_bytes(
response: Response,
limit_bytes: usize,
) -> Result<Vec<u8>, HttpError> {
let mut stream = response.bytes_stream();
let init_cap = std::cmp::min(limit_bytes, 4096);
let mut bytes = Vec::with_capacity(init_cap);
while let Some(chunk) = stream.next().await {
if let Ok(chunk_bytes) = chunk {
let remaining = limit_bytes.saturating_sub(bytes.len());
if remaining == 0 {
return Err(HttpError::PayloadTooLarge { limit_bytes });
}
if chunk_bytes.len() > remaining {
return Err(HttpError::PayloadTooLarge { limit_bytes });
}
bytes.extend_from_slice(&chunk_bytes);
} else {
break;
}
}
Ok(bytes)
}
pub async fn read_capped_bytes(
response: Response,
limit_bytes: usize,
) -> Result<Vec<u8>, HttpError> {
read_capped_body_bytes(response, limit_bytes).await
}
pub async fn read_capped_body(response: Response, limit_bytes: usize) -> Result<String, HttpError> {
let bytes = read_capped_bytes(response, limit_bytes).await?;
Ok(String::from_utf8(bytes)
.unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()))
}
pub async fn response_to_force_error(
response: Response,
fallback_message: &str,
) -> crate::error::ForceError {
let status_code = response.status().as_u16();
let body = match read_capped_body(response, 1024 * 1024).await {
Ok(body) => body,
Err(e) => return crate::error::ForceError::Http(e),
};
let payload = if body.trim().is_empty() {
fallback_message.to_string()
} else {
body
};
parse_api_error(status_code, &payload).into()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_api_error_with_salesforce_format() {
let body =
r#"[{"errorCode":"INVALID_FIELD","message":"Field does not exist","fields":["Name"]}]"#;
let error = parse_api_error(400, body);
if let HttpError::StatusError {
status_code,
message,
} = error
{
assert_eq!(status_code, 400);
assert_eq!(
message,
"[INVALID_FIELD] Field does not exist (fields: Name)"
);
} else {
panic!("Expected StatusError");
}
}
#[test]
fn test_parse_api_error_fallback() {
let body = "Some error text";
let error = parse_api_error(500, body);
if let HttpError::StatusError {
status_code,
message,
} = error
{
assert_eq!(status_code, 500);
assert_eq!(message, "Some error text");
} else {
panic!("Expected StatusError");
}
}
#[test]
fn test_parse_api_error_object() {
let body = r#"{"errorCode":"INVALID_FIELD","message":"Field does not exist"}"#;
let error = parse_api_error(400, body);
if let HttpError::StatusError {
status_code,
message,
} = error
{
assert_eq!(status_code, 400);
assert_eq!(message, body);
} else {
panic!("Expected StatusError");
}
}
#[test]
fn test_parse_api_error_malformed() {
let body = "{malformed_json}";
let error = parse_api_error(500, body);
if let HttpError::StatusError {
status_code,
message,
} = error
{
assert_eq!(status_code, 500);
assert_eq!(message, body);
} else {
panic!("Expected StatusError");
}
}
#[test]
fn test_parse_api_error_empty() {
let body = "";
let error = parse_api_error(500, body);
if let HttpError::StatusError {
status_code,
message,
} = error
{
assert_eq!(status_code, 500);
assert_eq!(message, "");
} else {
panic!("Expected StatusError");
}
}
#[test]
fn test_parse_api_error_without_error_code() {
let body = r#"[{"message":"Field does not exist","fields":["Name"]}]"#;
let error = parse_api_error(400, body);
if let HttpError::StatusError {
status_code,
message,
} = error
{
assert_eq!(status_code, 400);
assert_eq!(message, "[UNKNOWN] Field does not exist (fields: Name)");
} else {
panic!("Expected StatusError");
}
}
}
#[cfg(all(test, feature = "mock"))]
mod integration_tests {
use super::*;
use crate::test_support::Must;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn test_response_to_force_error_payload_too_large() {
let mock_server = MockServer::start().await;
let large_body = "A".repeat(1024 * 1024 + 1024);
Mock::given(method("GET"))
.and(path("/error"))
.respond_with(ResponseTemplate::new(400).set_body_string(large_body))
.mount(&mock_server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/error", mock_server.uri());
let response = client.get(&url).send().await.must();
let error = response_to_force_error(response, "fallback").await;
if let crate::error::ForceError::Http(HttpError::PayloadTooLarge { limit_bytes }) = error {
assert_eq!(limit_bytes, 1024 * 1024);
} else {
panic!("Expected PayloadTooLarge error, got: {:?}", error);
}
}
#[tokio::test]
async fn test_response_to_force_error_does_not_truncate_medium_body() {
let mock_server = MockServer::start().await;
let medium_body = "A".repeat(5000);
Mock::given(method("GET"))
.and(path("/error"))
.respond_with(ResponseTemplate::new(400).set_body_string(medium_body.clone()))
.mount(&mock_server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/error", mock_server.uri());
let response = client.get(&url).send().await.must();
let error = response_to_force_error(response, "fallback").await;
assert_eq!(
error.to_string(),
format!("HTTP request failed: HTTP 400: {}", medium_body)
);
}
#[tokio::test]
async fn test_read_capped_body_bytes_payload_too_large() {
let mock_server = MockServer::start().await;
let large_body = "A".repeat(5000);
Mock::given(method("GET"))
.and(path("/bytes"))
.respond_with(ResponseTemplate::new(200).set_body_string(large_body))
.mount(&mock_server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/bytes", mock_server.uri());
let response = client.get(&url).send().await.must();
let result = read_capped_body_bytes(response, 4096).await;
if let Err(HttpError::PayloadTooLarge { limit_bytes }) = result {
assert_eq!(limit_bytes, 4096);
} else {
panic!("Expected PayloadTooLarge error, got: {:?}", result);
}
}
}