use js_sys::{Atomics, Int32Array, JsString, SharedArrayBuffer};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct SharedWasmArray {
sab: SharedArrayBuffer,
view: Int32Array,
length: usize,
}
#[wasm_bindgen]
impl SharedWasmArray {
#[wasm_bindgen(constructor)]
pub fn new(length: usize) -> Result<SharedWasmArray, JsValue> {
if length == 0 {
return Err(JsValue::from_str(
"SharedWasmArray: length must be at least 1",
));
}
let byte_len = length
.checked_mul(4)
.ok_or_else(|| JsValue::from_str("SharedWasmArray: length overflow"))?;
let sab = SharedArrayBuffer::new(byte_len as u32);
let view = Int32Array::new(&sab);
Ok(SharedWasmArray { sab, view, length })
}
pub fn atomic_store_i32(&self, index: usize, value: i32) -> Result<(), JsValue> {
self.check_bounds(index)?;
Atomics::store(&self.view, index as u32, value)
.map_err(|e| JsValue::from_str(&format!("Atomics.store failed: {e:?}")))?;
Ok(())
}
pub fn atomic_load_i32(&self, index: usize) -> Result<i32, JsValue> {
self.check_bounds(index)?;
Atomics::load(&self.view, index as u32)
.map_err(|e| JsValue::from_str(&format!("Atomics.load failed: {e:?}")))
}
pub fn atomic_wait(
&self,
index: usize,
expected: i32,
timeout_ms: f64,
) -> Result<JsValue, JsValue> {
self.check_bounds(index)?;
let js_str: JsString =
Atomics::wait_with_timeout(&self.view, index as u32, expected, timeout_ms)
.map_err(|e| JsValue::from_str(&format!("Atomics.wait failed: {e:?}")))?;
Ok(js_str.into())
}
pub fn atomic_notify(&self, index: usize, count: u32) -> Result<u32, JsValue> {
self.check_bounds(index)?;
Atomics::notify_with_count(&self.view, index as u32, count)
.map_err(|e| JsValue::from_str(&format!("Atomics.notify failed: {e:?}")))
}
pub fn atomic_compare_exchange(
&self,
index: usize,
expected: i32,
replacement: i32,
) -> Result<i32, JsValue> {
self.check_bounds(index)?;
Atomics::compare_exchange(&self.view, index as u32, expected, replacement)
.map_err(|e| JsValue::from_str(&format!("Atomics.compareExchange failed: {e:?}")))
}
pub fn atomic_add(&self, index: usize, value: i32) -> Result<i32, JsValue> {
self.check_bounds(index)?;
Atomics::add(&self.view, index as u32, value)
.map_err(|e| JsValue::from_str(&format!("Atomics.add failed: {e:?}")))
}
pub fn buffer(&self) -> SharedArrayBuffer {
self.sab.clone()
}
pub fn length(&self) -> usize {
self.length
}
pub fn byte_length(&self) -> u32 {
SharedArrayBuffer::byte_length(&self.sab)
}
}
impl SharedWasmArray {
fn check_bounds(&self, index: usize) -> Result<(), JsValue> {
if index >= self.length {
Err(JsValue::from_str(&format!(
"SharedWasmArray: index {index} out of bounds (length {})",
self.length
)))
} else {
Ok(())
}
}
}
#[wasm_bindgen]
pub fn shared_array_buffer_available() -> bool {
let global = js_sys::global();
let key = JsValue::from_str("SharedArrayBuffer");
match js_sys::Reflect::get(&global, &key) {
Ok(v) => !v.is_undefined() && !v.is_null(),
Err(_) => false,
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
#[test]
fn check_bounds_out_of_range_error_message() {
let length = 8usize;
let index = 9usize;
let result: Result<(), String> = if index >= length {
Err(format!(
"SharedWasmArray: index {index} out of bounds (length {length})"
))
} else {
Ok(())
};
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(msg.contains("9"), "message: {msg}");
assert!(msg.contains("length 8"), "message: {msg}");
}
#[test]
fn check_bounds_in_range() {
let length = 8usize;
let index = 7usize;
let result: Result<(), String> = if index >= length {
Err("out of bounds".to_string())
} else {
Ok(())
};
assert!(result.is_ok());
}
#[test]
fn length_overflow_detection() {
let length = usize::MAX;
let overflow = length.checked_mul(4);
assert!(overflow.is_none(), "Expected overflow");
}
#[test]
fn zero_length_rejected() {
let length = 0usize;
let result: Result<(), String> = if length == 0 {
Err("length must be at least 1".to_string())
} else {
Ok(())
};
assert!(result.is_err());
}
}