1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
//! Validation implementations for shared pointers.

use super::{ArchivedRc, ArchivedRcWeak, ArchivedRcWeakTag, ArchivedRcWeakVariantSome};
use crate::{
    validation::{ArchiveContext, LayoutRaw, SharedContext},
    ArchivePointee, RelPtr,
};
use bytecheck::{CheckBytes, Error};
use core::{any::TypeId, convert::Infallible, fmt, ptr};
use ptr_meta::Pointee;

/// Errors that can occur while checking archived shared pointers.
#[derive(Debug)]
pub enum SharedPointerError<T, R, C> {
    /// An error occurred while checking the bytes of a shared value
    PointerCheckBytesError(T),
    /// An error occurred while checking the bytes of a shared reference
    ValueCheckBytesError(R),
    /// A context error occurred
    ContextError(C),
}

impl<T, R, C> fmt::Display for SharedPointerError<T, R, C>
where
    T: fmt::Display,
    R: fmt::Display,
    C: fmt::Display,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            SharedPointerError::PointerCheckBytesError(e) => e.fmt(f),
            SharedPointerError::ValueCheckBytesError(e) => e.fmt(f),
            SharedPointerError::ContextError(e) => e.fmt(f),
        }
    }
}

#[cfg(feature = "std")]
const _: () = {
    use std::error::Error;

    impl<T, R, C> Error for SharedPointerError<T, R, C>
    where
        T: Error + 'static,
        R: Error + 'static,
        C: Error + 'static,
    {
        fn source(&self) -> Option<&(dyn Error + 'static)> {
            match self {
                SharedPointerError::PointerCheckBytesError(e) => Some(e as &dyn Error),
                SharedPointerError::ValueCheckBytesError(e) => Some(e as &dyn Error),
                SharedPointerError::ContextError(e) => Some(e as &dyn Error),
            }
        }
    }
};

/// Errors that can occur while checking archived weak pointers.
#[derive(Debug)]
pub enum WeakPointerError<T, R, C> {
    /// The weak pointer had an invalid tag
    InvalidTag(u8),
    /// An error occurred while checking the underlying shared pointer
    CheckBytes(SharedPointerError<T, R, C>),
}

impl<T: fmt::Display, R: fmt::Display, C: fmt::Display> fmt::Display for WeakPointerError<T, R, C> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            WeakPointerError::InvalidTag(tag) => {
                write!(f, "archived weak had invalid tag: {}", tag)
            }
            WeakPointerError::CheckBytes(e) => e.fmt(f),
        }
    }
}

#[cfg(feature = "std")]
const _: () = {
    use std::error::Error;

    impl<T, R, C> Error for WeakPointerError<T, R, C>
    where
        T: Error + 'static,
        R: Error + 'static,
        C: Error + 'static,
    {
        fn source(&self) -> Option<&(dyn Error + 'static)> {
            match self {
                WeakPointerError::InvalidTag(_) => None,
                WeakPointerError::CheckBytes(e) => Some(e as &dyn Error),
            }
        }
    }
};

impl<T, R, C> From<Infallible> for WeakPointerError<T, R, C> {
    fn from(_: Infallible) -> Self {
        unsafe { core::hint::unreachable_unchecked() }
    }
}

impl<T, F, C> CheckBytes<C> for ArchivedRc<T, F>
where
    T: ArchivePointee + CheckBytes<C> + LayoutRaw + Pointee + ?Sized + 'static,
    C: ArchiveContext + SharedContext + ?Sized,
    T::ArchivedMetadata: CheckBytes<C>,
    C::Error: Error,
    F: 'static,
{
    type Error =
        SharedPointerError<<T::ArchivedMetadata as CheckBytes<C>>::Error, T::Error, C::Error>;

    #[inline]
    unsafe fn check_bytes<'a>(
        value: *const Self,
        context: &mut C,
    ) -> Result<&'a Self, Self::Error> {
        let rel_ptr = RelPtr::<T>::manual_check_bytes(value.cast(), context)
            .map_err(SharedPointerError::PointerCheckBytesError)?;
        let ptr = context
            .check_rel_ptr(rel_ptr)
            .map_err(SharedPointerError::ContextError)?;

        let type_id = TypeId::of::<Self>();
        if context
            .register_shared_ptr(ptr.cast(), type_id)
            .map_err(SharedPointerError::ContextError)?
        {
            context
                .bounds_check_subtree_ptr(ptr)
                .map_err(SharedPointerError::ContextError)?;

            let range = context
                .push_prefix_subtree(ptr)
                .map_err(SharedPointerError::ContextError)?;
            T::check_bytes(ptr, context).map_err(SharedPointerError::ValueCheckBytesError)?;
            context
                .pop_prefix_range(range)
                .map_err(SharedPointerError::ContextError)?;
        }
        Ok(&*value)
    }
}

impl ArchivedRcWeakTag {
    const TAG_NONE: u8 = ArchivedRcWeakTag::None as u8;
    const TAG_SOME: u8 = ArchivedRcWeakTag::Some as u8;
}

impl<T, F, C> CheckBytes<C> for ArchivedRcWeak<T, F>
where
    T: ArchivePointee + CheckBytes<C> + LayoutRaw + Pointee + ?Sized + 'static,
    C: ArchiveContext + SharedContext + ?Sized,
    T::ArchivedMetadata: CheckBytes<C>,
    C::Error: Error,
    F: 'static,
{
    type Error =
        WeakPointerError<<T::ArchivedMetadata as CheckBytes<C>>::Error, T::Error, C::Error>;

    #[inline]
    unsafe fn check_bytes<'a>(
        value: *const Self,
        context: &mut C,
    ) -> Result<&'a Self, Self::Error> {
        let tag = *u8::check_bytes(value.cast::<u8>(), context)?;
        match tag {
            ArchivedRcWeakTag::TAG_NONE => (),
            ArchivedRcWeakTag::TAG_SOME => {
                let value = value.cast::<ArchivedRcWeakVariantSome<T, F>>();
                ArchivedRc::<T, F>::check_bytes(ptr::addr_of!((*value).1), context)
                    .map_err(WeakPointerError::CheckBytes)?;
            }
            _ => return Err(WeakPointerError::InvalidTag(tag)),
        }
        Ok(&*value)
    }
}