Skip to main content

openwire_core/
error.rs

1use std::borrow::Cow;
2use std::error::Error as StdError;
3use std::fmt;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7use http::uri::Authority;
8use http::{StatusCode, Uri};
9
10pub type BoxError = Arc<dyn StdError + Send + Sync>;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum WireErrorKind {
14    InvalidRequest,
15    Timeout,
16    Canceled,
17    Dns,
18    Connect,
19    Tls,
20    Protocol,
21    Redirect,
22    Body,
23    Interceptor,
24    Internal,
25}
26
27impl fmt::Display for WireErrorKind {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        let label = match self {
30            Self::InvalidRequest => "invalid request",
31            Self::Timeout => "timeout",
32            Self::Canceled => "canceled",
33            Self::Dns => "dns",
34            Self::Connect => "connect",
35            Self::Tls => "tls",
36            Self::Protocol => "protocol",
37            Self::Redirect => "redirect",
38            Self::Body => "body",
39            Self::Interceptor => "interceptor",
40            Self::Internal => "internal",
41        };
42
43        f.write_str(label)
44    }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum FailurePhase {
49    RequestValidation,
50    Admission,
51    Dns,
52    Tcp,
53    ProxyTunnel,
54    Tls,
55    ProtocolBinding,
56    RequestExchange,
57    ResponseHeaders,
58    ResponseBody,
59    Policy,
60    Interceptor,
61    Internal,
62}
63
64impl fmt::Display for FailurePhase {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        let label = match self {
67            Self::RequestValidation => "request_validation",
68            Self::Admission => "admission",
69            Self::Dns => "dns",
70            Self::Tcp => "tcp",
71            Self::ProxyTunnel => "proxy_tunnel",
72            Self::Tls => "tls",
73            Self::ProtocolBinding => "protocol_binding",
74            Self::RequestExchange => "request_exchange",
75            Self::ResponseHeaders => "response_headers",
76            Self::ResponseBody => "response_body",
77            Self::Policy => "policy",
78            Self::Interceptor => "interceptor",
79            Self::Internal => "internal",
80        };
81
82        f.write_str(label)
83    }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum EstablishmentStage {
88    Dns,
89    Tcp,
90    Tls,
91    ProtocolBinding,
92    ProxyTunnel,
93    RouteExhausted,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
97struct EstablishmentContext {
98    stage: EstablishmentStage,
99    retryable: bool,
100    connect_timeout: bool,
101}
102
103#[derive(Debug, Clone, Default)]
104pub struct WireErrorDiagnostics {
105    authority: Option<Authority>,
106    proxy_addr: Option<SocketAddr>,
107    response_status: Option<StatusCode>,
108    request_committed: bool,
109}
110
111impl WireErrorDiagnostics {
112    pub fn authority(&self) -> Option<&Authority> {
113        self.authority.as_ref()
114    }
115
116    pub fn proxy_addr(&self) -> Option<SocketAddr> {
117        self.proxy_addr
118    }
119
120    pub fn response_status(&self) -> Option<StatusCode> {
121        self.response_status
122    }
123
124    pub fn request_committed(&self) -> bool {
125        self.request_committed
126    }
127}
128
129#[derive(Debug, Clone)]
130pub struct WireError {
131    kind: WireErrorKind,
132    phase: FailurePhase,
133    message: Cow<'static, str>,
134    diagnostics: WireErrorDiagnostics,
135    establishment: Option<EstablishmentContext>,
136    source: Option<BoxError>,
137}
138
139impl fmt::Display for WireError {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        write!(f, "{}: {}", self.kind, self.message)?;
142        if let Some(source) = &self.source {
143            write!(f, ": {source}")?;
144        }
145        Ok(())
146    }
147}
148
149impl StdError for WireError {
150    fn source(&self) -> Option<&(dyn StdError + 'static)> {
151        self.source
152            .as_deref()
153            .map(|source| source as &(dyn StdError + 'static))
154    }
155}
156
157impl WireError {
158    pub fn new(kind: WireErrorKind, message: impl Into<Cow<'static, str>>) -> Self {
159        Self {
160            kind,
161            phase: default_phase(kind),
162            message: message.into(),
163            diagnostics: WireErrorDiagnostics::default(),
164            establishment: None,
165            source: None,
166        }
167    }
168
169    pub fn with_source<E>(
170        kind: WireErrorKind,
171        message: impl Into<Cow<'static, str>>,
172        source: E,
173    ) -> Self
174    where
175        E: StdError + Send + Sync + 'static,
176    {
177        Self {
178            kind,
179            phase: default_phase(kind),
180            message: message.into(),
181            diagnostics: WireErrorDiagnostics::default(),
182            establishment: None,
183            source: Some(Arc::new(source)),
184        }
185    }
186
187    pub fn kind(&self) -> WireErrorKind {
188        self.kind
189    }
190
191    pub fn message(&self) -> &str {
192        self.message.as_ref()
193    }
194
195    pub fn phase(&self) -> FailurePhase {
196        self.phase
197    }
198
199    pub fn diagnostics(&self) -> &WireErrorDiagnostics {
200        &self.diagnostics
201    }
202
203    pub fn authority(&self) -> Option<&Authority> {
204        self.diagnostics.authority()
205    }
206
207    pub fn proxy_addr(&self) -> Option<SocketAddr> {
208        self.diagnostics.proxy_addr()
209    }
210
211    pub fn response_status(&self) -> Option<StatusCode> {
212        self.diagnostics.response_status()
213    }
214
215    pub fn request_committed(&self) -> bool {
216        self.diagnostics.request_committed()
217    }
218
219    pub fn establishment_stage(&self) -> Option<EstablishmentStage> {
220        self.establishment.map(|context| context.stage)
221    }
222
223    pub fn is_retryable_establishment(&self) -> bool {
224        self.establishment.is_some_and(|context| context.retryable)
225    }
226
227    pub fn is_connect_timeout(&self) -> bool {
228        self.establishment
229            .is_some_and(|context| context.connect_timeout)
230    }
231
232    pub fn is_non_retryable_connect(&self) -> bool {
233        self.establishment
234            .is_some_and(|context| context.stage == EstablishmentStage::Tcp && !context.retryable)
235    }
236
237    pub fn invalid_request(message: impl Into<Cow<'static, str>>) -> Self {
238        Self::new(WireErrorKind::InvalidRequest, message)
239    }
240
241    pub fn timeout(message: impl Into<Cow<'static, str>>) -> Self {
242        Self::new(WireErrorKind::Timeout, message)
243    }
244
245    pub fn body_timeout(message: impl Into<Cow<'static, str>>) -> Self {
246        Self::new(WireErrorKind::Timeout, message)
247            .with_phase(FailurePhase::ResponseBody)
248            .with_request_committed()
249    }
250
251    pub fn connect_timeout(message: impl Into<Cow<'static, str>>) -> Self {
252        Self::new(WireErrorKind::Timeout, message)
253            .with_establishment(EstablishmentStage::Tcp, true)
254            .with_connect_timeout()
255    }
256
257    pub fn canceled(message: impl Into<Cow<'static, str>>) -> Self {
258        Self::new(WireErrorKind::Canceled, message)
259    }
260
261    pub fn dns<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
262    where
263        E: StdError + Send + Sync + 'static,
264    {
265        Self::with_source(WireErrorKind::Dns, message, source)
266            .with_establishment(EstablishmentStage::Dns, true)
267    }
268
269    pub fn connect<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
270    where
271        E: StdError + Send + Sync + 'static,
272    {
273        Self::with_source(WireErrorKind::Connect, message, source)
274    }
275
276    pub fn tcp_connect<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
277    where
278        E: StdError + Send + Sync + 'static,
279    {
280        Self::with_source(WireErrorKind::Connect, message, source)
281            .with_establishment(EstablishmentStage::Tcp, true)
282    }
283
284    pub fn connect_non_retryable(message: impl Into<Cow<'static, str>>) -> Self {
285        Self::new(WireErrorKind::Connect, message)
286            .with_establishment(EstablishmentStage::Tcp, false)
287    }
288
289    pub fn tls<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
290    where
291        E: StdError + Send + Sync + 'static,
292    {
293        Self::with_source(WireErrorKind::Tls, message, source)
294            .with_establishment(EstablishmentStage::Tls, true)
295    }
296
297    pub fn tls_non_retryable<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
298    where
299        E: StdError + Send + Sync + 'static,
300    {
301        Self::with_source(WireErrorKind::Tls, message, source)
302            .with_establishment(EstablishmentStage::Tls, false)
303    }
304
305    pub fn protocol<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
306    where
307        E: StdError + Send + Sync + 'static,
308    {
309        Self::with_source(WireErrorKind::Protocol, message, source)
310    }
311
312    pub fn protocol_binding<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
313    where
314        E: StdError + Send + Sync + 'static,
315    {
316        Self::with_source(WireErrorKind::Protocol, message, source)
317            .with_establishment(EstablishmentStage::ProtocolBinding, true)
318    }
319
320    pub fn proxy_tunnel<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
321    where
322        E: StdError + Send + Sync + 'static,
323    {
324        Self::with_source(WireErrorKind::Connect, message, source)
325            .with_establishment(EstablishmentStage::ProxyTunnel, true)
326    }
327
328    pub fn proxy_tunnel_non_retryable(message: impl Into<Cow<'static, str>>) -> Self {
329        Self::new(WireErrorKind::Connect, message)
330            .with_establishment(EstablishmentStage::ProxyTunnel, false)
331    }
332
333    pub fn route_exhausted(message: impl Into<Cow<'static, str>>) -> Self {
334        Self::new(WireErrorKind::Connect, message)
335            .with_establishment(EstablishmentStage::RouteExhausted, true)
336    }
337
338    pub fn redirect(message: impl Into<Cow<'static, str>>) -> Self {
339        Self::new(WireErrorKind::Redirect, message)
340    }
341
342    pub fn body<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
343    where
344        E: StdError + Send + Sync + 'static,
345    {
346        Self::with_source(WireErrorKind::Body, message, source).with_request_committed()
347    }
348
349    pub fn interceptor<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
350    where
351        E: StdError + Send + Sync + 'static,
352    {
353        Self::with_source(WireErrorKind::Interceptor, message, source)
354    }
355
356    pub fn internal<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
357    where
358        E: StdError + Send + Sync + 'static,
359    {
360        Self::with_source(WireErrorKind::Internal, message, source)
361    }
362
363    pub fn with_phase(mut self, phase: FailurePhase) -> Self {
364        self.phase = phase;
365        self
366    }
367
368    pub fn with_authority(mut self, authority: Authority) -> Self {
369        self.diagnostics.authority = Some(authority);
370        self
371    }
372
373    pub fn with_authority_from_uri(mut self, uri: &Uri) -> Self {
374        if let Some(authority) = uri.authority().cloned() {
375            self.diagnostics.authority = Some(authority);
376        }
377        self
378    }
379
380    pub fn with_proxy_addr(mut self, proxy_addr: SocketAddr) -> Self {
381        self.diagnostics.proxy_addr = Some(proxy_addr);
382        self
383    }
384
385    pub fn with_response_status(mut self, response_status: StatusCode) -> Self {
386        self.diagnostics.response_status = Some(response_status);
387        self
388    }
389
390    pub fn with_request_committed(mut self) -> Self {
391        self.diagnostics.request_committed = true;
392        self
393    }
394
395    pub fn with_establishment(mut self, stage: EstablishmentStage, retryable: bool) -> Self {
396        self.phase = phase_for_establishment(stage);
397        self.establishment = Some(EstablishmentContext {
398            stage,
399            retryable,
400            connect_timeout: false,
401        });
402        self
403    }
404
405    pub fn with_connect_timeout(mut self) -> Self {
406        if let Some(establishment) = &mut self.establishment {
407            establishment.connect_timeout = true;
408        }
409        self
410    }
411}
412
413impl From<http::Error> for WireError {
414    fn from(source: http::Error) -> Self {
415        Self::with_source(
416            WireErrorKind::InvalidRequest,
417            "failed to build HTTP request",
418            source,
419        )
420    }
421}
422
423impl From<http::uri::InvalidUri> for WireError {
424    fn from(source: http::uri::InvalidUri) -> Self {
425        Self::with_source(WireErrorKind::InvalidRequest, "invalid URI", source)
426    }
427}
428
429impl From<hyper::Error> for WireError {
430    fn from(source: hyper::Error) -> Self {
431        if source.is_canceled() {
432            return Self::with_source(WireErrorKind::Canceled, "request canceled", source);
433        }
434
435        if source.is_timeout() {
436            return Self::with_source(WireErrorKind::Timeout, "request timed out", source);
437        }
438
439        Self::with_source(WireErrorKind::Protocol, "HTTP protocol error", source)
440    }
441}
442
443fn default_phase(kind: WireErrorKind) -> FailurePhase {
444    match kind {
445        WireErrorKind::InvalidRequest => FailurePhase::RequestValidation,
446        WireErrorKind::Timeout => FailurePhase::RequestExchange,
447        WireErrorKind::Canceled => FailurePhase::RequestExchange,
448        WireErrorKind::Dns => FailurePhase::Dns,
449        WireErrorKind::Connect => FailurePhase::Tcp,
450        WireErrorKind::Tls => FailurePhase::Tls,
451        WireErrorKind::Protocol => FailurePhase::RequestExchange,
452        WireErrorKind::Redirect => FailurePhase::Policy,
453        WireErrorKind::Body => FailurePhase::ResponseBody,
454        WireErrorKind::Interceptor => FailurePhase::Interceptor,
455        WireErrorKind::Internal => FailurePhase::Internal,
456    }
457}
458
459fn phase_for_establishment(stage: EstablishmentStage) -> FailurePhase {
460    match stage {
461        EstablishmentStage::Dns => FailurePhase::Dns,
462        EstablishmentStage::Tcp | EstablishmentStage::RouteExhausted => FailurePhase::Tcp,
463        EstablishmentStage::Tls => FailurePhase::Tls,
464        EstablishmentStage::ProtocolBinding => FailurePhase::ProtocolBinding,
465        EstablishmentStage::ProxyTunnel => FailurePhase::ProxyTunnel,
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use std::io;
472
473    use http::StatusCode;
474
475    use super::{FailurePhase, WireError};
476
477    #[test]
478    fn display_includes_underlying_source_when_present() {
479        let error = WireError::connect(
480            "TCP connect failed",
481            io::Error::new(io::ErrorKind::ConnectionRefused, "connection refused"),
482        );
483
484        assert_eq!(
485            error.to_string(),
486            "connect: TCP connect failed: connection refused"
487        );
488    }
489
490    #[test]
491    fn dns_errors_are_retryable_establishment_failures_with_dns_phase() {
492        let error = WireError::dns(
493            "DNS resolution failed",
494            io::Error::new(io::ErrorKind::NotFound, "not found"),
495        );
496
497        assert_eq!(error.phase(), FailurePhase::Dns);
498        assert!(error.is_retryable_establishment());
499    }
500
501    #[test]
502    fn body_timeout_marks_response_body_phase_and_committed_request() {
503        let error = WireError::body_timeout("body timed out").with_response_status(StatusCode::OK);
504
505        assert_eq!(error.phase(), FailurePhase::ResponseBody);
506        assert!(error.request_committed());
507        assert_eq!(error.response_status(), Some(StatusCode::OK));
508    }
509}