1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
//! POD (specifically, [`bytemuck::Pod`]) versions of some types that have invalid bit patterns
//! for use in FFI structs.
//!
//! Usually, it's possible to just mark a struct containing one of the
//! fundamental versions of these types as [`bytemuck::CheckedBitPattern`], which will allow them
//! to pass through the relevant module host memory util casting functions by checking the relevant
//! underlying memory bit pattern each time they get casted. However, the auto implementations of that
//! trait are sometimes suboptimal in terms of performance, and it may be useful in certain FFI types that
//! are:
//!
//! 1. particularly performance sensitive (i.e. large numbers pass through the FFI boundary per frame)
//! 2. contain many true-[`Pod`] fields and only a few [`CheckedBitPattern`][bytemuck::CheckedBitPattern] fields
//!
//! to use these types instead and then only check the relevant fields for validity when they actually
//! get used.

use core::ops::Deref;
use core::ops::DerefMut;

use bytemuck::AnyBitPattern;
use bytemuck::Pod;
use bytemuck::Zeroable;

/// A thin wrapper around u128 that ensures 16-byte alignment.
///
/// Since wasm is happy with 8-byte aligned `u128`, a `u128` value
/// would sometimes get placed into an 8-byte aligned memory address
/// by Rust when created in wasm-land, but this breaks memory safety requirements
/// if we want to access that value directly on the host side and triggering a relevant
/// assertion in safe casting functions. As a result,
/// in ffi types, this should always be used over a raw `u128`
#[repr(C, align(16))]
#[derive(Copy, Clone, Pod, Zeroable, Default, Debug)]
pub struct Align16U128(pub u128);

impl Deref for Align16U128 {
    type Target = u128;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl DerefMut for Align16U128 {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl AsRef<u128> for Align16U128 {
    fn as_ref(&self) -> &u128 {
        &self.0
    }
}

impl AsMut<u128> for Align16U128 {
    fn as_mut(&mut self) -> &mut u128 {
        &mut self.0
    }
}

impl From<u128> for Align16U128 {
    fn from(value: u128) -> Self {
        Self(value)
    }
}

impl From<Align16U128> for u128 {
    fn from(value: Align16U128) -> Self {
        value.0
    }
}

impl core::fmt::Display for Align16U128 {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "{}", self.0)
    }
}

/// A thin wrapper around i128 that ensures 16-byte alignment.
///
/// Since wasm is happy with 8-byte aligned `i128`, a `i128` value
/// would sometimes get placed into an 8-byte aligned memory address
/// by Rust when created in wasm-land, but this breaks memory safety requirements
/// if we want to access that value directly on the host side and triggering a relevant
/// assertion in safe casting functions. As a result,
/// in ffi types, this should always be used over a raw `i128`
#[repr(C, align(16))]
#[derive(Copy, Clone, Pod, Zeroable, Default, Debug)]
pub struct Align16I128(pub i128);

impl Deref for Align16I128 {
    type Target = i128;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl DerefMut for Align16I128 {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl AsRef<i128> for Align16I128 {
    fn as_ref(&self) -> &i128 {
        &self.0
    }
}

impl AsMut<i128> for Align16I128 {
    fn as_mut(&mut self) -> &mut i128 {
        &mut self.0
    }
}

impl From<i128> for Align16I128 {
    fn from(value: i128) -> Self {
        Self(value)
    }
}

impl From<Align16I128> for i128 {
    fn from(value: Align16I128) -> Self {
        value.0
    }
}

impl core::fmt::Display for Align16I128 {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "{}", self.0)
    }
}

/// A version of `bool` that has the same layout as `bool` but is [`bytemuck::Pod`].
///
/// In Rust, [`bool`] has a defined representation of being one byte with false being 0 and true being 1.
/// However, it is not [`Pod`] because it is *invalid* to interpret a byte that contains any value *other*
/// than 0 or 1 as a `bool` -- doing so is ***undefined behavior***.
///
/// Use this type to get around that limitation, as `PodBool` is valid for any value of `u8`, and its value is
/// checked to be valid as a `bool` upon calling [`as_bool`][PodBool::as_bool]
#[repr(C)]
#[derive(Copy, Clone, Eq, PartialEq, Pod, Zeroable)]
pub struct PodBool(u8);

impl core::fmt::Debug for PodBool {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        if let Ok(b) = self.try_as_bool() {
            write!(f, "PodBool({})", b)
        } else {
            write!(f, "PodBool(InvalidValue({}))", self.0)
        }
    }
}

impl From<bool> for PodBool {
    fn from(v: bool) -> Self {
        PodBool(v as u8)
    }
}

/// Invalid value stored in [`PodBool`] when trying to interpret it as a `bool`
pub struct InvalidPodBool {}

impl PodBool {
    pub fn as_bool(&self) -> bool {
        match self.0 {
            0 => false,
            1 => true,
            #[allow(clippy::panic)]
            _ => panic!("invalid value in PodBool"),
        }
    }

    pub fn try_as_bool(&self) -> Result<bool, InvalidPodBool> {
        match self.0 {
            0 => Ok(false),
            1 => Ok(true),
            _ => Err(InvalidPodBool {}),
        }
    }
}

/// This type adds some `const PAD` number of "explicit" or "manual" padding
/// bytes to the end of a struct.
///
/// This is useful to make a type not have *real* padding bytes,
/// and therefore be able to be marked as [`bytemuck::NoUninit`]. Specifically,
/// it's used in the `ark_api_macros::ffi_union` macro to equalize the size of all
/// fields of a union and therefore remove any "real" padding bytes from the union, making
/// it safe to store in WASM memory and pass through the ark module host memory utility functions.
/// It may also be useful in other places.
#[derive(Copy, Clone)]
#[repr(C)]
pub struct TransparentPad<T, const PAD: usize>(pub T, [u8; PAD]);

// SAFETY: Since `[u8; N]` is always Zeroable, this is safe
#[allow(unsafe_code)]
unsafe impl<T: Zeroable, const PAD: usize> Zeroable for TransparentPad<T, PAD> {}

// SAFETY: Since `[u8; N]` is always AnyBitPattern, this is safe
#[allow(unsafe_code)]
unsafe impl<T: AnyBitPattern, const PAD: usize> AnyBitPattern for TransparentPad<T, PAD> {}

#[macro_export]
macro_rules! impl_checked_bit_pattern_for_transparent_pad {
    ($inner:ident) => {
        #[allow(unsafe_code)]
        unsafe impl<const PAD: usize> bytemuck::CheckedBitPattern
            for $crate::TransparentPad<$inner, PAD>
        {
            type Bits = $crate::TransparentPad<<$inner as bytemuck::CheckedBitPattern>::Bits, PAD>;

            fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
                <$inner as bytemuck::CheckedBitPattern>::is_valid_bit_pattern(&bits.0)
            }
        }
    };
}

impl<T, const PAD: usize> TransparentPad<T, PAD> {
    pub fn new(inner: T) -> Self {
        Self(inner, [0u8; PAD])
    }
}

impl<T, const PAD: usize> AsRef<T> for TransparentPad<T, PAD> {
    fn as_ref(&self) -> &T {
        &self.0
    }
}

impl<T, const PAD: usize> AsMut<T> for TransparentPad<T, PAD> {
    fn as_mut(&mut self) -> &mut T {
        &mut self.0
    }
}

impl<T, const PAD: usize> core::ops::Deref for TransparentPad<T, PAD> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<T, const PAD: usize> core::ops::DerefMut for TransparentPad<T, PAD> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}