Skip to main content

pict_rs_error/
lib.rs

1mod code;
2
3use exn::Exn;
4use tracing_error::SpanTrace;
5
6pub use self::code::{ErrorCode, OwnedErrorCode};
7
8#[derive(Debug)]
9pub struct Disconnected;
10
11impl std::fmt::Display for Disconnected {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        f.write_str("Repo is disconnected")
14    }
15}
16
17impl std::error::Error for Disconnected {}
18
19#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
20pub enum Status {
21    Size,
22    Server,
23    Client,
24    RangeNotSatisfiable,
25    Forbidden,
26    Unauthorized,
27    NotFound,
28}
29
30pub struct ApplicationError<E>
31where
32    E: std::error::Error + Send + Sync + 'static,
33{
34    status: Option<Status>,
35    code: Option<ErrorCode>,
36    spantrace: SpanTrace,
37    exn: Exn<E>,
38}
39
40impl<E> std::fmt::Debug for ApplicationError<E>
41where
42    E: std::error::Error + Send + Sync + 'static,
43{
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.write_str("Error: ")?;
46        std::fmt::Display::fmt(&self.root_cause(), f)?;
47        f.write_str("\nCode: ")?;
48        std::fmt::Display::fmt(&self.get_error_code(), f)?;
49        f.write_str("\n\n")?;
50        std::fmt::Debug::fmt(&self.exn, f)?;
51        f.write_str("\n\n")?;
52        std::fmt::Display::fmt(&color_spantrace::colorize(&self.spantrace), f)?;
53        f.write_str("\n")
54    }
55}
56
57pub trait ApplicationResultExt<T, E>
58where
59    E: std::error::Error + Send + Sync + 'static,
60{
61    fn or_raise<F, G>(self, err: F) -> Result<T, ApplicationError<G>>
62    where
63        F: FnOnce() -> G,
64        G: std::error::Error + Send + Sync + 'static;
65
66    fn or_code(self, code: ErrorCode) -> Result<T, ApplicationError<E>>;
67
68    fn or_status(self, status: Status) -> Result<T, ApplicationError<E>>;
69}
70
71impl<T, E> ApplicationResultExt<T, E> for Result<T, ApplicationError<E>>
72where
73    E: std::error::Error + Send + Sync + 'static,
74{
75    #[track_caller]
76    fn or_raise<F, G>(self, err: F) -> Result<T, ApplicationError<G>>
77    where
78        F: FnOnce() -> G,
79        G: std::error::Error + Send + Sync + 'static,
80    {
81        match self {
82            Ok(t) => Ok(t),
83            Err(e) => Err(e.raise(err())),
84        }
85    }
86
87    fn or_code(self, code: ErrorCode) -> Result<T, ApplicationError<E>> {
88        self.map_err(|e| e.error_code(code))
89    }
90
91    fn or_status(self, status: Status) -> Result<T, ApplicationError<E>> {
92        self.map_err(|e| e.status(status))
93    }
94}
95
96impl<T, E> ApplicationResultExt<T, E> for Result<T, E>
97where
98    E: std::error::Error + Send + Sync + 'static,
99{
100    #[track_caller]
101    fn or_raise<F, G>(self, err: F) -> Result<T, ApplicationError<G>>
102    where
103        F: FnOnce() -> G,
104        G: std::error::Error + Send + Sync + 'static,
105    {
106        match self {
107            Ok(t) => Ok(t),
108            Err(e) => Err(ApplicationError::new(e).raise(err())),
109        }
110    }
111
112    #[track_caller]
113    fn or_code(self, code: ErrorCode) -> Result<T, ApplicationError<E>> {
114        match self {
115            Ok(t) => Ok(t),
116            Err(e) => Err(ApplicationError::new(e).error_code(code)),
117        }
118    }
119
120    #[track_caller]
121    fn or_status(self, status: Status) -> Result<T, ApplicationError<E>> {
122        match self {
123            Ok(t) => Ok(t),
124            Err(e) => Err(ApplicationError::new(e).status(status)),
125        }
126    }
127}
128
129impl<T, E> ApplicationResultExt<T, E> for Result<T, Exn<E>>
130where
131    E: std::error::Error + Send + Sync + 'static,
132{
133    #[track_caller]
134    fn or_raise<F, G>(self, err: F) -> Result<T, ApplicationError<G>>
135    where
136        F: FnOnce() -> G,
137        G: std::error::Error + Send + Sync + 'static,
138    {
139        match self {
140            Ok(t) => Ok(t),
141            Err(e) => Err(ApplicationError::wrap(e).raise(err())),
142        }
143    }
144
145    fn or_code(self, code: ErrorCode) -> Result<T, ApplicationError<E>> {
146        match self {
147            Ok(t) => Ok(t),
148            Err(e) => Err(ApplicationError::wrap(e).error_code(code)),
149        }
150    }
151
152    fn or_status(self, status: Status) -> Result<T, ApplicationError<E>> {
153        match self {
154            Ok(t) => Ok(t),
155            Err(e) => Err(ApplicationError::wrap(e).status(status)),
156        }
157    }
158}
159
160impl<E> std::fmt::Display for ApplicationError<E>
161where
162    E: std::error::Error + Send + Sync + 'static,
163{
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        std::fmt::Display::fmt(&*self.exn, f)
166    }
167}
168
169#[derive(Debug)]
170pub struct RunError;
171
172impl std::fmt::Display for RunError {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        f.write_str("Failed to run pict-rs")
175    }
176}
177
178impl std::error::Error for RunError {}
179
180#[derive(Debug)]
181pub struct PictRsError {
182    error: ApplicationError<RunError>,
183}
184
185impl std::fmt::Display for PictRsError {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        f.write_str("Error in pict-rs")
188    }
189}
190
191impl std::error::Error for PictRsError {
192    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
193        Some(self.error.error())
194    }
195}
196
197impl PictRsError {
198    /// Get the [`exn::Frame`] representing this error
199    ///
200    /// pict-rs uses [`exn`] to keep track of error nesting and sources, and the frame API can be
201    /// useful for traversing error trees
202    pub fn frame(&self) -> &exn::Frame {
203        self.error.frame()
204    }
205}
206
207impl From<ApplicationError<RunError>> for PictRsError {
208    fn from(value: ApplicationError<RunError>) -> Self {
209        Self { error: value }
210    }
211}
212
213impl From<std::io::Error> for PictRsError {
214    fn from(value: std::io::Error) -> Self {
215        ApplicationError::new(value).raise(RunError).into()
216    }
217}
218
219#[derive(Debug)]
220struct IoError;
221
222impl std::fmt::Display for IoError {
223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224        f.write_str("IO Error")
225    }
226}
227
228impl std::error::Error for IoError {}
229
230impl std::error::Error for ApplicationError<IoError> {
231    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
232        Some(self.error() as _)
233    }
234}
235
236impl<E> ApplicationError<E>
237where
238    E: std::error::Error + Send + Sync + 'static,
239{
240    #[track_caller]
241    pub fn new(error: E) -> Self {
242        Self {
243            status: None,
244            code: None,
245            spantrace: SpanTrace::capture(),
246            exn: Exn::new(error),
247        }
248    }
249
250    #[track_caller]
251    pub fn wrap(exn: Exn<E>) -> Self {
252        Self {
253            status: None,
254            code: None,
255            spantrace: SpanTrace::capture(),
256            exn,
257        }
258    }
259
260    pub fn status(self, status: Status) -> Self {
261        Self {
262            status: Some(status),
263            ..self
264        }
265    }
266
267    pub fn get_status(&self) -> Status {
268        if let Some(status) = self.status {
269            status
270        } else if let Some(io_error) = self.find_io_error() {
271            io_error.get_status()
272        } else {
273            Status::Server
274        }
275    }
276
277    pub fn get_error_code(&self) -> ErrorCode {
278        if let Some(code) = self.code {
279            code
280        } else if let Some(io_error) = self.find_io_error() {
281            io_error.get_error_code()
282        } else {
283            ErrorCode::UNKNOWN_ERROR
284        }
285    }
286
287    pub fn error_code(self, code: ErrorCode) -> Self {
288        Self {
289            code: Some(code),
290            ..self
291        }
292    }
293
294    pub fn io_error(self) -> std::io::Error {
295        fn walk_chain(mut error: &(dyn std::error::Error + 'static)) -> Option<std::io::ErrorKind> {
296            loop {
297                if let Some(e) = error.downcast_ref::<std::io::Error>() {
298                    return Some(e.kind());
299                }
300
301                if let Some(src) = error.source() {
302                    error = src
303                } else {
304                    return None;
305                }
306            }
307        }
308
309        fn walk_tree(frame: &exn::Frame) -> Option<std::io::ErrorKind> {
310            if let Some(kind) = walk_chain(frame.error()) {
311                Some(kind)
312            } else {
313                frame.children().iter().find_map(walk_tree)
314            }
315        }
316
317        let kind = walk_tree(self.exn.frame()).unwrap_or(std::io::ErrorKind::Other);
318
319        std::io::Error::new(kind, self.raise(IoError))
320    }
321
322    #[track_caller]
323    pub fn raise<G>(self, error: G) -> ApplicationError<G>
324    where
325        G: std::error::Error + Send + Sync + 'static,
326    {
327        let ApplicationError {
328            status,
329            code,
330            spantrace,
331            exn,
332        } = self;
333
334        ApplicationError {
335            status,
336            code,
337            spantrace,
338            exn: exn.raise(error),
339        }
340    }
341
342    pub fn root_cause(&self) -> &(dyn std::error::Error + 'static) {
343        fn get_first_root(mut frame: &exn::Frame) -> &(dyn std::error::Error + 'static) {
344            loop {
345                if frame.children().is_empty() {
346                    if let Some(e) = frame.error().downcast_ref::<ApplicationError<IoError>>() {
347                        return get_first_root(e.frame());
348                    } else {
349                        return frame.error();
350                    }
351                } else {
352                    if let Some(e) = frame.error().downcast_ref::<ApplicationError<IoError>>() {
353                        return get_first_root(e.frame());
354                    }
355
356                    frame = &frame.children()[0]
357                }
358            }
359        }
360
361        if let Some(io_error) = self.find_io_error() {
362            io_error.root_cause()
363        } else {
364            get_first_root(self.exn.frame())
365        }
366    }
367
368    fn find_io_error(&self) -> Option<&ApplicationError<IoError>> {
369        fn walk_chain<'a>(
370            mut error: &'a (dyn std::error::Error + 'static),
371        ) -> Option<&'a ApplicationError<IoError>> {
372            loop {
373                if let Some(std_io_error) = error.downcast_ref::<std::io::Error>()
374                    && let Some(io_error) = std_io_error
375                        .get_ref()
376                        .and_then(|e| e.downcast_ref::<ApplicationError<IoError>>())
377                {
378                    return Some(io_error);
379                }
380
381                if let Some(src) = error.source() {
382                    error = src
383                } else {
384                    return None;
385                }
386            }
387        }
388
389        fn walk_tree(frame: &exn::Frame) -> Option<&ApplicationError<IoError>> {
390            if let Some(e) = walk_chain(frame.error()) {
391                Some(e)
392            } else {
393                frame.children().iter().find_map(walk_tree)
394            }
395        }
396
397        walk_chain(&*self.exn).or_else(|| walk_tree(self.exn.frame()))
398    }
399
400    pub fn is_not_found(&self) -> bool {
401        matches!(self.get_status(), Status::NotFound)
402    }
403
404    pub fn error(&self) -> &E {
405        &self.exn
406    }
407
408    pub fn frame(&self) -> &exn::Frame {
409        self.exn.frame()
410    }
411
412    pub fn is_disconnected(&self) -> bool {
413        fn walk_chain(mut error: &(dyn std::error::Error + 'static)) -> bool {
414            loop {
415                if error.is::<Disconnected>() {
416                    return true;
417                }
418
419                if let Some(e) = error.downcast_ref::<ApplicationError<IoError>>()
420                    && walk_tree(e.frame())
421                {
422                    return true;
423                }
424
425                if let Some(src) = error.source() {
426                    error = src
427                } else {
428                    return false;
429                }
430            }
431        }
432
433        fn walk_tree(frame: &exn::Frame) -> bool {
434            if walk_chain(frame.error()) {
435                true
436            } else {
437                frame.children().iter().any(walk_tree)
438            }
439        }
440
441        walk_tree(self.exn.frame())
442    }
443}
444
445#[test]
446fn finds_io_error() {
447    #[derive(Debug)]
448    struct RootCause;
449
450    impl std::fmt::Display for RootCause {
451        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452            f.write_str("Root Cause")
453        }
454    }
455
456    impl std::error::Error for RootCause {}
457
458    let root_error = ApplicationError::new(RootCause);
459
460    let io_error = root_error.io_error();
461
462    let outer_error = ApplicationError::new(io_error);
463
464    outer_error.find_io_error().expect("IO Error Found");
465}
466
467#[test]
468fn retrieves_nested_root_cause() {
469    #[derive(Debug)]
470    struct RootCause;
471
472    impl std::fmt::Display for RootCause {
473        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
474            f.write_str("Root Cause")
475        }
476    }
477
478    impl std::error::Error for RootCause {}
479
480    let root_error = ApplicationError::new(RootCause);
481
482    let io_error = root_error.io_error();
483
484    let outer_error = ApplicationError::new(io_error);
485
486    assert_eq!(outer_error.root_cause().to_string(), RootCause.to_string());
487}
488
489#[test]
490fn retrieves_nested_code() {
491    #[derive(Debug)]
492    struct RootCause;
493
494    impl std::fmt::Display for RootCause {
495        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496            f.write_str("Root Cause")
497        }
498    }
499
500    impl std::error::Error for RootCause {}
501
502    let root_error = ApplicationError::new(RootCause).error_code(ErrorCode::PANIC);
503
504    let io_error = root_error.io_error();
505
506    let outer_error = ApplicationError::new(io_error);
507
508    assert_eq!(outer_error.get_error_code(), ErrorCode::PANIC);
509}
510
511#[test]
512fn prefers_outer_code() {
513    #[derive(Debug)]
514    struct RootCause;
515
516    impl std::fmt::Display for RootCause {
517        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518            f.write_str("Root Cause")
519        }
520    }
521
522    impl std::error::Error for RootCause {}
523
524    let root_error = ApplicationError::new(RootCause).error_code(ErrorCode::PANIC);
525
526    let io_error = root_error.io_error();
527
528    let outer_error = ApplicationError::new(io_error).error_code(ErrorCode::PUSH_JOB);
529
530    assert_eq!(outer_error.get_error_code(), ErrorCode::PUSH_JOB);
531}
532
533#[test]
534fn retrieves_nested_status() {
535    #[derive(Debug)]
536    struct RootCause;
537
538    impl std::fmt::Display for RootCause {
539        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
540            f.write_str("Root Cause")
541        }
542    }
543
544    impl std::error::Error for RootCause {}
545
546    let root_error = ApplicationError::new(RootCause).status(Status::Client);
547
548    let io_error = root_error.io_error();
549
550    let outer_error = ApplicationError::new(io_error);
551
552    assert_eq!(outer_error.get_status(), Status::Client);
553}
554
555#[test]
556fn prefers_outer_status() {
557    #[derive(Debug)]
558    struct RootCause;
559
560    impl std::fmt::Display for RootCause {
561        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562            f.write_str("Root Cause")
563        }
564    }
565
566    impl std::error::Error for RootCause {}
567
568    let root_error = ApplicationError::new(RootCause).status(Status::Client);
569
570    let io_error = root_error.io_error();
571
572    let outer_error = ApplicationError::new(io_error).status(Status::Size);
573
574    assert_eq!(outer_error.get_status(), Status::Size);
575}