use std::time::Duration;
use thiserror::Error;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WebSocketErrorKind {
Protocol,
Capacity,
Utf8,
Tls,
Io,
Http(u16),
Other,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ErrorKind {
Network,
Protocol,
Auth,
RateLimit,
Client,
}
#[derive(Error, Debug)]
pub enum MarketDataError {
#[error("Invalid symbol: {symbol}")]
InvalidSymbol {
symbol: String,
},
#[error("Invalid parameter '{name}': {reason}")]
InvalidParameter {
name: String,
reason: String,
},
#[error("Deserialization failed: {source}")]
DeserializationError {
#[from]
source: serde_json::Error,
},
#[error("Runtime error: {msg}")]
RuntimeError {
msg: String,
},
#[error("Configuration error: {0}")]
ConfigError(
String,
),
#[error("Connection error: {msg}")]
ConnectionError {
msg: String,
},
#[error("Authentication error: {msg}")]
AuthError {
msg: String,
},
#[error("API error (status {status}): {message}")]
ApiError {
status: u16,
message: String,
},
#[error("Timeout error: {operation}")]
TimeoutError {
operation: String,
},
#[error("WebSocket error ({kind:?}): {msg}")]
WebSocketError {
kind: WebSocketErrorKind,
msg: String,
},
#[error("Heartbeat timeout: no inbound frames for {elapsed:?}")]
HeartbeatTimeout {
elapsed: Duration,
},
#[error("Client already closed")]
ClientClosed,
#[error(transparent)]
Other(
#[from]
anyhow::Error,
),
}
impl From<tungstenite::Error> for MarketDataError {
fn from(err: tungstenite::Error) -> Self {
use tungstenite::Error as WsError;
let (kind, msg) = match err {
WsError::ConnectionClosed | WsError::AlreadyClosed | WsError::Io(_) => (
WebSocketErrorKind::Io,
format!("WebSocket transport error: {}", err),
),
WsError::Protocol(_) => (
WebSocketErrorKind::Protocol,
format!("WebSocket protocol violation: {}", err),
),
WsError::Capacity(_) => (
WebSocketErrorKind::Capacity,
format!("WebSocket capacity exceeded: {}", err),
),
WsError::Utf8(_) => (
WebSocketErrorKind::Utf8,
format!("WebSocket UTF-8 decode failure: {}", err),
),
WsError::Tls(_) => (
WebSocketErrorKind::Tls,
format!("TLS/certificate error: {}", err),
),
WsError::Http(response) => {
let status = response.status().as_u16();
(
WebSocketErrorKind::Http(status),
format!("HTTP {} during WebSocket handshake", status),
)
}
_ => (
WebSocketErrorKind::Other,
format!("WebSocket error: {}", err),
),
};
Self::WebSocketError { kind, msg }
}
}
impl MarketDataError {
#[must_use]
pub fn source_kind(&self) -> ErrorKind {
match self {
Self::ConnectionError { .. }
| Self::TimeoutError { .. }
| Self::HeartbeatTimeout { .. } => ErrorKind::Network,
Self::WebSocketError { kind, .. } => match kind {
WebSocketErrorKind::Protocol
| WebSocketErrorKind::Capacity
| WebSocketErrorKind::Utf8
| WebSocketErrorKind::Other => ErrorKind::Protocol,
WebSocketErrorKind::Tls => ErrorKind::Auth,
WebSocketErrorKind::Io => ErrorKind::Network,
WebSocketErrorKind::Http(status) => match *status {
401 | 403 => ErrorKind::Auth,
429 => ErrorKind::RateLimit,
500..=599 => ErrorKind::Network,
_ => ErrorKind::Client,
},
},
Self::AuthError { .. } => ErrorKind::Auth,
Self::ApiError { status, .. } => match *status {
401 | 403 => ErrorKind::Auth,
429 => ErrorKind::RateLimit,
500..=599 => ErrorKind::Network,
_ => ErrorKind::Client,
},
Self::InvalidSymbol { .. }
| Self::InvalidParameter { .. }
| Self::ConfigError(_)
| Self::DeserializationError { .. }
| Self::ClientClosed
| Self::RuntimeError { .. }
| Self::Other(_) => ErrorKind::Client,
}
}
pub fn to_error_code(&self) -> i32 {
match self {
Self::InvalidSymbol { .. } => 1001,
Self::InvalidParameter { .. } => 1005,
Self::DeserializationError { .. } => 1002,
Self::RuntimeError { .. } => 1003,
Self::ConfigError(_) => 1004,
Self::ConnectionError { .. } => 2001,
Self::AuthError { .. } => 2002,
Self::ApiError { .. } => 2003,
Self::TimeoutError { .. } => 3001,
Self::WebSocketError { .. } => 3002,
Self::HeartbeatTimeout { .. } => 3003,
Self::ClientClosed => 2010,
Self::Other(_) => 9999,
}
}
pub fn is_retryable(&self) -> bool {
match self {
Self::ConnectionError { .. }
| Self::TimeoutError { .. }
| Self::HeartbeatTimeout { .. } => true,
Self::WebSocketError { kind, .. } => match kind {
WebSocketErrorKind::Io | WebSocketErrorKind::Other => true,
WebSocketErrorKind::Http(status) => {
*status == 429 || (500..=599).contains(status)
}
WebSocketErrorKind::Protocol
| WebSocketErrorKind::Capacity
| WebSocketErrorKind::Utf8
| WebSocketErrorKind::Tls => false,
},
Self::ApiError { status, .. } => *status == 429 || (500..=599).contains(status),
Self::InvalidParameter { .. } => false,
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = MarketDataError::InvalidSymbol {
symbol: "INVALID".to_string(),
};
assert_eq!(err.to_string(), "Invalid symbol: INVALID");
let err = MarketDataError::RuntimeError {
msg: "test message".to_string(),
};
assert_eq!(err.to_string(), "Runtime error: test message");
let err = MarketDataError::ConfigError("missing key".to_string());
assert_eq!(err.to_string(), "Configuration error: missing key");
let err = MarketDataError::ApiError {
status: 404,
message: "not found".to_string(),
};
assert_eq!(err.to_string(), "API error (status 404): not found");
let err = MarketDataError::ClientClosed;
assert_eq!(err.to_string(), "Client already closed");
}
#[test]
fn test_error_codes() {
let err = MarketDataError::InvalidSymbol {
symbol: "test".to_string(),
};
assert_eq!(err.to_error_code(), 1001);
let err = MarketDataError::RuntimeError {
msg: "test".to_string(),
};
assert_eq!(err.to_error_code(), 1003);
let err = MarketDataError::ConfigError("test".to_string());
assert_eq!(err.to_error_code(), 1004);
let err = MarketDataError::ConnectionError {
msg: "test".to_string(),
};
assert_eq!(err.to_error_code(), 2001);
let err = MarketDataError::AuthError {
msg: "test".to_string(),
};
assert_eq!(err.to_error_code(), 2002);
let err = MarketDataError::ApiError {
status: 500,
message: "test".to_string(),
};
assert_eq!(err.to_error_code(), 2003);
let err = MarketDataError::TimeoutError {
operation: "test".to_string(),
};
assert_eq!(err.to_error_code(), 3001);
let err = MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Protocol,
msg: "test".to_string(),
};
assert_eq!(err.to_error_code(), 3002);
let err = MarketDataError::HeartbeatTimeout {
elapsed: Duration::from_secs(35),
};
assert_eq!(err.to_error_code(), 3003);
let err = MarketDataError::ClientClosed;
assert_eq!(err.to_error_code(), 2010);
let err = MarketDataError::Other(anyhow::anyhow!("test"));
assert_eq!(err.to_error_code(), 9999);
}
#[test]
fn test_retryable_classification() {
let err = MarketDataError::ConnectionError {
msg: "test".to_string(),
};
assert!(err.is_retryable());
let err = MarketDataError::TimeoutError {
operation: "test".to_string(),
};
assert!(err.is_retryable());
let err = MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Io,
msg: "reset".to_string(),
};
assert!(err.is_retryable());
let err = MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Protocol,
msg: "frame".to_string(),
};
assert!(!err.is_retryable());
let err = MarketDataError::HeartbeatTimeout {
elapsed: Duration::from_secs(35),
};
assert!(err.is_retryable());
let err = MarketDataError::InvalidSymbol {
symbol: "test".to_string(),
};
assert!(!err.is_retryable());
let err = MarketDataError::RuntimeError {
msg: "test".to_string(),
};
assert!(!err.is_retryable());
let err = MarketDataError::ConfigError("test".to_string());
assert!(!err.is_retryable());
let err = MarketDataError::AuthError {
msg: "test".to_string(),
};
assert!(!err.is_retryable());
let err = MarketDataError::ApiError {
status: 400,
message: "test".to_string(),
};
assert!(!err.is_retryable());
let err = MarketDataError::ApiError {
status: 429,
message: "rate limit".to_string(),
};
assert!(err.is_retryable());
let err = MarketDataError::ApiError {
status: 503,
message: "service unavailable".to_string(),
};
assert!(err.is_retryable());
let err = MarketDataError::ClientClosed;
assert!(!err.is_retryable());
let err = MarketDataError::Other(anyhow::anyhow!("test"));
assert!(!err.is_retryable());
}
#[test]
fn test_heartbeat_timeout_display() {
let err = MarketDataError::HeartbeatTimeout {
elapsed: Duration::from_secs(35),
};
assert!(err.to_string().contains("35s"));
assert!(err.to_string().starts_with("Heartbeat timeout"));
}
#[test]
fn test_from_serde_json_error() {
let json_err = serde_json::from_str::<serde_json::Value>("{invalid json")
.unwrap_err();
let err: MarketDataError = json_err.into();
assert_eq!(err.to_error_code(), 1002);
assert!(matches!(err, MarketDataError::DeserializationError { .. }));
}
#[test]
fn test_from_anyhow_error() {
let anyhow_err = anyhow::anyhow!("test error");
let err: MarketDataError = anyhow_err.into();
assert_eq!(err.to_error_code(), 9999);
assert!(matches!(err, MarketDataError::Other(_)));
}
#[test]
fn test_from_tungstenite_connection_closed() {
use tokio_tungstenite::tungstenite::Error as WsError;
let ws_err = WsError::ConnectionClosed;
let err: MarketDataError = ws_err.into();
assert_eq!(err.to_error_code(), 3002);
assert!(matches!(
err,
MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Io,
..
}
));
assert!(err.is_retryable());
}
#[test]
fn test_from_tungstenite_protocol_error() {
use tokio_tungstenite::tungstenite::Error as WsError;
use tokio_tungstenite::tungstenite::error::ProtocolError;
let ws_err = WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake);
let err: MarketDataError = ws_err.into();
assert_eq!(err.to_error_code(), 3002);
assert!(matches!(
err,
MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Protocol,
..
}
));
assert!(
!err.is_retryable(),
"Protocol violations must not retry (0.6.0+); retry against the same SDK + server combo will keep failing"
);
}
#[test]
fn test_from_tungstenite_already_closed() {
use tokio_tungstenite::tungstenite::Error as WsError;
let ws_err = WsError::AlreadyClosed;
let err: MarketDataError = ws_err.into();
assert_eq!(err.to_error_code(), 3002);
assert!(matches!(err, MarketDataError::WebSocketError { .. }));
}
#[test]
fn source_kind_network_for_transport_failures() {
let err = MarketDataError::ConnectionError {
msg: "reset".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Network);
let err = MarketDataError::TimeoutError {
operation: "read".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Network);
let err = MarketDataError::HeartbeatTimeout {
elapsed: Duration::from_secs(35),
};
assert_eq!(err.source_kind(), ErrorKind::Network);
}
#[test]
fn source_kind_for_websocket_protocol_kind() {
let err = MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Protocol,
msg: "frame".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Protocol);
}
#[test]
fn source_kind_for_websocket_io_routes_to_network() {
let err = MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Io,
msg: "reset".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Network);
}
#[test]
fn source_kind_for_websocket_tls_routes_to_auth() {
let err = MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Tls,
msg: "cert".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Auth);
}
#[test]
fn source_kind_for_websocket_http_401_routes_to_auth() {
let err = MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Http(401),
msg: "unauthorized".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Auth);
}
#[test]
fn source_kind_for_websocket_http_429_routes_to_rate_limit() {
let err = MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Http(429),
msg: "throttle".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::RateLimit);
}
#[test]
fn tungstenite_protocol_routes_to_protocol_kind() {
use tokio_tungstenite::tungstenite::error::ProtocolError;
use tokio_tungstenite::tungstenite::Error as WsError;
let ws_err = WsError::Protocol(ProtocolError::ResetWithoutClosingHandshake);
let err: MarketDataError = ws_err.into();
match err {
MarketDataError::WebSocketError { kind, .. } => {
assert_eq!(kind, WebSocketErrorKind::Protocol);
}
other => panic!("expected WebSocketError, got {other:?}"),
}
}
#[test]
fn tungstenite_io_routes_to_io_kind() {
use std::io;
use tokio_tungstenite::tungstenite::Error as WsError;
let ws_err = WsError::Io(io::Error::new(io::ErrorKind::ConnectionReset, "reset"));
let err: MarketDataError = ws_err.into();
match err {
MarketDataError::WebSocketError { kind, .. } => {
assert_eq!(kind, WebSocketErrorKind::Io);
}
other => panic!("expected WebSocketError, got {other:?}"),
}
}
#[test]
fn source_kind_auth_for_401_403_api_errors() {
let err = MarketDataError::ApiError {
status: 401,
message: "unauthorized".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Auth);
let err = MarketDataError::ApiError {
status: 403,
message: "forbidden".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Auth);
let err = MarketDataError::AuthError {
msg: "bad token".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Auth);
}
#[test]
fn source_kind_network_for_5xx() {
let err = MarketDataError::ApiError {
status: 503,
message: "service unavailable".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Network);
let err = MarketDataError::ApiError {
status: 500,
message: "internal".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Network);
}
#[test]
fn source_kind_rate_limit_for_429() {
let err = MarketDataError::ApiError {
status: 429,
message: "rate limit".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::RateLimit);
}
#[test]
fn source_kind_client_for_validation_failures() {
let err = MarketDataError::InvalidParameter {
name: "symbol".to_string(),
reason: "empty".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Client);
let err = MarketDataError::InvalidSymbol {
symbol: "?".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Client);
let err = MarketDataError::ConfigError("bad".to_string());
assert_eq!(err.source_kind(), ErrorKind::Client);
let err = MarketDataError::ClientClosed;
assert_eq!(err.source_kind(), ErrorKind::Client);
}
#[test]
fn source_kind_client_for_4xx_excl_auth() {
let err = MarketDataError::ApiError {
status: 404,
message: "not found".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Client);
let err = MarketDataError::ApiError {
status: 400,
message: "bad request".to_string(),
};
assert_eq!(err.source_kind(), ErrorKind::Client);
}
#[test]
fn error_kind_variants_exist() {
fn classify(k: ErrorKind) -> u8 {
match k {
ErrorKind::Network => 1,
ErrorKind::Protocol => 2,
ErrorKind::Auth => 3,
ErrorKind::RateLimit => 4,
ErrorKind::Client => 5,
}
}
assert_eq!(classify(ErrorKind::Network), 1);
assert_eq!(classify(ErrorKind::Protocol), 2);
assert_eq!(classify(ErrorKind::Auth), 3);
assert_eq!(classify(ErrorKind::RateLimit), 4);
assert_eq!(classify(ErrorKind::Client), 5);
}
}
#[cfg(test)]
mod http_mapping_consistency {
use super::{ErrorKind, MarketDataError, WebSocketErrorKind};
fn ws_err(status: u16) -> MarketDataError {
MarketDataError::WebSocketError {
kind: WebSocketErrorKind::Http(status),
msg: format!("HTTP {} during WebSocket handshake", status),
}
}
const HTTP_TABLE: &[(u16, ErrorKind, bool)] = &[
(401, ErrorKind::Auth, false),
(403, ErrorKind::Auth, false),
(404, ErrorKind::Client, false),
(429, ErrorKind::RateLimit, true),
(500, ErrorKind::Network, true),
(503, ErrorKind::Network, true),
(999, ErrorKind::Client, false),
];
#[test]
fn http_status_mapping_matches_doc_table() {
for &(status, expected_kind, expected_retryable) in HTTP_TABLE {
let err = ws_err(status);
assert_eq!(
err.source_kind(),
expected_kind,
"HTTP {status}: source_kind() mismatch with documented table on WebSocketErrorKind::Http"
);
assert_eq!(
err.is_retryable(),
expected_retryable,
"HTTP {status}: is_retryable() mismatch with documented table on WebSocketErrorKind::Http"
);
}
}
}