pliron/builtin/
types.rs

1use combine::{
2    Parser, choice,
3    parser::char::{spaces, string},
4};
5use pliron::derive::def_type;
6use pliron_derive::{format_type, type_interface_impl};
7
8use crate::{
9    builtin::type_interfaces::FloatType,
10    context::{Context, Ptr},
11    impl_verify_succ,
12    irfmt::parsers::int_parser,
13    parsable::{Parsable, ParseResult, StateStream},
14    printable::{self, Printable},
15    r#type::{Type, TypeObj, TypePtr},
16    utils::apfloat::{self, GetSemantics, Semantics},
17};
18
19#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)]
20pub enum Signedness {
21    Signed,
22    Unsigned,
23    Signless,
24}
25
26#[def_type("builtin.integer")]
27#[derive(Hash, PartialEq, Eq, Debug, Clone)]
28pub struct IntegerType {
29    width: u32,
30    signedness: Signedness,
31}
32
33impl IntegerType {
34    /// Get or create a new integer type.
35    pub fn get(ctx: &mut Context, width: u32, signedness: Signedness) -> TypePtr<Self> {
36        Type::register_instance(IntegerType { width, signedness }, ctx)
37    }
38    /// Get, if it already exists, an integer type.
39    pub fn get_existing(
40        ctx: &Context,
41        width: u32,
42        signedness: Signedness,
43    ) -> Option<TypePtr<Self>> {
44        Type::get_instance(IntegerType { width, signedness }, ctx)
45    }
46
47    /// Get width.
48    pub fn get_width(&self) -> u32 {
49        self.width
50    }
51
52    /// Get signedness.
53    pub fn get_signedness(&self) -> Signedness {
54        self.signedness
55    }
56
57    /// Is Signed?
58    pub fn is_signed(&self) -> bool {
59        matches!(self.signedness, Signedness::Signed)
60    }
61
62    /// Is Unsigned?
63    pub fn is_unsigned(&self) -> bool {
64        matches!(self.signedness, Signedness::Unsigned)
65    }
66
67    /// Is Signless?
68    pub fn is_signless(&self) -> bool {
69        matches!(self.signedness, Signedness::Signless)
70    }
71}
72
73impl Parsable for IntegerType {
74    type Arg = ();
75    type Parsed = TypePtr<Self>;
76    fn parse<'a>(
77        state_stream: &mut StateStream<'a>,
78        _arg: Self::Arg,
79    ) -> ParseResult<'a, Self::Parsed>
80    where
81        Self: Sized,
82    {
83        // Choose b/w si/ui/i ...
84        let choicer = choice((
85            string("si").map(|_| Signedness::Signed),
86            string("ui").map(|_| Signedness::Unsigned),
87            string("i").map(|_| Signedness::Signless),
88        ));
89
90        // followed by an integer.
91        let mut parser = spaces().with(choicer.and(int_parser()));
92        parser
93            .parse_stream(state_stream)
94            .map(|(signedness, width)| IntegerType::get(state_stream.state.ctx, width, signedness))
95            .into()
96    }
97}
98
99impl Printable for IntegerType {
100    fn fmt(
101        &self,
102        _ctx: &Context,
103        _state: &printable::State,
104        f: &mut core::fmt::Formatter<'_>,
105    ) -> core::fmt::Result {
106        match &self.signedness {
107            Signedness::Signed => write!(f, "si{}", self.width)?,
108            Signedness::Unsigned => write!(f, "ui{}", self.width)?,
109            Signedness::Signless => write!(f, "i{}", self.width)?,
110        };
111        Ok(())
112    }
113}
114
115impl_verify_succ!(IntegerType);
116
117/// Map from a list of inputs to a list of results
118///
119/// See MLIR's [FunctionType](https://mlir.llvm.org/docs/Dialects/Builtin/#functiontype).
120///
121#[def_type("builtin.function")]
122#[derive(Hash, PartialEq, Eq, Debug)]
123#[format_type(
124    "`<` `(` vec($inputs, CharSpace(`,`)) `)` `->` `(`vec($results, CharSpace(`,`)) `)` `>`"
125)]
126pub struct FunctionType {
127    /// Function arguments / inputs.
128    inputs: Vec<Ptr<TypeObj>>,
129    /// Function results / outputs.
130    results: Vec<Ptr<TypeObj>>,
131}
132
133impl FunctionType {
134    /// Get or create a new Function type.
135    pub fn get(
136        ctx: &mut Context,
137        inputs: Vec<Ptr<TypeObj>>,
138        results: Vec<Ptr<TypeObj>>,
139    ) -> TypePtr<Self> {
140        Type::register_instance(FunctionType { inputs, results }, ctx)
141    }
142    /// Get, if it already exists, a Function type.
143    pub fn get_existing(
144        ctx: &Context,
145        inputs: Vec<Ptr<TypeObj>>,
146        results: Vec<Ptr<TypeObj>>,
147    ) -> Option<TypePtr<Self>> {
148        Type::get_instance(FunctionType { inputs, results }, ctx)
149    }
150
151    /// Get a reference to the function input / argument types.
152    pub fn get_inputs(&self) -> &Vec<Ptr<TypeObj>> {
153        &self.inputs
154    }
155
156    /// Get a reference to the function result / output types.
157    pub fn get_results(&self) -> &Vec<Ptr<TypeObj>> {
158        &self.results
159    }
160}
161
162impl_verify_succ!(FunctionType);
163
164#[def_type("builtin.unit")]
165#[format_type]
166#[derive(Hash, PartialEq, Eq, Debug)]
167pub struct UnitType;
168
169impl UnitType {
170    /// Get or create a new unit type.
171    pub fn get(ctx: &mut Context) -> TypePtr<Self> {
172        Type::register_instance(Self {}, ctx)
173    }
174}
175
176impl_verify_succ!(UnitType);
177
178#[def_type("builtin.fp32")]
179#[format_type]
180#[derive(Hash, PartialEq, Eq, Debug)]
181pub struct FP32Type;
182impl_verify_succ!(FP32Type);
183#[type_interface_impl]
184impl FloatType for FP32Type {
185    fn get_semantics(&self) -> Semantics {
186        apfloat::Single::get_semantics()
187    }
188}
189
190impl FP32Type {
191    /// Register type in dialect and instantiate the singleton instance.
192    pub fn register_and_instantiate(ctx: &mut Context) {
193        Self::register_type_in_dialect(ctx, Self::parser_fn);
194        Type::register_instance(Self {}, ctx);
195    }
196
197    /// Get the singleton fp32 type.
198    pub fn get(ctx: &Context) -> TypePtr<Self> {
199        Type::get_instance(Self {}, ctx).expect("FP32Type singleton not instantiated")
200    }
201}
202
203#[def_type("builtin.fp64")]
204#[format_type]
205#[derive(Hash, PartialEq, Eq, Debug)]
206pub struct FP64Type;
207impl_verify_succ!(FP64Type);
208#[type_interface_impl]
209impl FloatType for FP64Type {
210    fn get_semantics(&self) -> Semantics {
211        apfloat::Double::get_semantics()
212    }
213}
214
215impl FP64Type {
216    /// Register type in dialect and instantiate the singleton instance.
217    pub fn register_and_instantiate(ctx: &mut Context) {
218        Self::register_type_in_dialect(ctx, Self::parser_fn);
219        Type::register_instance(Self {}, ctx);
220    }
221
222    /// Get or create a new fp64 type.
223    pub fn get(ctx: &Context) -> TypePtr<Self> {
224        Type::get_instance(Self {}, ctx).expect("FP64Type singleton not instantiated")
225    }
226}
227
228pub fn register(ctx: &mut Context) {
229    IntegerType::register_type_in_dialect(ctx, IntegerType::parser_fn);
230    FunctionType::register_type_in_dialect(ctx, FunctionType::parser_fn);
231    UnitType::register_type_in_dialect(ctx, UnitType::parser_fn);
232
233    FP32Type::register_and_instantiate(ctx);
234    FP64Type::register_and_instantiate(ctx);
235}
236
237#[cfg(test)]
238mod tests {
239    use combine::{Parser, eof};
240    use expect_test::expect;
241
242    use super::FunctionType;
243    use crate::{
244        builtin::{
245            self,
246            types::{IntegerType, Signedness},
247        },
248        context::Context,
249        location,
250        parsable::{self, Parsable, state_stream_from_iterator},
251        r#type::Type,
252    };
253    #[test]
254    fn test_integer_types() {
255        let mut ctx = Context::new();
256
257        let int32_1_ptr = IntegerType::get(&mut ctx, 32, Signedness::Signed);
258        let int32_2_ptr = IntegerType::get(&mut ctx, 32, Signedness::Signed);
259        let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signed);
260        let uint32_ptr = IntegerType::get(&mut ctx, 32, Signedness::Unsigned);
261
262        assert!(int32_1_ptr.deref(&ctx).hash_type() == int32_2_ptr.deref(&ctx).hash_type());
263        assert!(int32_1_ptr.deref(&ctx).hash_type() != int64_ptr.deref(&ctx).hash_type());
264        assert!(int32_1_ptr.deref(&ctx).hash_type() != uint32_ptr.deref(&ctx).hash_type());
265        assert!(int32_1_ptr == int32_2_ptr);
266        assert!(int32_1_ptr != int64_ptr);
267        assert!(int32_1_ptr != uint32_ptr);
268
269        assert!(int32_1_ptr.deref(&ctx).get_self_ptr(&ctx) == int32_1_ptr.into());
270        assert!(int32_2_ptr.deref(&ctx).get_self_ptr(&ctx) == int32_1_ptr.into());
271        assert!(int32_2_ptr.deref(&ctx).get_self_ptr(&ctx) == int32_2_ptr.into());
272        assert!(int64_ptr.deref(&ctx).get_self_ptr(&ctx) == int64_ptr.into());
273        assert!(uint32_ptr.deref(&ctx).get_self_ptr(&ctx) == uint32_ptr.into());
274        assert!(uint32_ptr.deref(&ctx).get_self_ptr(&ctx) != int32_1_ptr.into());
275        assert!(uint32_ptr.deref(&ctx).get_self_ptr(&ctx) != int64_ptr.into());
276    }
277
278    #[test]
279    fn test_function_types() {
280        let mut ctx = Context::new();
281        let int32_1_ptr = IntegerType::get(&mut ctx, 32, Signedness::Signed);
282        let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signed);
283
284        let ft_ref = FunctionType::get(&mut ctx, vec![int32_1_ptr.into()], vec![int64_ptr.into()])
285            .deref(&ctx);
286        assert!(
287            ft_ref.get_inputs()[0] == int32_1_ptr.into()
288                && ft_ref.get_results()[0] == int64_ptr.into()
289        );
290    }
291
292    #[test]
293    fn test_integer_parsing() {
294        let mut ctx = Context::new();
295        let state_stream = state_stream_from_iterator(
296            "si64".chars(),
297            parsable::State::new(&mut ctx, location::Source::InMemory),
298        );
299
300        let res = IntegerType::parser(())
301            .and(eof())
302            .parse(state_stream)
303            .unwrap()
304            .0
305            .0;
306        assert!(res == IntegerType::get_existing(&ctx, 64, Signedness::Signed).unwrap())
307    }
308
309    #[test]
310    fn test_integer_parsing_errs() {
311        let mut ctx = Context::new();
312        let a = "asi64".to_string();
313        let state_stream = state_stream_from_iterator(
314            a.chars(),
315            parsable::State::new(&mut ctx, location::Source::InMemory),
316        );
317
318        let res = IntegerType::parser(()).parse(state_stream);
319        let err_msg = format!("{}", res.err().unwrap());
320
321        let expected_err_msg = expect![[r#"
322            Parse error at line: 1, column: 1
323            Unexpected `a`
324            Expected whitespaces, si, ui or i
325        "#]];
326        expected_err_msg.assert_eq(&err_msg);
327    }
328
329    #[test]
330    fn test_fntype_parsing() {
331        let mut ctx = Context::new();
332        builtin::register(&mut ctx);
333
334        let si32 = IntegerType::get(&mut ctx, 32, Signedness::Signed);
335
336        let state_stream = state_stream_from_iterator(
337            "<() -> (builtin.integer si32)>".chars(),
338            parsable::State::new(&mut ctx, location::Source::InMemory),
339        );
340
341        let res = FunctionType::parser(())
342            .and(eof())
343            .parse(state_stream)
344            .unwrap()
345            .0
346            .0;
347        assert!(res == FunctionType::get_existing(&ctx, vec![], vec![si32.into()]).unwrap())
348    }
349}