#[cfg(feature = "wasm")]
use crate::{Result, TensorError};
#[cfg(feature = "wasm")]
use std::collections::HashMap;
#[cfg(feature = "wasm")]
use std::hash::Hash;
#[cfg(feature = "wasm")]
use js_sys::*;
#[cfg(feature = "wasm")]
use wasm_bindgen::prelude::*;
#[cfg(feature = "wasm")]
use web_sys::*;
#[cfg(feature = "wasm")]
use super::compression::{
WasmCompressedData, WasmQuantizedData, WasmRunLengthData, WasmSparseData,
};
#[cfg(feature = "wasm")]
pub struct WasmOptimizedTensor<T> {
data: WasmCompressedData<T>,
shape: Vec<usize>,
layout_flags: WasmLayoutFlags,
}
#[cfg(feature = "wasm")]
#[derive(Debug, Clone, Copy)]
pub struct WasmLayoutFlags {
pub memory_mapped: bool,
pub simd_enabled: bool,
pub shared_buffer: bool,
pub streaming: bool,
}
#[cfg(feature = "wasm")]
impl<T> WasmOptimizedTensor<T>
where
T: Clone + Default + PartialEq,
{
pub fn new(data: Vec<T>, shape: Vec<usize>) -> Result<Self> {
let layout_flags = WasmLayoutFlags {
memory_mapped: false,
simd_enabled: Self::detect_simd_support(),
shared_buffer: Self::detect_shared_buffer_support(),
streaming: false,
};
let compressed_data = Self::choose_optimal_storage(&data)?;
Ok(WasmOptimizedTensor {
data: compressed_data,
shape,
layout_flags,
})
}
fn detect_simd_support() -> bool {
#[cfg(target_arch = "wasm32")]
{
true
}
#[cfg(not(target_arch = "wasm32"))]
{
false
}
}
fn detect_shared_buffer_support() -> bool {
#[cfg(feature = "wasm")]
{
js_sys::eval("typeof SharedArrayBuffer !== 'undefined'")
.map(|val| val.as_bool().unwrap_or(false))
.unwrap_or(false)
}
#[cfg(not(feature = "wasm"))]
{
false
}
}
fn choose_optimal_storage(data: &[T]) -> Result<WasmCompressedData<T>> {
let data_size = data.len();
let unique_values = Self::count_unique_values(data);
let sparsity = Self::calculate_sparsity(data);
if sparsity > 0.9 && data_size > 1000 {
Ok(WasmCompressedData::Sparse(Self::create_sparse_data(data)?))
} else if unique_values < data_size / 10 {
Ok(WasmCompressedData::RunLength(Self::create_run_length_data(
data,
)))
} else if data_size > 10000 {
Ok(WasmCompressedData::Quantized(Self::create_quantized_data(
data,
)?))
} else {
Ok(WasmCompressedData::Dense(data.to_vec()))
}
}
fn count_unique_values(data: &[T]) -> usize {
let mut unique_items = Vec::new();
for item in data {
if !unique_items.contains(&item) {
unique_items.push(item);
}
}
unique_items.len()
}
fn calculate_sparsity(data: &[T]) -> f64 {
let zero_count = data.iter().filter(|&x| *x == T::default()).count();
zero_count as f64 / data.len() as f64
}
fn create_sparse_data(data: &[T]) -> Result<WasmSparseData<T>> {
let mut values = Vec::new();
let mut col_indices = Vec::new();
let mut row_ptr = vec![0];
let width = (data.len() as f64).sqrt() as usize; let height = (data.len() + width - 1) / width;
for i in 0..height {
for j in 0..width {
let idx = i * width + j;
if idx < data.len() && data[idx] != T::default() {
values.push(data[idx].clone());
col_indices.push(j as u32);
}
}
row_ptr.push(values.len() as u32);
}
let nnz = values.len();
Ok(WasmSparseData {
values,
row_ptr,
col_indices,
nnz,
})
}
fn create_run_length_data(data: &[T]) -> WasmRunLengthData<T> {
let mut values = Vec::new();
let mut lengths = Vec::new();
if data.is_empty() {
return WasmRunLengthData { values, lengths };
}
let mut current_value = &data[0];
let mut current_length = 1u32;
for item in data.iter().skip(1) {
if item == current_value {
current_length += 1;
} else {
values.push(current_value.clone());
lengths.push(current_length);
current_value = item;
current_length = 1;
}
}
values.push(current_value.clone());
lengths.push(current_length);
WasmRunLengthData { values, lengths }
}
fn create_quantized_data(data: &[T]) -> Result<WasmQuantizedData> {
let quantized_values = vec![0u8; data.len()];
Ok(WasmQuantizedData {
quantized_values,
scale: 1.0,
zero_point: 0,
bit_width: 8,
})
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn memory_usage(&self) -> usize {
match &self.data {
WasmCompressedData::Dense(data) => data.len() * std::mem::size_of::<T>(),
WasmCompressedData::Sparse(sparse) => {
sparse.values.len() * std::mem::size_of::<T>()
+ sparse.col_indices.len() * 4
+ sparse.row_ptr.len() * 4
}
WasmCompressedData::Quantized(quant) => quant.quantized_values.len() + 16,
WasmCompressedData::RunLength(rle) => {
rle.values.len() * std::mem::size_of::<T>() + rle.lengths.len() * 4
}
}
}
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub struct WasmTensorOperations {
memory_manager: super::memory::WasmMemoryManager,
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
impl WasmTensorOperations {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self {
memory_manager: super::memory::WasmMemoryManager::new(16 * 1024 * 1024), }
}
#[wasm_bindgen]
pub fn matmul_optimized(
&mut self,
a: &[f32],
b: &[f32],
m: usize,
n: usize,
k: usize,
) -> Vec<f32> {
let mut result = vec![0.0f32; m * n];
const BLOCK_SIZE: usize = 32;
for ii in (0..m).step_by(BLOCK_SIZE) {
for jj in (0..n).step_by(BLOCK_SIZE) {
for kk in (0..k).step_by(BLOCK_SIZE) {
for i in ii..std::cmp::min(ii + BLOCK_SIZE, m) {
for j in jj..std::cmp::min(jj + BLOCK_SIZE, n) {
for k_idx in kk..std::cmp::min(kk + BLOCK_SIZE, k) {
result[i * n + j] += a[i * k + k_idx] * b[k_idx * n + j];
}
}
}
}
}
}
result
}
#[wasm_bindgen]
pub fn get_memory_usage(&self) -> f64 {
self.memory_manager.total_allocated as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "wasm")]
#[ignore = "WASM tests require WASM target - cannot run on native"]
fn test_wasm_tensor_optimization() {
let data = vec![1.0f32, 0.0, 0.0, 2.0, 0.0];
let shape = vec![5];
let result = WasmOptimizedTensor::new(data, shape);
assert!(result.is_ok());
let tensor = result.expect("test: operation should succeed");
assert_eq!(tensor.shape(), &[5]);
}
}