use vyre::ir::DataType as IrDataType;
use crate::value::Value;
use vyre::ir::DataType;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct Buffer {
pub(crate) bytes: Arc<RwLock<Vec<u8>>>,
pub(crate) element: IrDataType,
}
impl Buffer {
#[must_use]
pub fn new(bytes: Vec<u8>, element: DataType) -> Self {
Self {
bytes: Arc::new(RwLock::new(bytes)),
element,
}
}
pub(crate) fn len(&self) -> u32 {
let bytes_guard = self.bytes.read().unwrap_or_else(|error| error.into_inner());
let count = if let Some(bits) = self.element.bit_width() {
bytes_guard
.len()
.checked_mul(8)
.map(|total_bits| total_bits / bits)
.unwrap_or(usize::MAX)
} else if let Some(stride) = self.element.size_bytes() {
if stride == 0 {
bytes_guard.len()
} else {
bytes_guard.len() / stride
}
} else {
bytes_guard.len()
};
match u32::try_from(count) {
Ok(value) => value,
Err(_) => {
debug_assert!(
false,
"Buffer::len overflowed u32::MAX for byte_len={}; element={:?}. \
Fix: split or downsize the buffer so per-element indexing remains representable.",
bytes_guard.len(),
self.element
);
u32::MAX
}
}
}
pub(crate) fn byte_len(&self) -> usize {
self.bytes
.read()
.unwrap_or_else(|error| error.into_inner())
.len()
}
pub(crate) fn element(&self) -> &IrDataType {
&self.element
}
pub(crate) fn zero_fill(&self) {
self.bytes
.write()
.unwrap_or_else(|error| error.into_inner())
.fill(0);
}
pub(crate) fn into_bytes(self) -> Vec<u8> {
std::sync::Arc::try_unwrap(self.bytes)
.map(|rw| rw.into_inner().unwrap_or_else(|error| error.into_inner()))
.unwrap_or_else(|a| a.read().unwrap_or_else(|error| error.into_inner()).clone())
}
#[must_use]
pub fn to_value(self) -> crate::value::Value {
crate::value::Value::from(self.into_bytes())
}
}
pub(crate) fn load(buffer: &Buffer, index: u32) -> Value {
let bytes_guard = buffer
.bytes
.read()
.unwrap_or_else(|error| error.into_inner());
let stride = buffer.element.min_bytes();
let ty = ir_to_conform_type(buffer.element.clone());
if matches!(buffer.element, IrDataType::Bytes) {
let offset = index as usize;
if offset > bytes_guard.len() {
return Value::from(Vec::new());
}
return Value::from(&bytes_guard[offset..]);
}
let Some(offset) = byte_offset(index, stride) else {
return Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new()));
};
if stride == 0 || offset + stride > bytes_guard.len() {
return Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new()));
}
read_element(ty.clone(), &bytes_guard[offset..offset + stride])
.unwrap_or_else(|_| Value::try_zero_for(ty).unwrap_or_else(|| Value::from(Vec::new())))
}
pub(crate) fn store(buffer: &mut Buffer, index: u32, value: &Value) {
let mut bytes_guard = buffer
.bytes
.write()
.unwrap_or_else(|error| error.into_inner());
let stride = buffer.element.min_bytes();
if matches!(buffer.element, IrDataType::Bytes) {
let offset = index as usize;
if offset >= bytes_guard.len() {
return;
}
let bytes = value.to_bytes();
let available = bytes_guard.len() - offset;
let write_len = bytes.len().min(available);
bytes_guard[offset..offset + write_len].copy_from_slice(&bytes[..write_len]);
return;
}
let Some(offset) = byte_offset(index, stride) else {
return;
};
if stride == 0 || offset + stride > bytes_guard.len() {
return;
}
write_element(
buffer.element.clone(),
&mut bytes_guard[offset..offset + stride],
value,
);
}
pub(crate) fn atomic_load(buffer: &Buffer, index: u32) -> Option<u32> {
let bytes_guard = buffer
.bytes
.read()
.unwrap_or_else(|error| error.into_inner());
let stride = buffer.element.min_bytes().max(4);
let offset = byte_offset(index, stride)?;
if offset + 4 > bytes_guard.len() {
None
} else {
Some(read_u32(&bytes_guard[offset..offset + 4]))
}
}
pub(crate) fn atomic_store(buffer: &mut Buffer, index: u32, value: u32) {
let mut bytes_guard = buffer
.bytes
.write()
.unwrap_or_else(|error| error.into_inner());
let stride = buffer.element.min_bytes().max(4);
let Some(offset) = byte_offset(index, stride) else {
return;
};
if offset + 4 <= bytes_guard.len() {
write_u32(&mut bytes_guard[offset..offset + 4], value);
}
}
fn byte_offset(index: u32, stride: usize) -> Option<usize> {
(index as usize).checked_mul(stride)
}
fn write_element(element: IrDataType, target: &mut [u8], value: &Value) {
match element {
IrDataType::U32 => {
value.write_bytes_width_into(target);
}
IrDataType::I32 => {
value.write_bytes_width_into(target);
}
IrDataType::Bool => {
value.write_bytes_width_into(target);
}
IrDataType::U64 => {
value.write_bytes_width_into(target);
}
IrDataType::F32 => {
let v = match value {
Value::Float(v) => *v as f32,
Value::U32(v) => f32::from_bits(*v),
_ => 0.0,
};
let v = crate::execution::typed_ops::canonical_f32(v);
target.copy_from_slice(&v.to_le_bytes());
}
IrDataType::Bytes | IrDataType::Vec2U32 | IrDataType::Vec4U32 => {
value.write_bytes_width_into(target);
}
_ => {
value.write_bytes_width_into(target);
}
}
}
fn read_element(ty: DataType, bytes: &[u8]) -> Result<Value, String> {
Value::from_element_bytes(ty, bytes)
}
fn read_u32(bytes: &[u8]) -> u32 {
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
fn write_u32(bytes: &mut [u8], value: u32) {
bytes.copy_from_slice(&value.to_le_bytes());
}
fn ir_to_conform_type(ty: IrDataType) -> DataType {
match ty {
IrDataType::U32 => DataType::U32,
IrDataType::I32 => DataType::I32,
IrDataType::U64 => DataType::U64,
IrDataType::F32 => DataType::F32,
IrDataType::F64 => DataType::F64,
IrDataType::Vec2U32 => DataType::Vec2U32,
IrDataType::Vec4U32 => DataType::Vec4U32,
IrDataType::Bool => DataType::U32,
IrDataType::Bytes => DataType::Bytes,
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn f32_bits(value: Value) -> u32 {
match value {
Value::Float(value) => (value as f32).to_bits(),
other => {
let bytes = other.to_bytes();
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
}
}
#[test]
fn f32_load_canonicalizes_subnormal_and_nan_payloads() {
let positive_subnormal = Buffer::new(1u32.to_le_bytes().to_vec(), DataType::F32);
assert_eq!(f32_bits(load(&positive_subnormal, 0)), 0x0000_0000);
let negative_subnormal = Buffer::new(0x8000_0001u32.to_le_bytes().to_vec(), DataType::F32);
assert_eq!(f32_bits(load(&negative_subnormal, 0)), 0x8000_0000);
let payload_nan = Buffer::new(0x7fa0_0001u32.to_le_bytes().to_vec(), DataType::F32);
assert_eq!(f32_bits(load(&payload_nan, 0)), 0x7fc0_0000);
}
#[test]
fn f32_store_canonicalizes_subnormal_and_nan_payloads() {
let mut subnormal = Buffer::new(vec![0; 4], DataType::F32);
store(
&mut subnormal,
0,
&Value::Float(f64::from(f32::from_bits(0x8000_0001))),
);
assert_eq!(f32_bits(subnormal.to_value()), 0x8000_0000);
let mut payload_nan = Buffer::new(vec![0; 4], DataType::F32);
store(&mut payload_nan, 0, &Value::U32(0x7fa0_0001));
assert_eq!(f32_bits(payload_nan.to_value()), 0x7fc0_0000);
}
}