use anyhow::{Result, anyhow};
use ronn_core::{DataType, Tensor, TensorLayout};
use std::collections::HashMap;
#[derive(Debug)]
pub struct WasmBridge {
cache: IndexedDbCache,
typed_array_interface: TypedArrayInterface,
config: WasmBridgeConfig,
}
#[derive(Debug, Clone)]
pub struct WasmBridgeConfig {
pub enable_caching: bool,
pub max_cache_size: usize,
pub enable_web_workers: bool,
pub worker_count: usize,
pub cache_expiry_ms: u64,
}
impl Default for WasmBridgeConfig {
fn default() -> Self {
Self {
enable_caching: true,
max_cache_size: 128 * 1024 * 1024, enable_web_workers: true,
worker_count: navigator_hardware_concurrency().max(1),
cache_expiry_ms: 24 * 60 * 60 * 1000, }
}
}
fn navigator_hardware_concurrency() -> usize {
#[cfg(target_arch = "wasm32")]
{
4 }
#[cfg(not(target_arch = "wasm32"))]
{
num_cpus::get()
}
}
#[derive(Debug, Clone)]
pub struct TypedArrayInterface;
impl TypedArrayInterface {
pub fn tensor_to_typed_array(&self, tensor: &Tensor) -> Result<TypedArrayData> {
let data = tensor.to_vec()?;
match tensor.dtype() {
DataType::F32 => Ok(TypedArrayData::Float32(data)),
DataType::F16 => {
Ok(TypedArrayData::Float32(data))
}
DataType::U8 => {
let u8_data: Vec<u8> = data.iter().map(|&x| x as u8).collect();
Ok(TypedArrayData::Uint8(u8_data))
}
DataType::I8 => {
let i8_data: Vec<i8> = data.iter().map(|&x| x as i8).collect();
Ok(TypedArrayData::Int8(i8_data))
}
DataType::I32 => {
let i32_data: Vec<i32> = data.iter().map(|&x| x as i32).collect();
Ok(TypedArrayData::Int32(i32_data))
}
DataType::U32 => {
let u32_data: Vec<u32> = data.iter().map(|&x| x as u32).collect();
Ok(TypedArrayData::Uint32(u32_data))
}
DataType::Bool => {
let u8_data: Vec<u8> = data.iter().map(|&x| if x > 0.5 { 1 } else { 0 }).collect();
Ok(TypedArrayData::Uint8(u8_data))
}
_ => Err(anyhow!(
"Unsupported data type for TypedArray conversion: {:?}",
tensor.dtype()
)),
}
}
pub fn typed_array_to_tensor(
&self,
data: TypedArrayData,
shape: Vec<usize>,
dtype: DataType,
) -> Result<Tensor> {
let f32_data = match data {
TypedArrayData::Float32(data) => data,
TypedArrayData::Float64(data) => data.iter().map(|&x| x as f32).collect(),
TypedArrayData::Uint8(data) => data.iter().map(|&x| x as f32).collect(),
TypedArrayData::Int8(data) => data.iter().map(|&x| x as f32).collect(),
TypedArrayData::Uint32(data) => data.iter().map(|&x| x as f32).collect(),
TypedArrayData::Int32(data) => data.iter().map(|&x| x as f32).collect(),
};
Tensor::from_data(f32_data, shape, dtype, TensorLayout::RowMajor)
}
pub fn get_optimal_batch_size(&self, tensor_size: usize) -> usize {
let available_memory = self.estimate_available_memory();
let memory_per_tensor = tensor_size * std::mem::size_of::<f32>();
if memory_per_tensor == 0 {
return 1;
}
let max_memory_for_batch = available_memory / 4;
(max_memory_for_batch / memory_per_tensor).max(1).min(64) }
fn estimate_available_memory(&self) -> usize {
#[cfg(target_arch = "wasm32")]
{
512 * 1024 * 1024 }
#[cfg(not(target_arch = "wasm32"))]
{
1024 * 1024 * 1024 }
}
}
#[derive(Debug, Clone)]
pub enum TypedArrayData {
Float32(Vec<f32>),
Float64(Vec<f64>),
Uint8(Vec<u8>),
Int8(Vec<i8>),
Uint32(Vec<u32>),
Int32(Vec<i32>),
}
#[derive(Debug)]
pub struct IndexedDbCache {
memory_cache: HashMap<String, CacheEntry>,
max_size: usize,
current_size: usize,
}
#[derive(Debug, Clone)]
struct CacheEntry {
data: Vec<u8>,
timestamp: u64,
access_count: u64,
size: usize,
}
impl IndexedDbCache {
pub fn new(max_size: usize) -> Self {
Self {
memory_cache: HashMap::new(),
max_size,
current_size: 0,
}
}
pub async fn store(&mut self, key: &str, data: &[u8]) -> Result<()> {
let entry = CacheEntry {
data: data.to_vec(),
timestamp: current_timestamp_ms(),
access_count: 0,
size: data.len(),
};
while self.current_size + entry.size > self.max_size && !self.memory_cache.is_empty() {
self.evict_lru_entry();
}
if entry.size <= self.max_size {
self.current_size += entry.size;
self.memory_cache.insert(key.to_string(), entry);
}
Ok(())
}
pub async fn retrieve(&mut self, key: &str) -> Option<Vec<u8>> {
if let Some(entry) = self.memory_cache.get_mut(key) {
let current_time = current_timestamp_ms();
if current_time - entry.timestamp > 24 * 60 * 60 * 1000 {
return None;
}
entry.access_count += 1;
Some(entry.data.clone())
} else {
None
}
}
pub async fn clear(&mut self) -> Result<()> {
self.memory_cache.clear();
self.current_size = 0;
Ok(())
}
pub fn get_stats(&self) -> CacheStats {
CacheStats {
entry_count: self.memory_cache.len(),
total_size: self.current_size,
max_size: self.max_size,
hit_rate: 0.0, }
}
fn evict_lru_entry(&mut self) {
let lru_key = self
.memory_cache
.iter()
.min_by_key(|(_, entry)| entry.access_count)
.map(|(key, _)| key.clone());
if let Some(key) = lru_key {
if let Some(entry) = self.memory_cache.remove(&key) {
self.current_size -= entry.size;
}
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub entry_count: usize,
pub total_size: usize,
pub max_size: usize,
pub hit_rate: f32,
}
fn current_timestamp_ms() -> u64 {
#[cfg(target_arch = "wasm32")]
{
0 }
#[cfg(not(target_arch = "wasm32"))]
{
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
}
impl WasmBridge {
pub fn new() -> Self {
Self::with_config(WasmBridgeConfig::default())
}
pub fn with_config(config: WasmBridgeConfig) -> Self {
let cache = IndexedDbCache::new(config.max_cache_size);
let typed_array_interface = TypedArrayInterface;
Self {
cache,
typed_array_interface,
config,
}
}
pub fn export_tensor(&self, tensor: &Tensor) -> Result<TensorExport> {
let typed_array = self.typed_array_interface.tensor_to_typed_array(tensor)?;
Ok(TensorExport {
data: typed_array,
shape: tensor.shape().to_vec(),
dtype: format!("{:?}", tensor.dtype()),
})
}
pub fn import_tensor(&self, export: TensorImport) -> Result<Tensor> {
let dtype = match export.dtype.as_str() {
"F32" => DataType::F32,
"F16" => DataType::F16,
"U8" => DataType::U8,
"I8" => DataType::I8,
"I32" => DataType::I32,
"U32" => DataType::U32,
"Bool" => DataType::Bool,
_ => return Err(anyhow!("Unknown data type: {}", export.dtype)),
};
self.typed_array_interface
.typed_array_to_tensor(export.data, export.shape, dtype)
}
pub async fn cache_model_data(&mut self, model_id: &str, data: &[u8]) -> Result<()> {
if self.config.enable_caching {
self.cache.store(model_id, data).await?;
}
Ok(())
}
pub async fn get_cached_model_data(&mut self, model_id: &str) -> Option<Vec<u8>> {
if self.config.enable_caching {
self.cache.retrieve(model_id).await
} else {
None
}
}
pub fn get_cache_stats(&self) -> CacheStats {
self.cache.get_stats()
}
pub async fn initialize_workers(&self) -> Result<WorkerPool> {
if !self.config.enable_web_workers {
return Ok(WorkerPool::new(0));
}
Ok(WorkerPool::new(self.config.worker_count))
}
}
impl Default for WasmBridge {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TensorExport {
pub data: TypedArrayData,
pub shape: Vec<usize>,
pub dtype: String,
}
#[derive(Debug, Clone)]
pub struct TensorImport {
pub data: TypedArrayData,
pub shape: Vec<usize>,
pub dtype: String,
}
#[derive(Debug)]
pub struct WorkerPool {
worker_count: usize,
available_workers: Vec<bool>,
}
impl WorkerPool {
pub fn new(worker_count: usize) -> Self {
Self {
worker_count,
available_workers: vec![true; worker_count],
}
}
pub fn available_count(&self) -> usize {
self.available_workers
.iter()
.filter(|&&available| available)
.count()
}
pub fn reserve_worker(&mut self) -> Option<usize> {
for (i, available) in self.available_workers.iter_mut().enumerate() {
if *available {
*available = false;
return Some(i);
}
}
None
}
pub fn release_worker(&mut self, worker_id: usize) {
if worker_id < self.worker_count {
self.available_workers[worker_id] = true;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wasm_bridge_creation() {
let bridge = WasmBridge::new();
assert!(bridge.config.enable_caching);
assert!(bridge.config.enable_web_workers);
}
#[test]
fn test_tensor_export_import() -> Result<()> {
let bridge = WasmBridge::new();
let original = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
DataType::F32,
TensorLayout::RowMajor,
)?;
let exported = bridge.export_tensor(&original)?;
let imported_data = TensorImport {
data: exported.data,
shape: exported.shape,
dtype: exported.dtype,
};
let imported = bridge.import_tensor(imported_data)?;
assert_eq!(original.shape(), imported.shape());
assert_eq!(original.to_vec().unwrap(), imported.to_vec().unwrap());
Ok(())
}
#[test]
fn test_typed_array_interface() -> Result<()> {
let interface = TypedArrayInterface;
let tensor = Tensor::from_data(
vec![1.0, -2.0, 3.5],
vec![3],
DataType::F32,
TensorLayout::RowMajor,
)?;
let typed_array = interface.tensor_to_typed_array(&tensor)?;
match typed_array {
TypedArrayData::Float32(data) => {
assert_eq!(data, vec![1.0, -2.0, 3.5]);
}
_ => panic!("Expected Float32 array"),
}
Ok(())
}
#[tokio::test]
async fn test_indexeddb_cache() -> Result<()> {
let mut cache = IndexedDbCache::new(1024);
let test_data = vec![1, 2, 3, 4, 5];
cache.store("test_key", &test_data).await?;
let retrieved = cache.retrieve("test_key").await;
assert_eq!(retrieved, Some(test_data));
let stats = cache.get_stats();
assert_eq!(stats.entry_count, 1);
assert_eq!(stats.total_size, 5);
Ok(())
}
#[test]
fn test_optimal_batch_size() {
let interface = TypedArrayInterface;
let batch_size = interface.get_optimal_batch_size(1000);
assert!(batch_size > 0);
assert!(batch_size <= 64);
}
#[tokio::test]
async fn test_worker_pool() {
let mut pool = WorkerPool::new(4);
assert_eq!(pool.available_count(), 4);
let worker1 = pool.reserve_worker();
assert_eq!(worker1, Some(0));
assert_eq!(pool.available_count(), 3);
let worker2 = pool.reserve_worker();
assert_eq!(worker2, Some(1));
assert_eq!(pool.available_count(), 2);
pool.release_worker(0);
assert_eq!(pool.available_count(), 3);
}
}