cranelift_codegen_meta/cdsl/
typevar.rs

1use std::cell::RefCell;
2use std::collections::BTreeSet;
3use std::fmt;
4use std::hash;
5use std::ops;
6use std::rc::Rc;
7
8use crate::cdsl::types::{LaneType, ReferenceType, ValueType};
9
10const MAX_LANES: u16 = 256;
11const MAX_BITS: u16 = 128;
12const MAX_FLOAT_BITS: u16 = 128;
13
14/// Type variables can be used in place of concrete types when defining
15/// instructions. This makes the instructions *polymorphic*.
16///
17/// A type variable is restricted to vary over a subset of the value types.
18/// This subset is specified by a set of flags that control the permitted base
19/// types and whether the type variable can assume scalar or vector types, or
20/// both.
21#[derive(Debug)]
22pub(crate) struct TypeVarContent {
23    /// Short name of type variable used in instruction descriptions.
24    pub name: String,
25
26    /// Documentation string.
27    pub doc: String,
28
29    /// Type set associated to the type variable.
30    /// This field must remain private; use `get_typeset()` or `get_raw_typeset()` to get the
31    /// information you want.
32    type_set: TypeSet,
33
34    pub base: Option<TypeVarParent>,
35}
36
37#[derive(Clone, Debug)]
38pub(crate) struct TypeVar {
39    content: Rc<RefCell<TypeVarContent>>,
40}
41
42impl TypeVar {
43    pub fn new(name: impl Into<String>, doc: impl Into<String>, type_set: TypeSet) -> Self {
44        Self {
45            content: Rc::new(RefCell::new(TypeVarContent {
46                name: name.into(),
47                doc: doc.into(),
48                type_set,
49                base: None,
50            })),
51        }
52    }
53
54    pub fn new_singleton(value_type: ValueType) -> Self {
55        let (name, doc) = (value_type.to_string(), value_type.doc());
56        let mut builder = TypeSetBuilder::new();
57
58        let (scalar_type, num_lanes) = match value_type {
59            ValueType::Reference(ReferenceType(reference_type)) => {
60                let bits = reference_type as RangeBound;
61                return TypeVar::new(name, doc, builder.refs(bits..bits).build());
62            }
63            ValueType::Lane(lane_type) => (lane_type, 1),
64            ValueType::Vector(vec_type) => {
65                (vec_type.lane_type(), vec_type.lane_count() as RangeBound)
66            }
67            ValueType::DynamicVector(vec_type) => (
68                vec_type.lane_type(),
69                vec_type.minimum_lane_count() as RangeBound,
70            ),
71        };
72
73        builder = builder.simd_lanes(num_lanes..num_lanes);
74
75        // Only generate dynamic types for multiple lanes.
76        if num_lanes > 1 {
77            builder = builder.dynamic_simd_lanes(num_lanes..num_lanes);
78        }
79
80        let builder = match scalar_type {
81            LaneType::Int(int_type) => {
82                let bits = int_type as RangeBound;
83                builder.ints(bits..bits)
84            }
85            LaneType::Float(float_type) => {
86                let bits = float_type as RangeBound;
87                builder.floats(bits..bits)
88            }
89        };
90        TypeVar::new(name, doc, builder.build())
91    }
92
93    /// Get a fresh copy of self, named after `name`. Can only be called on non-derived typevars.
94    pub fn copy_from(other: &TypeVar, name: String) -> TypeVar {
95        assert!(
96            other.base.is_none(),
97            "copy_from() can only be called on non-derived type variables"
98        );
99        TypeVar {
100            content: Rc::new(RefCell::new(TypeVarContent {
101                name,
102                doc: "".into(),
103                type_set: other.type_set.clone(),
104                base: None,
105            })),
106        }
107    }
108
109    /// Returns the typeset for this TV. If the TV is derived, computes it recursively from the
110    /// derived function and the base's typeset.
111    /// Note this can't be done non-lazily in the constructor, because the TypeSet of the base may
112    /// change over time.
113    pub fn get_typeset(&self) -> TypeSet {
114        match &self.base {
115            Some(base) => base.type_var.get_typeset().image(base.derived_func),
116            None => self.type_set.clone(),
117        }
118    }
119
120    /// Returns this typevar's type set, assuming this type var has no parent.
121    pub fn get_raw_typeset(&self) -> &TypeSet {
122        assert_eq!(self.type_set, self.get_typeset());
123        &self.type_set
124    }
125
126    /// If the associated typeset has a single type return it. Otherwise return None.
127    pub fn singleton_type(&self) -> Option<ValueType> {
128        let type_set = self.get_typeset();
129        if type_set.size() == 1 {
130            Some(type_set.get_singleton())
131        } else {
132            None
133        }
134    }
135
136    /// Get the free type variable controlling this one.
137    pub fn free_typevar(&self) -> Option<TypeVar> {
138        match &self.base {
139            Some(base) => base.type_var.free_typevar(),
140            None => {
141                match self.singleton_type() {
142                    // A singleton type isn't a proper free variable.
143                    Some(_) => None,
144                    None => Some(self.clone()),
145                }
146            }
147        }
148    }
149
150    /// Create a type variable that is a function of another.
151    pub fn derived(&self, derived_func: DerivedFunc) -> TypeVar {
152        let ts = self.get_typeset();
153
154        // Safety checks to avoid over/underflows.
155        match derived_func {
156            DerivedFunc::HalfWidth => {
157                assert!(
158                    ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
159                    "can't halve all integer types"
160                );
161                assert!(
162                    ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 16,
163                    "can't halve all float types"
164                );
165            }
166            DerivedFunc::DoubleWidth => {
167                assert!(
168                    ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
169                    "can't double all integer types"
170                );
171                assert!(
172                    ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
173                    "can't double all float types"
174                );
175            }
176            DerivedFunc::SplitLanes => {
177                assert!(
178                    ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
179                    "can't halve all integer types"
180                );
181                assert!(
182                    ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 16,
183                    "can't halve all float types"
184                );
185                assert!(
186                    *ts.lanes.iter().max().unwrap() < MAX_LANES,
187                    "can't double 256 lanes"
188                );
189            }
190            DerivedFunc::MergeLanes => {
191                assert!(
192                    ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
193                    "can't double all integer types"
194                );
195                assert!(
196                    ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
197                    "can't double all float types"
198                );
199                assert!(
200                    *ts.lanes.iter().min().unwrap() > 1,
201                    "can't halve a scalar type"
202                );
203            }
204            DerivedFunc::Narrower => {
205                assert_eq!(
206                    *ts.lanes.iter().max().unwrap(),
207                    1,
208                    "The `narrower` constraint does not apply to vectors"
209                );
210                assert!(
211                    (!ts.ints.is_empty() || !ts.floats.is_empty())
212                        && ts.refs.is_empty()
213                        && ts.dynamic_lanes.is_empty(),
214                    "The `narrower` constraint only applies to scalar ints or floats"
215                );
216            }
217            DerivedFunc::Wider => {
218                assert_eq!(
219                    *ts.lanes.iter().max().unwrap(),
220                    1,
221                    "The `wider` constraint does not apply to vectors"
222                );
223                assert!(
224                    (!ts.ints.is_empty() || !ts.floats.is_empty())
225                        && ts.refs.is_empty()
226                        && ts.dynamic_lanes.is_empty(),
227                    "The `wider` constraint only applies to scalar ints or floats"
228                );
229            }
230            DerivedFunc::LaneOf | DerivedFunc::AsTruthy | DerivedFunc::DynamicToVector => {
231                /* no particular assertions */
232            }
233        }
234
235        TypeVar {
236            content: Rc::new(RefCell::new(TypeVarContent {
237                name: format!("{}({})", derived_func.name(), self.name),
238                doc: "".into(),
239                type_set: ts,
240                base: Some(TypeVarParent {
241                    type_var: self.clone(),
242                    derived_func,
243                }),
244            })),
245        }
246    }
247
248    pub fn lane_of(&self) -> TypeVar {
249        self.derived(DerivedFunc::LaneOf)
250    }
251    pub fn as_truthy(&self) -> TypeVar {
252        self.derived(DerivedFunc::AsTruthy)
253    }
254    pub fn half_width(&self) -> TypeVar {
255        self.derived(DerivedFunc::HalfWidth)
256    }
257    pub fn double_width(&self) -> TypeVar {
258        self.derived(DerivedFunc::DoubleWidth)
259    }
260    pub fn split_lanes(&self) -> TypeVar {
261        self.derived(DerivedFunc::SplitLanes)
262    }
263    pub fn merge_lanes(&self) -> TypeVar {
264        self.derived(DerivedFunc::MergeLanes)
265    }
266    pub fn dynamic_to_vector(&self) -> TypeVar {
267        self.derived(DerivedFunc::DynamicToVector)
268    }
269
270    /// Make a new [TypeVar] that includes all types narrower than self.
271    pub fn narrower(&self) -> TypeVar {
272        self.derived(DerivedFunc::Narrower)
273    }
274
275    /// Make a new [TypeVar] that includes all types wider than self.
276    pub fn wider(&self) -> TypeVar {
277        self.derived(DerivedFunc::Wider)
278    }
279}
280
281impl From<&TypeVar> for TypeVar {
282    fn from(type_var: &TypeVar) -> Self {
283        type_var.clone()
284    }
285}
286impl From<ValueType> for TypeVar {
287    fn from(value_type: ValueType) -> Self {
288        TypeVar::new_singleton(value_type)
289    }
290}
291
292// Hash TypeVars by pointers.
293// There might be a better way to do this, but since TypeVar's content (namely TypeSet) can be
294// mutated, it makes sense to use pointer equality/hashing here.
295impl hash::Hash for TypeVar {
296    fn hash<H: hash::Hasher>(&self, h: &mut H) {
297        match &self.base {
298            Some(base) => {
299                base.type_var.hash(h);
300                base.derived_func.hash(h);
301            }
302            None => {
303                (&**self as *const TypeVarContent).hash(h);
304            }
305        }
306    }
307}
308
309impl PartialEq for TypeVar {
310    fn eq(&self, other: &TypeVar) -> bool {
311        match (&self.base, &other.base) {
312            (Some(base1), Some(base2)) => {
313                base1.type_var.eq(&base2.type_var) && base1.derived_func == base2.derived_func
314            }
315            (None, None) => Rc::ptr_eq(&self.content, &other.content),
316            _ => false,
317        }
318    }
319}
320
321// Allow TypeVar as map keys, based on pointer equality (see also above PartialEq impl).
322impl Eq for TypeVar {}
323
324impl ops::Deref for TypeVar {
325    type Target = TypeVarContent;
326    fn deref(&self) -> &Self::Target {
327        unsafe { self.content.as_ptr().as_ref().unwrap() }
328    }
329}
330
331#[derive(Clone, Copy, Debug, Hash, PartialEq)]
332pub(crate) enum DerivedFunc {
333    LaneOf,
334    AsTruthy,
335    HalfWidth,
336    DoubleWidth,
337    SplitLanes,
338    MergeLanes,
339    DynamicToVector,
340    Narrower,
341    Wider,
342}
343
344impl DerivedFunc {
345    pub fn name(self) -> &'static str {
346        match self {
347            DerivedFunc::LaneOf => "lane_of",
348            DerivedFunc::AsTruthy => "as_truthy",
349            DerivedFunc::HalfWidth => "half_width",
350            DerivedFunc::DoubleWidth => "double_width",
351            DerivedFunc::SplitLanes => "split_lanes",
352            DerivedFunc::MergeLanes => "merge_lanes",
353            DerivedFunc::DynamicToVector => "dynamic_to_vector",
354            DerivedFunc::Narrower => "narrower",
355            DerivedFunc::Wider => "wider",
356        }
357    }
358}
359
360#[derive(Debug, Hash)]
361pub(crate) struct TypeVarParent {
362    pub type_var: TypeVar,
363    pub derived_func: DerivedFunc,
364}
365
366/// A set of types.
367///
368/// We don't allow arbitrary subsets of types, but use a parametrized approach
369/// instead.
370///
371/// Objects of this class can be used as dictionary keys.
372///
373/// Parametrized type sets are specified in terms of ranges:
374/// - The permitted range of vector lanes, where 1 indicates a scalar type.
375/// - The permitted range of integer types.
376/// - The permitted range of floating point types, and
377/// - The permitted range of boolean types.
378///
379/// The ranges are inclusive from smallest bit-width to largest bit-width.
380
381type RangeBound = u16;
382type Range = ops::Range<RangeBound>;
383type NumSet = BTreeSet<RangeBound>;
384
385macro_rules! num_set {
386    ($($expr:expr),*) => {
387        NumSet::from_iter(vec![$($expr),*])
388    };
389}
390
391#[derive(Clone, PartialEq, Eq, Hash)]
392pub(crate) struct TypeSet {
393    pub lanes: NumSet,
394    pub dynamic_lanes: NumSet,
395    pub ints: NumSet,
396    pub floats: NumSet,
397    pub refs: NumSet,
398}
399
400impl TypeSet {
401    fn new(
402        lanes: NumSet,
403        dynamic_lanes: NumSet,
404        ints: NumSet,
405        floats: NumSet,
406        refs: NumSet,
407    ) -> Self {
408        Self {
409            lanes,
410            dynamic_lanes,
411            ints,
412            floats,
413            refs,
414        }
415    }
416
417    /// Return the number of concrete types represented by this typeset.
418    pub fn size(&self) -> usize {
419        self.lanes.len() * (self.ints.len() + self.floats.len() + self.refs.len())
420            + self.dynamic_lanes.len() * (self.ints.len() + self.floats.len() + self.refs.len())
421    }
422
423    /// Return the image of self across the derived function func.
424    fn image(&self, derived_func: DerivedFunc) -> TypeSet {
425        match derived_func {
426            DerivedFunc::LaneOf => self.lane_of(),
427            DerivedFunc::AsTruthy => self.as_truthy(),
428            DerivedFunc::HalfWidth => self.half_width(),
429            DerivedFunc::DoubleWidth => self.double_width(),
430            DerivedFunc::SplitLanes => self.half_width().double_vector(),
431            DerivedFunc::MergeLanes => self.double_width().half_vector(),
432            DerivedFunc::DynamicToVector => self.dynamic_to_vector(),
433            DerivedFunc::Narrower => self.clone(),
434            DerivedFunc::Wider => self.clone(),
435        }
436    }
437
438    /// Return a TypeSet describing the image of self across lane_of.
439    fn lane_of(&self) -> TypeSet {
440        let mut copy = self.clone();
441        copy.lanes = num_set![1];
442        copy
443    }
444
445    /// Return a TypeSet describing the image of self across as_truthy.
446    fn as_truthy(&self) -> TypeSet {
447        let mut copy = self.clone();
448
449        // If this type set represents a scalar, `as_truthy` produces an I8, otherwise it returns a
450        // vector of the same number of lanes, whose elements are integers of the same width. For
451        // example, F32X4 gets turned into I32X4, while I32 gets turned into I8.
452        if self.lanes.len() == 1 && self.lanes.contains(&1) {
453            copy.ints = NumSet::from([8]);
454        } else {
455            copy.ints.extend(&self.floats)
456        }
457
458        copy.floats = NumSet::new();
459        copy.refs = NumSet::new();
460        copy
461    }
462
463    /// Return a TypeSet describing the image of self across halfwidth.
464    fn half_width(&self) -> TypeSet {
465        let mut copy = self.clone();
466        copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x > 8).map(|&x| x / 2));
467        copy.floats = NumSet::from_iter(self.floats.iter().filter(|&&x| x > 16).map(|&x| x / 2));
468        copy
469    }
470
471    /// Return a TypeSet describing the image of self across doublewidth.
472    fn double_width(&self) -> TypeSet {
473        let mut copy = self.clone();
474        copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x < MAX_BITS).map(|&x| x * 2));
475        copy.floats = NumSet::from_iter(
476            self.floats
477                .iter()
478                .filter(|&&x| x < MAX_FLOAT_BITS)
479                .map(|&x| x * 2),
480        );
481        copy
482    }
483
484    /// Return a TypeSet describing the image of self across halfvector.
485    fn half_vector(&self) -> TypeSet {
486        let mut copy = self.clone();
487        copy.lanes = NumSet::from_iter(self.lanes.iter().filter(|&&x| x > 1).map(|&x| x / 2));
488        copy
489    }
490
491    /// Return a TypeSet describing the image of self across doublevector.
492    fn double_vector(&self) -> TypeSet {
493        let mut copy = self.clone();
494        copy.lanes = NumSet::from_iter(
495            self.lanes
496                .iter()
497                .filter(|&&x| x < MAX_LANES)
498                .map(|&x| x * 2),
499        );
500        copy
501    }
502
503    fn dynamic_to_vector(&self) -> TypeSet {
504        let mut copy = self.clone();
505        copy.lanes = NumSet::from_iter(
506            self.dynamic_lanes
507                .iter()
508                .filter(|&&x| x < MAX_LANES)
509                .copied(),
510        );
511        copy.dynamic_lanes = NumSet::new();
512        copy
513    }
514
515    fn concrete_types(&self) -> Vec<ValueType> {
516        let mut ret = Vec::new();
517        for &num_lanes in &self.lanes {
518            for &bits in &self.ints {
519                ret.push(LaneType::int_from_bits(bits).by(num_lanes));
520            }
521            for &bits in &self.floats {
522                ret.push(LaneType::float_from_bits(bits).by(num_lanes));
523            }
524            for &bits in &self.refs {
525                ret.push(ReferenceType::ref_from_bits(bits).into());
526            }
527        }
528        for &num_lanes in &self.dynamic_lanes {
529            for &bits in &self.ints {
530                ret.push(LaneType::int_from_bits(bits).to_dynamic(num_lanes));
531            }
532            for &bits in &self.floats {
533                ret.push(LaneType::float_from_bits(bits).to_dynamic(num_lanes));
534            }
535        }
536        ret
537    }
538
539    /// Return the singleton type represented by self. Can only call on typesets containing 1 type.
540    fn get_singleton(&self) -> ValueType {
541        let mut types = self.concrete_types();
542        assert_eq!(types.len(), 1);
543        types.remove(0)
544    }
545}
546
547impl fmt::Debug for TypeSet {
548    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
549        write!(fmt, "TypeSet(")?;
550
551        let mut subsets = Vec::new();
552        if !self.lanes.is_empty() {
553            subsets.push(format!(
554                "lanes={{{}}}",
555                Vec::from_iter(self.lanes.iter().map(|x| x.to_string())).join(", ")
556            ));
557        }
558        if !self.dynamic_lanes.is_empty() {
559            subsets.push(format!(
560                "dynamic_lanes={{{}}}",
561                Vec::from_iter(self.dynamic_lanes.iter().map(|x| x.to_string())).join(", ")
562            ));
563        }
564        if !self.ints.is_empty() {
565            subsets.push(format!(
566                "ints={{{}}}",
567                Vec::from_iter(self.ints.iter().map(|x| x.to_string())).join(", ")
568            ));
569        }
570        if !self.floats.is_empty() {
571            subsets.push(format!(
572                "floats={{{}}}",
573                Vec::from_iter(self.floats.iter().map(|x| x.to_string())).join(", ")
574            ));
575        }
576        if !self.refs.is_empty() {
577            subsets.push(format!(
578                "refs={{{}}}",
579                Vec::from_iter(self.refs.iter().map(|x| x.to_string())).join(", ")
580            ));
581        }
582
583        write!(fmt, "{})", subsets.join(", "))?;
584        Ok(())
585    }
586}
587
588pub(crate) struct TypeSetBuilder {
589    ints: Interval,
590    floats: Interval,
591    refs: Interval,
592    includes_scalars: bool,
593    simd_lanes: Interval,
594    dynamic_simd_lanes: Interval,
595}
596
597impl TypeSetBuilder {
598    pub fn new() -> Self {
599        Self {
600            ints: Interval::None,
601            floats: Interval::None,
602            refs: Interval::None,
603            includes_scalars: true,
604            simd_lanes: Interval::None,
605            dynamic_simd_lanes: Interval::None,
606        }
607    }
608
609    pub fn ints(mut self, interval: impl Into<Interval>) -> Self {
610        assert!(self.ints == Interval::None);
611        self.ints = interval.into();
612        self
613    }
614    pub fn floats(mut self, interval: impl Into<Interval>) -> Self {
615        assert!(self.floats == Interval::None);
616        self.floats = interval.into();
617        self
618    }
619    pub fn refs(mut self, interval: impl Into<Interval>) -> Self {
620        assert!(self.refs == Interval::None);
621        self.refs = interval.into();
622        self
623    }
624    pub fn includes_scalars(mut self, includes_scalars: bool) -> Self {
625        self.includes_scalars = includes_scalars;
626        self
627    }
628    pub fn simd_lanes(mut self, interval: impl Into<Interval>) -> Self {
629        assert!(self.simd_lanes == Interval::None);
630        self.simd_lanes = interval.into();
631        self
632    }
633    pub fn dynamic_simd_lanes(mut self, interval: impl Into<Interval>) -> Self {
634        assert!(self.dynamic_simd_lanes == Interval::None);
635        self.dynamic_simd_lanes = interval.into();
636        self
637    }
638
639    pub fn build(self) -> TypeSet {
640        let min_lanes = if self.includes_scalars { 1 } else { 2 };
641
642        TypeSet::new(
643            range_to_set(self.simd_lanes.to_range(min_lanes..MAX_LANES, Some(1))),
644            range_to_set(self.dynamic_simd_lanes.to_range(2..MAX_LANES, None)),
645            range_to_set(self.ints.to_range(8..MAX_BITS, None)),
646            range_to_set(self.floats.to_range(16..MAX_FLOAT_BITS, None)),
647            range_to_set(self.refs.to_range(32..64, None)),
648        )
649    }
650}
651
652#[derive(PartialEq)]
653pub(crate) enum Interval {
654    None,
655    All,
656    Range(Range),
657}
658
659impl Interval {
660    fn to_range(&self, full_range: Range, default: Option<RangeBound>) -> Option<Range> {
661        match self {
662            Interval::None => default.map(|default_val| default_val..default_val),
663
664            Interval::All => Some(full_range),
665
666            Interval::Range(range) => {
667                let (low, high) = (range.start, range.end);
668                assert!(low.is_power_of_two());
669                assert!(high.is_power_of_two());
670                assert!(low <= high);
671                assert!(low >= full_range.start);
672                assert!(high <= full_range.end);
673                Some(low..high)
674            }
675        }
676    }
677}
678
679impl From<Range> for Interval {
680    fn from(range: Range) -> Self {
681        Interval::Range(range)
682    }
683}
684
685/// Generates a set with all the powers of two included in the range.
686fn range_to_set(range: Option<Range>) -> NumSet {
687    let mut set = NumSet::new();
688
689    let (low, high) = match range {
690        Some(range) => (range.start, range.end),
691        None => return set,
692    };
693
694    assert!(low.is_power_of_two());
695    assert!(high.is_power_of_two());
696    assert!(low <= high);
697
698    for i in low.trailing_zeros()..=high.trailing_zeros() {
699        assert!(1 << i <= RangeBound::max_value());
700        set.insert(1 << i);
701    }
702    set
703}
704
705#[test]
706fn test_typevar_builder() {
707    let type_set = TypeSetBuilder::new().ints(Interval::All).build();
708    assert_eq!(type_set.lanes, num_set![1]);
709    assert!(type_set.floats.is_empty());
710    assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]);
711
712    let type_set = TypeSetBuilder::new().floats(Interval::All).build();
713    assert_eq!(type_set.lanes, num_set![1]);
714    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
715    assert!(type_set.ints.is_empty());
716
717    let type_set = TypeSetBuilder::new()
718        .floats(Interval::All)
719        .simd_lanes(Interval::All)
720        .includes_scalars(false)
721        .build();
722    assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]);
723    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
724    assert!(type_set.ints.is_empty());
725
726    let type_set = TypeSetBuilder::new()
727        .floats(Interval::All)
728        .simd_lanes(Interval::All)
729        .includes_scalars(true)
730        .build();
731    assert_eq!(type_set.lanes, num_set![1, 2, 4, 8, 16, 32, 64, 128, 256]);
732    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
733    assert!(type_set.ints.is_empty());
734
735    let type_set = TypeSetBuilder::new()
736        .floats(Interval::All)
737        .simd_lanes(Interval::All)
738        .includes_scalars(false)
739        .build();
740    assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]);
741    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
742    assert!(type_set.dynamic_lanes.is_empty());
743    assert!(type_set.ints.is_empty());
744
745    let type_set = TypeSetBuilder::new()
746        .ints(Interval::All)
747        .floats(Interval::All)
748        .dynamic_simd_lanes(Interval::All)
749        .includes_scalars(false)
750        .build();
751    assert_eq!(
752        type_set.dynamic_lanes,
753        num_set![2, 4, 8, 16, 32, 64, 128, 256]
754    );
755    assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]);
756    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
757    assert_eq!(type_set.lanes, num_set![1]);
758
759    let type_set = TypeSetBuilder::new()
760        .floats(Interval::All)
761        .dynamic_simd_lanes(Interval::All)
762        .includes_scalars(false)
763        .build();
764    assert_eq!(
765        type_set.dynamic_lanes,
766        num_set![2, 4, 8, 16, 32, 64, 128, 256]
767    );
768    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
769    assert_eq!(type_set.lanes, num_set![1]);
770    assert!(type_set.ints.is_empty());
771
772    let type_set = TypeSetBuilder::new().ints(16..64).build();
773    assert_eq!(type_set.lanes, num_set![1]);
774    assert_eq!(type_set.ints, num_set![16, 32, 64]);
775    assert!(type_set.floats.is_empty());
776}
777
778#[test]
779fn test_dynamic_to_vector() {
780    // We don't generate single lane dynamic types, so the maximum number of
781    // lanes we support is 128, as MAX_BITS is 256.
782    assert_eq!(
783        TypeSetBuilder::new()
784            .dynamic_simd_lanes(Interval::All)
785            .ints(Interval::All)
786            .build()
787            .dynamic_to_vector(),
788        TypeSetBuilder::new()
789            .simd_lanes(2..128)
790            .ints(Interval::All)
791            .build()
792    );
793    assert_eq!(
794        TypeSetBuilder::new()
795            .dynamic_simd_lanes(Interval::All)
796            .floats(Interval::All)
797            .build()
798            .dynamic_to_vector(),
799        TypeSetBuilder::new()
800            .simd_lanes(2..128)
801            .floats(Interval::All)
802            .build()
803    );
804}
805
806#[test]
807#[should_panic]
808fn test_typevar_builder_too_high_bound_panic() {
809    TypeSetBuilder::new().ints(16..2 * MAX_BITS).build();
810}
811
812#[test]
813#[should_panic]
814fn test_typevar_builder_inverted_bounds_panic() {
815    TypeSetBuilder::new().ints(32..16).build();
816}
817
818#[test]
819fn test_as_truthy() {
820    let a = TypeSetBuilder::new()
821        .simd_lanes(2..8)
822        .ints(8..8)
823        .floats(32..32)
824        .build();
825    assert_eq!(
826        a.lane_of(),
827        TypeSetBuilder::new().ints(8..8).floats(32..32).build()
828    );
829
830    let mut a_as_truthy = TypeSetBuilder::new().simd_lanes(2..8).build();
831    a_as_truthy.ints = num_set![8, 32];
832    assert_eq!(a.as_truthy(), a_as_truthy);
833
834    let a = TypeSetBuilder::new().ints(8..32).floats(32..64).build();
835    let a_as_truthy = TypeSetBuilder::new().ints(8..8).build();
836    assert_eq!(a.as_truthy(), a_as_truthy);
837}
838
839#[test]
840fn test_forward_images() {
841    let empty_set = TypeSetBuilder::new().build();
842
843    // Half vector.
844    assert_eq!(
845        TypeSetBuilder::new()
846            .simd_lanes(1..32)
847            .build()
848            .half_vector(),
849        TypeSetBuilder::new().simd_lanes(1..16).build()
850    );
851
852    // Double vector.
853    assert_eq!(
854        TypeSetBuilder::new()
855            .simd_lanes(1..32)
856            .build()
857            .double_vector(),
858        TypeSetBuilder::new().simd_lanes(2..64).build()
859    );
860    assert_eq!(
861        TypeSetBuilder::new()
862            .simd_lanes(128..256)
863            .build()
864            .double_vector(),
865        TypeSetBuilder::new().simd_lanes(256..256).build()
866    );
867
868    // Half width.
869    assert_eq!(
870        TypeSetBuilder::new().ints(8..32).build().half_width(),
871        TypeSetBuilder::new().ints(8..16).build()
872    );
873    assert_eq!(
874        TypeSetBuilder::new().floats(16..16).build().half_width(),
875        empty_set
876    );
877    assert_eq!(
878        TypeSetBuilder::new().floats(32..128).build().half_width(),
879        TypeSetBuilder::new().floats(16..64).build()
880    );
881
882    // Double width.
883    assert_eq!(
884        TypeSetBuilder::new().ints(8..32).build().double_width(),
885        TypeSetBuilder::new().ints(16..64).build()
886    );
887    assert_eq!(
888        TypeSetBuilder::new().ints(32..64).build().double_width(),
889        TypeSetBuilder::new().ints(64..128).build()
890    );
891    assert_eq!(
892        TypeSetBuilder::new().floats(32..32).build().double_width(),
893        TypeSetBuilder::new().floats(64..64).build()
894    );
895    assert_eq!(
896        TypeSetBuilder::new().floats(16..64).build().double_width(),
897        TypeSetBuilder::new().floats(32..128).build()
898    );
899}
900
901#[test]
902#[should_panic]
903fn test_typeset_singleton_panic_nonsingleton_types() {
904    TypeSetBuilder::new()
905        .ints(8..8)
906        .floats(32..32)
907        .build()
908        .get_singleton();
909}
910
911#[test]
912#[should_panic]
913fn test_typeset_singleton_panic_nonsingleton_lanes() {
914    TypeSetBuilder::new()
915        .simd_lanes(1..2)
916        .floats(32..32)
917        .build()
918        .get_singleton();
919}
920
921#[test]
922fn test_typeset_singleton() {
923    use crate::shared::types as shared_types;
924    assert_eq!(
925        TypeSetBuilder::new().ints(16..16).build().get_singleton(),
926        ValueType::Lane(shared_types::Int::I16.into())
927    );
928    assert_eq!(
929        TypeSetBuilder::new().floats(64..64).build().get_singleton(),
930        ValueType::Lane(shared_types::Float::F64.into())
931    );
932    assert_eq!(
933        TypeSetBuilder::new()
934            .simd_lanes(4..4)
935            .ints(32..32)
936            .build()
937            .get_singleton(),
938        LaneType::from(shared_types::Int::I32).by(4)
939    );
940}
941
942#[test]
943fn test_typevar_functions() {
944    let x = TypeVar::new(
945        "x",
946        "i16 and up",
947        TypeSetBuilder::new().ints(16..64).build(),
948    );
949    assert_eq!(x.half_width().name, "half_width(x)");
950    assert_eq!(
951        x.half_width().double_width().name,
952        "double_width(half_width(x))"
953    );
954
955    let x = TypeVar::new("x", "up to i32", TypeSetBuilder::new().ints(8..32).build());
956    assert_eq!(x.double_width().name, "double_width(x)");
957}
958
959#[test]
960fn test_typevar_singleton() {
961    use crate::cdsl::types::VectorType;
962    use crate::shared::types as shared_types;
963
964    // Test i32.
965    let typevar = TypeVar::new_singleton(ValueType::Lane(LaneType::Int(shared_types::Int::I32)));
966    assert_eq!(typevar.name, "i32");
967    assert_eq!(typevar.type_set.ints, num_set![32]);
968    assert!(typevar.type_set.floats.is_empty());
969    assert_eq!(typevar.type_set.lanes, num_set![1]);
970
971    // Test f32x4.
972    let typevar = TypeVar::new_singleton(ValueType::Vector(VectorType::new(
973        LaneType::Float(shared_types::Float::F32),
974        4,
975    )));
976    assert_eq!(typevar.name, "f32x4");
977    assert!(typevar.type_set.ints.is_empty());
978    assert_eq!(typevar.type_set.floats, num_set![32]);
979    assert_eq!(typevar.type_set.lanes, num_set![4]);
980}