polkavm_common/
utils.rs

1#![allow(unsafe_code)]
2
3use core::mem::MaybeUninit;
4use core::ops::{Deref, Range};
5
6use crate::cast::cast;
7use crate::program::Reg;
8#[cfg(feature = "alloc")]
9use alloc::{borrow::Cow, sync::Arc, vec::Vec};
10
11#[cfg(feature = "alloc")]
12#[derive(Clone)]
13enum LifetimeObject {
14    None,
15    Arc { obj: Arc<[u8]> },
16    Other { obj: Arc<dyn AsRef<[u8]>> },
17}
18
19#[derive(Clone)]
20pub struct ArcBytes {
21    pointer: core::ptr::NonNull<u8>,
22    length: usize,
23
24    #[cfg(feature = "alloc")]
25    lifetime: LifetimeObject,
26}
27
28impl core::fmt::Debug for ArcBytes {
29    fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
30        fmt.debug_struct("ArcBytes").field("data", &self.deref()).finish()
31    }
32}
33
34impl Default for ArcBytes {
35    fn default() -> Self {
36        ArcBytes::empty()
37    }
38}
39
40// SAFETY: It's always safe to send `ArcBytes` to another thread due to atomic refcounting.
41unsafe impl Send for ArcBytes {}
42
43// SAFETY: It's always safe to access `ArcBytes` from multiple threads due to atomic refcounting.
44unsafe impl Sync for ArcBytes {}
45
46impl ArcBytes {
47    pub const fn empty() -> Self {
48        ArcBytes {
49            pointer: core::ptr::NonNull::dangling(),
50            length: 0,
51
52            #[cfg(feature = "alloc")]
53            lifetime: LifetimeObject::None,
54        }
55    }
56
57    pub const fn from_static(bytes: &'static [u8]) -> Self {
58        ArcBytes {
59            // SAFETY: `bytes` is always a valid slice, so its pointer is also always non-null and valid.
60            pointer: unsafe { core::ptr::NonNull::new_unchecked(bytes.as_ptr().cast_mut()) },
61            length: bytes.len(),
62
63            #[cfg(feature = "alloc")]
64            lifetime: LifetimeObject::None,
65        }
66    }
67
68    pub(crate) fn subslice(&self, subrange: Range<usize>) -> Self {
69        if subrange.start == subrange.end {
70            return Self::empty();
71        }
72
73        assert!(subrange.end >= subrange.start);
74        let length = subrange.end - subrange.start;
75        assert!(length <= self.length);
76
77        ArcBytes {
78            // TODO: Use `NonNull::add` once we migrate to Rust 1.80+.
79            // SAFETY: We've checked that the new subslice is valid with `assert`s.
80            pointer: unsafe { core::ptr::NonNull::new_unchecked(self.pointer.as_ptr().add(subrange.start)) },
81            length,
82
83            #[cfg(feature = "alloc")]
84            lifetime: self.lifetime.clone(),
85        }
86    }
87
88    #[cfg(feature = "alloc")]
89    pub(crate) fn parent_address_range(&self) -> Range<usize> {
90        let slice = match &self.lifetime {
91            LifetimeObject::None => return 0..0,
92            LifetimeObject::Arc { obj } => obj.as_ref(),
93            LifetimeObject::Other { obj } => obj.as_ref().as_ref(),
94        };
95        slice.as_ptr() as usize..(slice.as_ptr() as usize + slice.len())
96    }
97}
98
99impl Eq for ArcBytes {}
100
101impl PartialEq for ArcBytes {
102    fn eq(&self, rhs: &ArcBytes) -> bool {
103        self.deref() == rhs.deref()
104    }
105}
106
107impl Deref for ArcBytes {
108    type Target = [u8];
109
110    fn deref(&self) -> &Self::Target {
111        // SAFETY: `pointer` is always non-null and `length` is always valid.
112        unsafe { core::slice::from_raw_parts(self.pointer.as_ptr(), self.length) }
113    }
114}
115
116impl AsRef<[u8]> for ArcBytes {
117    fn as_ref(&self) -> &[u8] {
118        self.deref()
119    }
120}
121
122#[cfg(feature = "alloc")]
123impl<'a> From<&'a [u8]> for ArcBytes {
124    fn from(data: &'a [u8]) -> Self {
125        let data: Arc<[u8]> = data.into();
126        Self::from(data)
127    }
128}
129
130#[cfg(not(feature = "alloc"))]
131impl From<&'static [u8]> for ArcBytes {
132    fn from(data: &'static [u8]) -> Self {
133        ArcBytes::from_static(data)
134    }
135}
136
137#[cfg(feature = "alloc")]
138impl From<Vec<u8>> for ArcBytes {
139    fn from(data: Vec<u8>) -> Self {
140        ArcBytes {
141            pointer: core::ptr::NonNull::new(data.as_ptr().cast_mut()).unwrap(),
142            length: data.len(),
143            lifetime: LifetimeObject::Other { obj: Arc::new(data) },
144        }
145    }
146}
147
148#[cfg(feature = "alloc")]
149impl From<Arc<[u8]>> for ArcBytes {
150    fn from(data: Arc<[u8]>) -> Self {
151        ArcBytes {
152            pointer: core::ptr::NonNull::new(data.deref().as_ptr().cast_mut()).unwrap(),
153            length: data.len(),
154            lifetime: LifetimeObject::Arc { obj: data },
155        }
156    }
157}
158
159#[cfg(feature = "alloc")]
160impl<'a> From<Cow<'a, [u8]>> for ArcBytes {
161    fn from(cow: Cow<'a, [u8]>) -> Self {
162        match cow {
163            Cow::Borrowed(data) => data.into(),
164            Cow::Owned(data) => data.into(),
165        }
166    }
167}
168
169macro_rules! define_align_to_next_page {
170    ($name:ident, $type:ty) => {
171        /// Aligns the `value` to the next `page_size`, or returns the `value` as-is if it's already aligned.
172        #[inline]
173        pub const fn $name(page_size: $type, value: $type) -> Option<$type> {
174            assert!(
175                page_size != 0 && (page_size & (page_size - 1)) == 0,
176                "page size is not a power of two"
177            );
178            if value & page_size - 1 == 0 {
179                Some(value)
180            } else {
181                if value <= <$type>::MAX - page_size {
182                    Some((value + page_size) & !(page_size - 1))
183                } else {
184                    None
185                }
186            }
187        }
188    };
189}
190
191define_align_to_next_page!(align_to_next_page_u32, u32);
192define_align_to_next_page!(align_to_next_page_u64, u64);
193define_align_to_next_page!(align_to_next_page_usize, usize);
194
195#[test]
196fn test_align_to_next_page() {
197    assert_eq!(align_to_next_page_u64(4096, 0), Some(0));
198    assert_eq!(align_to_next_page_u64(4096, 1), Some(4096));
199    assert_eq!(align_to_next_page_u64(4096, 4095), Some(4096));
200    assert_eq!(align_to_next_page_u64(4096, 4096), Some(4096));
201    assert_eq!(align_to_next_page_u64(4096, 4097), Some(8192));
202    let max = (0x10000000000000000_u128 - 4096) as u64;
203    assert_eq!(align_to_next_page_u64(4096, max), Some(max));
204    assert_eq!(align_to_next_page_u64(4096, max + 1), None);
205}
206
207pub trait AsUninitSliceMut {
208    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>];
209}
210
211impl AsUninitSliceMut for [MaybeUninit<u8>] {
212    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>] {
213        self
214    }
215}
216
217impl AsUninitSliceMut for [u8] {
218    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>] {
219        #[allow(unsafe_code)]
220        // SAFETY: `MaybeUnunit<T>` is guaranteed to have the same representation as `T`,
221        //         so casting `[T]` into `[MaybeUninit<T>]` is safe.
222        unsafe {
223            core::slice::from_raw_parts_mut(self.as_mut_ptr().cast(), self.len())
224        }
225    }
226}
227
228impl<const N: usize> AsUninitSliceMut for MaybeUninit<[u8; N]> {
229    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>] {
230        #[allow(unsafe_code)]
231        // SAFETY: `MaybeUnunit<T>` is guaranteed to have the same representation as `T`,
232        //         so casting `[T; N]` into `[MaybeUninit<T>]` is safe.
233        unsafe {
234            core::slice::from_raw_parts_mut(self.as_mut_ptr().cast(), N)
235        }
236    }
237}
238
239impl<const N: usize> AsUninitSliceMut for [u8; N] {
240    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>] {
241        let slice: &mut [u8] = &mut self[..];
242        slice.as_uninit_slice_mut()
243    }
244}
245
246// Copied from `MaybeUninit::slice_assume_init_mut`.
247// TODO: Remove this once this API is stabilized.
248#[allow(clippy::missing_safety_doc)]
249#[allow(unsafe_code)]
250pub unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
251    // SAFETY: The caller is responsible for making sure the `slice` was properly initialized.
252    unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
253}
254
255#[allow(unsafe_code)]
256pub fn byte_slice_init<'dst>(dst: &'dst mut [MaybeUninit<u8>], src: &[u8]) -> &'dst mut [u8] {
257    assert_eq!(dst.len(), src.len());
258
259    let length = dst.len();
260    let src_ptr: *const u8 = src.as_ptr();
261    let dst_ptr: *mut u8 = dst.as_mut_ptr().cast::<u8>();
262
263    // SAFETY: Both pointers are valid and are guaranteed to point to a region of memory
264    // at least `length` bytes big.
265    unsafe {
266        core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, length);
267    }
268
269    // SAFETY: We've just initialized this slice.
270    unsafe { slice_assume_init_mut(dst) }
271}
272
273pub fn parse_imm(text: &str) -> Option<i32> {
274    let text = text.trim();
275    if let Some(text) = text.strip_prefix("0x") {
276        return u32::from_str_radix(text, 16).ok().map(|value| value as i32);
277    }
278
279    if let Some(text) = text.strip_prefix("0b") {
280        return u32::from_str_radix(text, 2).ok().map(|value| value as i32);
281    }
282
283    if let Ok(value) = text.parse::<i32>() {
284        Some(value)
285    } else if let Ok(value) = text.parse::<u32>() {
286        Some(value as i32)
287    } else {
288        None
289    }
290}
291
292#[derive(Debug, PartialEq)]
293pub enum ParsedImmediate {
294    U32(u32),
295    U64(u64),
296}
297
298impl TryFrom<ParsedImmediate> for u32 {
299    type Error = &'static str;
300
301    fn try_from(value: ParsedImmediate) -> Result<Self, Self::Error> {
302        match value {
303            ParsedImmediate::U32(v) => Ok(v),
304            ParsedImmediate::U64(_) => Err("value is too large for u32"),
305        }
306    }
307}
308
309impl From<ParsedImmediate> for u64 {
310    fn from(value: ParsedImmediate) -> Self {
311        match value {
312            ParsedImmediate::U32(v) => cast(v).to_u64_sign_extend(),
313            ParsedImmediate::U64(v) => v,
314        }
315    }
316}
317
318pub fn parse_immediate(text: &str) -> Option<ParsedImmediate> {
319    let text = text.trim();
320
321    let (force_imm64, text) = if let Some(text) = text.strip_prefix("i64 ") {
322        (true, text.trim())
323    } else {
324        (false, text)
325    };
326
327    let value = if let Some(text) = text.strip_prefix("0x") {
328        u64::from_str_radix(text, 16).ok()?
329    } else if let Some(text) = text.strip_prefix("0b") {
330        u64::from_str_radix(text, 2).ok()?
331    } else {
332        match text.parse::<i64>() {
333            Ok(signed) => signed as u64,
334            Err(_) => return None,
335        }
336    };
337
338    if force_imm64 {
339        return Some(ParsedImmediate::U64(value));
340    }
341
342    if value < 0x7fffffff || cast(cast(value).truncate_to_u32()).to_u64_sign_extend() == value {
343        Some(ParsedImmediate::U32(cast(value).truncate_to_u32()))
344    } else {
345        Some(ParsedImmediate::U64(value))
346    }
347}
348
349#[test]
350fn test_parse_immediate() {
351    // "special cases"
352    assert_eq!(parse_immediate("0xffffffff"), Some(ParsedImmediate::U64(0xffffffff)));
353    assert_eq!(parse_immediate("0xffffffff87654321"), Some(ParsedImmediate::U32(0x87654321)));
354    assert_eq!(parse_immediate("0x80000075"), Some(ParsedImmediate::U64(0x80000075)));
355    // "normal cases"
356    assert_eq!(parse_immediate("0x1234"), Some(ParsedImmediate::U32(0x1234)));
357    assert_eq!(parse_immediate("0x12345678"), Some(ParsedImmediate::U32(0x12345678)));
358    assert_eq!(parse_immediate("0x1234567890"), Some(ParsedImmediate::U64(0x1234567890)));
359    assert_eq!(parse_immediate("-1"), Some(ParsedImmediate::U32(0xffffffff)));
360    assert_eq!(parse_immediate("-2"), Some(ParsedImmediate::U32(0xfffffffe)));
361    assert_eq!(parse_immediate("i64 0xffffffff"), Some(ParsedImmediate::U64(0xffffffff)));
362    assert_eq!(parse_immediate("0xdeadbeef"), Some(ParsedImmediate::U64(0xdeadbeef)));
363    assert_eq!(
364        parse_immediate("0xffffffff00000000"),
365        Some(ParsedImmediate::U64(0xffffffff00000000))
366    );
367    assert_eq!(parse_immediate("0xf000000e").map(Into::into), Some(0xf000000eu64));
368    assert_eq!(parse_immediate("0x80000075").and_then(|imm| imm.try_into().ok()), None::<u32>);
369}
370
371pub fn parse_reg(text: &str) -> Option<Reg> {
372    const REG_NAME_ALT: [&str; 13] = ["r0", "r1", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", "r11", "r12"];
373
374    let text = text.trim();
375    for (reg, name_alt) in Reg::ALL.into_iter().zip(REG_NAME_ALT) {
376        if text == reg.name() || text == name_alt {
377            return Some(reg);
378        }
379    }
380
381    None
382}
383
384#[test]
385fn test_arc_bytes() {
386    assert_eq!(&*ArcBytes::empty(), b"");
387    assert_eq!(ArcBytes::empty().as_ptr(), ArcBytes::empty().as_ptr());
388
389    #[cfg(feature = "alloc")]
390    #[allow(clippy::redundant_clone)]
391    {
392        let ab = ArcBytes::from(alloc::vec![1, 2, 3, 4]);
393        assert_eq!(ab.as_ptr(), ab.as_ptr());
394        assert_eq!(ab.clone().as_ptr(), ab.as_ptr());
395        assert_eq!(&*ab, &[1, 2, 3, 4]);
396        assert_eq!(&*ab.subslice(0..4), &[1, 2, 3, 4]);
397        assert_eq!(&*ab.subslice(0..3), &[1, 2, 3]);
398        assert_eq!(&*ab.subslice(1..4), &[2, 3, 4]);
399
400        let mut arc = Arc::<[u8]>::from(alloc::vec![1, 2, 3, 4]);
401        assert!(Arc::get_mut(&mut arc).is_some());
402        let ab2 = ArcBytes::from(Arc::clone(&arc));
403        assert!(Arc::get_mut(&mut arc).is_none());
404        assert_eq!(ab2.as_ptr(), ab2.as_ptr());
405        assert_eq!(ab2.clone().as_ptr(), ab2.as_ptr());
406        core::mem::drop(ab2);
407        assert!(Arc::get_mut(&mut arc).is_some());
408    }
409}