Skip to main content

alpaca_data/
error.rs

1use std::fmt::{self, Display, Formatter};
2
3use crate::transport::meta::TransportErrorMeta;
4
5#[derive(Clone, Debug, Eq, PartialEq)]
6pub enum Error {
7    InvalidConfiguration(String),
8    MissingCredentials,
9    Transport(String),
10    Timeout(String),
11    RateLimited {
12        endpoint: &'static str,
13        retry_after: Option<u64>,
14        request_id: Option<String>,
15        attempt_count: u32,
16        body: Option<String>,
17    },
18    HttpStatus {
19        endpoint: &'static str,
20        status: u16,
21        request_id: Option<String>,
22        attempt_count: u32,
23        body: Option<String>,
24    },
25    Deserialize(String),
26    InvalidRequest(String),
27    Pagination(String),
28    NotImplemented {
29        operation: &'static str,
30    },
31}
32
33impl Display for Error {
34    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::InvalidConfiguration(message) => {
37                write!(f, "invalid configuration: {message}")
38            }
39            Self::MissingCredentials => write!(f, "missing credentials"),
40            Self::Transport(message) => write!(f, "transport error: {message}"),
41            Self::Timeout(message) => write!(f, "timeout error: {message}"),
42            Self::RateLimited {
43                endpoint,
44                retry_after,
45                request_id,
46                attempt_count,
47                body,
48            } => write_transport_error(
49                f,
50                "rate limited",
51                *endpoint,
52                Some(("retry_after", retry_after.map(|value| value.to_string()))),
53                request_id.as_deref(),
54                *attempt_count,
55                body.as_deref(),
56            ),
57            Self::HttpStatus {
58                endpoint,
59                status,
60                request_id,
61                attempt_count,
62                body,
63            } => write_transport_error(
64                f,
65                "http status error",
66                *endpoint,
67                Some(("status", Some(status.to_string()))),
68                request_id.as_deref(),
69                *attempt_count,
70                body.as_deref(),
71            ),
72            Self::Deserialize(message) => write!(f, "deserialize error: {message}"),
73            Self::InvalidRequest(message) => write!(f, "invalid request: {message}"),
74            Self::Pagination(message) => write!(f, "pagination error: {message}"),
75            Self::NotImplemented { operation } => {
76                write!(f, "operation not implemented: {operation}")
77            }
78        }
79    }
80}
81
82impl std::error::Error for Error {}
83
84impl Error {
85    pub(crate) fn from_rate_limited(meta: TransportErrorMeta) -> Self {
86        Self::RateLimited {
87            endpoint: meta.endpoint,
88            retry_after: meta.retry_after,
89            request_id: meta.request_id,
90            attempt_count: meta.attempt_count,
91            body: meta.body,
92        }
93    }
94
95    pub(crate) fn from_http_status(meta: TransportErrorMeta) -> Self {
96        Self::HttpStatus {
97            endpoint: meta.endpoint,
98            status: meta.status,
99            request_id: meta.request_id,
100            attempt_count: meta.attempt_count,
101            body: meta.body,
102        }
103    }
104
105    pub(crate) fn from_reqwest(error: reqwest::Error) -> Self {
106        let message = sanitize_reqwest_error_message(&error.to_string());
107
108        if error.is_timeout() {
109            Self::Timeout(message)
110        } else {
111            Self::Transport(message)
112        }
113    }
114
115    pub fn endpoint(&self) -> Option<&str> {
116        match self {
117            Self::RateLimited { endpoint, .. } | Self::HttpStatus { endpoint, .. } => {
118                Some(endpoint)
119            }
120            _ => None,
121        }
122    }
123
124    pub fn request_id(&self) -> Option<&str> {
125        match self {
126            Self::RateLimited { request_id, .. } | Self::HttpStatus { request_id, .. } => {
127                request_id.as_deref()
128            }
129            _ => None,
130        }
131    }
132}
133
134fn write_transport_error(
135    f: &mut Formatter<'_>,
136    label: &str,
137    endpoint: &'static str,
138    primary_field: Option<(&str, Option<String>)>,
139    request_id: Option<&str>,
140    attempt_count: u32,
141    body: Option<&str>,
142) -> fmt::Result {
143    write!(f, "{label}: endpoint={endpoint}")?;
144
145    if let Some((field_name, Some(field_value))) = primary_field {
146        write!(f, ", {field_name}={field_value}")?;
147    }
148
149    if let Some(request_id) = request_id {
150        write!(f, ", request_id={request_id}")?;
151    }
152
153    write!(f, ", attempt_count={attempt_count}")?;
154
155    if let Some(body) = body {
156        write!(f, ", body={body}")?;
157    }
158
159    Ok(())
160}
161
162fn sanitize_reqwest_error_message(message: &str) -> String {
163    if !message.contains('@') {
164        return message.to_string();
165    }
166
167    let mut sanitized = String::with_capacity(message.len());
168    let mut segment_start = 0;
169    let mut in_segment = false;
170
171    for (index, ch) in message.char_indices() {
172        if is_message_delimiter(ch) {
173            if in_segment {
174                sanitized.push_str(&redact_urlish_userinfo(&message[segment_start..index]));
175                in_segment = false;
176            }
177
178            sanitized.push(ch);
179        } else if !in_segment {
180            segment_start = index;
181            in_segment = true;
182        }
183    }
184
185    if in_segment {
186        sanitized.push_str(&redact_urlish_userinfo(&message[segment_start..]));
187    }
188
189    sanitized
190}
191
192fn is_message_delimiter(ch: char) -> bool {
193    ch.is_whitespace()
194        || matches!(
195            ch,
196            '"' | '\'' | '(' | ')' | '[' | ']' | '{' | '}' | '<' | '>' | ',' | ';'
197        )
198}
199
200fn redact_urlish_userinfo(segment: &str) -> String {
201    if let Ok(mut url) = reqwest::Url::parse(segment) {
202        if !url.username().is_empty() || url.password().is_some() {
203            let _ = url.set_username("");
204            let _ = url.set_password(None);
205            return url.to_string();
206        }
207    }
208
209    redact_urlish_userinfo_fallback(segment)
210}
211
212fn redact_urlish_userinfo_fallback(segment: &str) -> String {
213    let (prefix, rest) = if let Some((scheme, rest)) = segment.split_once("://") {
214        (&segment[..scheme.len() + 3], rest)
215    } else if let Some(rest) = segment.strip_prefix("//") {
216        ("//", rest)
217    } else {
218        ("", segment)
219    };
220
221    let (authority, suffix) = split_authority_and_suffix(rest);
222
223    match authority.rfind('@') {
224        Some(index) => format!("{prefix}{}{}", &authority[index + 1..], suffix),
225        None => segment.to_string(),
226    }
227}
228
229fn split_authority_and_suffix(rest: &str) -> (&str, &str) {
230    match rest
231        .char_indices()
232        .find_map(|(index, ch)| matches!(ch, '/' | '?' | '#').then_some(index))
233    {
234        Some(index) => (&rest[..index], &rest[index..]),
235        None => (rest, ""),
236    }
237}