Skip to main content

vyre_reference/
oob.rs

1//! Out-of-bounds rules enforced by the parity engine.
2//!
3//! GPU drivers differ on what happens when a shader indexes past the end of a
4//! buffer: some clamp, some return zero, some crash. The reference interpreter
5//! eliminates that ambiguity by defining one deterministic behavior  -  defined-type
6//! zero-fill for scalar loads, empty slice for `Bytes`, and silent no-op for stores.
7//! Any backend that diverges from these rules fails the conform gate.
8
9use vyre::ir::DataType as IrDataType;
10
11use crate::value::Value;
12use vyre::ir::DataType;
13
14use std::sync::{Arc, RwLock};
15
16/// Typed bytes backing one declared IR buffer.
17///
18/// This struct exists to give the reference interpreter a single place to enforce
19/// stride-correct indexing and OOB semantics, independent of how any GPU driver
20/// handles buffer bounds.
21#[derive(Debug, Clone)]
22pub struct Buffer {
23    pub(crate) bytes: Arc<RwLock<Vec<u8>>>,
24    pub(crate) element: IrDataType,
25}
26
27impl Buffer {
28    /// Create a buffer from typed bytes.
29    #[must_use]
30    pub fn new(bytes: Vec<u8>, element: DataType) -> Self {
31        Self {
32            bytes: Arc::new(RwLock::new(bytes)),
33            element,
34        }
35    }
36
37    pub(crate) fn len(&self) -> u32 {
38        let bytes_guard = self.bytes.read().unwrap_or_else(|error| error.into_inner());
39        let count = if let Some(bits) = self.element.bit_width() {
40            bytes_guard
41                .len()
42                .checked_mul(8)
43                .map(|total_bits| total_bits / bits)
44                .unwrap_or(usize::MAX)
45        } else if let Some(stride) = self.element.size_bytes() {
46            if stride == 0 {
47                bytes_guard.len()
48            } else {
49                bytes_guard.len() / stride
50            }
51        } else {
52            bytes_guard.len()
53        };
54        match u32::try_from(count) {
55            Ok(value) => value,
56            Err(_) => {
57                debug_assert!(
58                    false,
59                    "Buffer::len overflowed u32::MAX for byte_len={}; element={:?}. \
60                     Fix: split or downsize the buffer so per-element indexing remains representable.",
61                    bytes_guard.len(),
62                    self.element
63                );
64                u32::MAX
65            }
66        }
67    }
68
69    pub(crate) fn byte_len(&self) -> usize {
70        self.bytes
71            .read()
72            .unwrap_or_else(|error| error.into_inner())
73            .len()
74    }
75
76    pub(crate) fn element(&self) -> &IrDataType {
77        &self.element
78    }
79
80    pub(crate) fn zero_fill(&self) {
81        self.bytes
82            .write()
83            .unwrap_or_else(|error| error.into_inner())
84            .fill(0);
85    }
86
87    pub(crate) fn into_bytes(self) -> Vec<u8> {
88        std::sync::Arc::try_unwrap(self.bytes)
89            .map(|rw| rw.into_inner().unwrap_or_else(|error| error.into_inner()))
90            .unwrap_or_else(|a| a.read().unwrap_or_else(|error| error.into_inner()).clone())
91    }
92
93    /// Consume this buffer and return its contents as a Value.
94    #[must_use]
95    pub fn to_value(self) -> crate::value::Value {
96        crate::value::Value::from(self.into_bytes())
97    }
98}
99
100pub(crate) fn load(buffer: &Buffer, index: u32) -> Value {
101    let bytes_guard = buffer
102        .bytes
103        .read()
104        .unwrap_or_else(|error| error.into_inner());
105    let stride = buffer.element.min_bytes();
106    let ty = ir_to_conform_type(buffer.element.clone());
107    if matches!(buffer.element, IrDataType::Bytes) {
108        let offset = index as usize;
109        if offset > bytes_guard.len() {
110            return Value::from(Vec::new());
111        }
112        return Value::from(&bytes_guard[offset..]);
113    }
114    let Some(offset) = byte_offset(index, stride) else {
115        return Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new()));
116    };
117    if stride == 0 || offset + stride > bytes_guard.len() {
118        return Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new()));
119    }
120    read_element(ty.clone(), &bytes_guard[offset..offset + stride])
121        .unwrap_or_else(|_| Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new())))
122}
123
124pub(crate) fn store(buffer: &mut Buffer, index: u32, value: &Value) {
125    let mut bytes_guard = buffer
126        .bytes
127        .write()
128        .unwrap_or_else(|error| error.into_inner());
129    let stride = buffer.element.min_bytes();
130    if matches!(buffer.element, IrDataType::Bytes) {
131        let offset = index as usize;
132        if offset >= bytes_guard.len() {
133            return;
134        }
135        let bytes = value.to_bytes();
136        let available = bytes_guard.len() - offset;
137        let write_len = bytes.len().min(available);
138        bytes_guard[offset..offset + write_len].copy_from_slice(&bytes[..write_len]);
139        return;
140    }
141    let Some(offset) = byte_offset(index, stride) else {
142        return;
143    };
144    if stride == 0 || offset + stride > bytes_guard.len() {
145        return;
146    }
147    write_element(
148        buffer.element.clone(),
149        &mut bytes_guard[offset..offset + stride],
150        value,
151    );
152}
153
154pub(crate) fn atomic_load(buffer: &Buffer, index: u32) -> Option<u32> {
155    let bytes_guard = buffer
156        .bytes
157        .read()
158        .unwrap_or_else(|error| error.into_inner());
159    let stride = buffer.element.min_bytes().max(4);
160    let offset = byte_offset(index, stride)?;
161    if offset + 4 > bytes_guard.len() {
162        None
163    } else {
164        Some(read_u32(&bytes_guard[offset..offset + 4]))
165    }
166}
167
168pub(crate) fn atomic_store(buffer: &mut Buffer, index: u32, value: u32) {
169    let mut bytes_guard = buffer
170        .bytes
171        .write()
172        .unwrap_or_else(|error| error.into_inner());
173    let stride = buffer.element.min_bytes().max(4);
174    let Some(offset) = byte_offset(index, stride) else {
175        return;
176    };
177    if offset + 4 <= bytes_guard.len() {
178        write_u32(&mut bytes_guard[offset..offset + 4], value);
179    }
180}
181
182fn byte_offset(index: u32, stride: usize) -> Option<usize> {
183    (index as usize).checked_mul(stride)
184}
185
186fn write_element(element: IrDataType, target: &mut [u8], value: &Value) {
187    match element {
188        IrDataType::U32 => {
189            value.write_bytes_width_into(target);
190        }
191        IrDataType::I32 => {
192            value.write_bytes_width_into(target);
193        }
194        IrDataType::Bool => {
195            value.write_bytes_width_into(target);
196        }
197        IrDataType::U64 => {
198            value.write_bytes_width_into(target);
199        }
200        IrDataType::F32 => {
201            // Value::Float carries an f64; the GPU buffer is four bytes
202            // of f32, so narrow via `as f32` before writing. Dropping the
203            // upper four bytes of `v.to_le_bytes()` (what the default
204            // to_bytes_width path does) would mangle the f32 bit pattern.
205            let v = match value {
206                Value::Float(v) => *v as f32,
207                Value::U32(v) => f32::from_bits(*v),
208                _ => 0.0,
209            };
210            let v = crate::execution::typed_ops::canonical_f32(v);
211            target.copy_from_slice(&v.to_le_bytes());
212        }
213        IrDataType::Bytes | IrDataType::Vec2U32 | IrDataType::Vec4U32 => {
214            value.write_bytes_width_into(target);
215        }
216        _ => {
217            value.write_bytes_width_into(target);
218        }
219    }
220}
221
222fn read_element(ty: DataType, bytes: &[u8]) -> Result<Value, String> {
223    Value::from_element_bytes(ty, bytes)
224}
225
226fn read_u32(bytes: &[u8]) -> u32 {
227    u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
228}
229
230fn write_u32(bytes: &mut [u8], value: u32) {
231    bytes.copy_from_slice(&value.to_le_bytes());
232}
233
234fn ir_to_conform_type(ty: IrDataType) -> DataType {
235    match ty {
236        IrDataType::U32 => DataType::U32,
237        IrDataType::I32 => DataType::I32,
238        IrDataType::U64 => DataType::U64,
239        IrDataType::F32 => DataType::F32,
240        IrDataType::F64 => DataType::F64,
241        IrDataType::Vec2U32 => DataType::Vec2U32,
242        IrDataType::Vec4U32 => DataType::Vec4U32,
243        IrDataType::Bool => DataType::U32,
244        IrDataType::Bytes => DataType::Bytes,
245        other => other,
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    fn f32_bits(value: Value) -> u32 {
254        match value {
255            Value::Float(value) => (value as f32).to_bits(),
256            other => {
257                let bytes = other.to_bytes();
258                u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
259            }
260        }
261    }
262
263    #[test]
264    fn f32_load_canonicalizes_subnormal_and_nan_payloads() {
265        let positive_subnormal = Buffer::new(1u32.to_le_bytes().to_vec(), DataType::F32);
266        assert_eq!(f32_bits(load(&positive_subnormal, 0)), 0x0000_0000);
267
268        let negative_subnormal = Buffer::new(0x8000_0001u32.to_le_bytes().to_vec(), DataType::F32);
269        assert_eq!(f32_bits(load(&negative_subnormal, 0)), 0x8000_0000);
270
271        let payload_nan = Buffer::new(0x7fa0_0001u32.to_le_bytes().to_vec(), DataType::F32);
272        assert_eq!(f32_bits(load(&payload_nan, 0)), 0x7fc0_0000);
273    }
274
275    #[test]
276    fn f32_store_canonicalizes_subnormal_and_nan_payloads() {
277        let mut subnormal = Buffer::new(vec![0; 4], DataType::F32);
278        store(
279            &mut subnormal,
280            0,
281            &Value::Float(f64::from(f32::from_bits(0x8000_0001))),
282        );
283        assert_eq!(f32_bits(subnormal.to_value()), 0x8000_0000);
284
285        let mut payload_nan = Buffer::new(vec![0; 4], DataType::F32);
286        store(&mut payload_nan, 0, &Value::U32(0x7fa0_0001));
287        assert_eq!(f32_bits(payload_nan.to_value()), 0x7fc0_0000);
288    }
289}