air_types/
types.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use num::BigInt;
use std::fmt::{Debug, Display, Formatter};
use std::num::{NonZero, NonZeroU8};

/// Type to use as convenience
pub const VOID: Type = Type::Void;
/// Type to use as convenience
pub const BOOL: Type = Type::int(1);
/// Type to use as convenience
pub const I8: Type = Type::int(8);
/// Type to use as convenience
pub const I16: Type = Type::int(16);
/// Type to use as convenience
pub const I32: Type = Type::int(32);
/// Type to use as convenience
pub const I64: Type = Type::int(64);

/// All non-void types
pub const TYPES: [Type; 5] = [BOOL, I8, I16, I32, I64];

/// Represents a type
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub enum Type {
    /// The void type
    Void,
    /// An integer type with an associated width
    Int(NonZeroU8),
}

impl Type {
    /// Return the bit width of a value
    pub fn bit_width(&self) -> u32 {
        match self {
            Type::Void => 0,
            Type::Int(w) => w.get().into(),
        }
    }
}

impl Type {
    /// Create a new integer type
    pub const fn int(n: u8) -> Self {
        if n == 0 {
            panic!("Cannot have i0");
        }

        // Safety: we know the value is not zero
        Self::Int(unsafe { NonZero::new_unchecked(n) })
    }

    /// Returns true if self is void
    #[inline]
    pub fn is_void(&self) -> bool {
        matches!(self, Self::Void)
    }

    /// Returns true if self is bool
    #[inline]
    pub fn is_bool(&self) -> bool {
        matches!(self, Self::Int(w) if w.get() == 1)
    }

    /// The maximum value that can be stored in the type.
    #[inline]
    pub fn max_val(&self) -> BigInt {
        match self {
            Type::Void => panic!("Cannot get max value of void"),
            Type::Int(w) => (BigInt::from(1) << w.get()) - 1,
        }
    }
}

impl Display for Type {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            Type::Void => write!(f, "void"),
            Type::Int(w) => write!(f, "i{}", w),
        }
    }
}

impl Debug for Type {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            Type::Void => write!(f, "VOID"),
            Type::Int(w) => write!(f, "I{}", w),
        }
    }
}

impl TryFrom<&str> for Type {
    type Error = ();

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        match value {
            "void" => Ok(Type::Void),
            "i1" | "bool" => Ok(BOOL),
            "i8" => Ok(I8),
            "i16" => Ok(I16),
            "i32" => Ok(I32),
            "i64" => Ok(I64),
            _ => Err(()),
        }
    }
}

#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for Type {
    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
        let variant = u.int_in_range(0..=4)?;
        let ty = match variant {
            0 => BOOL,
            1 => I8,
            2 => I16,
            3 => I32,
            4 => I64,
            _ => unreachable!(),
        };

        Ok(ty)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_display() {
        let actual = &[VOID, BOOL, I8, I16, I32, I64]
            .iter()
            .map(|ty| ty.to_string())
            .collect::<Vec<_>>()
            .join(", ");
        let expected = "void, i1, i8, i16, i32, i64";
        assert_eq!(actual, expected)
    }

    #[test]
    fn test_debug() {
        let actual = &[VOID, BOOL, I8, I16, I32, I64]
            .iter()
            .map(|ty| format!("{ty:?}"))
            .collect::<Vec<_>>()
            .join(", ");
        let expected = "VOID, I1, I8, I16, I32, I64";
        assert_eq!(actual, expected)
    }
}