Skip to main content

ostd_pod/
array_factory.rs

1// SPDX-License-Identifier: MPL-2.0
2
3use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
4
5/// A transparent wrapper around `[u8; N]` with guaranteed 1-byte alignment.
6///
7/// This type implements the zerocopy traits (`FromBytes`, `IntoBytes`, `Immutable`, `KnownLayout`)
8/// making it safe to transmute to/from byte arrays. It is primarily used internally by the
9/// `ArrayFactory` type system to provide aligned arrays for POD unions.
10#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone, Copy)]
11#[repr(transparent)]
12pub struct U8Array<const N: usize>([u8; N]);
13
14const _: () = assert!(align_of::<U8Array<0>>() == 1);
15
16/// A transparent wrapper around `[u16; N]` with guaranteed 2-byte alignment.
17#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone, Copy)]
18#[repr(transparent)]
19pub struct U16Array<const N: usize>([u16; N]);
20
21const _: () = assert!(align_of::<U16Array<0>>() == 2);
22
23/// A transparent wrapper around `[u32; N]` with guaranteed 4-byte alignment.
24#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone, Copy)]
25#[repr(transparent)]
26pub struct U32Array<const N: usize>([u32; N]);
27
28const _: () = assert!(align_of::<U32Array<0>>() == 4);
29
30/// A transparent wrapper around `[u64; N]` with guaranteed 8-byte alignment.
31#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone, Copy)]
32#[repr(transparent)]
33pub struct U64Array<const N: usize>([u64; N]);
34
35const _: () = assert!(align_of::<U64Array<0>>() == 8);
36
37/// A type-level factory for creating aligned arrays based on alignment requirements.
38///
39/// This zero-sized type uses const generics to select the appropriate underlying array type
40/// (`U8Array`, `U16Array`, `U32Array`, or `U64Array`) based on the alignment requirement `A` and
41/// the number of elements `N`.
42///
43/// # Type Parameters
44///
45/// * `A` - The required alignment in bytes (1, 2, 4, or 8).
46/// * `N` - The number of elements in the array.
47///
48/// # Examples
49///
50/// ```rust
51/// use ostd_pod::{ArrayFactory, ArrayManufacture};
52///
53/// // Creates a `U32Array<8>` (8 `u32` elements with 4-byte alignment)
54/// type MyArray = <ArrayFactory<4, 8> as ArrayManufacture>::Array;
55/// ```
56pub enum ArrayFactory<const A: usize, const N: usize> {}
57
58/// Trait that associates an `ArrayFactory` with its corresponding aligned array type.
59///
60/// This trait is implemented for `ArrayFactory<A, N>` where `A` is 1, 2, 4, or 8,
61/// mapping to `U8Array`, `U16Array`, `U32Array`, and `U64Array` respectively.
62pub trait ArrayManufacture {
63    /// The aligned array type produced by this factory.
64    type Array: FromBytes + IntoBytes + Immutable;
65}
66
67impl<const N: usize> ArrayManufacture for ArrayFactory<1, N> {
68    type Array = U8Array<N>;
69}
70
71impl<const N: usize> ArrayManufacture for ArrayFactory<2, N> {
72    type Array = U16Array<N>;
73}
74
75impl<const N: usize> ArrayManufacture for ArrayFactory<4, N> {
76    type Array = U32Array<N>;
77}
78
79impl<const N: usize> ArrayManufacture for ArrayFactory<8, N> {
80    type Array = U64Array<N>;
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn u8array_alignment() {
89        assert_eq!(align_of::<U8Array<0>>(), 1);
90        assert_eq!(align_of::<U8Array<1>>(), 1);
91        assert_eq!(align_of::<U8Array<10>>(), 1);
92    }
93
94    #[test]
95    fn u8array_size() {
96        assert_eq!(size_of::<U8Array<0>>(), 0);
97        assert_eq!(size_of::<U8Array<1>>(), 1);
98        assert_eq!(size_of::<U8Array<4>>(), 4);
99        assert_eq!(size_of::<U8Array<10>>(), 10);
100    }
101
102    #[test]
103    fn u16array_alignment() {
104        assert_eq!(align_of::<U16Array<0>>(), 2);
105        assert_eq!(align_of::<U16Array<1>>(), 2);
106        assert_eq!(align_of::<U16Array<10>>(), 2);
107    }
108
109    #[test]
110    fn u16array_size() {
111        assert_eq!(size_of::<U16Array<0>>(), 0);
112        assert_eq!(size_of::<U16Array<1>>(), 2);
113        assert_eq!(size_of::<U16Array<4>>(), 8);
114        assert_eq!(size_of::<U16Array<10>>(), 20);
115    }
116
117    #[test]
118    fn u32array_alignment() {
119        assert_eq!(align_of::<U32Array<0>>(), 4);
120        assert_eq!(align_of::<U32Array<1>>(), 4);
121        assert_eq!(align_of::<U32Array<10>>(), 4);
122    }
123
124    #[test]
125    fn u32array_size() {
126        assert_eq!(size_of::<U32Array<0>>(), 0);
127        assert_eq!(size_of::<U32Array<1>>(), 4);
128        assert_eq!(size_of::<U32Array<4>>(), 16);
129        assert_eq!(size_of::<U32Array<10>>(), 40);
130    }
131
132    #[test]
133    fn u64array_alignment() {
134        assert_eq!(align_of::<U64Array<0>>(), 8);
135        assert_eq!(align_of::<U64Array<1>>(), 8);
136        assert_eq!(align_of::<U64Array<10>>(), 8);
137    }
138
139    #[test]
140    fn u64array_size() {
141        assert_eq!(size_of::<U64Array<0>>(), 0);
142        assert_eq!(size_of::<U64Array<1>>(), 8);
143        assert_eq!(size_of::<U64Array<4>>(), 32);
144        assert_eq!(size_of::<U64Array<10>>(), 80);
145    }
146
147    #[test]
148    fn array_factory_1byte_alignment() {
149        type Array = <ArrayFactory<1, 5> as ArrayManufacture>::Array;
150        assert_eq!(align_of::<Array>(), 1);
151        assert_eq!(size_of::<Array>(), 5);
152    }
153
154    #[test]
155    fn array_factory_2byte_alignment() {
156        type Array = <ArrayFactory<2, 5> as ArrayManufacture>::Array;
157        assert_eq!(align_of::<Array>(), 2);
158        assert_eq!(size_of::<Array>(), 10);
159    }
160
161    #[test]
162    fn array_factory_4byte_alignment() {
163        type Array = <ArrayFactory<4, 5> as ArrayManufacture>::Array;
164        assert_eq!(align_of::<Array>(), 4);
165        assert_eq!(size_of::<Array>(), 20);
166    }
167
168    #[test]
169    fn array_factory_8byte_alignment() {
170        type Array = <ArrayFactory<8, 5> as ArrayManufacture>::Array;
171        assert_eq!(align_of::<Array>(), 8);
172        assert_eq!(size_of::<Array>(), 40);
173    }
174
175    #[test]
176    fn zerocopy_traits() {
177        // Test that the types implement the required zerocopy traits
178        fn assert_traits<T: FromBytes + IntoBytes + Immutable + KnownLayout>() {}
179
180        assert_traits::<U16Array<4>>();
181        assert_traits::<U32Array<4>>();
182        assert_traits::<U64Array<4>>();
183    }
184}