1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
use crate::error::SpirvCrossError;
use crate::{error, Compiler, PhantomCompiler};
use spirv_cross_sys::spvc_compiler_s;
use std::fmt::{Debug, Formatter};
use std::ptr::NonNull;

use crate::sealed::Sealed;

/// A SPIR-V ID to a specialization constant.
pub use spirv_cross_sys::ConstantId;

/// A SPIR-V ID to a type.
pub use spirv_cross_sys::TypeId;

/// A SPIR-V ID to a variable.
pub use spirv_cross_sys::VariableId;

#[derive(Copy, Clone)]
#[repr(transparent)]
struct PointerOnlyForComparison<T>(NonNull<T>);

impl<T> PartialEq for PointerOnlyForComparison<T> {
    fn eq(&self, other: &Self) -> bool {
        other.0.as_ptr() == self.0.as_ptr()
    }
}

impl<T> Eq for PointerOnlyForComparison<T> {}

impl<T> Debug for PointerOnlyForComparison<T> {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        // Truncate the tag, we don't really care about the upper 32 bytes.
        // - Chop off ignored 16 bits
        // - Low 2 bits are always 0, so we can ignore that too.
        // - Either the low or high 32 bits remaining are good enough to show uniqueness.
        write!(
            f,
            "Tag({:x})",
            (((self.0.as_ptr() as usize) << 16) >> 18) as u32
        )
    }
}

/// A reference to an ID referring to an item in the compiler instance.
///
/// The usage of `Handle<T>` ensures that item IDs can not be forged from
/// a different compiler instance or from a `u32`.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Handle<T> {
    id: T,
    tag: PointerOnlyForComparison<spvc_compiler_s>,
}

impl<T: Id> Handle<T> {
    /// Return the `u32` part of the Id.
    ///
    /// Note that [`Handle<T>`] **can not** implement [`Id`]
    /// for safety reasons. Getting an `impl Id` out of a
    /// [`Handle<T>`] requires using [`Compiler::yield_id`].
    pub fn id(&self) -> u32 {
        self.id.id()
    }
}

/// Trait for SPIRV-Cross ID types.
pub trait Id: Sealed + Debug + 'static {
    /// Return the `u32` part of the Id.
    fn id(&self) -> u32;
}

impl Sealed for TypeId {}
impl Id for TypeId {
    #[inline(always)]
    fn id(&self) -> u32 {
        self.0 .0
    }
}

impl Sealed for VariableId {}
impl Id for VariableId {
    #[inline(always)]
    fn id(&self) -> u32 {
        self.0 .0
    }
}

impl Sealed for ConstantId {}
impl Id for ConstantId {
    #[inline(always)]
    fn id(&self) -> u32 {
        self.0 .0
    }
}

impl<T: Id> Handle<T> {
    /// Erase the type of the handle, this is useful for errors
    /// but is otherwise useless.
    #[cold]
    fn erase_type(self) -> Handle<Box<dyn Id>> {
        Handle {
            id: Box::new(self.id) as Box<dyn Id>,
            tag: self.tag,
        }
    }
}

/// APIs for comparing handles
impl<T> Compiler<'_, T> {
    #[inline(always)]
    /// Create a handle for the given ID tagged with this compiler instance.
    ///
    /// # Safety
    /// When creating a handle, the ID must be valid for the compilation.
    pub unsafe fn create_handle<I>(&self, id: I) -> Handle<I> {
        Handle {
            id,
            tag: PointerOnlyForComparison(self.ptr),
        }
    }

    #[inline(always)]
    /// Create a handle for the given ID tagged with this compiler instance,
    /// if the provided ID is not zero.
    ///
    /// # Safety
    /// When creating a handle, the ID must be valid for the compilation.
    pub unsafe fn create_handle_if_not_zero<I: Id>(&self, id: I) -> Option<Handle<I>> {
        let raw = id.id();
        if raw == 0 {
            return None;
        }
        Some(Handle {
            id,
            tag: PointerOnlyForComparison(self.ptr),
        })
    }

    /// Returns whether the given handle is valid for this compiler instance.
    pub fn handle_is_valid<I>(&self, handle: &Handle<I>) -> bool {
        handle.tag == PointerOnlyForComparison(self.ptr)
    }

    /// Yield the value of the handle, if it originated from the same context,
    /// otherwise return [`SpirvCrossError::InvalidHandle`].
    pub fn yield_id<I: Id>(&self, handle: Handle<I>) -> error::Result<I> {
        if self.handle_is_valid(&handle) {
            Ok(handle.id)
        } else {
            Err(SpirvCrossError::InvalidHandle(handle.erase_type()))
        }
    }
}

impl PhantomCompiler<'_> {
    /// Internal method for creating a handle
    ///
    /// This is not marked unsafe, because it is only ever used internally
    /// for handles valid for a compiler instance, i.e. we never smuggle
    /// an invalid handle. Marking it unsafe would make it too noisy to
    /// audit actually unsafe code.
    ///
    /// This is not necessarily the case for the public API.
    #[inline(always)]
    pub(crate) fn create_handle<I>(&self, id: I) -> Handle<I> {
        Handle {
            id,
            tag: PointerOnlyForComparison(self.ptr),
        }
    }
}