hyperlight_host/mem/
ptr_offset.rs

1/*
2Copyright 2024 The Hyperlight Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
18use std::convert::From;
19use std::ops::{Add, Sub};
20
21use tracing::{instrument, Span};
22
23use crate::error::HyperlightError;
24use crate::Result;
25
26/// An offset into a given address space.
27///
28/// Use this type to distinguish between an offset and a raw pointer
29#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
30pub(crate) struct Offset(u64);
31
32impl Offset {
33    /// Get the offset representing `0`
34    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
35    pub(super) fn zero() -> Self {
36        Self::default()
37    }
38
39    /// round up to the nearest multiple of `alignment`
40    pub(super) fn round_up_to(self, alignment: u64) -> Self {
41        let remainder = self.0 % alignment;
42        let multiples = self.0 / alignment;
43        match remainder {
44            0 => self,
45            _ => Offset::from((multiples + 1) * alignment),
46        }
47    }
48}
49
50impl Default for Offset {
51    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
52    fn default() -> Self {
53        Offset::from(0_u64)
54    }
55}
56
57impl From<u64> for Offset {
58    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
59    fn from(val: u64) -> Self {
60        Self(val)
61    }
62}
63
64impl From<&Offset> for u64 {
65    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
66    fn from(val: &Offset) -> u64 {
67        val.0
68    }
69}
70
71impl From<Offset> for u64 {
72    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
73    fn from(val: Offset) -> u64 {
74        val.0
75    }
76}
77
78impl TryFrom<Offset> for i64 {
79    type Error = HyperlightError;
80    #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
81    fn try_from(val: Offset) -> Result<i64> {
82        Ok(i64::try_from(val.0)?)
83    }
84}
85
86impl TryFrom<i64> for Offset {
87    type Error = HyperlightError;
88    #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
89    fn try_from(val: i64) -> Result<Offset> {
90        let val_u64 = u64::try_from(val)?;
91        Ok(Offset::from(val_u64))
92    }
93}
94
95impl TryFrom<usize> for Offset {
96    type Error = HyperlightError;
97    #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
98    fn try_from(val: usize) -> Result<Offset> {
99        Ok(u64::try_from(val).map(Offset::from)?)
100    }
101}
102
103/// Convert an `Offset` to a `usize`, returning an `Err` if the
104/// conversion couldn't be made.
105impl TryFrom<&Offset> for usize {
106    type Error = HyperlightError;
107    #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
108    fn try_from(val: &Offset) -> Result<usize> {
109        Ok(usize::try_from(val.0)?)
110    }
111}
112
113impl TryFrom<Offset> for usize {
114    type Error = HyperlightError;
115    #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
116    fn try_from(val: Offset) -> Result<usize> {
117        usize::try_from(&val)
118    }
119}
120
121impl Add<Offset> for Offset {
122    type Output = Offset;
123    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
124    fn add(self, rhs: Offset) -> Offset {
125        Offset::from(self.0 + rhs.0)
126    }
127}
128
129impl Add<usize> for Offset {
130    type Output = Offset;
131    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
132    fn add(self, rhs: usize) -> Offset {
133        Offset(self.0 + rhs as u64)
134    }
135}
136
137impl Add<Offset> for usize {
138    type Output = Offset;
139    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
140    fn add(self, rhs: Offset) -> Offset {
141        rhs.add(self)
142    }
143}
144
145impl Add<u64> for Offset {
146    type Output = Offset;
147    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
148    fn add(self, rhs: u64) -> Offset {
149        Offset(self.0 + rhs)
150    }
151}
152
153impl Add<Offset> for u64 {
154    type Output = Offset;
155    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
156    fn add(self, rhs: Offset) -> Offset {
157        rhs.add(self)
158    }
159}
160
161impl Sub<Offset> for Offset {
162    type Output = Offset;
163    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
164    fn sub(self, rhs: Offset) -> Offset {
165        Offset::from(self.0 - rhs.0)
166    }
167}
168
169impl Sub<usize> for Offset {
170    type Output = Offset;
171    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
172    fn sub(self, rhs: usize) -> Offset {
173        Offset(self.0 - rhs as u64)
174    }
175}
176
177impl Sub<Offset> for usize {
178    type Output = Offset;
179    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
180    fn sub(self, rhs: Offset) -> Offset {
181        rhs.sub(self)
182    }
183}
184
185impl Sub<u64> for Offset {
186    type Output = Offset;
187    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
188    fn sub(self, rhs: u64) -> Offset {
189        Offset(self.0 - rhs)
190    }
191}
192
193impl Sub<Offset> for u64 {
194    type Output = Offset;
195    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
196    fn sub(self, rhs: Offset) -> Offset {
197        rhs.sub(self)
198    }
199}
200
201impl PartialEq<usize> for Offset {
202    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
203    fn eq(&self, other: &usize) -> bool {
204        if let Ok(offset_usize) = usize::try_from(self) {
205            offset_usize == *other
206        } else {
207            false
208        }
209    }
210}
211
212impl PartialOrd<usize> for Offset {
213    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
214    fn partial_cmp(&self, rhs: &usize) -> Option<Ordering> {
215        match usize::try_from(self) {
216            Ok(offset_usize) if offset_usize > *rhs => Some(Ordering::Greater),
217            Ok(offset_usize) if offset_usize == *rhs => Some(Ordering::Equal),
218            Ok(_) => Some(Ordering::Less),
219            Err(_) => None,
220        }
221    }
222}
223
224impl PartialEq<u64> for Offset {
225    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
226    fn eq(&self, rhs: &u64) -> bool {
227        u64::from(self) == *rhs
228    }
229}
230
231impl PartialOrd<u64> for Offset {
232    #[instrument(skip_all, parent = Span::current(), level= "Trace")]
233    fn partial_cmp(&self, rhs: &u64) -> Option<Ordering> {
234        let lhs: u64 = self.into();
235        match lhs > *rhs {
236            true => Some(Ordering::Greater),
237            false if lhs == *rhs => Some(Ordering::Equal),
238            false => Some(Ordering::Less),
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use proptest::prelude::*;
246
247    use super::Offset;
248
249    proptest! {
250        #[test]
251        fn i64_roundtrip(i64_val in (i64::MIN..i64::MAX)) {
252            let offset_res = Offset::try_from(i64_val);
253
254            if i64_val < 0 {
255                assert!(offset_res.is_err());
256            } else {
257                assert!(offset_res.is_ok());
258                let offset = offset_res.unwrap();
259                let ret_i64_val = {
260                    let res = i64::try_from(offset);
261                    assert!(res.is_ok());
262                    res.unwrap()
263                };
264                assert_eq!(i64_val, ret_i64_val);
265            }
266        }
267        #[test]
268        fn usize_roundtrip(val in (usize::MIN..usize::MAX)) {
269            let offset = Offset::try_from(val).unwrap();
270            assert_eq!(val, usize::try_from(offset).unwrap());
271        }
272
273        #[test]
274        fn add_numeric_types(usize_val in (usize::MIN..usize::MAX), u64_val in (u64::MIN..u64::MAX)) {
275            let start = Offset::default();
276            {
277                // add usize to offset
278                assert_eq!(usize_val, usize::try_from(start + usize_val).unwrap());
279            }
280            {
281                // add u64 to offset
282                assert_eq!(u64_val, u64::from(start + u64_val));
283            }
284        }
285    }
286
287    #[test]
288    fn round_up_to() {
289        let offset = Offset::from(0);
290        let rounded = offset.round_up_to(4);
291        assert_eq!(rounded, offset);
292
293        let offset = Offset::from(1);
294        let rounded = offset.round_up_to(4);
295        assert_eq!(rounded, Offset::from(4));
296
297        let offset = Offset::from(3);
298        let rounded = offset.round_up_to(4);
299        assert_eq!(rounded, Offset::from(4));
300
301        let offset = Offset::from(4);
302        let rounded = offset.round_up_to(4);
303        assert_eq!(rounded, Offset::from(4));
304
305        let offset = Offset::from(5);
306        let rounded = offset.round_up_to(4);
307        assert_eq!(rounded, Offset::from(8));
308    }
309}