use std::time::Duration;
use http::StatusCode;
use serde_with::{DisplayFromStr, PickFirst, serde_as};
use crate::streamer;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("unauthorized: {0}")]
Unauthorized(ErrorBody),
#[error("not found: {0}")]
NotFound(ErrorBody),
#[error("rate limited: {body}")]
RateLimited {
retry_after: Option<Duration>,
body: ErrorBody,
},
#[error("http {status}: {body}")]
Http {
status: StatusCode,
retry_after: Option<Duration>,
body: ErrorBody,
},
#[error("transport: {0}")]
Transport(#[from] reqwest::Error),
#[error("websocket: {0}")]
WebSocket(#[from] streamer::WebSocketError),
#[error("codec {context}: {reason}")]
Codec {
context: String,
reason: String,
},
#[error("invalid preference {field}: {reason}")]
InvalidPreference {
field: &'static str,
reason: String,
},
#[error("order id unrecoverable: {0}")]
OrderIdUnrecoverable(String),
#[error("order response not representable as a request: {reason}")]
OrderResponseNotRepresentable {
reason: String,
},
#[error("token provider: {source}")]
TokenProvider {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("insecure base url {url}: {reason}")]
InsecureBaseUrl {
url: String,
reason: String,
},
}
impl Error {
pub(crate) fn from_status(
status: StatusCode,
retry_after: Option<Duration>,
body: ErrorBody,
) -> Error {
match status {
StatusCode::UNAUTHORIZED => Error::Unauthorized(body),
StatusCode::NOT_FOUND => Error::NotFound(body),
StatusCode::TOO_MANY_REQUESTS => Error::RateLimited { retry_after, body },
_ => Error::Http {
status,
retry_after,
body,
},
}
}
pub fn is_retryable(&self) -> bool {
match self {
Error::RateLimited { .. } => true,
Error::Http { status, .. } => status.is_server_error(),
Error::Transport(e) => e.is_timeout() || e.is_connect() || e.is_request(),
Error::WebSocket(e) => e.is_retryable(),
Error::Unauthorized(_)
| Error::NotFound(_)
| Error::Codec { .. }
| Error::InvalidPreference { .. }
| Error::OrderIdUnrecoverable(_)
| Error::OrderResponseNotRepresentable { .. }
| Error::TokenProvider { .. }
| Error::InsecureBaseUrl { .. } => false,
}
}
pub fn retry_after(&self) -> Option<Duration> {
match self {
Error::RateLimited { retry_after, .. } | Error::Http { retry_after, .. } => {
*retry_after
}
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ErrorBody {
Trader(ServiceError),
MarketData(ErrorResponse),
Unrecognized(String),
}
impl ErrorBody {
pub(crate) fn parse(raw: &str) -> Self {
if let Ok(trader) = serde_json::from_str::<ServiceError>(raw) {
ErrorBody::Trader(trader)
} else if let Ok(market_data) = serde_json::from_str::<ErrorResponse>(raw) {
ErrorBody::MarketData(market_data)
} else {
ErrorBody::Unrecognized(raw.to_string())
}
}
}
impl std::fmt::Display for ErrorBody {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorBody::Trader(e) => write!(f, "{e}"),
ErrorBody::MarketData(e) => write!(f, "{e}"),
ErrorBody::Unrecognized(raw) => write!(f, "{raw}"),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct ServiceError {
pub message: String,
#[serde(default)]
pub errors: Vec<String>,
}
impl std::fmt::Display for ServiceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)?;
if !self.errors.is_empty() {
for (i, error) in self.errors.iter().enumerate() {
let sep = if i == 0 { ": " } else { "; " };
write!(f, "{sep}{error}")?;
}
}
Ok(())
}
}
#[derive(Debug, Clone, serde::Deserialize, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct ErrorResponse {
#[serde(default)]
pub errors: Vec<ApiError>,
}
impl std::fmt::Display for ErrorResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.errors.is_empty() {
return write!(f, "no error detail");
}
for (i, error) in self.errors.iter().enumerate() {
if i > 0 {
write!(f, "; ")?;
}
write!(f, "{error}")?;
}
Ok(())
}
}
#[serde_as]
#[derive(Debug, Clone, serde::Deserialize, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct ApiError {
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
#[serde_as(as = "Option<PickFirst<(_, DisplayFromStr)>>")]
pub status: Option<u16>,
#[serde(default)]
pub title: Option<String>,
#[serde(default)]
pub detail: Option<String>,
#[serde(default)]
pub source: Option<ErrorSource>,
}
impl std::fmt::Display for ApiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match (&self.title, &self.detail) {
(Some(title), Some(detail)) => write!(f, "{title}: {detail}"),
(Some(title), None) => write!(f, "{title}"),
(None, Some(detail)) => write!(f, "{detail}"),
(None, None) => match &self.id {
Some(id) => write!(f, "error {id}"),
None => write!(f, "unspecified error"),
},
}
}
}
#[derive(Debug, Clone, serde::Deserialize, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct ErrorSource {
#[serde(default)]
pub pointer: Vec<String>,
#[serde(default)]
pub parameter: Option<String>,
#[serde(default)]
pub header: Option<String>,
}
pub(crate) async fn map_response_to_error(response: reqwest::Response) -> Error {
let status = response.status();
let retry_after = parse_retry_after(response.headers());
let raw = response
.text()
.await
.unwrap_or_else(|e| format!("<error body unavailable: {e}>"));
Error::from_status(status, retry_after, ErrorBody::parse(&raw))
}
fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
let value = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?;
value.parse::<u64>().ok().map(Duration::from_secs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trader_error_body_parses() {
let raw = r#"{
"message": "Order validation failed",
"errors": ["quantity must be positive", "symbol is required"]
}"#;
let ErrorBody::Trader(body) = ErrorBody::parse(raw) else {
panic!("expected Trader body");
};
assert_eq!(body.message, "Order validation failed");
assert_eq!(body.errors.len(), 2);
assert_eq!(
body.to_string(),
"Order validation failed: quantity must be positive; symbol is required"
);
}
#[test]
fn trader_error_body_without_errors_renders_message_only() {
let svc = ServiceError {
message: "Forbidden".to_string(),
errors: Vec::new(),
};
assert_eq!(svc.to_string(), "Forbidden");
}
#[test]
fn trader_error_body_without_errors_array_parses() {
let ErrorBody::Trader(body) = ErrorBody::parse(r#"{"message": "Forbidden"}"#) else {
panic!("expected Trader body");
};
assert_eq!(body.message, "Forbidden");
assert!(body.errors.is_empty());
}
#[test]
fn market_data_error_body_parses() {
let raw = r#"{
"errors": [
{
"id": "6808262e-52bb-4421-9d31-6c0e762e7dd5",
"status": "400",
"title": "Bad Request",
"detail": "Missing header",
"source": { "header": "Authorization" }
},
{
"id": "0be22ae7-efdf-44d9-99f4-f138049d76ca",
"status": "400",
"title": "Bad Request",
"detail": "Search combination should have min of 1.",
"source": { "pointer": ["/data/attributes/symbols", "/data/attributes/cusips"] }
},
{
"id": "28485414-290f-42e2-992b-58ea3e3203b1",
"status": "400",
"title": "Bad Request",
"detail": "valid fields should be any of all,fundamental,reference",
"source": { "parameter": "fields" }
}
]
}"#;
let ErrorBody::MarketData(body) = ErrorBody::parse(raw) else {
panic!("expected MarketData body");
};
assert_eq!(body.errors.len(), 3);
let first = &body.errors[0];
assert_eq!(first.status, Some(400));
assert_eq!(first.title.as_deref(), Some("Bad Request"));
assert_eq!(first.detail.as_deref(), Some("Missing header"));
assert_eq!(
first.source.as_ref().unwrap().header.as_deref(),
Some("Authorization")
);
assert_eq!(first.to_string(), "Bad Request: Missing header");
assert_eq!(body.errors[1].source.as_ref().unwrap().pointer.len(), 2);
assert_eq!(
body.errors[2].source.as_ref().unwrap().parameter.as_deref(),
Some("fields")
);
}
#[test]
fn market_data_numeric_status_parses() {
let raw = r#"{
"errors": [
{ "id": "0be22ae7-efdf-44d9-99f4-f138049d76ca", "status": 401, "title": "Unauthorized" }
]
}"#;
let ErrorBody::MarketData(body) = ErrorBody::parse(raw) else {
panic!("expected MarketData body");
};
assert_eq!(body.errors[0].status, Some(401));
assert_eq!(body.errors[0].title.as_deref(), Some("Unauthorized"));
}
#[test]
fn unrecognized_body_is_preserved() {
let ErrorBody::Unrecognized(raw) = ErrorBody::parse("upstream request timeout") else {
panic!("expected Unrecognized body");
};
assert_eq!(raw, "upstream request timeout");
}
#[test]
fn trader_and_market_data_schemas_are_disjoint() {
let trader = r#"{"message": "x", "errors": ["a"]}"#;
let market_data = r#"{"errors": [{"status": 400, "title": "Bad Request"}]}"#;
assert!(serde_json::from_str::<ErrorResponse>(trader).is_err());
assert!(serde_json::from_str::<ServiceError>(market_data).is_err());
}
#[test]
fn from_status_maps_each_documented_status() {
let body = || ErrorBody::Unrecognized(String::new());
assert!(matches!(
Error::from_status(StatusCode::UNAUTHORIZED, None, body()),
Error::Unauthorized(_)
));
assert!(matches!(
Error::from_status(StatusCode::NOT_FOUND, None, body()),
Error::NotFound(_)
));
assert!(matches!(
Error::from_status(StatusCode::TOO_MANY_REQUESTS, None, body()),
Error::RateLimited { .. }
));
assert!(matches!(
Error::from_status(StatusCode::BAD_REQUEST, None, body()),
Error::Http { status, .. } if status == StatusCode::BAD_REQUEST
));
assert!(matches!(
Error::from_status(StatusCode::FORBIDDEN, None, body()),
Error::Http { status, .. } if status == StatusCode::FORBIDDEN
));
assert!(matches!(
Error::from_status(StatusCode::SERVICE_UNAVAILABLE, None, body()),
Error::Http { status, .. } if status == StatusCode::SERVICE_UNAVAILABLE
));
assert!(matches!(
Error::from_status(StatusCode::INTERNAL_SERVER_ERROR, None, body()),
Error::Http { status, .. } if status == StatusCode::INTERNAL_SERVER_ERROR
));
assert!(matches!(
Error::from_status(StatusCode::BAD_GATEWAY, None, body()),
Error::Http { status, .. } if status == StatusCode::BAD_GATEWAY
));
}
#[test]
fn rate_limited_carries_retry_after_and_is_retryable() {
let error = Error::from_status(
StatusCode::TOO_MANY_REQUESTS,
Some(Duration::from_secs(30)),
ErrorBody::Unrecognized(String::new()),
);
assert_eq!(error.retry_after(), Some(Duration::from_secs(30)));
assert!(error.is_retryable());
}
#[test]
fn http_503_with_retry_after_surfaces_through_accessor() {
let error = Error::from_status(
StatusCode::SERVICE_UNAVAILABLE,
Some(Duration::from_secs(15)),
ErrorBody::Unrecognized(String::new()),
);
assert!(matches!(error, Error::Http { .. }));
assert_eq!(error.retry_after(), Some(Duration::from_secs(15)));
assert!(error.is_retryable());
}
#[test]
fn http_without_retry_after_returns_none() {
let error = Error::from_status(
StatusCode::INTERNAL_SERVER_ERROR,
None,
ErrorBody::Unrecognized(String::new()),
);
assert_eq!(error.retry_after(), None);
}
#[test]
fn client_errors_are_not_retryable() {
let body = || ErrorBody::Unrecognized(String::new());
assert!(!Error::from_status(StatusCode::BAD_REQUEST, None, body()).is_retryable());
assert!(!Error::from_status(StatusCode::NOT_FOUND, None, body()).is_retryable());
assert!(!Error::from_status(StatusCode::UNAUTHORIZED, None, body()).is_retryable());
assert!(Error::from_status(StatusCode::INTERNAL_SERVER_ERROR, None, body()).is_retryable());
assert!(Error::from_status(StatusCode::BAD_GATEWAY, None, body()).is_retryable());
}
}