1use vyre::ir::DataType as IrDataType;
10
11use crate::value::Value;
12use vyre::ir::DataType;
13
14use std::sync::{Arc, RwLock};
15
16#[derive(Debug, Clone)]
22pub struct Buffer {
23 pub(crate) bytes: Arc<RwLock<Vec<u8>>>,
24 pub(crate) element: IrDataType,
25}
26
27impl Buffer {
28 #[must_use]
30 pub fn new(bytes: Vec<u8>, element: DataType) -> Self {
31 Self {
32 bytes: Arc::new(RwLock::new(bytes)),
33 element,
34 }
35 }
36
37 pub(crate) fn len(&self) -> u32 {
38 let bytes_guard = self.bytes.read().unwrap_or_else(|error| error.into_inner());
39 let count = if let Some(bits) = self.element.bit_width() {
40 bytes_guard
41 .len()
42 .checked_mul(8)
43 .map(|total_bits| total_bits / bits)
44 .unwrap_or(usize::MAX)
45 } else if let Some(stride) = self.element.size_bytes() {
46 if stride == 0 {
47 bytes_guard.len()
48 } else {
49 bytes_guard.len() / stride
50 }
51 } else {
52 bytes_guard.len()
53 };
54 match u32::try_from(count) {
55 Ok(value) => value,
56 Err(_) => {
57 debug_assert!(
58 false,
59 "Buffer::len overflowed u32::MAX for byte_len={}; element={:?}. \
60 Fix: split or downsize the buffer so per-element indexing remains representable.",
61 bytes_guard.len(),
62 self.element
63 );
64 u32::MAX
65 }
66 }
67 }
68
69 pub(crate) fn byte_len(&self) -> usize {
70 self.bytes
71 .read()
72 .unwrap_or_else(|error| error.into_inner())
73 .len()
74 }
75
76 pub(crate) fn element(&self) -> &IrDataType {
77 &self.element
78 }
79
80 pub(crate) fn zero_fill(&self) {
81 self.bytes
82 .write()
83 .unwrap_or_else(|error| error.into_inner())
84 .fill(0);
85 }
86
87 pub(crate) fn into_bytes(self) -> Vec<u8> {
88 std::sync::Arc::try_unwrap(self.bytes)
89 .map(|rw| rw.into_inner().unwrap_or_else(|error| error.into_inner()))
90 .unwrap_or_else(|a| a.read().unwrap_or_else(|error| error.into_inner()).clone())
91 }
92
93 #[must_use]
95 pub fn to_value(self) -> crate::value::Value {
96 crate::value::Value::from(self.into_bytes())
97 }
98}
99
100pub(crate) fn load(buffer: &Buffer, index: u32) -> Value {
101 let bytes_guard = buffer
102 .bytes
103 .read()
104 .unwrap_or_else(|error| error.into_inner());
105 let stride = buffer.element.min_bytes();
106 let ty = ir_to_conform_type(buffer.element.clone());
107 if matches!(buffer.element, IrDataType::Bytes) {
108 let offset = index as usize;
109 if offset > bytes_guard.len() {
110 return Value::from(Vec::new());
111 }
112 return Value::from(&bytes_guard[offset..]);
113 }
114 let Some(offset) = byte_offset(index, stride) else {
115 return Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new()));
116 };
117 if stride == 0 || offset + stride > bytes_guard.len() {
118 return Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new()));
119 }
120 read_element(ty.clone(), &bytes_guard[offset..offset + stride])
121 .unwrap_or_else(|_| Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new())))
122}
123
124pub(crate) fn store(buffer: &mut Buffer, index: u32, value: &Value) {
125 let mut bytes_guard = buffer
126 .bytes
127 .write()
128 .unwrap_or_else(|error| error.into_inner());
129 let stride = buffer.element.min_bytes();
130 if matches!(buffer.element, IrDataType::Bytes) {
131 let offset = index as usize;
132 if offset >= bytes_guard.len() {
133 return;
134 }
135 let bytes = value.to_bytes();
136 let available = bytes_guard.len() - offset;
137 let write_len = bytes.len().min(available);
138 bytes_guard[offset..offset + write_len].copy_from_slice(&bytes[..write_len]);
139 return;
140 }
141 let Some(offset) = byte_offset(index, stride) else {
142 return;
143 };
144 if stride == 0 || offset + stride > bytes_guard.len() {
145 return;
146 }
147 write_element(
148 buffer.element.clone(),
149 &mut bytes_guard[offset..offset + stride],
150 value,
151 );
152}
153
154pub(crate) fn atomic_load(buffer: &Buffer, index: u32) -> Option<u32> {
155 let bytes_guard = buffer
156 .bytes
157 .read()
158 .unwrap_or_else(|error| error.into_inner());
159 let stride = buffer.element.min_bytes().max(4);
160 let offset = byte_offset(index, stride)?;
161 if offset + 4 > bytes_guard.len() {
162 None
163 } else {
164 Some(read_u32(&bytes_guard[offset..offset + 4]))
165 }
166}
167
168pub(crate) fn atomic_store(buffer: &mut Buffer, index: u32, value: u32) {
169 let mut bytes_guard = buffer
170 .bytes
171 .write()
172 .unwrap_or_else(|error| error.into_inner());
173 let stride = buffer.element.min_bytes().max(4);
174 let Some(offset) = byte_offset(index, stride) else {
175 return;
176 };
177 if offset + 4 <= bytes_guard.len() {
178 write_u32(&mut bytes_guard[offset..offset + 4], value);
179 }
180}
181
182fn byte_offset(index: u32, stride: usize) -> Option<usize> {
183 (index as usize).checked_mul(stride)
184}
185
186fn write_element(element: IrDataType, target: &mut [u8], value: &Value) {
187 match element {
188 IrDataType::U32 => {
189 value.write_bytes_width_into(target);
190 }
191 IrDataType::I32 => {
192 value.write_bytes_width_into(target);
193 }
194 IrDataType::Bool => {
195 value.write_bytes_width_into(target);
196 }
197 IrDataType::U64 => {
198 value.write_bytes_width_into(target);
199 }
200 IrDataType::F32 => {
201 let v = match value {
206 Value::Float(v) => *v as f32,
207 Value::U32(v) => f32::from_bits(*v),
208 _ => 0.0,
209 };
210 let v = crate::execution::typed_ops::canonical_f32(v);
211 target.copy_from_slice(&v.to_le_bytes());
212 }
213 IrDataType::Bytes | IrDataType::Vec2U32 | IrDataType::Vec4U32 => {
214 value.write_bytes_width_into(target);
215 }
216 _ => {
217 value.write_bytes_width_into(target);
218 }
219 }
220}
221
222fn read_element(ty: DataType, bytes: &[u8]) -> Result<Value, String> {
223 Value::from_element_bytes(ty, bytes)
224}
225
226fn read_u32(bytes: &[u8]) -> u32 {
227 u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
228}
229
230fn write_u32(bytes: &mut [u8], value: u32) {
231 bytes.copy_from_slice(&value.to_le_bytes());
232}
233
234fn ir_to_conform_type(ty: IrDataType) -> DataType {
235 match ty {
236 IrDataType::U32 => DataType::U32,
237 IrDataType::I32 => DataType::I32,
238 IrDataType::U64 => DataType::U64,
239 IrDataType::F32 => DataType::F32,
240 IrDataType::F64 => DataType::F64,
241 IrDataType::Vec2U32 => DataType::Vec2U32,
242 IrDataType::Vec4U32 => DataType::Vec4U32,
243 IrDataType::Bool => DataType::U32,
244 IrDataType::Bytes => DataType::Bytes,
245 other => other,
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 fn f32_bits(value: Value) -> u32 {
254 match value {
255 Value::Float(value) => (value as f32).to_bits(),
256 other => {
257 let bytes = other.to_bytes();
258 u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
259 }
260 }
261 }
262
263 #[test]
264 fn f32_load_canonicalizes_subnormal_and_nan_payloads() {
265 let positive_subnormal = Buffer::new(1u32.to_le_bytes().to_vec(), DataType::F32);
266 assert_eq!(f32_bits(load(&positive_subnormal, 0)), 0x0000_0000);
267
268 let negative_subnormal = Buffer::new(0x8000_0001u32.to_le_bytes().to_vec(), DataType::F32);
269 assert_eq!(f32_bits(load(&negative_subnormal, 0)), 0x8000_0000);
270
271 let payload_nan = Buffer::new(0x7fa0_0001u32.to_le_bytes().to_vec(), DataType::F32);
272 assert_eq!(f32_bits(load(&payload_nan, 0)), 0x7fc0_0000);
273 }
274
275 #[test]
276 fn f32_store_canonicalizes_subnormal_and_nan_payloads() {
277 let mut subnormal = Buffer::new(vec![0; 4], DataType::F32);
278 store(
279 &mut subnormal,
280 0,
281 &Value::Float(f64::from(f32::from_bits(0x8000_0001))),
282 );
283 assert_eq!(f32_bits(subnormal.to_value()), 0x8000_0000);
284
285 let mut payload_nan = Buffer::new(vec![0; 4], DataType::F32);
286 store(&mut payload_nan, 0, &Value::U32(0x7fa0_0001));
287 assert_eq!(f32_bits(payload_nan.to_value()), 0x7fc0_0000);
288 }
289}