tor-basic-utils 0.41.0

General helpers used by Tor
Documentation
//! Helpers for iterating over error sources.

use std::{io, sync::Arc};

/// An iterator over the lower-level error sources, and possibly their wrapped errors, of
/// an [`std::error::Error`].
///
/// One of the main reasons why you might want to use this instead of calling [`std::error::Error::source`]
/// repeatedly is because the `source` implementation of [`io::Error`] doesn't return wrapped errors unless
/// you call `get_ref` on them (see: <https://github.com/rust-lang/rust/pull/124536>). You can think of this
/// iterator as walking down the chain of how an error was constructed. However, this iterator shouldn't be
/// used to display or format errors. Doing so could result in displaying the same error twice (due to the
/// wrapping behavior of `io::Error`).
///
/// Each call to [`Iterator::next`] will attempt to peel off the outer layer of the error.
///
/// The first item returned is always the original error. Subsequent items are generated by calling:
///   * [`io::Error::get_ref`] if the last error could be downcast to an [`io::Error`] or
///     [`Arc<io::Error>`], or
///   * [`std::error::Error::source`] in all other cases
///
/// # Limitations
///
/// This is currently not handling [`io::Error`]s that are wrapped in containers such as `Box`, `Rc`, etc.
pub struct ErrorSources<'a> {
    /// The last error we managed to get via `get_ref` or `source`.
    ///
    /// Initially this is set to the error passed in via [`Self::new`].
    error: Option<&'a (dyn std::error::Error + 'static)>,
}

impl<'a> ErrorSources<'a> {
    /// Create an iterator over the lower-level sources of this error.
    pub fn new(error: &'a (dyn std::error::Error + 'static)) -> Self {
        Self { error: Some(error) }
    }
}

impl<'a> Iterator for ErrorSources<'a> {
    type Item = &'a (dyn std::error::Error + 'static);

    fn next(&mut self) -> Option<Self::Item> {
        let error = self.error.take()?;

        if let Some(io_error) = error.downcast_ref::<io::Error>() {
            // This match is necessary to cast from `&dyn Error + Send + Sync` to `&dyn Error` :/
            //
            // The use of `get_ref` here is intentional because we want to save the error that
            // this `io::Error` is wrapping. If we used `source` that would give us the source of
            // the error that's being wrapped.
            self.error = io_error.get_ref().map(|e| e as _);
        } else if let Some(io_error) = error.downcast_ref::<Arc<io::Error>>() {
            self.error = io_error.get_ref().map(|e| e as _);
        } else {
            self.error = error.source();
        }

        Some(error)
    }
}

// ----------------------------------------------------------------------

#[cfg(test)]
mod test {
    // @@ begin test lint list maintained by maint/add_warning @@
    #![allow(clippy::bool_assert_comparison)]
    #![allow(clippy::clone_on_copy)]
    #![allow(clippy::dbg_macro)]
    #![allow(clippy::mixed_attributes_style)]
    #![allow(clippy::print_stderr)]
    #![allow(clippy::print_stdout)]
    #![allow(clippy::single_char_pattern)]
    #![allow(clippy::unwrap_used)]
    #![allow(clippy::unchecked_time_subtraction)]
    #![allow(clippy::useless_vec)]
    #![allow(clippy::needless_pass_by_value)]
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
    use super::*;

    #[derive(thiserror::Error, Debug)]
    #[error("my error")]
    struct MyError;

    macro_rules! downcast_next {
        ($errors:expr, $ty:ty) => {
            $errors.next().unwrap().downcast_ref::<$ty>().unwrap()
        };
    }

    #[test]
    fn error_sources() {
        let wrapped_error = io::Error::new(
            io::ErrorKind::ConnectionReset,
            Arc::new(io::Error::new(io::ErrorKind::ConnectionReset, MyError)),
        );
        let mut errors = ErrorSources::new(&wrapped_error);

        downcast_next!(errors, io::Error);
        downcast_next!(errors, Arc<io::Error>);
        downcast_next!(errors, MyError);
        assert!(errors.next().is_none());
    }
}