use crate::error::WasmError;
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen(getter_with_clone)]
pub struct TransferableArray {
pub data: Vec<f64>,
pub shape: Vec<usize>,
pub dtype: String,
}
#[wasm_bindgen]
impl TransferableArray {
#[wasm_bindgen(constructor)]
pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Result<TransferableArray, JsValue> {
let expected: usize = shape.iter().product();
if expected != data.len() {
return Err(WasmError::InvalidParameter(format!(
"TransferableArray: data length {} does not match shape product {}",
data.len(),
expected,
))
.into());
}
Ok(TransferableArray {
data,
shape,
dtype: "f64".to_string(),
})
}
pub fn numel(&self) -> usize {
self.data.len()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn to_json(&self) -> Result<String, JsValue> {
serde_json::to_string(self)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn from_json(json: &str) -> Result<TransferableArray, JsValue> {
serde_json::from_str(json)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MatrixOpKind {
Multiply,
Transpose,
Inverse,
Eigenvalues,
Svd,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StatsOpKind {
Descriptive,
KsTest,
TTest,
Correlation,
Histogram {
bins: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FftOpKind {
Forward,
Inverse,
Stft {
window_size: usize,
hop_size: usize,
},
Psd,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op_type", content = "payload", rename_all = "snake_case")]
pub enum WorkerMessage {
MatrixOp {
task_id: String,
op: MatrixOpKind,
a: TransferableArray,
b: Option<TransferableArray>,
},
StatsOp {
task_id: String,
op: StatsOpKind,
data: TransferableArray,
data2: Option<TransferableArray>,
},
FftOp {
task_id: String,
op: FftOpKind,
signal: TransferableArray,
is_complex: bool,
},
Cancel {
task_id: String,
},
Result {
task_id: String,
success: bool,
data: Option<TransferableArray>,
metadata: Option<String>,
error: Option<String>,
},
Progress {
task_id: String,
percent: f64,
status: String,
},
}
impl WorkerMessage {
pub fn serialize_message(&self) -> Result<String, JsValue> {
serde_json::to_string(self)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
}
#[wasm_bindgen]
pub fn serialize_worker_message(json: &str) -> Result<String, JsValue> {
let msg: WorkerMessage = serde_json::from_str(json)
.map_err(|e| WasmError::SerializationError(e.to_string()))?;
serde_json::to_string(&msg)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
#[wasm_bindgen]
pub fn serialize_for_worker(data: &[f64], shape: Vec<usize>) -> Result<String, JsValue> {
let arr = TransferableArray::new(data.to_vec(), shape)?;
arr.to_json()
}
#[wasm_bindgen]
pub fn deserialize_from_worker(json: &str) -> Result<TransferableArray, JsValue> {
let arr = TransferableArray::from_json(json)?;
let expected: usize = arr.shape.iter().product();
if expected != arr.data.len() {
return Err(WasmError::InvalidParameter(format!(
"deserialize_from_worker: shape product {} ≠ data length {}",
expected,
arr.data.len(),
))
.into());
}
Ok(arr)
}
#[wasm_bindgen]
pub fn parse_worker_message(json: &str) -> Result<JsValue, JsValue> {
let msg: WorkerMessage = serde_json::from_str(json)
.map_err(|e| WasmError::SerializationError(e.to_string()))?;
serde_wasm_bindgen::to_value(&msg)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen(getter_with_clone)]
pub struct SharedBuffer {
pub byte_length: usize,
pub element_count: usize,
pub element_size: usize,
pub shape: Vec<usize>,
pub label: String,
}
#[wasm_bindgen]
impl SharedBuffer {
#[wasm_bindgen(constructor)]
pub fn new(
shape: Vec<usize>,
element_size: usize,
label: String,
) -> Result<SharedBuffer, JsValue> {
if shape.is_empty() {
return Err(
WasmError::InvalidParameter("SharedBuffer: shape must not be empty".to_string())
.into(),
);
}
if element_size != 4 && element_size != 8 {
return Err(WasmError::InvalidParameter(format!(
"SharedBuffer: element_size must be 4 (f32) or 8 (f64), got {}",
element_size
))
.into());
}
let element_count: usize = shape.iter().product();
let byte_length = 32 + element_count * element_size;
Ok(SharedBuffer {
byte_length,
element_count,
element_size,
shape,
label,
})
}
pub fn data_offset(&self) -> usize {
32
}
pub fn to_json(&self) -> Result<String, JsValue> {
serde_json::to_string(self)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn from_json(json: &str) -> Result<SharedBuffer, JsValue> {
serde_json::from_str(json)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn build_header(&self) -> Vec<u8> {
let mut header = vec![0u8; 32];
header[0] = 0x53;
header[1] = 0x43;
header[2] = 0x52;
header[3] = 0x53;
let state: u32 = 0;
header[4..8].copy_from_slice(&state.to_le_bytes());
let count = self.element_count as u32;
header[8..12].copy_from_slice(&count.to_le_bytes());
let esize = self.element_size as u32;
header[12..16].copy_from_slice(&esize.to_le_bytes());
let shape_len = self.shape.len().min(3) as u32;
header[16..20].copy_from_slice(&shape_len.to_le_bytes());
for (i, &dim) in self.shape.iter().take(3).enumerate() {
let d = dim as u32;
let off = 20 + i * 4;
header[off..off + 4].copy_from_slice(&d.to_le_bytes());
}
header
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WorkerState {
Idle,
Busy,
Terminated,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerEntry {
pub index: usize,
pub state: WorkerState,
pub current_task_id: Option<String>,
pub tasks_completed: u64,
pub tasks_failed: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen(getter_with_clone)]
pub struct WorkerPoolConfig {
pub worker_count: usize,
pub worker_script_url: String,
pub max_queue_depth: usize,
pub task_timeout_ms: u64,
}
#[wasm_bindgen]
impl WorkerPoolConfig {
#[wasm_bindgen(constructor)]
pub fn new(worker_count: usize, worker_script_url: String) -> Result<WorkerPoolConfig, JsValue> {
if worker_script_url.is_empty() {
return Err(WasmError::InvalidParameter(
"WorkerPoolConfig: worker_script_url must not be empty".to_string(),
)
.into());
}
let worker_count = worker_count.clamp(1, 64);
Ok(WorkerPoolConfig {
worker_count,
worker_script_url,
max_queue_depth: 256,
task_timeout_ms: 30_000,
})
}
pub fn with_max_queue_depth(mut self, depth: usize) -> WorkerPoolConfig {
self.max_queue_depth = depth.max(1);
self
}
pub fn with_task_timeout_ms(mut self, ms: u64) -> WorkerPoolConfig {
self.task_timeout_ms = ms;
self
}
pub fn to_json(&self) -> Result<String, JsValue> {
serde_json::to_string(self)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
}
#[wasm_bindgen]
pub struct WorkerPool {
config: WorkerPoolConfig,
workers: Vec<WorkerEntry>,
pending_queue: Vec<String>, next_task_seq: u64,
}
#[wasm_bindgen]
impl WorkerPool {
#[wasm_bindgen(constructor)]
pub fn new(config: WorkerPoolConfig) -> WorkerPool {
let workers: Vec<WorkerEntry> = (0..config.worker_count)
.map(|i| WorkerEntry {
index: i,
state: WorkerState::Idle,
current_task_id: None,
tasks_completed: 0,
tasks_failed: 0,
})
.collect();
WorkerPool {
config,
workers,
pending_queue: Vec::new(),
next_task_seq: 0,
}
}
pub fn enqueue_task(&mut self, op_type: &str) -> String {
let task_id = format!("{}-{}", op_type, self.next_task_seq);
self.next_task_seq += 1;
self.pending_queue.push(task_id.clone());
task_id
}
pub fn try_dispatch(&mut self) -> Option<String> {
if self.pending_queue.is_empty() {
return None;
}
let idle_idx = self.workers.iter().position(|w| w.state == WorkerState::Idle)?;
let task_id = self.pending_queue.remove(0);
self.workers[idle_idx].state = WorkerState::Busy;
self.workers[idle_idx].current_task_id = Some(task_id.clone());
Some(task_id)
}
pub fn task_completed(&mut self, task_id: &str, success: bool) -> bool {
if let Some(worker) = self
.workers
.iter_mut()
.find(|w| w.current_task_id.as_deref() == Some(task_id))
{
if success {
worker.tasks_completed += 1;
} else {
worker.tasks_failed += 1;
}
worker.state = WorkerState::Idle;
worker.current_task_id = None;
return true;
}
false
}
pub fn stats(&self) -> Result<String, JsValue> {
let idle = self.workers.iter().filter(|w| w.state == WorkerState::Idle).count();
let busy = self.workers.iter().filter(|w| w.state == WorkerState::Busy).count();
let completed: u64 = self.workers.iter().map(|w| w.tasks_completed).sum();
let failed: u64 = self.workers.iter().map(|w| w.tasks_failed).sum();
let stats = serde_json::json!({
"worker_count": self.config.worker_count,
"idle": idle,
"busy": busy,
"pending_queue_depth": self.pending_queue.len(),
"total_completed": completed,
"total_failed": failed,
});
serde_json::to_string(&stats)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn worker_count(&self) -> usize {
self.config.worker_count
}
pub fn pending_count(&self) -> usize {
self.pending_queue.len()
}
pub fn to_init_script(&self) -> Result<String, JsValue> {
let script = format!(
r#"
// Auto-generated by WorkerPool::to_init_script
const POOL_CONFIG = {{
workerCount: {worker_count},
workerScriptUrl: {url_json},
maxQueueDepth: {max_queue},
taskTimeoutMs: {timeout_ms},
}};
const pool = Array.from({{ length: POOL_CONFIG.workerCount }}, () =>
new Worker(POOL_CONFIG.workerScriptUrl, {{ type: 'module' }})
);
"#,
worker_count = self.config.worker_count,
url_json = serde_json::to_string(&self.config.worker_script_url)
.map_err(|e| WasmError::SerializationError(e.to_string()))?,
max_queue = self.config.max_queue_depth,
timeout_ms = self.config.task_timeout_ms,
);
Ok(script)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transferable_array_roundtrip() {
let arr = TransferableArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])
.expect("construction ok");
let json = arr.to_json().expect("to_json ok");
let recovered = TransferableArray::from_json(&json).expect("from_json ok");
assert_eq!(recovered.data, arr.data);
assert_eq!(recovered.shape, vec![2, 2]);
}
#[test]
fn test_transferable_array_shape_mismatch() {
let result = TransferableArray::new(vec![1.0, 2.0, 3.0], vec![2, 2]);
assert!(result.is_err());
}
#[test]
fn test_serialize_for_worker() {
let data = vec![1.0_f64, 2.0, 3.0];
let json = serialize_for_worker(&data, vec![3]).expect("serialize ok");
let recovered = deserialize_from_worker(&json).expect("deserialize ok");
assert_eq!(recovered.data, data);
}
#[test]
fn test_shared_buffer_header() {
let buf = SharedBuffer::new(vec![4, 4], 4, "test".to_string())
.expect("SharedBuffer ok");
assert_eq!(buf.element_count, 16);
assert_eq!(buf.element_size, 4);
assert_eq!(buf.byte_length, 32 + 64);
let header = buf.build_header();
assert_eq!(&header[0..4], b"SCRS");
assert_eq!(&header[4..8], &0u32.to_le_bytes());
assert_eq!(&header[8..12], &16u32.to_le_bytes());
}
#[test]
fn test_shared_buffer_bad_element_size() {
let result = SharedBuffer::new(vec![4], 3, "bad".to_string());
assert!(result.is_err());
}
#[test]
fn test_worker_pool_dispatch() {
let config =
WorkerPoolConfig::new(2, "/worker.js".to_string()).expect("config ok");
let mut pool = WorkerPool::new(config);
let task1 = pool.enqueue_task("matrix_op");
let task2 = pool.enqueue_task("stats_op");
assert_eq!(pool.pending_count(), 2);
let dispatched = pool.try_dispatch().expect("dispatch ok");
assert_eq!(dispatched, task1);
assert_eq!(pool.pending_count(), 1);
let dispatched2 = pool.try_dispatch().expect("dispatch ok");
assert_eq!(dispatched2, task2);
assert_eq!(pool.pending_count(), 0);
assert!(pool.try_dispatch().is_none());
assert!(pool.task_completed(&task1, true));
let task3 = pool.enqueue_task("fft_op");
let d = pool.try_dispatch().expect("dispatch ok");
assert_eq!(d, task3);
}
#[test]
fn test_pool_init_script() {
let config =
WorkerPoolConfig::new(4, "https://example.com/worker.js".to_string())
.expect("config ok");
let pool = WorkerPool::new(config);
let script = pool.to_init_script().expect("script ok");
assert!(script.contains("workerCount: 4"));
assert!(script.contains("https://example.com/worker.js"));
}
#[test]
fn test_worker_pool_config_clamp() {
let config = WorkerPoolConfig::new(0, "/w.js".to_string()).expect("ok");
assert_eq!(config.worker_count, 1);
let config2 = WorkerPoolConfig::new(1000, "/w.js".to_string()).expect("ok");
assert_eq!(config2.worker_count, 64);
}
#[test]
fn test_worker_message_serialization() {
let arr = TransferableArray::new(vec![1.0, 2.0], vec![2]).expect("ok");
let msg = WorkerMessage::MatrixOp {
task_id: "task-0".to_string(),
op: MatrixOpKind::Transpose,
a: arr,
b: None,
};
let json = msg.serialize_message().expect("serialize ok");
assert!(json.contains("task-0"));
assert!(json.contains("transpose"));
}
}