1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
use syn::{Data, DataStruct, DeriveInput, Field, Fields};

const NEWTYPE_MUST_HAVE_ONLY_ONE_FIELD: &str =
    "Newtype struct must only have one field.\n\
     See https://doc.rust-lang.org/book/ch19-04-advanced-types.html#advanced-types \
     for more information.";
const MACRO_MUST_BE_USED_ON_NEWTYPE_STRUCT: &str =
    "This macro must be used on a newtype struct.\n\
     See https://doc.rust-lang.org/book/ch19-04-advanced-types.html#advanced-types \
     for more information.";

/// Functions to make it ergonomic to work with newtype `struct` ASTs.
pub trait DeriveInputNewtypeExt {
    /// Returns the `Field` of the first unnamed field of this struct's AST.
    ///
    /// # Panics
    ///
    /// Panics if the AST is not for a newtype struct.
    fn inner_type(&self) -> &Field;

    /// Returns the `Field` of the first unnamed field of this struct's AST.
    ///
    /// # Panics
    ///
    /// Panics if the AST is not for a newtype struct.
    fn inner_type_mut(&mut self) -> &mut Field;
}

impl DeriveInputNewtypeExt for DeriveInput {
    fn inner_type(&self) -> &Field {
        if let Data::Struct(DataStruct {
            fields: Fields::Unnamed(fields_unnamed),
            ..
        }) = &self.data
        {
            if fields_unnamed.unnamed.len() == 1 {
                fields_unnamed
                    .unnamed
                    .first()
                    .expect("Expected field to exist.")
                    .value()
            } else {
                panic!(NEWTYPE_MUST_HAVE_ONLY_ONE_FIELD)
            }
        } else {
            panic!(MACRO_MUST_BE_USED_ON_NEWTYPE_STRUCT)
        }
    }

    fn inner_type_mut(&mut self) -> &mut Field {
        if let Data::Struct(DataStruct {
            fields: Fields::Unnamed(fields_unnamed),
            ..
        }) = &mut self.data
        {
            if fields_unnamed.unnamed.len() == 1 {
                fields_unnamed
                    .unnamed
                    .iter_mut()
                    .next()
                    .expect("Expected field to exist.")
            } else {
                panic!(NEWTYPE_MUST_HAVE_ONLY_ONE_FIELD)
            }
        } else {
            panic!(MACRO_MUST_BE_USED_ON_NEWTYPE_STRUCT)
        }
    }
}

#[cfg(test)]
mod tests {
    use syn::{parse_quote, DeriveInput, Type};

    use super::DeriveInputNewtypeExt;

    #[test]
    fn inner_type_returns_field() {
        let ast: DeriveInput = parse_quote! {
            struct Newtype(u32);
        };

        let inner_field = ast.inner_type();

        let expected_type: Type = Type::Path(parse_quote!(u32));
        assert_eq!(expected_type, inner_field.ty);
    }

    #[test]
    #[should_panic(expected = "This macro must be used on a newtype struct.\n\
        See https://doc.rust-lang.org/book/ch19-04-advanced-types.html#advanced-types \
        for more information.")]
    fn inner_type_panics_when_struct_fields_not_unnamed() {
        let ast: DeriveInput = parse_quote! {
            struct Unit;
        };

        ast.inner_type();
    }

    #[test]
    #[should_panic(expected = "Newtype struct must only have one field.\n\
        See https://doc.rust-lang.org/book/ch19-04-advanced-types.html#advanced-types \
        for more information.")]
    fn inner_type_panics_when_struct_has_multiple_fields() {
        let ast: DeriveInput = parse_quote! {
            struct Newtype(u32, u32);
        };

        ast.inner_type();
    }

    #[test]
    fn inner_type_mut_returns_field() {
        let mut ast: DeriveInput = parse_quote! {
            struct Newtype(u32);
        };

        let inner_field = ast.inner_type_mut();

        let expected_type: Type = Type::Path(parse_quote!(u32));
        assert_eq!(expected_type, inner_field.ty);
    }

    #[test]
    #[should_panic(expected = "This macro must be used on a newtype struct.\n\
        See https://doc.rust-lang.org/book/ch19-04-advanced-types.html#advanced-types \
        for more information.")]
    fn inner_type_mut_panics_when_struct_fields_not_unnamed() {
        let mut ast: DeriveInput = parse_quote! {
            struct Unit;
        };

        ast.inner_type_mut();
    }

    #[test]
    #[should_panic(expected = "Newtype struct must only have one field.\n\
        See https://doc.rust-lang.org/book/ch19-04-advanced-types.html#advanced-types \
        for more information.")]
    fn inner_type_mut_panics_when_struct_has_multiple_fields() {
        let mut ast: DeriveInput = parse_quote! {
            struct Newtype(u32, u32);
        };

        ast.inner_type_mut();
    }
}