mockiato-codegen 0.9.6

Internally used by mockiato for code generation. This crate should never be used directly
Documentation
use std::collections::hash_map::RandomState;
use std::collections::HashSet;

use proc_macro2::Ident;
use syn::spanned::Spanned;
use syn::visit::{visit_type, Visit};
use syn::{GenericParam, Generics, Path, Signature, TraitItem, TraitItemMethod, Type, TypePath};

use crate::diagnostic::DiagnosticBuilder;
use crate::parse::check_option_is_none;
use crate::parse::method_decl::{MethodDecl, MethodDeclParser};
use crate::parse::method_inputs::MethodInputsParser;
use crate::result::{merge_results, Error, Result};
use crate::syn_ext::PathExt;

#[derive(Debug)]
pub(crate) struct MethodDeclParserImpl {
    method_inputs_parser: Box<dyn MethodInputsParser>,
}

impl MethodDeclParserImpl {
    pub(crate) fn new(method_inputs_parser: Box<dyn MethodInputsParser>) -> Self {
        Self {
            method_inputs_parser,
        }
    }
}

impl MethodDeclParser for MethodDeclParserImpl {
    fn parse(
        &self,
        trait_item: TraitItem,
        generic_types_on_trait: &HashSet<Ident, RandomState>,
    ) -> Result<MethodDecl> {
        match trait_item {
            TraitItem::Method(method) => self.parse_method(method, generic_types_on_trait),
            trait_item => Err(invalid_trait_item_error(&trait_item)),
        }
    }
}

impl MethodDeclParserImpl {
    fn parse_method(
        &self,
        method: TraitItemMethod,
        generic_types_on_trait: &HashSet<Ident>,
    ) -> Result<MethodDecl> {
        let span = method.span();

        let TraitItemMethod {
            attrs,
            sig: signature,
            ..
        } = method;

        validate_usage_of_generic_types(&signature, generic_types_on_trait)?;

        let Signature {
            constness,
            unsafety,
            asyncness,
            ident,
            generics,
            inputs,
            output,
            ..
        } = signature;

        validate_generic_type_parameters(&generics)?;

        check_option_is_none(&constness, span, "`const` methods are not supported")?;
        check_option_is_none(&asyncness, span, "`async` methods are not supported")?;

        Ok(MethodDecl {
            attrs,
            unsafety,
            ident,
            generics,
            span,
            inputs: self.method_inputs_parser.parse(inputs)?,
            output,
        })
    }
}

fn invalid_trait_item_error(trait_item: &TraitItem) -> Error {
    DiagnosticBuilder::error(
        trait_item.span(),
        "Traits are only allowed to contain methods",
    )
    .build()
    .into()
}

fn validate_generic_type_parameters(generics: &Generics) -> Result<()> {
    let results = generics
        .params
        .iter()
        .map(|generic_param| match generic_param {
            GenericParam::Lifetime(_) => Ok(()),
            generic_param => Err(invalid_generic_param(generic_param)),
        });

    merge_results(results).map(|_| ())
}

fn invalid_generic_param(generic_param: &GenericParam) -> Error {
    let error_message = "Only lifetimes are supported as generic parameters on methods";
    DiagnosticBuilder::error(generic_param.span(), error_message)
        .build()
        .into()
}

fn validate_usage_of_generic_types(
    signature: &Signature,
    generic_types_on_trait: &HashSet<Ident>,
) -> Result<()> {
    let references_to_generic_types =
        find_references_to_generic_types(signature, generic_types_on_trait);

    if references_to_generic_types.is_empty() {
        Ok(())
    } else {
        Err(references_to_generic_types
            .into_iter()
            .map(error_for_reference_to_generic_type)
            .collect())
    }
}

fn error_for_reference_to_generic_type(ty: &Type) -> Error {
    DiagnosticBuilder::error(ty.span(), "References to generic types are not supported")
        .build()
        .into()
}

fn find_references_to_generic_types<'a>(
    signature: &'a Signature,
    generic_types_on_trait: &'a HashSet<Ident>,
) -> Vec<&'a Type> {
    let mut visitor = TypeVisitor {
        generic_types_on_trait,
        references_to_generic_types: Vec::new(),
        state: TypeVisitorState::Initial,
    };
    visitor.visit_signature(signature);
    visitor.references_to_generic_types
}

struct TypeVisitor<'a> {
    generic_types_on_trait: &'a HashSet<Ident>,
    references_to_generic_types: Vec<&'a Type>,
    state: TypeVisitorState,
}

#[derive(Copy, Clone)]
enum TypeVisitorState {
    Initial,
    CheckingReferenceInner,
    FoundReferenceToGenericType,
}

impl<'a> Visit<'a> for TypeVisitor<'a> {
    fn visit_type(&mut self, ty: &'a Type) {
        match (self.state, ty) {
            (TypeVisitorState::Initial, Type::Reference(_)) => {
                self.visit_reference_in_initial_state(ty);
            }
            (TypeVisitorState::CheckingReferenceInner, Type::Path(TypePath { path, .. })) => {
                self.visit_path_when_checking_reference_inner(ty, path)
            }
            _ => visit_type(self, ty),
        }
    }
}

impl<'a> TypeVisitor<'a> {
    fn visit_reference_in_initial_state(&mut self, ty: &'a Type) {
        self.state = TypeVisitorState::CheckingReferenceInner;

        visit_type(self, ty);
        if let TypeVisitorState::FoundReferenceToGenericType = self.state {
            self.references_to_generic_types.push(ty);
        }

        self.state = TypeVisitorState::Initial;
    }

    fn visit_path_when_checking_reference_inner(&mut self, ty: &'a Type, path: &Path) {
        if self
            .generic_types_on_trait
            .contains(path.first_segment_as_ident().unwrap())
        {
            self.state = TypeVisitorState::FoundReferenceToGenericType;
        } else {
            visit_type(self, ty)
        }
    }
}