test-context-macros 0.5.8

Macro crate for test-context
Documentation
use quote::quote;
use syn::FnArg;

#[derive(Clone)]
pub struct ContextArg {
    /// The original pattern (left side of a `pattern: type` expression).
    pub pattern: syn::Pat,
    /// The mode in which the context was passed to the test function.
    pub mode: ContextArgMode,
}

#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum ContextArgMode {
    /// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`.
    Owned,
    /// The argument was passed as an immutable reference (`&ContextType`).
    Reference,
    /// The argument was passed as a mutable reference (`&mut ContextType`).
    MutableReference,
}

impl ContextArgMode {
    pub fn is_owned(&self) -> bool {
        match self {
            ContextArgMode::Owned => true,
            ContextArgMode::Reference => false,
            ContextArgMode::MutableReference => false,
        }
    }
}

#[derive(Clone)]
pub enum TestArg {
    Any(FnArg),
    Context(ContextArg),
}

impl TestArg {
    pub fn parse_arg_with_expected_context(arg: FnArg, expected_context_type: &syn::Type) -> Self {
        let pat_type = match arg {
            FnArg::Typed(pat_type) => pat_type,
            FnArg::Receiver(_) => return TestArg::Any(arg),
        };

        // fn example(pattern: arg_type)
        let arg_type = &*pat_type.ty;
        let pattern = &*pat_type.pat;

        // Check for mutable/immutable reference
        if let syn::Type::Reference(type_ref) = arg_type
            && types_equal(&type_ref.elem, expected_context_type)
        {
            let mode = if type_ref.mutability.is_some() {
                ContextArgMode::MutableReference
            } else {
                ContextArgMode::Reference
            };

            return TestArg::Context(ContextArg {
                pattern: pattern.to_owned(),
                mode,
            });
        }

        if types_equal(arg_type, expected_context_type) {
            return TestArg::Context(ContextArg {
                pattern: pattern.to_owned(),
                mode: ContextArgMode::Owned,
            });
        }

        TestArg::Any(FnArg::Typed(pat_type))
    }
}

fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
    if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
        return a_path.path.segments.last().unwrap().ident
            == b_path.path.segments.last().unwrap().ident;
    }
    quote!(#a).to_string() == quote!(#b).to_string()
}