use std::{num::ParseIntError, string::FromUtf8Error};
use thiserror::Error;
use crate::market_data::historical::HistoricalParseError;
use crate::messages::ResponseMessage;
use crate::orders::builder::ValidationError;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
ParseInt(#[from] ParseIntError),
#[error(transparent)]
FromUtf8(#[from] FromUtf8Error),
#[error(transparent)]
ParseTime(#[from] time::error::Parse),
#[error("{0}")]
Poison(String),
#[error("not implemented")]
NotImplemented,
#[error("parse error: {0} - {1} - {2}")]
Parse(usize, String, String),
#[error("server version {0} required, got {1}: {2}")]
ServerVersion(i32, i32, String),
#[error("error occurred: {0}")]
Simple(String),
#[error("InvalidArgument: {0}")]
InvalidArgument(String),
#[error("ConnectionFailed")]
ConnectionFailed,
#[error("ConnectionReset")]
ConnectionReset,
#[error("Cancelled")]
Cancelled,
#[error("Shutdown")]
Shutdown,
#[error("EndOfStream")]
EndOfStream,
#[error("UnexpectedResponse: {0:?}")]
UnexpectedResponse(ResponseMessage),
#[error("UnexpectedEndOfStream")]
UnexpectedEndOfStream,
#[error("[{0}] {1}")]
Message(i32, String),
#[error("AlreadySubscribed")]
AlreadySubscribed,
#[error("HistoricalParseError: {0}")]
HistoricalParseError(HistoricalParseError),
}
impl From<ResponseMessage> for Error {
fn from(err: ResponseMessage) -> Error {
let code = err.error_code();
let message = err.error_message();
Error::Message(code, message)
}
}
impl<T> From<std::sync::PoisonError<T>> for Error {
fn from(err: std::sync::PoisonError<T>) -> Error {
Error::Poison(format!("Mutex poison error: {err}"))
}
}
impl From<ValidationError> for Error {
fn from(err: ValidationError) -> Self {
match err {
ValidationError::InvalidQuantity(q) => Error::InvalidArgument(format!("Invalid quantity: {}", q)),
ValidationError::InvalidPrice(p) => Error::InvalidArgument(format!("Invalid price: {}", p)),
ValidationError::MissingRequiredField(field) => Error::InvalidArgument(format!("Missing required field: {}", field)),
ValidationError::InvalidCombination(msg) => Error::InvalidArgument(format!("Invalid combination: {}", msg)),
ValidationError::InvalidStopPrice { stop, current } => {
Error::InvalidArgument(format!("Invalid stop price {} for current price {}", stop, current))
}
ValidationError::InvalidLimitPrice { limit, current } => {
Error::InvalidArgument(format!("Invalid limit price {} for current price {}", limit, current))
}
ValidationError::InvalidBracketOrder(msg) => Error::InvalidArgument(format!("Invalid bracket order: {}", msg)),
ValidationError::InvalidPercentage { field, value, min, max } => {
Error::InvalidArgument(format!("Invalid {}: {} (must be between {} and {})", field, value, min, max))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error as StdError;
use std::io;
use std::sync::{Mutex, PoisonError};
use time::macros::format_description;
use time::Time;
#[test]
fn test_error_debug() {
let error = Error::Simple("test error".to_string());
assert_eq!(format!("{error:?}"), "Simple(\"test error\")");
}
#[test]
fn test_error_display() {
let cases = vec![
(Error::Io(io::Error::new(io::ErrorKind::NotFound, "file not found")), "file not found"),
(Error::ParseInt("123x".parse::<i32>().unwrap_err()), "invalid digit found in string"),
(
Error::FromUtf8(String::from_utf8(vec![0, 159, 146, 150]).unwrap_err()),
"invalid utf-8 sequence of 1 bytes from index 1",
),
(
Error::ParseTime(Time::parse("2021-13-01", format_description!("[year]-[month]-[day]")).unwrap_err()),
"the 'month' component could not be parsed",
),
(Error::Poison("test poison".to_string()), "test poison"),
(Error::NotImplemented, "not implemented"),
(
Error::Parse(1, "value".to_string(), "message".to_string()),
"parse error: 1 - value - message",
),
(
Error::ServerVersion(2, 1, "old version".to_string()),
"server version 2 required, got 1: old version",
),
(Error::ConnectionFailed, "ConnectionFailed"),
(Error::Cancelled, "Cancelled"),
(Error::Simple("simple error".to_string()), "error occurred: simple error"),
];
for (error, expected) in cases {
assert_eq!(error.to_string(), expected);
}
}
#[test]
fn test_error_is_error() {
let error = Error::Simple("test error".to_string());
assert!(error.source().is_none());
}
#[test]
fn test_from_io_error() {
let io_error = io::Error::other("io error");
let error: Error = io_error.into();
assert!(matches!(error, Error::Io(_)));
}
#[test]
fn test_from_parse_int_error() {
let parse_error = "abc".parse::<i32>().unwrap_err();
let error: Error = parse_error.into();
assert!(matches!(error, Error::ParseInt(_)));
}
#[test]
fn test_from_utf8_error() {
let utf8_error = String::from_utf8(vec![0, 159, 146, 150]).unwrap_err();
let error: Error = utf8_error.into();
assert!(matches!(error, Error::FromUtf8(_)));
}
#[test]
fn test_from_parse_time_error() {
let time_error = Time::parse("2021-13-01", format_description!("[year]-[month]-[day]")).unwrap_err();
let error: Error = time_error.into();
assert!(matches!(error, Error::ParseTime(_)));
}
#[test]
fn test_from_poison_error() {
let mutex = Mutex::new(());
let poison_error = PoisonError::new(mutex);
let error: Error = poison_error.into();
assert!(matches!(error, Error::Poison(_)));
}
#[test]
fn test_non_exhaustive() {
fn assert_non_exhaustive<T: StdError>() {}
assert_non_exhaustive::<Error>();
}
}