1use serde::{Deserialize, Serialize};
7use std::fmt;
8use thiserror::Error;
9
10pub type Result<T> = std::result::Result<T, AlpacaError>;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[repr(u32)]
18pub enum ApiErrorCode {
19 MalformedRequest = 40010000,
21 InvalidCredentials = 40110000,
23 Forbidden = 40310000,
25 NotFound = 40410000,
27 UnprocessableEntity = 42210000,
29 RateLimitExceeded = 42910000,
31 InternalServerError = 50010000,
33 Unknown = 0,
35}
36
37impl ApiErrorCode {
38 #[must_use]
40 pub fn from_code(code: u32) -> Self {
41 match code {
42 40010000 => Self::MalformedRequest,
43 40110000 => Self::InvalidCredentials,
44 40310000 => Self::Forbidden,
45 40410000 => Self::NotFound,
46 42210000 => Self::UnprocessableEntity,
47 42910000 => Self::RateLimitExceeded,
48 50010000 => Self::InternalServerError,
49 _ => Self::Unknown,
50 }
51 }
52
53 #[must_use]
55 pub fn as_code(&self) -> u32 {
56 *self as u32
57 }
58
59 #[must_use]
61 pub fn is_client_error(&self) -> bool {
62 let code = self.as_code();
63 (40000000..50000000).contains(&code)
64 }
65
66 #[must_use]
68 pub fn is_server_error(&self) -> bool {
69 let code = self.as_code();
70 code >= 50000000
71 }
72
73 #[must_use]
75 pub fn is_retryable(&self) -> bool {
76 matches!(self, Self::RateLimitExceeded | Self::InternalServerError)
77 }
78}
79
80impl fmt::Display for ApiErrorCode {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self {
83 Self::MalformedRequest => write!(f, "malformed request"),
84 Self::InvalidCredentials => write!(f, "invalid credentials"),
85 Self::Forbidden => write!(f, "forbidden"),
86 Self::NotFound => write!(f, "not found"),
87 Self::UnprocessableEntity => write!(f, "unprocessable entity"),
88 Self::RateLimitExceeded => write!(f, "rate limit exceeded"),
89 Self::InternalServerError => write!(f, "internal server error"),
90 Self::Unknown => write!(f, "unknown error"),
91 }
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ApiErrorResponse {
98 #[serde(default)]
100 pub code: u32,
101 #[serde(default)]
103 pub message: String,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub request_id: Option<String>,
107}
108
109impl ApiErrorResponse {
110 #[must_use]
112 pub fn new(code: u32, message: impl Into<String>) -> Self {
113 Self {
114 code,
115 message: message.into(),
116 request_id: None,
117 }
118 }
119
120 #[must_use]
122 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
123 self.request_id = Some(request_id.into());
124 self
125 }
126
127 #[must_use]
129 pub fn error_code(&self) -> ApiErrorCode {
130 ApiErrorCode::from_code(self.code)
131 }
132}
133
134#[derive(Debug, Clone, Default)]
136pub struct RateLimitInfo {
137 pub remaining: Option<u32>,
139 pub limit: Option<u32>,
141 pub retry_after: Option<u64>,
143}
144
145impl RateLimitInfo {
146 #[must_use]
148 pub fn new() -> Self {
149 Self::default()
150 }
151
152 #[must_use]
154 pub fn with_remaining(mut self, remaining: u32) -> Self {
155 self.remaining = Some(remaining);
156 self
157 }
158
159 #[must_use]
161 pub fn with_limit(mut self, limit: u32) -> Self {
162 self.limit = Some(limit);
163 self
164 }
165
166 #[must_use]
168 pub fn with_retry_after(mut self, seconds: u64) -> Self {
169 self.retry_after = Some(seconds);
170 self
171 }
172
173 #[must_use]
175 pub fn is_limited(&self) -> bool {
176 self.remaining == Some(0)
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ValidationError {
183 pub field: String,
185 pub message: String,
187}
188
189impl ValidationError {
190 #[must_use]
192 pub fn new(field: impl Into<String>, message: impl Into<String>) -> Self {
193 Self {
194 field: field.into(),
195 message: message.into(),
196 }
197 }
198}
199
200impl fmt::Display for ValidationError {
201 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202 write!(f, "{}: {}", self.field, self.message)
203 }
204}
205
206#[derive(Error, Debug)]
208pub enum AlpacaError {
209 #[error("http error: {0}")]
211 Http(String),
212
213 #[error("json error: {0}")]
215 Json(String),
216
217 #[error("api error {status}: {message}")]
219 Api {
220 status: u16,
222 message: String,
224 #[source]
226 error_code: Option<ApiErrorCode>,
227 request_id: Option<String>,
229 },
230
231 #[error("authentication error: {0}")]
233 Auth(String),
234
235 #[error("configuration error: {0}")]
237 Config(String),
238
239 #[error("websocket error: {0}")]
241 WebSocket(String),
242
243 #[error("rate limit exceeded, retry after {retry_after_secs} seconds")]
245 RateLimit {
246 retry_after_secs: u64,
248 info: RateLimitInfo,
250 },
251
252 #[error("network error: {0}")]
254 Network(String),
255
256 #[error("timeout error: {0}")]
258 Timeout(String),
259
260 #[error("invalid data format: {0}")]
262 InvalidData(String),
263
264 #[error("validation error: {0}")]
266 Validation(String),
267
268 #[error("validation errors: {}", .0.iter().map(|e| e.to_string()).collect::<Vec<_>>().join(", "))]
270 ValidationErrors(Vec<ValidationError>),
271}
272
273impl AlpacaError {
274 #[must_use]
276 pub fn api(status: u16, message: impl Into<String>) -> Self {
277 Self::Api {
278 status,
279 message: message.into(),
280 error_code: None,
281 request_id: None,
282 }
283 }
284
285 #[must_use]
287 pub fn api_with_details(
288 status: u16,
289 message: impl Into<String>,
290 error_code: ApiErrorCode,
291 request_id: Option<String>,
292 ) -> Self {
293 Self::Api {
294 status,
295 message: message.into(),
296 error_code: Some(error_code),
297 request_id,
298 }
299 }
300
301 #[must_use]
303 pub fn rate_limit(retry_after_secs: u64) -> Self {
304 Self::RateLimit {
305 retry_after_secs,
306 info: RateLimitInfo::new().with_retry_after(retry_after_secs),
307 }
308 }
309
310 #[must_use]
312 pub fn rate_limit_with_info(info: RateLimitInfo) -> Self {
313 Self::RateLimit {
314 retry_after_secs: info.retry_after.unwrap_or(60),
315 info,
316 }
317 }
318
319 #[must_use]
321 pub fn is_retryable(&self) -> bool {
322 match self {
323 Self::RateLimit { .. } => true,
324 Self::Network(_) => true,
325 Self::Timeout(_) => true,
326 Self::Api {
327 status, error_code, ..
328 } => {
329 if *status >= 500 {
331 return true;
332 }
333 error_code.is_some_and(|code| code.is_retryable())
335 }
336 _ => false,
337 }
338 }
339
340 #[must_use]
342 pub fn retry_after(&self) -> Option<u64> {
343 match self {
344 Self::RateLimit {
345 retry_after_secs, ..
346 } => Some(*retry_after_secs),
347 _ => None,
348 }
349 }
350
351 #[must_use]
353 pub fn request_id(&self) -> Option<&str> {
354 match self {
355 Self::Api { request_id, .. } => request_id.as_deref(),
356 _ => None,
357 }
358 }
359
360 #[must_use]
362 pub fn status_code(&self) -> Option<u16> {
363 match self {
364 Self::Api { status, .. } => Some(*status),
365 _ => None,
366 }
367 }
368}
369
370impl From<serde_json::Error> for AlpacaError {
371 fn from(err: serde_json::Error) -> Self {
372 AlpacaError::Json(err.to_string())
373 }
374}
375
376impl From<tokio_tungstenite::tungstenite::Error> for AlpacaError {
377 fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
378 AlpacaError::WebSocket(err.to_string())
379 }
380}
381
382impl std::error::Error for ApiErrorCode {}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_api_error_code_from_code() {
391 assert_eq!(
392 ApiErrorCode::from_code(40010000),
393 ApiErrorCode::MalformedRequest
394 );
395 assert_eq!(
396 ApiErrorCode::from_code(40110000),
397 ApiErrorCode::InvalidCredentials
398 );
399 assert_eq!(
400 ApiErrorCode::from_code(42910000),
401 ApiErrorCode::RateLimitExceeded
402 );
403 assert_eq!(ApiErrorCode::from_code(99999999), ApiErrorCode::Unknown);
404 }
405
406 #[test]
407 fn test_api_error_code_is_retryable() {
408 assert!(ApiErrorCode::RateLimitExceeded.is_retryable());
409 assert!(ApiErrorCode::InternalServerError.is_retryable());
410 assert!(!ApiErrorCode::NotFound.is_retryable());
411 assert!(!ApiErrorCode::InvalidCredentials.is_retryable());
412 }
413
414 #[test]
415 fn test_api_error_code_is_client_error() {
416 assert!(ApiErrorCode::MalformedRequest.is_client_error());
417 assert!(ApiErrorCode::NotFound.is_client_error());
418 assert!(!ApiErrorCode::InternalServerError.is_client_error());
419 }
420
421 #[test]
422 fn test_api_error_code_is_server_error() {
423 assert!(ApiErrorCode::InternalServerError.is_server_error());
424 assert!(!ApiErrorCode::NotFound.is_server_error());
425 }
426
427 #[test]
428 fn test_rate_limit_info() {
429 let info = RateLimitInfo::new()
430 .with_remaining(0)
431 .with_limit(200)
432 .with_retry_after(60);
433
434 assert!(info.is_limited());
435 assert_eq!(info.remaining, Some(0));
436 assert_eq!(info.limit, Some(200));
437 assert_eq!(info.retry_after, Some(60));
438 }
439
440 #[test]
441 fn test_validation_error() {
442 let err = ValidationError::new("qty", "must be positive");
443 assert_eq!(err.field, "qty");
444 assert_eq!(err.message, "must be positive");
445 assert_eq!(err.to_string(), "qty: must be positive");
446 }
447
448 #[test]
449 fn test_alpaca_error_is_retryable() {
450 let rate_limit = AlpacaError::rate_limit(60);
451 assert!(rate_limit.is_retryable());
452 assert_eq!(rate_limit.retry_after(), Some(60));
453
454 let network = AlpacaError::Network("connection reset".to_string());
455 assert!(network.is_retryable());
456
457 let auth = AlpacaError::Auth("invalid key".to_string());
458 assert!(!auth.is_retryable());
459 }
460
461 #[test]
462 fn test_alpaca_error_api_with_details() {
463 let err = AlpacaError::api_with_details(
464 404,
465 "order not found",
466 ApiErrorCode::NotFound,
467 Some("req-123".to_string()),
468 );
469
470 assert_eq!(err.status_code(), Some(404));
471 assert_eq!(err.request_id(), Some("req-123"));
472 assert!(!err.is_retryable());
473 }
474
475 #[test]
476 fn test_api_error_response() {
477 let response = ApiErrorResponse::new(40410000, "not found").with_request_id("req-456");
478
479 assert_eq!(response.error_code(), ApiErrorCode::NotFound);
480 assert_eq!(response.request_id, Some("req-456".to_string()));
481 }
482}