murf-macros 0.2.0

Murf's proc macros
Documentation
use std::{
    cell::UnsafeCell,
    ptr::{from_mut, from_ref},
};

use proc_macro2::{Ident, Span};
use syn::{GenericArgument, Lifetime, Path, PathArguments, ReturnType, Type, TypeParamBound};

use super::TempLifetimes;

pub(crate) enum LifetimeReplaceMode<'x> {
    Mock,
    Temp(&'x mut TempLifetimes),
}

impl<'x> LifetimeReplaceMode<'x> {
    fn generate(&mut self) -> Lifetime {
        match self {
            Self::Mock => Lifetime::new("'mock", Span::call_site()),
            Self::Temp(tmp) => tmp.generate(),
        }
    }
}

pub(crate) trait TypeEx {
    fn contains_lifetime(&self, lt: &Lifetime) -> bool;
    fn contains_self_type(&self) -> bool;

    fn replace_self_type(self, type_: &Type, changed: &mut bool) -> Self;
    fn replace_default_lifetime(self, mode: LifetimeReplaceMode<'_>) -> Self;

    fn make_static(self) -> Self;
}

impl TypeEx for Type {
    fn contains_lifetime(&self, lt: &Lifetime) -> bool {
        struct Visitor<'a> {
            lt: &'a Lifetime,
            result: bool,
        }

        impl<'a> TypeVisitor for Visitor<'a> {
            fn visit_lifetime(&mut self, lt: &UnsafeCell<Lifetime>) -> bool {
                let lt = unsafe { &*lt.get() };
                self.result = self.lt.ident == lt.ident || self.result;

                !self.result
            }
        }

        let mut visitor = Visitor { lt, result: false };

        visitor.visit(unsafe_cell_ref(self));

        visitor.result
    }

    fn contains_self_type(&self) -> bool {
        struct Visitor {
            result: bool,
        }

        impl TypeVisitor for Visitor {
            fn visit_type(&mut self, ty: &UnsafeCell<Type>) -> bool {
                let ty = unsafe { &*ty.get() };

                if let Type::Path(t) = ty {
                    if t.path.segments.len() == 1 && t.path.segments[0].ident == "Self" {
                        self.result = true;
                    }
                }

                !self.result
            }
        }

        let mut visitor = Visitor { result: false };

        visitor.visit(unsafe_cell_ref(self));

        visitor.result
    }

    fn replace_self_type(mut self, type_: &Type, changed: &mut bool) -> Self {
        struct Visitor<'a> {
            type_: &'a Type,
            changed: &'a mut bool,
        }

        impl<'a> TypeVisitor for Visitor<'a> {
            fn visit_type(&mut self, ty: &UnsafeCell<Type>) -> bool {
                let ty = unsafe { &mut *ty.get() };

                if let Type::Path(t) = ty {
                    if t.path.segments.len() == 1 && t.path.segments[0].ident == "Self" {
                        *ty = self.type_.clone();
                        *self.changed = true;
                    }
                }

                true
            }
        }

        let mut visitor = Visitor { type_, changed };

        visitor.visit(unsafe_cell_mut(&mut self));

        self
    }

    fn replace_default_lifetime(mut self, mode: LifetimeReplaceMode<'_>) -> Self {
        struct Visitor<'a> {
            mode: LifetimeReplaceMode<'a>,
        }

        impl<'a> TypeVisitor for Visitor<'a> {
            fn visit_type(&mut self, ty: &UnsafeCell<Type>) -> bool {
                let ty = unsafe { &mut *ty.get() };

                if let Type::Reference(r) = ty {
                    if r.lifetime.is_none() {
                        r.lifetime = Some(self.mode.generate());
                    }
                }

                true
            }

            fn visit_lifetime(&mut self, lt: &UnsafeCell<Lifetime>) -> bool {
                let lt = unsafe { &mut *lt.get() };

                if lt.ident == "_" {
                    *lt = self.mode.generate();
                }

                true
            }
        }

        let mut visitor = Visitor { mode };

        visitor.visit(unsafe_cell_mut(&mut self));

        self
    }

    fn make_static(mut self) -> Self {
        struct Visitor;

        impl TypeVisitor for Visitor {
            fn visit_type(&mut self, ty: &UnsafeCell<Type>) -> bool {
                let ty = unsafe { &mut *ty.get() };

                match ty {
                    Type::Path(ty) => {
                        for seg in &mut ty.path.segments {
                            match &mut seg.arguments {
                                PathArguments::None | PathArguments::Parenthesized(_) => (),
                                PathArguments::AngleBracketed(x) => {
                                    for arg in &mut x.args {
                                        if let GenericArgument::Lifetime(lt) = arg {
                                            lt.ident = Ident::new("static", Span::call_site());
                                        }
                                    }
                                }
                            }
                        }
                    }
                    Type::Reference(ty) => ty.lifetime = None,
                    _ => (),
                }

                true
            }
        }

        Visitor.visit(unsafe_cell_mut(&mut self));

        self
    }
}

trait TypeVisitor: Sized {
    fn visit_type(&mut self, ty: &UnsafeCell<Type>) -> bool {
        let _ty = ty;

        true
    }

    fn visit_lifetime(&mut self, lt: &UnsafeCell<Lifetime>) -> bool {
        let _lt = lt;

        true
    }

    fn visit(&mut self, ty: &UnsafeCell<Type>) -> bool {
        fn visit_path<X: TypeVisitor>(this: &mut X, path: &Path) -> bool {
            for seg in &path.segments {
                match &seg.arguments {
                    PathArguments::None => (),
                    PathArguments::AngleBracketed(x) => {
                        for arg in &x.args {
                            match arg {
                                GenericArgument::Type(t) => {
                                    if !this.visit(unsafe_cell_ref(t)) {
                                        return false;
                                    }
                                }
                                GenericArgument::Lifetime(lt) => {
                                    if !this.visit_lifetime(unsafe_cell_ref(lt)) {
                                        return false;
                                    }
                                }
                                GenericArgument::AssocType(t) => {
                                    if !this.visit(unsafe_cell_ref(&t.ty)) {
                                        return false;
                                    }
                                }
                                _ => (),
                            }
                        }
                    }
                    PathArguments::Parenthesized(x) => {
                        for t in &x.inputs {
                            if !this.visit(unsafe_cell_ref(t)) {
                                return false;
                            }
                        }

                        match &x.output {
                            ReturnType::Type(_, t) => {
                                if !this.visit(unsafe_cell_ref(t)) {
                                    return false;
                                }
                            }
                            ReturnType::Default => (),
                        }
                    }
                }
            }

            true
        }

        if !self.visit_type(ty) {
            return false;
        }

        let ty = unsafe { &*ty.get() };

        match ty {
            Type::Path(ty) => visit_path(self, &ty.path),
            Type::Reference(t) => {
                if let Some(lt) = &t.lifetime {
                    if !self.visit_lifetime(unsafe_cell_ref(lt)) {
                        return false;
                    }
                }

                if !self.visit(unsafe_cell_ref(&t.elem)) {
                    return false;
                }

                true
            }
            Type::Array(t) => self.visit(unsafe_cell_ref(&t.elem)),
            Type::Slice(t) => self.visit(unsafe_cell_ref(&t.elem)),
            Type::Tuple(t) => {
                for t in &t.elems {
                    if !self.visit(unsafe_cell_ref(t)) {
                        return false;
                    }
                }

                true
            }
            Type::TraitObject(t) => {
                for b in &t.bounds {
                    match b {
                        TypeParamBound::Lifetime(lt) => {
                            if !self.visit_lifetime(unsafe_cell_ref(lt)) {
                                return false;
                            }
                        }
                        TypeParamBound::Trait(t) => {
                            if !visit_path(self, &t.path) {
                                return false;
                            }
                        }
                        _ => (),
                    }
                }

                true
            }
            _ => true,
        }
    }
}

fn unsafe_cell_ref<T>(value: &T) -> &UnsafeCell<T> {
    unsafe { &*(from_ref(value).cast::<std::cell::UnsafeCell<T>>()) }
}

fn unsafe_cell_mut<T>(value: &mut T) -> &UnsafeCell<T> {
    unsafe { &*(from_mut(value) as *const std::cell::UnsafeCell<T>) }
}