midenc_hir_type/
layout.rs

1use alloc::{alloc::Layout, collections::VecDeque};
2use core::cmp::{self, Ordering};
3
4use smallvec::{smallvec, SmallVec};
5
6use super::*;
7
8const FELT_SIZE: usize = core::mem::size_of::<u32>();
9const WORD_SIZE: usize = core::mem::size_of::<[u32; 4]>();
10
11impl Type {
12    /// Convert this type into a vector of types corresponding to how this type
13    /// will be represented in memory.
14    ///
15    /// The largest "part" size is 32 bits, so types that fit in 32 bits remain
16    /// unchanged. For types larger than 32 bits, they will be broken up into parts
17    /// that do fit in 32 bits, preserving accurate types to the extent possible.
18    /// For types smaller than 32 bits, they will be merged into packed structs no
19    /// larger than 32 bits, to preserve the type information, and make it possible
20    /// to reason about how to extract parts of the original type.
21    ///
22    /// For an example, a struct of type `{ *ptr, u8, u8 }` will be encoded on the
23    /// operand stack as `[*ptr, {u8, u8}]`, where the first value is the 32-bit pointer
24    /// field, and the remaining fields are encoded as a 16-bit struct in the second value.
25    pub fn to_raw_parts(self) -> Option<SmallVec<[Type; 4]>> {
26        match self {
27            Type::Unknown => None,
28            ty => {
29                let mut parts = SmallVec::<[Type; 4]>::default();
30                let (part, mut rest) = ty.split(4);
31                parts.push(part);
32                while let Some(ty) = rest.take() {
33                    let (part, remaining) = ty.split(4);
34                    parts.push(part);
35                    rest = remaining;
36                }
37                Some(parts)
38            }
39        }
40    }
41
42    /// Split this type into two parts:
43    ///
44    /// * The first part is no more than `n` bytes in size, and may contain the type itself if it
45    ///   fits
46    /// * The second part is None if the first part is smaller than or equal in size to the
47    ///   requested split size
48    /// * The second part is Some if there is data left in the original type after the split. This
49    ///   part will be a type that attempts to preserve, to the extent possible, the original type
50    ///   structure, but will fall back to an array of bytes if a larger type must be split down
51    ///   the middle somewhere.
52    pub fn split(self, n: usize) -> (Type, Option<Type>) {
53        if n == 0 {
54            return (self, None);
55        }
56
57        let size_in_bytes = self.size_in_bytes();
58        if n >= size_in_bytes {
59            return (self, None);
60        }
61
62        // The type is larger than the split size
63        match self {
64            ty @ (Self::U256
65            | Self::I128
66            | Self::U128
67            | Self::I64
68            | Self::U64
69            | Self::F64
70            | Self::Felt
71            | Self::I32
72            | Self::U32
73            | Self::Ptr(_)
74            | Self::I16
75            | Self::U16
76            | Self::Function(_)) => {
77                let len = ty.size_in_bytes();
78                let remaining = len - n;
79                match (n, remaining) {
80                    (0, _) | (_, 0) => unreachable!(),
81                    (1, 1) => (Type::U8, Some(Type::U8)),
82                    (1, remaining) => {
83                        (Type::U8, Some(Type::from(ArrayType::new(Type::U8, remaining))))
84                    }
85                    (taken, 1) => (Type::from(ArrayType::new(Type::U8, taken)), Some(Type::U8)),
86                    (taken, remaining) => (
87                        Type::from(ArrayType::new(Type::U8, taken)),
88                        Some(Type::from(ArrayType::new(Type::U8, remaining))),
89                    ),
90                }
91            }
92            Self::Array(array_type) => match &*array_type {
93                ArrayType {
94                    ty: elem_ty,
95                    len: 1,
96                } => elem_ty.clone().split(n),
97                ArrayType {
98                    ty: elem_ty,
99                    len: array_len,
100                } => {
101                    let elem_size = elem_ty.size_in_bytes();
102                    if n >= elem_size {
103                        // The requested split consumes 1 or more elements..
104                        let take = n / elem_size;
105                        let extra = n % elem_size;
106                        if extra == 0 {
107                            // The split is on an element boundary
108                            let split = match take {
109                                1 => (*elem_ty).clone(),
110                                _ => Self::from(ArrayType::new(elem_ty.clone(), take)),
111                            };
112                            let rest = match array_len - take {
113                                0 => unreachable!(),
114                                1 => elem_ty.clone(),
115                                len => Self::from(ArrayType::new(elem_ty.clone(), len)),
116                            };
117                            (split, Some(rest))
118                        } else {
119                            // The element type must be split somewhere in order to get the input type
120                            // down to the requested size
121                            let (partial1, partial2) = (*elem_ty).clone().split(elem_size - extra);
122                            match array_len - take {
123                                0 => unreachable!(),
124                                1 => {
125                                    let taken = Self::from(ArrayType::new(elem_ty.clone(), take));
126                                    let split = Self::from(StructType::new_with_repr(
127                                        TypeRepr::packed(1),
128                                        [taken, partial1],
129                                    ));
130                                    (split, partial2)
131                                }
132                                remaining => {
133                                    let remaining_input =
134                                        Self::from(ArrayType::new(elem_ty.clone(), remaining));
135                                    let taken = Self::from(ArrayType::new(elem_ty.clone(), take));
136                                    let split = Self::from(StructType::new_with_repr(
137                                        TypeRepr::packed(1),
138                                        [taken, partial1],
139                                    ));
140                                    let rest = Self::from(StructType::new_with_repr(
141                                        TypeRepr::packed(1),
142                                        [partial2.unwrap(), remaining_input],
143                                    ));
144                                    (split, Some(rest))
145                                }
146                            }
147                        }
148                    } else {
149                        // The requested split consumes less than one element
150                        let (partial1, partial2) = (*elem_ty).clone().split(n);
151                        let remaining_input = match array_len - 1 {
152                            0 => unreachable!(),
153                            1 => (*elem_ty).clone(),
154                            len => Self::from(ArrayType::new(elem_ty.clone(), len - 1)),
155                        };
156                        let rest = Self::from(StructType::new_with_repr(
157                            TypeRepr::packed(1),
158                            [partial2.unwrap(), remaining_input],
159                        ));
160                        (partial1, Some(rest))
161                    }
162                }
163            },
164            Self::Struct(struct_ty) => match &*struct_ty {
165                StructType {
166                    repr: TypeRepr::Transparent,
167                    fields,
168                    ..
169                } => {
170                    let underlying = fields
171                        .into_iter()
172                        .find(|f| !f.ty.is_zst())
173                        .expect("invalid type: expected non-zero sized field");
174                    underlying.ty.clone().split(n)
175                }
176                struct_ty => {
177                    let original_repr = struct_ty.repr;
178                    let original_size = struct_ty.size;
179                    let mut fields = VecDeque::from_iter(struct_ty.fields.iter().cloned());
180                    let mut split = StructType {
181                        repr: original_repr,
182                        size: 0,
183                        fields: smallvec![],
184                    };
185                    let mut remaining = StructType {
186                        repr: TypeRepr::packed(1),
187                        size: 0,
188                        fields: smallvec![],
189                    };
190                    let mut needed: u32 = n.try_into().expect(
191                        "invalid type split: number of bytes is larger than what is representable \
192                         in memory",
193                    );
194                    let mut current_offset = 0u32;
195                    while let Some(mut field) = fields.pop_front() {
196                        let padding = field.offset - current_offset;
197                        // If the padding was exactly what was needed, add it to the `split`
198                        // struct, and then place the remaining fields in a new struct
199                        let original_offset = field.offset;
200                        if padding == needed {
201                            split.size += needed;
202                            // Handle the edge case where padding is at the front of the struct
203                            if split.fields.is_empty() {
204                                split.fields.push(StructField {
205                                    index: 0,
206                                    align: 1,
207                                    offset: 0,
208                                    ty: Type::from(ArrayType::new(Type::U8, needed as usize)),
209                                });
210                            }
211                            let mut prev_offset = original_offset;
212                            let mut field_offset = 0;
213                            field.index = 0;
214                            field.offset = field_offset;
215                            remaining.repr = TypeRepr::Default;
216                            remaining.size = original_size - split.size;
217                            remaining.fields.reserve(1 + fields.len());
218                            field_offset += field.ty.size_in_bytes() as u32;
219                            remaining.fields.push(field);
220                            for (index, mut field) in fields.into_iter().enumerate() {
221                                field.index = (index + 1) as u8;
222                                let align_offset = field.offset - prev_offset;
223                                let field_size = field.ty.size_in_bytes() as u32;
224                                prev_offset = field.offset + field_size;
225                                field.offset = field_offset + align_offset;
226                                field_offset += align_offset;
227                                field_offset += field_size;
228                                remaining.fields.push(field);
229                            }
230                            break;
231                        }
232
233                        // If the padding is more than was needed, we fill out the rest of the
234                        // request by padding the size of the `split` struct, and then adjust
235                        // the remaining struct to account for the leftover padding.
236                        if padding > needed {
237                            // The struct size must match the requested split size
238                            split.size += needed;
239                            // Handle the edge case where padding is at the front of the struct
240                            if split.fields.is_empty() {
241                                split.fields.push(StructField {
242                                    index: 0,
243                                    align: 1,
244                                    offset: 0,
245                                    ty: Type::from(ArrayType::new(Type::U8, needed as usize)),
246                                });
247                            }
248                            // What's left must account for what has been split off
249                            let leftover_padding = u16::try_from(padding - needed).expect(
250                                "invalid type: padding is larger than maximum allowed alignment",
251                            );
252                            let effective_alignment = leftover_padding.prev_power_of_two();
253                            let align_offset = leftover_padding % effective_alignment;
254                            let default_alignment = cmp::max(
255                                fields.iter().map(|f| f.align).max().unwrap_or(1),
256                                field.align,
257                            );
258                            let repr = match default_alignment.cmp(&effective_alignment) {
259                                Ordering::Equal => TypeRepr::Default,
260                                Ordering::Greater => TypeRepr::packed(effective_alignment),
261                                Ordering::Less => TypeRepr::align(effective_alignment),
262                            };
263                            let mut prev_offset = original_offset;
264                            let mut field_offset = align_offset as u32;
265                            field.index = 0;
266                            field.offset = field_offset;
267                            remaining.repr = repr;
268                            remaining.size = original_size - split.size;
269                            remaining.fields.reserve(1 + fields.len());
270                            field_offset += field.ty.size_in_bytes() as u32;
271                            remaining.fields.push(field);
272                            for (index, mut field) in fields.into_iter().enumerate() {
273                                field.index = (index + 1) as u8;
274                                let align_offset = field.offset - prev_offset;
275                                let field_size = field.ty.size_in_bytes() as u32;
276                                prev_offset = field.offset + field_size;
277                                field.offset = field_offset + align_offset;
278                                field_offset += align_offset;
279                                field_offset += field_size;
280                                remaining.fields.push(field);
281                            }
282                            break;
283                        }
284
285                        // The padding must be less than what was needed, so consume it, and
286                        // then process the current field for the rest of the request
287                        split.size += padding;
288                        needed -= padding;
289                        current_offset += padding;
290                        let field_size = field.ty.size_in_bytes() as u32;
291                        // If the field fully satisifies the remainder of the request, then
292                        // finalize the `split` struct, and place remaining fields in a trailing
293                        // struct with an appropriate repr
294                        if field_size == needed {
295                            split.size += field_size;
296                            field.offset = current_offset;
297                            split.fields.push(field);
298
299                            debug_assert!(
300                                !fields.is_empty(),
301                                "expected struct that is the exact size of the split request to \
302                                 have been handled elsewhere"
303                            );
304
305                            remaining.repr = original_repr;
306                            remaining.size = original_size - split.size;
307                            remaining.fields.reserve(fields.len());
308                            let mut prev_offset = current_offset + field_size;
309                            let mut field_offset = 0;
310                            for (index, mut field) in fields.into_iter().enumerate() {
311                                field.index = index as u8;
312                                let align_offset = field.offset - prev_offset;
313                                let field_size = field.ty.size_in_bytes() as u32;
314                                prev_offset = field.offset + field_size;
315                                field.offset = field_offset + align_offset;
316                                field_offset += align_offset;
317                                field_offset += field_size;
318                                remaining.fields.push(field);
319                            }
320                            break;
321                        }
322
323                        // If the field is larger than what is needed, we have to split it
324                        if field_size > needed {
325                            split.size += needed;
326
327                            // Add the portion needed to `split`
328                            let index = field.index;
329                            let offset = current_offset;
330                            let align = field.align;
331                            let (partial1, partial2) = field.ty.split(needed as usize);
332                            // The second half of the split will always be a type
333                            let partial2 = partial2.unwrap();
334                            split.fields.push(StructField {
335                                index,
336                                offset,
337                                align,
338                                ty: partial1,
339                            });
340
341                            // Build a struct with the remaining fields and trailing partial field
342                            let mut prev_offset = current_offset + needed;
343                            let mut field_offset = needed + partial2.size_in_bytes() as u32;
344                            remaining.size = original_size - split.size;
345                            remaining.fields.reserve(1 + fields.len());
346                            remaining.fields.push(StructField {
347                                index: 0,
348                                offset: 1,
349                                align: 1,
350                                ty: partial2,
351                            });
352                            for (index, mut field) in fields.into_iter().enumerate() {
353                                field.index = (index + 1) as u8;
354                                let align_offset = field.offset - prev_offset;
355                                let field_size = field.ty.size_in_bytes() as u32;
356                                prev_offset = field.offset + needed + field_size;
357                                field.offset = field_offset + align_offset;
358                                field_offset += align_offset;
359                                field_offset += field_size;
360                                remaining.fields.push(field);
361                            }
362                            break;
363                        }
364
365                        // We need to process more fields for this request (i.e. field_size < needed)
366                        needed -= field_size;
367                        split.size += field_size;
368                        field.offset = current_offset;
369                        current_offset += field_size;
370                        split.fields.push(field);
371                    }
372
373                    let split = if split.fields.len() > 1 {
374                        Type::from(split)
375                    } else {
376                        split.fields.pop().map(|f| f.ty).unwrap()
377                    };
378                    match remaining.fields.len() {
379                        0 => (split, None),
380                        1 => (split, remaining.fields.pop().map(|f| f.ty)),
381                        _ => (split, Some(remaining.into())),
382                    }
383                }
384            },
385            Type::List(_) => {
386                todo!("invalid type: list has no defined representation yet, so cannot be split")
387            }
388            // These types either have no size, or are 1 byte in size, so must have
389            // been handled above when checking if the size of the type is <= the
390            // requested split size
391            Self::Unknown | Self::Never | Self::I1 | Self::U8 | Self::I8 => {
392                unreachable!()
393            }
394        }
395    }
396
397    /// Returns the minimum alignment, in bytes, of this type
398    pub fn min_alignment(&self) -> usize {
399        match self {
400            // These types don't have a meaningful alignment, so choose byte-aligned
401            Self::Unknown | Self::Never => 1,
402            // Felts must be naturally aligned to a 32-bit boundary (4 bytes)
403            Self::Felt => 4,
404            // 256-bit and 128-bit integers must be word-aligned
405            Self::U256 | Self::I128 | Self::U128 => 16,
406            // 64-bit integers and floats must be element-aligned
407            Self::I64 | Self::U64 | Self::F64 => 4,
408            // 32-bit integers and pointers must be element-aligned
409            Self::I32 | Self::U32 | Self::Ptr(_) | Self::Function(..) => 4,
410            // 16-bit integers can be naturally aligned
411            Self::I16 | Self::U16 => 2,
412            // 8-bit integers and booleans can be naturally aligned
413            Self::I8 | Self::U8 | Self::I1 => 1,
414            // Structs use the minimum alignment of their first field, or 1 if a zero-sized type
415            Self::Struct(ref struct_ty) => struct_ty.min_alignment(),
416            // Arrays use the minimum alignment of their element type
417            Self::Array(ref array_ty) => array_ty.min_alignment(),
418            // Lists use the minimum alignment of their element type
419            Self::List(ref element_ty) => element_ty.min_alignment(),
420        }
421    }
422
423    /// Returns the size in bits of this type, without alignment padding.
424    pub fn size_in_bits(&self) -> usize {
425        match self {
426            // These types have no representation in memory
427            Self::Unknown | Self::Never => 0,
428            // Booleans are represented as i1
429            Self::I1 => 1,
430            // Integers are naturally sized
431            Self::I8 | Self::U8 => 8,
432            Self::I16 | Self::U16 => 16,
433            // Field elements have a range that is almost 64 bits, but because
434            // our byte-addressable memory model only sees each element as a 32-bit
435            // chunk, we treat field elements in this model as 32-bit values. This
436            // has no effect on their available range, just how much memory they are
437            // assumed to require for storage.
438            Self::I32 | Self::U32 | Self::Felt => 32,
439            Self::I64 | Self::U64 | Self::F64 => 64,
440            Self::I128 | Self::U128 => 128,
441            Self::U256 => 256,
442            // Raw pointers  are 32-bits, the same size as the native integer width, u32
443            Self::Ptr(_) | Self::Function(_) => 32,
444            // Packed structs have no alignment padding between fields
445            Self::Struct(ref struct_ty) => struct_ty.size as usize * 8,
446            Self::Array(ref array_ty) => array_ty.size_in_bits(),
447            Type::List(_) => todo!(
448                "invalid type: list has no defined representation yet, so its size cannot be \
449                 determined"
450            ),
451        }
452    }
453
454    /// Returns the minimum number of bytes required to store a value of this type
455    pub fn size_in_bytes(&self) -> usize {
456        let bits = self.size_in_bits();
457        (bits / 8) + (!bits.is_multiple_of(8)) as usize
458    }
459
460    /// Same as `size_in_bytes`, but with sufficient padding to guarantee alignment of the value.
461    pub fn aligned_size_in_bytes(&self) -> usize {
462        let align = self.min_alignment();
463        let size = self.size_in_bytes();
464        // Zero-sized types have no alignment
465        if size == 0 {
466            return 0;
467        }
468
469        // Assuming that a pointer is allocated with the worst possible alignment,
470        // i.e. it is not aligned on a power-of-two boundary, we can ensure that there
471        // is enough space to align the pointer to the required minimum alignment and
472        // still fit it in the allocated block of memory without overflowing its bounds,
473        // by adding `align` to size.
474        //
475        // We panic if padding the size overflows `usize`.
476        //
477        // So let's say we have a type with a min alignment of 16, and size of 24. If
478        // we add 16 to 24, we get 40. We then allocate a block of memory of 40 bytes,
479        // the pointer of which happens to be at address 0x01. If we align that pointer
480        // to 0x10 (the next closest aligned address within the block we allocated),
481        // that consumes 15 bytes of the 40 we have, leaving us with 25 bytes to hold
482        // our 24 byte value.
483        size.checked_add(align)
484            .expect("type cannot meet its minimum alignment requirement due to its size")
485    }
486
487    /// Returns the size in field elements of this type
488    pub fn size_in_felts(&self) -> usize {
489        let bytes = self.size_in_bytes();
490        let trailing = bytes % FELT_SIZE;
491        (bytes / FELT_SIZE) + ((trailing > 0) as usize)
492    }
493
494    /// Returns the size in words of this type
495    pub fn size_in_words(&self) -> usize {
496        let bytes = self.size_in_bytes();
497        let trailing = bytes % WORD_SIZE;
498        (bytes / WORD_SIZE) + ((trailing > 0) as usize)
499    }
500
501    /// Returns the layout of this type in memory
502    pub fn layout(&self) -> Layout {
503        Layout::from_size_align(self.size_in_bytes(), self.min_alignment())
504            .expect("invalid layout: the size, when padded for alignment, overflows isize")
505    }
506
507    /// Returns true if this type can be loaded on to the operand stack
508    ///
509    /// The rule for "loadability" is a bit arbitrary, but the purpose is to
510    /// force users of the IR to either pass large values by reference, or calculate
511    /// the addresses of the individual fields needed from a large structure or array,
512    /// and issue loads/stores against those instead.
513    ///
514    /// In effect, we reject loads of values that are larger than a single word, as that
515    /// is the largest value which can be worked with on the operand stack of the Miden VM.
516    pub fn is_loadable(&self) -> bool {
517        self.size_in_words() <= WORD_SIZE
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use smallvec::smallvec;
524
525    use crate::*;
526
527    #[test]
528    fn struct_type_test() {
529        let ptr_ty = Type::from(PointerType::new(Type::U32));
530        // A struct with default alignment and padding between fields
531        let struct_ty = StructType::new([ptr_ty.clone(), Type::U8, Type::I32]);
532        assert_eq!(struct_ty.min_alignment(), ptr_ty.min_alignment());
533        assert_eq!(struct_ty.size(), 12);
534        assert_eq!(
535            struct_ty.get(0),
536            &StructField {
537                index: 0,
538                align: 4,
539                offset: 0,
540                ty: ptr_ty.clone()
541            }
542        );
543        assert_eq!(
544            struct_ty.get(1),
545            &StructField {
546                index: 1,
547                align: 1,
548                offset: 4,
549                ty: Type::U8
550            }
551        );
552        assert_eq!(
553            struct_ty.get(2),
554            &StructField {
555                index: 2,
556                align: 4,
557                offset: 8,
558                ty: Type::I32
559            }
560        );
561
562        // A struct with no alignment requirement, and no alignment padding between fields
563        let struct_ty =
564            StructType::new_with_repr(TypeRepr::packed(1), [ptr_ty.clone(), Type::U8, Type::I32]);
565        assert_eq!(struct_ty.min_alignment(), 1);
566        assert_eq!(struct_ty.size(), 9);
567        assert_eq!(
568            struct_ty.get(0),
569            &StructField {
570                index: 0,
571                align: 1,
572                offset: 0,
573                ty: ptr_ty.clone()
574            }
575        );
576        assert_eq!(
577            struct_ty.get(1),
578            &StructField {
579                index: 1,
580                align: 1,
581                offset: 4,
582                ty: Type::U8
583            }
584        );
585        assert_eq!(
586            struct_ty.get(2),
587            &StructField {
588                index: 2,
589                align: 1,
590                offset: 5,
591                ty: Type::I32
592            }
593        );
594
595        // A struct with larger-than-default alignment, but default alignment for the fields
596        let struct_ty =
597            StructType::new_with_repr(TypeRepr::align(8), [ptr_ty.clone(), Type::U8, Type::I32]);
598        assert_eq!(struct_ty.min_alignment(), 8);
599        assert_eq!(struct_ty.size(), 16);
600        assert_eq!(
601            struct_ty.get(0),
602            &StructField {
603                index: 0,
604                align: 4,
605                offset: 0,
606                ty: ptr_ty.clone()
607            }
608        );
609        assert_eq!(
610            struct_ty.get(1),
611            &StructField {
612                index: 1,
613                align: 1,
614                offset: 4,
615                ty: Type::U8
616            }
617        );
618        assert_eq!(
619            struct_ty.get(2),
620            &StructField {
621                index: 2,
622                align: 4,
623                offset: 8,
624                ty: Type::I32
625            }
626        );
627    }
628
629    #[test]
630    fn type_to_raw_parts_test() {
631        let ty = Type::from(ArrayType::new(Type::U8, 5));
632        assert_eq!(
633            ty.to_raw_parts(),
634            Some(smallvec![Type::from(ArrayType::new(Type::U8, 4)), Type::U8,])
635        );
636
637        let ty = Type::from(ArrayType::new(Type::I16, 3));
638        assert_eq!(
639            ty.to_raw_parts(),
640            Some(smallvec![Type::from(ArrayType::new(Type::I16, 2)), Type::I16,])
641        );
642
643        let ptr_ty = Type::from(PointerType::new(Type::U32));
644
645        // Default struct
646        let ty = Type::from(StructType::new([ptr_ty.clone(), Type::U8, Type::I32]));
647        assert_eq!(ty.to_raw_parts(), Some(smallvec![ptr_ty.clone(), Type::U8, Type::I32,]));
648
649        // Packed struct
650        let ty = Type::from(StructType::new_with_repr(
651            TypeRepr::packed(1),
652            [ptr_ty.clone(), Type::U8, Type::I32],
653        ));
654        let partial_ty = Type::from(StructType::new_with_repr(
655            TypeRepr::packed(1),
656            [Type::U8, Type::from(ArrayType::new(Type::U8, 3))],
657        ));
658        assert_eq!(ty.to_raw_parts(), Some(smallvec![ptr_ty.clone(), partial_ty, Type::U8]));
659    }
660}