1use alloc::alloc::{alloc, alloc_zeroed, dealloc, handle_alloc_error, Layout};
13use alloc::boxed::Box;
14use core::fmt;
15use core::mem;
16use core::ops::{Deref, DerefMut};
17use core::ptr::{self, NonNull};
18use core::slice;
19
20fn layout(size: usize, align: usize) -> Layout {
21 match Layout::from_size_align(size, align) {
22 Ok(layout) => layout,
23 Err(_) => panic!("Invalid layout: size = {}, align = {}", size, align),
24 }
25}
26
27unsafe fn aligned_alloc(alloc: unsafe fn(Layout) -> *mut u8, layout: Layout) -> *mut u8 {
28 let ptr = alloc(layout);
29 if ptr.is_null() {
30 handle_alloc_error(layout);
31 }
32 debug_assert!(
33 (ptr as usize) % layout.align() == 0,
34 "pointer = {:p} is not a multiple of alignment = {}",
35 ptr,
36 layout.align()
37 );
38 ptr
39}
40
41pub struct AlignedBytes {
43 buf: NonNull<[u8]>,
44 align: usize,
45}
46
47unsafe impl Send for AlignedBytes {}
48unsafe impl Sync for AlignedBytes {}
49
50#[cfg(feature = "std")]
51mod std_impl {
52 use super::AlignedBytes;
53
54 use std::panic::{RefUnwindSafe, UnwindSafe};
55
56 impl RefUnwindSafe for AlignedBytes {}
57 impl UnwindSafe for AlignedBytes {}
58}
59
60impl Drop for AlignedBytes {
61 fn drop(&mut self) {
62 unsafe {
63 let size = self.buf.as_mut().len();
64 if size != 0 {
65 let ptr = self.buf.as_ptr() as *mut u8;
66 let layout = Layout::from_size_align_unchecked(size, self.align);
67 dealloc(ptr, layout);
68 }
69 }
70 }
71}
72
73impl AlignedBytes {
74 #[must_use]
76 pub fn new_zeroed(len: usize, align: usize) -> Self {
77 let layout = layout(len, align);
78 unsafe {
79 let ptr = if len == 0 {
80 align as *mut u8
81 } else {
82 aligned_alloc(alloc_zeroed, layout)
83 };
84 let buf = NonNull::from(slice::from_raw_parts_mut(ptr, len));
85 Self { buf, align }
86 }
87 }
88
89 #[must_use]
91 pub fn new_from_slice(bytes: &[u8], align: usize) -> Self {
92 let len = bytes.len();
93
94 let layout = layout(len, align);
95 unsafe {
96 let ptr = if len == 0 {
97 align as *mut u8
98 } else {
99 let dst = aligned_alloc(alloc, layout);
100 ptr::copy_nonoverlapping(bytes.as_ptr(), dst, len);
101 dst
102 };
103 let buf = NonNull::from(slice::from_raw_parts_mut(ptr, len));
104 Self { buf, align }
105 }
106 }
107
108 #[must_use]
110 pub const fn align(&self) -> usize {
111 self.align
112 }
113
114 #[must_use]
119 pub fn into_raw(this: Self) -> (NonNull<[u8]>, usize) {
120 let ret = (this.buf, this.align);
121 mem::forget(this);
122 ret
123 }
124
125 #[must_use]
132 pub const unsafe fn from_raw(buf: NonNull<[u8]>, align: usize) -> Self {
133 Self { buf, align }
134 }
135}
136
137impl Clone for AlignedBytes {
138 fn clone(&self) -> Self {
139 Self::new_from_slice(self, self.align)
140 }
141}
142
143impl Deref for AlignedBytes {
144 type Target = [u8];
145 fn deref(&self) -> &Self::Target {
146 unsafe { self.buf.as_ref() }
147 }
148}
149
150impl DerefMut for AlignedBytes {
151 fn deref_mut(&mut self) -> &mut Self::Target {
152 unsafe { self.buf.as_mut() }
153 }
154}
155
156impl AsRef<[u8]> for AlignedBytes {
157 fn as_ref(&self) -> &[u8] {
158 self
159 }
160}
161
162impl AsMut<[u8]> for AlignedBytes {
163 fn as_mut(&mut self) -> &mut [u8] {
164 self
165 }
166}
167
168impl fmt::Debug for AlignedBytes {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 <[u8] as fmt::Debug>::fmt(self, f)
171 }
172}
173
174impl From<Box<[u8]>> for AlignedBytes {
175 fn from(b: Box<[u8]>) -> Self {
176 unsafe {
177 let buf = NonNull::new_unchecked(Box::into_raw(b));
178 Self { buf, align: 1 }
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::AlignedBytes;
186 use alloc::boxed::Box;
187
188 #[test]
189 fn check_content() {
190 {
191 let bytes = AlignedBytes::new_zeroed(8, 8);
192 assert_eq!(&*bytes, &[0, 0, 0, 0, 0, 0, 0, 0,]);
193 }
194 {
195 let bytes = &[1, 2, 3, 4, 5, 6, 7, 8];
196 let aligned_bytes = AlignedBytes::new_from_slice(bytes, 8);
197 assert_eq!(&*aligned_bytes, bytes);
198
199 let aligned_bytes_cloned = aligned_bytes.clone();
200 drop(aligned_bytes);
201 assert_eq!(&*aligned_bytes_cloned, bytes);
202 }
203 {
204 let bytes: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8];
205 let boxed_bytes: Box<[u8]> = bytes.into();
206 let aligned_bytes: AlignedBytes = boxed_bytes.into();
207 assert_eq!(&*aligned_bytes, bytes);
208 }
209 }
210
211 #[test]
212 fn check_alignment() {
213 let align = 4096;
214 let bytes = AlignedBytes::new_zeroed(8, align);
215 assert_eq!(bytes.align(), align);
216 assert!(bytes.as_ptr() as usize % align == 0);
217 }
218
219 #[should_panic(expected = "Invalid layout: size = 1, align = 0")]
220 #[test]
221 fn check_layout_zero_align() {
222 let bytes = AlignedBytes::new_zeroed(1, 0);
223 drop(bytes);
224 }
225
226 #[should_panic(expected = "Invalid layout: size = 0, align = 0")]
227 #[test]
228 fn check_layout_zero_len_align() {
229 let bytes = AlignedBytes::new_zeroed(0, 0);
230 drop(bytes);
231 }
232
233 #[should_panic(expected = "Invalid layout: size = 1, align = 3")]
234 #[test]
235 fn check_layout_align_not_power_of_2() {
236 let bytes = AlignedBytes::new_zeroed(1, 3);
237 drop(bytes);
238 }
239
240 #[should_panic]
241 #[test]
242 fn check_layout_overflow() {
243 let size = core::mem::size_of::<usize>() * 8;
244 let bytes = AlignedBytes::new_zeroed((1_usize << (size - 1)) + 1, 1_usize << (size - 1));
245 drop(bytes);
246 }
247
248 macro_rules! require {
249 ($ty:ty: $($markers:tt)+) => {{
250 fn __require<T: $($markers)*>() {}
251 __require::<$ty>();
252 }};
253 }
254
255 #[test]
256 fn check_markers() {
257 require!(AlignedBytes: Send + Sync);
258
259 #[cfg(feature = "std")]
260 {
261 use std::panic::{RefUnwindSafe, UnwindSafe};
262 require!(AlignedBytes: RefUnwindSafe + UnwindSafe);
263 }
264 }
265
266 #[test]
267 fn check_zst() {
268 let bytes = AlignedBytes::new_zeroed(0, 2);
269 drop(bytes);
270 }
271
272 #[test]
273 fn check_into_raw() {
274 let bytes = AlignedBytes::new_zeroed(0, 2);
275 let (buf, align) = AlignedBytes::into_raw(bytes);
276 drop(unsafe { AlignedBytes::from_raw(buf, align) });
277 }
278}