1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//! Validators add validation capabilities by wrapping and extending basic validators.

use crate::{validation::SharedContext, Fallible};
use core::{any::TypeId, fmt};

#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[cfg(feature = "std")]
use std::collections::HashMap;

/// Errors that can occur when checking shared memory.
#[derive(Debug)]
pub enum SharedError {
    /// Multiple pointers exist to the same location with different types
    TypeMismatch {
        /// A previous type that the location was checked as
        previous: TypeId,
        /// The current type that the location is checked as
        current: TypeId,
    },
}

impl fmt::Display for SharedError {
    #[inline]
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            SharedError::TypeMismatch { previous, current } => write!(
                f,
                "the same memory region has been claimed as two different types ({:?} and {:?})",
                previous, current
            ),
        }
    }
}

#[cfg(feature = "std")]
const _: () = {
    use std::error::Error;

    impl Error for SharedError {
        fn source(&self) -> Option<&(dyn Error + 'static)> {
            match self {
                SharedError::TypeMismatch { .. } => None,
            }
        }
    }
};

/// A validator that can verify shared memory.
#[derive(Debug)]
pub struct SharedValidator {
    shared: HashMap<*const u8, TypeId>,
}

// SAFETY: SharedValidator is safe to send to another thread
// This trait is not automatically implemented because the struct contains a pointer
unsafe impl Send for SharedValidator {}

// SAFETY: SharedValidator is safe to share between threads
// This trait is not automatically implemented because the struct contains a pointer
unsafe impl Sync for SharedValidator {}

impl SharedValidator {
    /// Wraps the given context and adds shared memory validation.
    #[inline]
    pub fn new() -> Self {
        Self {
            // TODO: consider deferring this to avoid the overhead of constructing
            shared: HashMap::new(),
        }
    }

    /// Shared memory validator with specific capacity.
    #[inline]
    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            shared: HashMap::with_capacity(capacity),
        }
    }
}

impl Default for SharedValidator {
    #[inline]
    fn default() -> Self {
        Self::new()
    }
}

impl Fallible for SharedValidator {
    type Error = SharedError;
}

impl SharedContext for SharedValidator {
    #[inline]
    fn register_shared_ptr(
        &mut self,
        ptr: *const u8,
        type_id: TypeId,
    ) -> Result<bool, Self::Error> {
        #[cfg(not(feature = "std"))]
        use hashbrown::hash_map::Entry;
        #[cfg(feature = "std")]
        use std::collections::hash_map::Entry;

        match self.shared.entry(ptr) {
            Entry::Occupied(previous_type_entry) => {
                let previous_type_id = previous_type_entry.get();
                if previous_type_id != &type_id {
                    Err(SharedError::TypeMismatch {
                        previous: *previous_type_id,
                        current: type_id,
                    })
                } else {
                    Ok(false)
                }
            }
            Entry::Vacant(ent) => {
                ent.insert(type_id);
                Ok(true)
            }
        }
    }
}