1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
//! Validators add validation capabilities by wrapping and extending basic validators.

use crate::{validation::SharedContext, Fallible};
use core::{any::TypeId, fmt};

#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[cfg(feature = "std")]
use std::collections::HashMap;

/// Errors that can occur when checking shared memory.
#[derive(Debug)]
pub enum SharedError {
    /// Multiple pointers exist to the same location with different types
    TypeMismatch {
        /// A previous type that the location was checked as
        previous: TypeId,
        /// The current type that the location is checked as
        current: TypeId,
    },
}

impl fmt::Display for SharedError {
    #[inline]
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            SharedError::TypeMismatch { previous, current } => write!(
                f,
                "the same memory region has been claimed as two different types ({:?} and {:?})",
                previous, current
            ),
        }
    }
}

#[cfg(feature = "std")]
const _: () = {
    use std::error::Error;

    impl Error for SharedError {
        fn source(&self) -> Option<&(dyn Error + 'static)> {
            match self {
                SharedError::TypeMismatch { .. } => None,
            }
        }
    }
};

/// A validator that can verify shared memory.
#[derive(Debug)]
pub struct SharedValidator {
    shared: HashMap<*const u8, TypeId>,
}

// SAFETY: SharedValidator is safe to send to another thread
// This trait is not automatically implemented because the struct contains a pointer
unsafe impl Send for SharedValidator {}

// SAFETY: SharedValidator is safe to share between threads
// This trait is not automatically implemented because the struct contains a pointer
unsafe impl Sync for SharedValidator {}

impl SharedValidator {
    /// Wraps the given context and adds shared memory validation.
    #[inline]
    pub fn new() -> Self {
        Self {
            // TODO: consider deferring this to avoid the overhead of constructing
            shared: HashMap::new(),
        }
    }
}

impl Default for SharedValidator {
    #[inline]
    fn default() -> Self {
        Self::new()
    }
}

impl Fallible for SharedValidator {
    type Error = SharedError;
}

impl SharedContext for SharedValidator {
    #[inline]
    fn register_shared_ptr(
        &mut self,
        ptr: *const u8,
        type_id: TypeId,
    ) -> Result<bool, Self::Error> {
        if let Some(previous_type_id) = self.shared.get(&ptr) {
            if previous_type_id != &type_id {
                Err(SharedError::TypeMismatch {
                    previous: *previous_type_id,
                    current: type_id,
                })
            } else {
                Ok(false)
            }
        } else {
            self.shared.insert(ptr, type_id);
            Ok(true)
        }
    }
}