Skip to main content

ntex_error/
lib.rs

1//! Error management.
2#![deny(clippy::pedantic)]
3#![allow(clippy::must_use_candidate)]
4use std::collections::HashMap;
5use std::hash::{BuildHasher, Hasher};
6use std::marker::PhantomData;
7use std::panic::Location;
8use std::{cell::RefCell, error, fmt, fmt::Write, ops, os, ptr, sync::Arc};
9
10use backtrace::{BacktraceFmt, BacktraceFrame, BytesOrWideString};
11
12thread_local! {
13    static FRAMES: RefCell<HashMap<*mut os::raw::c_void, BacktraceFrame>> = RefCell::new(HashMap::default());
14    static REPRS: RefCell<HashMap<u64, Arc<str>>> = RefCell::new(HashMap::default());
15}
16static mut START: Option<(&'static str, u32)> = None;
17
18#[track_caller]
19pub fn set_backtrace_start(file: &'static str, line: u32) {
20    unsafe {
21        START = Some((file, line));
22    }
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum ErrorType {
27    Client,
28    Service,
29}
30
31impl ErrorType {
32    pub const fn as_str(&self) -> &'static str {
33        match self {
34            ErrorType::Client => "ClientError",
35            ErrorType::Service => "ServiceError",
36        }
37    }
38}
39
40pub trait ErrorKind: fmt::Display + fmt::Debug + 'static {
41    /// Defines type of the error
42    fn error_type(&self) -> ErrorType;
43}
44
45impl ErrorKind for ErrorType {
46    fn error_type(&self) -> ErrorType {
47        *self
48    }
49}
50
51pub trait ErrorDiagnostic: error::Error + 'static {
52    type Kind: ErrorKind;
53
54    /// Provides specific kind of the error
55    fn kind(&self) -> Self::Kind;
56
57    /// Provides a string to identify responsible service
58    fn service(&self) -> Option<&'static str> {
59        None
60    }
61
62    /// Check if error is service related
63    fn is_service(&self) -> bool {
64        self.kind().error_type() == ErrorType::Service
65    }
66
67    /// Provides a string to identify specific kind of the error
68    fn signature(&self) -> &'static str {
69        self.kind().error_type().as_str()
70    }
71
72    /// Provides error call location
73    fn backtrace(&self) -> Option<&Backtrace> {
74        None
75    }
76
77    #[track_caller]
78    fn chain(self) -> ErrorChain<Self::Kind>
79    where
80        Self: Sized,
81    {
82        ErrorChain::new(self)
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct Error<E> {
88    inner: Box<ErrorInner<E>>,
89}
90
91#[derive(Debug, Clone)]
92struct ErrorInner<E> {
93    error: E,
94    service: Option<&'static str>,
95    backtrace: Backtrace,
96}
97
98impl<E> Error<E> {
99    #[track_caller]
100    pub fn new<T>(error: T, service: &'static str) -> Self
101    where
102        E: ErrorDiagnostic,
103        E: From<T>,
104    {
105        let error = E::from(error);
106        let backtrace = if let Some(bt) = error.backtrace() {
107            bt.clone()
108        } else {
109            Backtrace::new(Location::caller())
110        };
111        Self {
112            inner: Box::new(ErrorInner {
113                error,
114                backtrace,
115                service: Some(service),
116            }),
117        }
118    }
119
120    #[must_use]
121    /// Set response service
122    pub fn set_service(mut self, name: &'static str) -> Self {
123        self.inner.service = Some(name);
124        self
125    }
126
127    /// Map inner error to new error
128    ///
129    /// Keep same `service` and `location`
130    pub fn map<U, F>(self, f: F) -> Error<U>
131    where
132        F: FnOnce(E) -> U,
133    {
134        Error {
135            inner: Box::new(ErrorInner {
136                error: f(self.inner.error),
137                service: self.inner.service,
138                backtrace: self.inner.backtrace,
139            }),
140        }
141    }
142
143    /// Get inner error value
144    pub fn into_error(self) -> E {
145        self.inner.error
146    }
147}
148
149impl<E: ErrorDiagnostic> From<E> for Error<E> {
150    #[track_caller]
151    fn from(error: E) -> Self {
152        let backtrace = if let Some(bt) = error.backtrace() {
153            bt.clone()
154        } else {
155            Backtrace::new(Location::caller())
156        };
157        Self {
158            inner: Box::new(ErrorInner {
159                error,
160                backtrace,
161                service: None,
162            }),
163        }
164    }
165}
166
167impl<E> Eq for Error<E> where E: Eq {}
168
169impl<E> PartialEq for Error<E>
170where
171    E: PartialEq,
172{
173    fn eq(&self, other: &Self) -> bool {
174        self.inner.error.eq(&other.inner.error)
175    }
176}
177
178impl<E> PartialEq<E> for Error<E>
179where
180    E: PartialEq,
181{
182    fn eq(&self, other: &E) -> bool {
183        self.inner.error.eq(other)
184    }
185}
186
187impl<E> ops::Deref for Error<E> {
188    type Target = E;
189
190    fn deref(&self) -> &E {
191        &self.inner.error
192    }
193}
194
195impl<E> error::Error for Error<E>
196where
197    E: ErrorDiagnostic,
198{
199    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
200        Some(&self.inner.error)
201    }
202}
203
204impl<E> ErrorDiagnostic for Error<E>
205where
206    E: ErrorDiagnostic,
207{
208    type Kind = E::Kind;
209
210    fn kind(&self) -> Self::Kind {
211        self.inner.error.kind()
212    }
213
214    fn service(&self) -> Option<&'static str> {
215        if self.inner.service.is_some() {
216            self.inner.service
217        } else {
218            self.inner.error.service()
219        }
220    }
221
222    fn signature(&self) -> &'static str {
223        self.inner.error.signature()
224    }
225
226    fn backtrace(&self) -> Option<&Backtrace> {
227        Some(&self.inner.backtrace)
228    }
229}
230
231#[derive(Debug, Clone)]
232pub struct ErrorChain<K: ErrorKind> {
233    error: Arc<dyn ErrorDiagnostic<Kind = K>>,
234}
235
236impl<K: ErrorKind> ErrorChain<K> {
237    #[track_caller]
238    pub fn new<E>(error: E) -> Self
239    where
240        E: ErrorDiagnostic + Sized,
241        E::Kind: Into<K>,
242    {
243        let service = error.service();
244        let backtrace = if let Some(bt) = error.backtrace() {
245            bt.clone()
246        } else {
247            Backtrace::new(Location::caller())
248        };
249
250        Self {
251            error: Arc::new(ErrorChainWrapper {
252                error,
253                service,
254                backtrace,
255                _k: PhantomData,
256            }),
257        }
258    }
259}
260
261impl<E, K> From<Error<E>> for ErrorChain<K>
262where
263    E: ErrorDiagnostic + Sized,
264    E::Kind: Into<K>,
265    K: ErrorKind,
266{
267    fn from(err: Error<E>) -> Self {
268        Self {
269            error: Arc::new(ErrorChainWrapper {
270                service: err.service(),
271                error: err.inner.error,
272                backtrace: err.inner.backtrace,
273                _k: PhantomData,
274            }),
275        }
276    }
277}
278
279impl<K> error::Error for ErrorChain<K>
280where
281    K: ErrorKind,
282{
283    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
284        self.error.source()
285    }
286}
287
288impl<K> ErrorDiagnostic for ErrorChain<K>
289where
290    K: ErrorKind,
291{
292    type Kind = K;
293
294    fn kind(&self) -> Self::Kind {
295        self.error.kind()
296    }
297
298    fn service(&self) -> Option<&'static str> {
299        self.error.service()
300    }
301
302    fn signature(&self) -> &'static str {
303        self.error.signature()
304    }
305
306    fn backtrace(&self) -> Option<&Backtrace> {
307        self.error.backtrace()
308    }
309}
310
311struct ErrorChainWrapper<E: Sized, K> {
312    error: E,
313    service: Option<&'static str>,
314    backtrace: Backtrace,
315    _k: PhantomData<K>,
316}
317
318impl<E, K> error::Error for ErrorChainWrapper<E, K>
319where
320    E: ErrorDiagnostic,
321{
322    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
323        Some(&self.error)
324    }
325}
326
327impl<E, K> ErrorDiagnostic for ErrorChainWrapper<E, K>
328where
329    E: ErrorDiagnostic,
330    E::Kind: Into<K>,
331    K: ErrorKind,
332{
333    type Kind = K;
334
335    fn kind(&self) -> Self::Kind {
336        self.error.kind().into()
337    }
338
339    fn service(&self) -> Option<&'static str> {
340        self.service
341    }
342
343    fn signature(&self) -> &'static str {
344        self.error.signature()
345    }
346
347    fn backtrace(&self) -> Option<&Backtrace> {
348        Some(&self.backtrace)
349    }
350}
351
352#[derive(Clone)]
353/// Representation of a backtrace.
354///
355/// This structure can be used to capture a backtrace at various
356/// points in a program and later used to inspect what the backtrace
357/// was at that time.
358pub struct Backtrace(Arc<str>);
359
360impl Backtrace {
361    fn new(loc: &Location<'_>) -> Self {
362        let repr = FRAMES.with(|c| {
363            let mut cache = c.borrow_mut();
364            let mut idx = 0;
365            let mut st = foldhash::fast::FixedState::default().build_hasher();
366            let mut idxs: [*mut os::raw::c_void; 128] = [ptr::null_mut(); 128];
367
368            backtrace::trace(|frm| {
369                let ip = frm.ip();
370                st.write_usize(ip as usize);
371                cache.entry(ip).or_insert_with(|| {
372                    let mut f = BacktraceFrame::from(frm.clone());
373                    f.resolve();
374                    f
375                });
376                idxs[idx] = ip;
377                idx += 1;
378
379                idx < 128
380            });
381
382            let id = st.finish();
383
384            REPRS.with(|r| {
385                let mut reprs = r.borrow_mut();
386                if let Some(repr) = reprs.get(&id) {
387                    repr.clone()
388                } else {
389                    let mut frames: [Option<&BacktraceFrame>; 128] = [None; 128];
390                    for (idx, ip) in idxs.as_ref().iter().enumerate() {
391                        if !ip.is_null() {
392                            frames[idx] = Some(&cache[ip]);
393                        }
394                    }
395
396                    find_loc(loc, &mut frames);
397
398                    #[allow(static_mut_refs)]
399                    if let Some(start) = unsafe { START } {
400                        find_loc_start(start, &mut frames);
401                    }
402
403                    let bt = Bt(&frames[..]);
404                    let mut buf = String::new();
405                    let _ = write!(&mut buf, "\n{bt:?}");
406                    let repr: Arc<str> = Arc::from(buf);
407                    reprs.insert(id, repr.clone());
408                    repr
409                }
410            })
411        });
412
413        Self(repr)
414    }
415
416    /// Backtrace repr
417    pub fn repr(&self) -> &str {
418        &self.0
419    }
420}
421
422fn find_loc(loc: &Location<'_>, frames: &mut [Option<&BacktraceFrame>]) {
423    for (idx, frm) in frames.iter_mut().enumerate() {
424        if let Some(f) = frm {
425            for sym in f.symbols() {
426                if let Some(fname) = sym.filename()
427                    && let Some(lineno) = sym.lineno()
428                    && fname.ends_with(loc.file())
429                    && lineno == loc.line()
430                {
431                    for f in frames.iter_mut().take(idx) {
432                        *f = None;
433                    }
434                    return;
435                }
436            }
437        } else {
438            break;
439        }
440    }
441}
442
443fn find_loc_start(loc: (&str, u32), frames: &mut [Option<&BacktraceFrame>]) {
444    let mut idx = frames.len();
445    while idx > 0 {
446        idx -= 1;
447        if let Some(frm) = &frames[idx] {
448            for sym in frm.symbols() {
449                if let Some(fname) = sym.filename()
450                    && let Some(lineno) = sym.lineno()
451                    && fname.ends_with(loc.0)
452                    && lineno == loc.1
453                {
454                    for f in frames.iter_mut().skip(idx + 1) {
455                        if f.is_some() {
456                            *f = None;
457                        } else {
458                            return;
459                        }
460                    }
461                }
462            }
463        }
464    }
465}
466
467struct Bt<'a>(&'a [Option<&'a BacktraceFrame>]);
468
469impl fmt::Debug for Bt<'_> {
470    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
471        let cwd = std::env::current_dir();
472        let mut print_path =
473            move |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
474                let path = path.into_path_buf();
475                if let Ok(cwd) = &cwd
476                    && let Ok(suffix) = path.strip_prefix(cwd)
477                {
478                    return fmt::Display::fmt(&suffix.display(), fmt);
479                }
480                fmt::Display::fmt(&path.display(), fmt)
481            };
482
483        let mut f = BacktraceFmt::new(fmt, backtrace::PrintFmt::Short, &mut print_path);
484        f.add_context()?;
485        for frm in self.0.iter().flatten() {
486            f.frame().backtrace_frame(frm)?;
487        }
488        f.finish()?;
489        Ok(())
490    }
491}
492
493impl fmt::Debug for Backtrace {
494    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
495        fmt::Display::fmt(&self.0, f)
496    }
497}
498
499impl fmt::Display for Backtrace {
500    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501        fmt::Display::fmt(&self.0, f)
502    }
503}
504
505impl<E> fmt::Display for Error<E>
506where
507    E: ErrorDiagnostic,
508{
509    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510        fmt::Display::fmt(&self.inner.error, f)
511    }
512}
513
514impl<K> fmt::Display for ErrorChain<K>
515where
516    K: ErrorKind,
517{
518    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
519        fmt::Display::fmt(&self.error, f)
520    }
521}
522
523impl<E, K> fmt::Display for ErrorChainWrapper<E, K>
524where
525    E: error::Error,
526{
527    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
528        fmt::Display::fmt(&self.error, f)
529    }
530}
531
532impl<E, K> fmt::Debug for ErrorChainWrapper<E, K>
533where
534    E: error::Error,
535{
536    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
537        fmt::Debug::fmt(&self.error, f)
538    }
539}
540
541impl fmt::Display for ErrorType {
542    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
543        match self {
544            ErrorType::Client => write!(f, "ClientError"),
545            ErrorType::Service => write!(f, "ServiceError"),
546        }
547    }
548}
549
550#[allow(dead_code)]
551#[cfg(test)]
552mod tests {
553    use std::mem;
554
555    use super::*;
556
557    #[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
558    enum TestKind {
559        #[error("Connect")]
560        Connect,
561        #[error("Disconnect")]
562        Disconnect,
563        #[error("ServiceError")]
564        ServiceError,
565    }
566
567    impl ErrorKind for TestKind {
568        fn error_type(&self) -> ErrorType {
569            match self {
570                TestKind::Connect | TestKind::Disconnect => ErrorType::Client,
571                TestKind::ServiceError => ErrorType::Service,
572            }
573        }
574    }
575
576    #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
577    enum TestError {
578        #[error("Connect err: {0}")]
579        Connect(&'static str),
580        #[error("Disconnect")]
581        Disconnect,
582        #[error("InternalServiceError")]
583        Service(&'static str),
584    }
585
586    impl ErrorDiagnostic for TestError {
587        type Kind = TestKind;
588
589        fn kind(&self) -> Self::Kind {
590            match self {
591                TestError::Connect(_) => TestKind::Connect,
592                TestError::Disconnect => TestKind::Disconnect,
593                TestError::Service(_) => TestKind::ServiceError,
594            }
595        }
596
597        fn service(&self) -> Option<&'static str> {
598            Some("test")
599        }
600
601        fn signature(&self) -> &'static str {
602            match self {
603                TestError::Connect(_) => "Client-Connect",
604                TestError::Disconnect => "Client-Disconnect",
605                TestError::Service(_) => "Service-Internal",
606            }
607        }
608    }
609
610    #[test]
611    fn test_error() {
612        let err: Error<TestError> = TestError::Service("409 Error").into();
613        assert_eq!(err.kind(), TestKind::ServiceError);
614        assert_eq!((*err).kind(), TestKind::ServiceError);
615        assert_eq!(err.to_string(), "InternalServiceError");
616        assert_eq!(err.service(), Some("test"));
617        assert_eq!(
618            err,
619            Into::<Error<TestError>>::into(TestError::Service("409 Error"))
620        );
621        assert!(err.backtrace().is_some());
622        assert!(err.is_service());
623
624        let err = err.set_service("SVC");
625        assert_eq!(err.service(), Some("SVC"));
626
627        assert_eq!(
628            TestError::Connect("").kind().error_type(),
629            ErrorType::Client
630        );
631        assert_eq!(TestError::Disconnect.kind().error_type(), ErrorType::Client);
632        assert_eq!(
633            TestError::Service("").kind().error_type(),
634            ErrorType::Service
635        );
636        assert_eq!(TestError::Connect("").to_string(), "Connect err: ");
637        assert_eq!(TestError::Disconnect.to_string(), "Disconnect");
638        assert_eq!(TestError::Disconnect.service(), Some("test"));
639        assert!(TestError::Disconnect.backtrace().is_none());
640
641        assert_eq!(ErrorType::Client.as_str(), "ClientError");
642        assert_eq!(ErrorType::Service.as_str(), "ServiceError");
643        assert_eq!(ErrorType::Client.error_type(), ErrorType::Client);
644        assert_eq!(ErrorType::Service.error_type(), ErrorType::Service);
645        assert_eq!(ErrorType::Client.to_string(), "ClientError");
646        assert_eq!(ErrorType::Service.to_string(), "ServiceError");
647
648        assert_eq!(TestKind::Connect.to_string(), "Connect");
649        assert_eq!(TestError::Connect("").signature(), "Client-Connect");
650        assert_eq!(TestKind::Disconnect.to_string(), "Disconnect");
651        assert_eq!(TestError::Disconnect.signature(), "Client-Disconnect");
652        assert_eq!(TestKind::ServiceError.to_string(), "ServiceError");
653        assert_eq!(TestError::Service("").signature(), "Service-Internal");
654
655        let err = err.into_error().chain();
656        assert_eq!(err.kind(), TestKind::ServiceError);
657        assert_eq!(err.kind(), TestError::Service("409 Error").kind());
658        assert_eq!(err.to_string(), "InternalServiceError");
659        assert!(format!("{err:?}").contains("Service(\"409 Error\")"));
660
661        let err: Error<TestError> = TestError::Service("404 Error").into();
662        let err: ErrorChain<TestKind> = err.into();
663        assert_eq!(err.kind(), TestKind::ServiceError);
664        assert_eq!(err.kind(), TestError::Service("404 Error").kind());
665        assert_eq!(err.service(), Some("test"));
666        assert_eq!(err.signature(), "Service-Internal");
667        assert_eq!(err.to_string(), "InternalServiceError");
668        assert!(err.backtrace().is_some());
669        assert!(format!("{err:?}").contains("Service(\"404 Error\")"));
670
671        assert_eq!(24, mem::size_of::<TestError>());
672        assert_eq!(8, mem::size_of::<Error<TestError>>());
673    }
674}