mirror-mirror 0.1.21

Reflection library for Rust
Documentation
use crate::{
    type_info::{OpaqueType, Type, VariantField},
    Reflect, ScalarRef,
};
use alloc::string::String;

macro_rules! visit_scalar_fn {
    ($name:ident, $ty:ty) => {
        #[allow(clippy::ptr_arg)]
        #[inline]
        fn $name(&mut self, value: $ty) -> Result<(), Self::Error> {
            Ok(())
        }
    };
}

#[allow(unused_variables)]
pub trait TryVisit {
    type Error;

    visit_scalar_fn!(try_visit_usize, usize);
    visit_scalar_fn!(try_visit_u8, u8);
    visit_scalar_fn!(try_visit_u16, u16);
    visit_scalar_fn!(try_visit_u32, u32);
    visit_scalar_fn!(try_visit_u64, u64);
    visit_scalar_fn!(try_visit_u128, u128);
    visit_scalar_fn!(try_visit_i8, i8);
    visit_scalar_fn!(try_visit_i16, i16);
    visit_scalar_fn!(try_visit_i32, i32);
    visit_scalar_fn!(try_visit_i64, i64);
    visit_scalar_fn!(try_visit_i128, i128);
    visit_scalar_fn!(try_visit_bool, bool);
    visit_scalar_fn!(try_visit_char, char);
    visit_scalar_fn!(try_visit_f32, f32);
    visit_scalar_fn!(try_visit_f64, f64);
    visit_scalar_fn!(try_visit_string, &String);

    #[inline]
    fn try_visit_opaque(
        &mut self,
        value: &dyn Reflect,
        ty: OpaqueType<'_>,
    ) -> Result<(), Self::Error> {
        Ok(())
    }
}

pub fn try_visit<V>(visitor: &mut V, value: &dyn Reflect, ty: Type<'_>) -> Result<(), V::Error>
where
    V: TryVisit,
{
    match ty {
        Type::Scalar(_) => {
            let scalar = value.as_scalar().unwrap();
            match scalar {
                ScalarRef::usize(inner) => visitor.try_visit_usize(inner)?,
                ScalarRef::f32(inner) => visitor.try_visit_f32(inner)?,
                ScalarRef::bool(inner) => visitor.try_visit_bool(inner)?,
                ScalarRef::u8(inner) => visitor.try_visit_u8(inner)?,
                ScalarRef::u16(inner) => visitor.try_visit_u16(inner)?,
                ScalarRef::u32(inner) => visitor.try_visit_u32(inner)?,
                ScalarRef::u64(inner) => visitor.try_visit_u64(inner)?,
                ScalarRef::u128(inner) => visitor.try_visit_u128(inner)?,
                ScalarRef::i8(inner) => visitor.try_visit_i8(inner)?,
                ScalarRef::i16(inner) => visitor.try_visit_i16(inner)?,
                ScalarRef::i32(inner) => visitor.try_visit_i32(inner)?,
                ScalarRef::i64(inner) => visitor.try_visit_i64(inner)?,
                ScalarRef::i128(inner) => visitor.try_visit_i128(inner)?,
                ScalarRef::char(inner) => visitor.try_visit_char(inner)?,
                ScalarRef::f64(inner) => visitor.try_visit_f64(inner)?,
                ScalarRef::String(inner) => visitor.try_visit_string(inner)?,
            }
        }
        Type::Struct(struct_ty) => {
            let struct_ = value.as_struct().unwrap();

            for field_ty in struct_ty.field_types() {
                let field = struct_.field(field_ty.name()).unwrap();
                try_visit(visitor, field, field_ty.get_type())?;
            }
        }
        Type::TupleStruct(tuple_struct_ty) => {
            let tuple_struct = value.as_tuple_struct().unwrap();

            for (idx, field_ty) in tuple_struct_ty.field_types().enumerate() {
                let field = tuple_struct.field_at(idx).unwrap();
                try_visit(visitor, field, field_ty.get_type())?;
            }
        }
        Type::Tuple(tuple_ty) => {
            let tuple = value.as_tuple().unwrap();

            for (idx, field_ty) in tuple_ty.field_types().enumerate() {
                let field = tuple.field_at(idx).unwrap();
                try_visit(visitor, field, field_ty.get_type())?;
            }
        }
        Type::Enum(enum_ty) => {
            let enum_ = value.as_enum().unwrap();
            let variant_ty = enum_ty.variant(enum_.variant_name()).unwrap();

            for (idx, field_ty) in variant_ty.field_types().enumerate() {
                let field = match field_ty {
                    VariantField::Named(named_field_ty) => {
                        enum_.field(named_field_ty.name()).unwrap()
                    }
                    VariantField::Unnamed(_) => enum_.field_at(idx).unwrap(),
                };
                try_visit(visitor, field, field_ty.get_type())?;
            }
        }
        Type::List(list_ty) => {
            let list = value.as_list().unwrap();
            let element_ty = list_ty.element_type();

            for element in list.iter() {
                try_visit(visitor, element, element_ty)?;
            }
        }
        Type::Array(array_ty) => {
            let array = value.as_array().unwrap();
            let element_ty = array_ty.element_type();

            for element in array.iter() {
                try_visit(visitor, element, element_ty)?;
            }
        }
        Type::Map(map_ty) => {
            let map = value.as_map().unwrap();
            let key_ty = map_ty.key_type();
            let value_ty = map_ty.value_type();

            for (key, value) in map.iter() {
                try_visit(visitor, key, key_ty)?;
                try_visit(visitor, value, value_ty)?;
            }
        }
        Type::Opaque(opaque_ty) => {
            visitor.try_visit_opaque(value, opaque_ty)?;
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::DescribeType;
    use alloc::collections::BTreeMap;
    use core::convert::Infallible;

    #[derive(Debug, Clone, Reflect)]
    #[reflect(crate_name(crate))]
    struct Foo {
        a: String,
        b: i32,
        c: Vec<Bar>,
    }

    #[derive(Debug, Clone, Reflect)]
    #[reflect(crate_name(crate))]
    enum Bar {
        A(BTreeMap<i32, i32>),
    }

    #[derive(Default, Debug)]
    struct CountsI32sAndStrings {
        string_count: usize,
        i32_count: usize,
    }

    impl TryVisit for CountsI32sAndStrings {
        type Error = Infallible;

        fn try_visit_string(&mut self, _value: &String) -> Result<(), Self::Error> {
            self.string_count += 1;
            Ok(())
        }

        fn try_visit_i32(&mut self, _value: i32) -> Result<(), Self::Error> {
            self.i32_count += 1;
            Ok(())
        }
    }

    #[test]
    fn works() {
        let foo = Foo {
            a: "a".to_owned(),
            b: 1337,
            c: Vec::from([Bar::A(BTreeMap::from_iter([(1, 1), (2, 2)]))]),
        };

        let mut visitor = CountsI32sAndStrings::default();
        try_visit(
            &mut visitor,
            &foo,
            <Foo as DescribeType>::type_descriptor().get_type(),
        )
        .unwrap();

        assert_eq!(visitor.string_count, 1);
        assert_eq!(visitor.i32_count, 5);
    }

    #[test]
    fn recursive() {
        #[derive(Debug, Clone, Reflect)]
        #[reflect(crate_name(crate))]
        struct Recursive(i32, Vec<Recursive>);

        let value = Recursive(
            1,
            Vec::from([Recursive(2, Vec::from([Recursive(3, Vec::new())]))]),
        );

        let mut visitor = CountsI32sAndStrings::default();
        try_visit(
            &mut visitor,
            &value,
            <Recursive as DescribeType>::type_descriptor().get_type(),
        )
        .unwrap();

        assert_eq!(visitor.string_count, 0);
        assert_eq!(visitor.i32_count, 3);
    }
}