Skip to main content

ntex_error/
lib.rs

1//! Error management.
2#![deny(clippy::pedantic)]
3#![allow(clippy::must_use_candidate)]
4use std::collections::HashMap;
5use std::hash::{BuildHasher, Hasher};
6use std::marker::PhantomData;
7use std::panic::Location;
8use std::{cell::RefCell, error, fmt, fmt::Write, ops, os, ptr, sync::Arc};
9
10use backtrace::{BacktraceFmt, BacktraceFrame, BytesOrWideString};
11
12thread_local! {
13    static FRAMES: RefCell<HashMap<*mut os::raw::c_void, BacktraceFrame>> = RefCell::new(HashMap::default());
14    static REPRS: RefCell<HashMap<u64, Arc<str>>> = RefCell::new(HashMap::default());
15}
16static mut START: Option<(&'static str, u32)> = None;
17static mut START_ALT: Option<(&'static str, u32)> = None;
18
19pub fn set_backtrace_start(file: &'static str, line: u32) {
20    unsafe {
21        START = Some((file, line));
22    }
23}
24
25#[doc(hidden)]
26pub fn set_backtrace_start_alt(file: &'static str, line: u32) {
27    unsafe {
28        START_ALT = Some((file, line));
29    }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub enum ErrorType {
34    Client,
35    Service,
36}
37
38impl ErrorType {
39    pub const fn as_str(&self) -> &'static str {
40        match self {
41            ErrorType::Client => "ClientError",
42            ErrorType::Service => "ServiceError",
43        }
44    }
45}
46
47pub trait ErrorKind: fmt::Display + fmt::Debug + 'static {
48    /// Defines type of the error
49    fn error_type(&self) -> ErrorType;
50}
51
52impl ErrorKind for ErrorType {
53    fn error_type(&self) -> ErrorType {
54        *self
55    }
56}
57
58pub trait ErrorDiagnostic: error::Error + 'static {
59    type Kind: ErrorKind;
60
61    /// Provides specific kind of the error
62    fn kind(&self) -> Self::Kind;
63
64    /// Provides a string to identify responsible service
65    fn service(&self) -> Option<&'static str> {
66        None
67    }
68
69    /// Check if error is service related
70    fn is_service(&self) -> bool {
71        self.kind().error_type() == ErrorType::Service
72    }
73
74    /// Provides a string to identify specific kind of the error
75    fn signature(&self) -> &'static str {
76        self.kind().error_type().as_str()
77    }
78
79    /// Provides error call location
80    fn backtrace(&self) -> Option<&Backtrace> {
81        None
82    }
83
84    #[track_caller]
85    fn chain(self) -> ErrorChain<Self::Kind>
86    where
87        Self: Sized,
88    {
89        ErrorChain::new(self)
90    }
91}
92
93#[derive(Clone)]
94pub struct Error<E> {
95    inner: Box<ErrorInner<E>>,
96}
97
98#[derive(Debug, Clone)]
99struct ErrorInner<E> {
100    error: E,
101    service: Option<&'static str>,
102    backtrace: Backtrace,
103}
104
105impl<E> Error<E> {
106    #[track_caller]
107    pub fn new<T>(error: T, service: &'static str) -> Self
108    where
109        E: ErrorDiagnostic,
110        E: From<T>,
111    {
112        let error = E::from(error);
113        let backtrace = if let Some(bt) = error.backtrace() {
114            bt.clone()
115        } else {
116            Backtrace::new(Location::caller())
117        };
118        Self {
119            inner: Box::new(ErrorInner {
120                error,
121                backtrace,
122                service: Some(service),
123            }),
124        }
125    }
126
127    #[must_use]
128    /// Set response service
129    pub fn set_service(mut self, name: &'static str) -> Self {
130        self.inner.service = Some(name);
131        self
132    }
133
134    /// Map inner error to new error
135    ///
136    /// Keep same `service` and `location`
137    pub fn map<U, F>(self, f: F) -> Error<U>
138    where
139        F: FnOnce(E) -> U,
140    {
141        Error {
142            inner: Box::new(ErrorInner {
143                error: f(self.inner.error),
144                service: self.inner.service,
145                backtrace: self.inner.backtrace,
146            }),
147        }
148    }
149
150    /// Get inner error value
151    pub fn into_error(self) -> E {
152        self.inner.error
153    }
154}
155
156impl<E: ErrorDiagnostic> From<E> for Error<E> {
157    #[track_caller]
158    fn from(error: E) -> Self {
159        let backtrace = if let Some(bt) = error.backtrace() {
160            bt.clone()
161        } else {
162            Backtrace::new(Location::caller())
163        };
164        Self {
165            inner: Box::new(ErrorInner {
166                error,
167                backtrace,
168                service: None,
169            }),
170        }
171    }
172}
173
174impl<E> Eq for Error<E> where E: Eq {}
175
176impl<E> PartialEq for Error<E>
177where
178    E: PartialEq,
179{
180    fn eq(&self, other: &Self) -> bool {
181        self.inner.error.eq(&other.inner.error) && self.inner.service == other.inner.service
182    }
183}
184
185impl<E> PartialEq<E> for Error<E>
186where
187    E: PartialEq,
188{
189    fn eq(&self, other: &E) -> bool {
190        self.inner.error.eq(other)
191    }
192}
193
194impl<E> ops::Deref for Error<E> {
195    type Target = E;
196
197    fn deref(&self) -> &E {
198        &self.inner.error
199    }
200}
201
202impl<E> error::Error for Error<E>
203where
204    E: ErrorDiagnostic,
205{
206    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
207        Some(&self.inner.error)
208    }
209}
210
211impl<E> ErrorDiagnostic for Error<E>
212where
213    E: ErrorDiagnostic,
214{
215    type Kind = E::Kind;
216
217    fn kind(&self) -> Self::Kind {
218        self.inner.error.kind()
219    }
220
221    fn service(&self) -> Option<&'static str> {
222        if self.inner.service.is_some() {
223            self.inner.service
224        } else {
225            self.inner.error.service()
226        }
227    }
228
229    fn signature(&self) -> &'static str {
230        self.inner.error.signature()
231    }
232
233    fn backtrace(&self) -> Option<&Backtrace> {
234        Some(&self.inner.backtrace)
235    }
236}
237
238#[derive(Debug, Clone)]
239pub struct ErrorChain<K: ErrorKind> {
240    error: Arc<dyn ErrorDiagnostic<Kind = K>>,
241}
242
243impl<K: ErrorKind> ErrorChain<K> {
244    #[track_caller]
245    pub fn new<E>(error: E) -> Self
246    where
247        E: ErrorDiagnostic + Sized,
248        E::Kind: Into<K>,
249    {
250        let service = error.service();
251        let backtrace = if let Some(bt) = error.backtrace() {
252            bt.clone()
253        } else {
254            Backtrace::new(Location::caller())
255        };
256
257        Self {
258            error: Arc::new(ErrorChainWrapper {
259                error,
260                service,
261                backtrace,
262                _k: PhantomData,
263            }),
264        }
265    }
266}
267
268impl<E, K> From<Error<E>> for ErrorChain<K>
269where
270    E: ErrorDiagnostic + Sized,
271    E::Kind: Into<K>,
272    K: ErrorKind,
273{
274    fn from(err: Error<E>) -> Self {
275        Self {
276            error: Arc::new(ErrorChainWrapper {
277                service: err.service(),
278                error: err.inner.error,
279                backtrace: err.inner.backtrace,
280                _k: PhantomData,
281            }),
282        }
283    }
284}
285
286impl<K> error::Error for ErrorChain<K>
287where
288    K: ErrorKind,
289{
290    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
291        self.error.source()
292    }
293}
294
295impl<K> ErrorDiagnostic for ErrorChain<K>
296where
297    K: ErrorKind,
298{
299    type Kind = K;
300
301    fn kind(&self) -> Self::Kind {
302        self.error.kind()
303    }
304
305    fn service(&self) -> Option<&'static str> {
306        self.error.service()
307    }
308
309    fn signature(&self) -> &'static str {
310        self.error.signature()
311    }
312
313    fn backtrace(&self) -> Option<&Backtrace> {
314        self.error.backtrace()
315    }
316}
317
318struct ErrorChainWrapper<E: Sized, K> {
319    error: E,
320    service: Option<&'static str>,
321    backtrace: Backtrace,
322    _k: PhantomData<K>,
323}
324
325impl<E, K> error::Error for ErrorChainWrapper<E, K>
326where
327    E: ErrorDiagnostic,
328{
329    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
330        Some(&self.error)
331    }
332}
333
334impl<E, K> ErrorDiagnostic for ErrorChainWrapper<E, K>
335where
336    E: ErrorDiagnostic,
337    E::Kind: Into<K>,
338    K: ErrorKind,
339{
340    type Kind = K;
341
342    fn kind(&self) -> Self::Kind {
343        self.error.kind().into()
344    }
345
346    fn service(&self) -> Option<&'static str> {
347        self.service
348    }
349
350    fn signature(&self) -> &'static str {
351        self.error.signature()
352    }
353
354    fn backtrace(&self) -> Option<&Backtrace> {
355        Some(&self.backtrace)
356    }
357}
358
359#[derive(Clone)]
360/// Representation of a backtrace.
361///
362/// This structure can be used to capture a backtrace at various
363/// points in a program and later used to inspect what the backtrace
364/// was at that time.
365pub struct Backtrace(Arc<str>);
366
367impl Backtrace {
368    /// Create new backtrace
369    pub fn new(loc: &Location<'_>) -> Self {
370        let repr = FRAMES.with(|c| {
371            let mut cache = c.borrow_mut();
372            let mut idx = 0;
373            let mut st = foldhash::fast::FixedState::default().build_hasher();
374            let mut idxs: [*mut os::raw::c_void; 128] = [ptr::null_mut(); 128];
375
376            backtrace::trace(|frm| {
377                let ip = frm.ip();
378                st.write_usize(ip as usize);
379                cache.entry(ip).or_insert_with(|| {
380                    let mut f = BacktraceFrame::from(frm.clone());
381                    f.resolve();
382                    f
383                });
384                idxs[idx] = ip;
385                idx += 1;
386
387                idx < 128
388            });
389
390            let id = st.finish();
391
392            REPRS.with(|r| {
393                let mut reprs = r.borrow_mut();
394                if let Some(repr) = reprs.get(&id) {
395                    repr.clone()
396                } else {
397                    let mut frames: [Option<&BacktraceFrame>; 128] = [None; 128];
398                    for (idx, ip) in idxs.as_ref().iter().enumerate() {
399                        if !ip.is_null() {
400                            frames[idx] = Some(&cache[ip]);
401                        }
402                    }
403
404                    find_loc(loc, &mut frames);
405
406                    #[allow(static_mut_refs)]
407                    {
408                        if let Some(start) = unsafe { START } {
409                            find_loc_start(start, &mut frames);
410                        }
411                        if let Some(start) = unsafe { START_ALT } {
412                            find_loc_start(start, &mut frames);
413                        }
414                    }
415
416                    let bt = Bt(&frames[..]);
417                    let mut buf = String::new();
418                    let _ = write!(&mut buf, "\n{bt:?}");
419                    let repr: Arc<str> = Arc::from(buf);
420                    reprs.insert(id, repr.clone());
421                    repr
422                }
423            })
424        });
425
426        Self(repr)
427    }
428
429    /// Backtrace repr
430    pub fn repr(&self) -> &str {
431        &self.0
432    }
433}
434
435fn find_loc(loc: &Location<'_>, frames: &mut [Option<&BacktraceFrame>]) {
436    for (idx, frm) in frames.iter_mut().enumerate() {
437        if let Some(f) = frm {
438            for sym in f.symbols() {
439                if let Some(fname) = sym.filename()
440                    && let Some(lineno) = sym.lineno()
441                    && fname.ends_with(loc.file())
442                    && lineno == loc.line()
443                {
444                    for f in frames.iter_mut().take(idx) {
445                        *f = None;
446                    }
447                    return;
448                }
449            }
450        } else {
451            break;
452        }
453    }
454}
455
456fn find_loc_start(loc: (&str, u32), frames: &mut [Option<&BacktraceFrame>]) {
457    let mut idx = 0;
458    while idx < frames.len() {
459        if let Some(frm) = &frames[idx] {
460            for sym in frm.symbols() {
461                if let Some(fname) = sym.filename()
462                    && let Some(lineno) = sym.lineno()
463                    && fname.ends_with(loc.0)
464                    && (loc.1 == 0 || lineno == loc.1)
465                {
466                    for f in frames.iter_mut().skip(idx) {
467                        if f.is_some() {
468                            *f = None;
469                        } else {
470                            return;
471                        }
472                    }
473                    break;
474                }
475            }
476        }
477        idx += 1;
478    }
479}
480
481struct Bt<'a>(&'a [Option<&'a BacktraceFrame>]);
482
483impl fmt::Debug for Bt<'_> {
484    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
485        let cwd = std::env::current_dir();
486        let mut print_path =
487            move |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
488                let path = path.into_path_buf();
489                if let Ok(cwd) = &cwd
490                    && let Ok(suffix) = path.strip_prefix(cwd)
491                {
492                    return fmt::Display::fmt(&suffix.display(), fmt);
493                }
494                fmt::Display::fmt(&path.display(), fmt)
495            };
496
497        let mut f = BacktraceFmt::new(fmt, backtrace::PrintFmt::Short, &mut print_path);
498        f.add_context()?;
499        for frm in self.0.iter().flatten() {
500            f.frame().backtrace_frame(frm)?;
501        }
502        f.finish()?;
503        Ok(())
504    }
505}
506
507impl fmt::Debug for Backtrace {
508    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
509        fmt::Display::fmt(&self.0, f)
510    }
511}
512
513impl fmt::Display for Backtrace {
514    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
515        fmt::Display::fmt(&self.0, f)
516    }
517}
518
519impl<E> fmt::Debug for Error<E>
520where
521    E: ErrorDiagnostic,
522{
523    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
524        f.debug_struct("Error")
525            .field("error", &self.inner.error)
526            .field("service", &self.inner.service)
527            .field("backtrace", &self.inner.backtrace)
528            .finish()
529    }
530}
531
532impl<E> fmt::Display for Error<E>
533where
534    E: ErrorDiagnostic,
535{
536    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
537        fmt::Display::fmt(&self.inner.error, f)
538    }
539}
540
541impl<K> fmt::Display for ErrorChain<K>
542where
543    K: ErrorKind,
544{
545    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
546        fmt::Display::fmt(&self.error, f)
547    }
548}
549
550impl<E, K> fmt::Display for ErrorChainWrapper<E, K>
551where
552    E: error::Error,
553{
554    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
555        fmt::Display::fmt(&self.error, f)
556    }
557}
558
559impl<E, K> fmt::Debug for ErrorChainWrapper<E, K>
560where
561    E: error::Error,
562{
563    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
564        fmt::Debug::fmt(&self.error, f)
565    }
566}
567
568impl fmt::Display for ErrorType {
569    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
570        match self {
571            ErrorType::Client => write!(f, "ClientError"),
572            ErrorType::Service => write!(f, "ServiceError"),
573        }
574    }
575}
576
577#[allow(dead_code)]
578#[cfg(test)]
579mod tests {
580    use std::{error::Error as StdError, mem};
581
582    use super::*;
583
584    #[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
585    enum TestKind {
586        #[error("Connect")]
587        Connect,
588        #[error("Disconnect")]
589        Disconnect,
590        #[error("ServiceError")]
591        ServiceError,
592    }
593
594    impl ErrorKind for TestKind {
595        fn error_type(&self) -> ErrorType {
596            match self {
597                TestKind::Connect | TestKind::Disconnect => ErrorType::Client,
598                TestKind::ServiceError => ErrorType::Service,
599            }
600        }
601    }
602
603    #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
604    enum TestError {
605        #[error("Connect err: {0}")]
606        Connect(&'static str),
607        #[error("Disconnect")]
608        Disconnect,
609        #[error("InternalServiceError")]
610        Service(&'static str),
611    }
612
613    impl ErrorDiagnostic for TestError {
614        type Kind = TestKind;
615
616        fn kind(&self) -> Self::Kind {
617            match self {
618                TestError::Connect(_) => TestKind::Connect,
619                TestError::Disconnect => TestKind::Disconnect,
620                TestError::Service(_) => TestKind::ServiceError,
621            }
622        }
623
624        fn service(&self) -> Option<&'static str> {
625            Some("test")
626        }
627
628        fn signature(&self) -> &'static str {
629            match self {
630                TestError::Connect(_) => "Client-Connect",
631                TestError::Disconnect => "Client-Disconnect",
632                TestError::Service(_) => "Service-Internal",
633            }
634        }
635    }
636
637    #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
638    #[error("TestError2")]
639    struct TestError2;
640    impl ErrorDiagnostic for TestError2 {
641        type Kind = ErrorType;
642
643        fn kind(&self) -> Self::Kind {
644            ErrorType::Client
645        }
646    }
647
648    #[test]
649    fn test_error() {
650        let err: Error<TestError> = TestError::Service("409 Error").into();
651        assert_eq!(err.kind(), TestKind::ServiceError);
652        assert_eq!((*err).kind(), TestKind::ServiceError);
653        assert_eq!(err.to_string(), "InternalServiceError");
654        assert_eq!(err.service(), Some("test"));
655        assert_eq!(err.signature(), "Service-Internal");
656        assert_eq!(
657            err,
658            Into::<Error<TestError>>::into(TestError::Service("409 Error"))
659        );
660        assert!(err.backtrace().is_some());
661        assert!(err.is_service());
662        assert!(
663            format!("{:?}", err.source()).contains("Service(\"409 Error\")"),
664            "{:?}",
665            err.source().unwrap()
666        );
667
668        let err = err.set_service("SVC");
669        assert_eq!(err.service(), Some("SVC"));
670
671        let err2: Error<TestError> = Error::new(TestError::Service("409 Error"), "TEST");
672        assert!(err != err2);
673        assert_eq!(err, TestError::Service("409 Error"));
674
675        let err2 = err2.set_service("SVC");
676        assert_eq!(err, err2);
677        let err2 = err2.map(|_| TestError::Disconnect);
678        assert!(err != err2);
679
680        assert_eq!(
681            TestError::Connect("").kind().error_type(),
682            ErrorType::Client
683        );
684        assert_eq!(TestError::Disconnect.kind().error_type(), ErrorType::Client);
685        assert_eq!(
686            TestError::Service("").kind().error_type(),
687            ErrorType::Service
688        );
689        assert_eq!(TestError::Connect("").to_string(), "Connect err: ");
690        assert_eq!(TestError::Disconnect.to_string(), "Disconnect");
691        assert_eq!(TestError::Disconnect.service(), Some("test"));
692        assert!(TestError::Disconnect.backtrace().is_none());
693
694        assert_eq!(ErrorType::Client.as_str(), "ClientError");
695        assert_eq!(ErrorType::Service.as_str(), "ServiceError");
696        assert_eq!(ErrorType::Client.error_type(), ErrorType::Client);
697        assert_eq!(ErrorType::Service.error_type(), ErrorType::Service);
698        assert_eq!(ErrorType::Client.to_string(), "ClientError");
699        assert_eq!(ErrorType::Service.to_string(), "ServiceError");
700
701        assert_eq!(TestKind::Connect.to_string(), "Connect");
702        assert_eq!(TestError::Connect("").signature(), "Client-Connect");
703        assert_eq!(TestKind::Disconnect.to_string(), "Disconnect");
704        assert_eq!(TestError::Disconnect.signature(), "Client-Disconnect");
705        assert_eq!(TestKind::ServiceError.to_string(), "ServiceError");
706        assert_eq!(TestError::Service("").signature(), "Service-Internal");
707
708        let err = err.into_error().chain();
709        assert_eq!(err.kind(), TestKind::ServiceError);
710        assert_eq!(err.kind(), TestError::Service("409 Error").kind());
711        assert_eq!(err.to_string(), "InternalServiceError");
712        assert!(format!("{err:?}").contains("Service(\"409 Error\")"));
713        assert!(
714            format!("{:?}", err.source()).contains("Service(\"409 Error\")"),
715            "{:?}",
716            err.source().unwrap()
717        );
718
719        let err: Error<TestError> = TestError::Service("404 Error").into();
720        let err: ErrorChain<TestKind> = err.into();
721        assert_eq!(err.kind(), TestKind::ServiceError);
722        assert_eq!(err.kind(), TestError::Service("404 Error").kind());
723        assert_eq!(err.service(), Some("test"));
724        assert_eq!(err.signature(), "Service-Internal");
725        assert_eq!(err.to_string(), "InternalServiceError");
726        assert!(err.backtrace().is_some());
727        assert!(format!("{err:?}").contains("Service(\"404 Error\")"));
728
729        assert_eq!(24, mem::size_of::<TestError>());
730        assert_eq!(8, mem::size_of::<Error<TestError>>());
731
732        assert_eq!(TestError2.service(), None);
733        assert_eq!(TestError2.signature(), "ClientError");
734    }
735}