#[cfg(all(target_arch = "wasm32", feature = "web"))]
use crate::{MemoryOptimization, MobileConfig, MobileStats};
#[cfg(all(target_arch = "wasm32", feature = "web"))]
use serde::{Deserialize, Serialize};
#[cfg(all(target_arch = "wasm32", feature = "web"))]
use std::collections::HashMap;
use trustformers_core::error::{CoreError, Result};
#[cfg(all(target_arch = "wasm32", feature = "web"))]
use trustformers_core::Tensor;
#[cfg(all(target_arch = "wasm32", feature = "web"))]
use js_sys::{Array, ArrayBuffer, Promise, Uint8Array};
#[cfg(all(target_arch = "wasm32", feature = "web"))]
use wasm_bindgen::prelude::*;
#[cfg(all(target_arch = "wasm32", feature = "web"))]
use wasm_bindgen_futures::JsFuture;
#[cfg(all(target_arch = "wasm32", feature = "web"))]
use web_sys::{console, window, Navigator, Performance, WorkerGlobalScope};
#[cfg(all(target_arch = "wasm32", feature = "web"))]
#[wasm_bindgen]
pub struct WasmMobileEngine {
config: WasmMobileConfig,
model_weights: Option<HashMap<String, Tensor>>,
stats: WasmMobileStats,
browser_info: BrowserInfo,
worker_pool: Option<WorkerPool>,
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmMobileConfig {
pub use_web_workers: bool,
pub num_workers: usize,
pub use_webgl: bool,
pub use_webgpu: bool,
pub enable_simd: bool,
pub memory_optimization: MemoryOptimization,
pub max_memory_mb: usize,
pub enable_streaming: bool,
pub batch_size: usize,
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BrowserInfo {
pub user_agent: String,
pub memory_mb: Option<usize>,
pub hardware_concurrency: usize,
pub has_webgl: bool,
pub has_webgl2: bool,
pub has_webgpu: bool,
pub has_simd: bool,
pub has_web_workers: bool,
pub has_service_workers: bool,
pub is_mobile: bool,
pub has_touch: bool,
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmMobileStats {
pub total_inferences: usize,
pub avg_inference_time_ms: f32,
pub memory_usage_mb: usize,
pub worker_utilization: f32,
pub gpu_utilization: f32,
pub model_load_time_ms: f32,
pub wasm_compilation_time_ms: f32,
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
struct WorkerPool {
workers: Vec<web_sys::Worker>,
task_queue: Vec<WorkerTask>,
available_workers: Vec<usize>,
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
struct WorkerTask {
input_data: Vec<u8>,
callback: Box<dyn FnOnce(Vec<u8>)>,
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
#[wasm_bindgen]
impl WasmMobileEngine {
#[wasm_bindgen(constructor)]
pub fn new(config_json: &str) -> Result<WasmMobileEngine, JsValue> {
console_error_panic_hook::set_once();
let config: WasmMobileConfig = serde_json::from_str(config_json)
.map_err(|e| JsValue::from_str(&format!("Config parse error: {}", e)))?;
let browser_info = Self::detect_browser_info()
.map_err(|e| JsValue::from_str(&format!("Browser detection error: {}", e)))?;
let stats = WasmMobileStats::new();
let mut engine = Self {
config,
model_weights: None,
stats,
browser_info,
worker_pool: None,
};
if engine.config.use_web_workers && engine.browser_info.has_web_workers {
engine
.init_worker_pool()
.map_err(|e| JsValue::from_str(&format!("Worker pool init error: {}", e)))?;
}
Ok(engine)
}
#[wasm_bindgen]
pub async fn load_model(&mut self, model_data: ArrayBuffer) -> Result<(), JsValue> {
let start_time = Self::get_performance_now();
let uint8_array = Uint8Array::new(&model_data);
let data_vec = uint8_array.to_vec();
console::log_1(&JsValue::from_str(&format!(
"Loading WASM model ({} bytes)",
data_vec.len()
)));
let weights = self
.parse_model_weights(&data_vec)
.map_err(|e| JsValue::from_str(&format!("Model parse error: {}", e)))?;
self.model_weights = Some(weights);
let load_time = Self::get_performance_now() - start_time;
self.stats.model_load_time_ms = load_time;
console::log_1(&JsValue::from_str(&format!(
"WASM model loaded in {:.2}ms",
load_time
)));
Ok(())
}
#[wasm_bindgen]
pub async fn inference(&mut self, input_data: ArrayBuffer) -> Result<ArrayBuffer, JsValue> {
if self.model_weights.is_none() {
return Err(JsValue::from_str("Model not loaded"));
}
let start_time = Self::get_performance_now();
let uint8_array = Uint8Array::new(&input_data);
let input_vec = uint8_array.to_vec();
let input_tensors = self
.parse_input_data(&input_vec)
.map_err(|e| JsValue::from_str(&format!("Input parse error: {}", e)))?;
let output_tensors = if self.config.use_web_workers && self.worker_pool.is_some() {
self.inference_with_workers(&input_tensors).await
} else {
self.inference_single_threaded(&input_tensors)
}
.map_err(|e| JsValue::from_str(&format!("Inference error: {}", e)))?;
let output_data = self
.serialize_output(&output_tensors)
.map_err(|e| JsValue::from_str(&format!("Output serialize error: {}", e)))?;
let inference_time = Self::get_performance_now() - start_time;
self.stats.update_inference(inference_time);
let output_array = Uint8Array::from(&output_data[..]);
Ok(output_array.buffer())
}
#[wasm_bindgen]
pub fn get_stats(&self) -> String {
serde_json::to_string(&self.stats).unwrap_or_default()
}
#[wasm_bindgen]
pub fn get_browser_info(&self) -> String {
serde_json::to_string(&self.browser_info).unwrap_or_default()
}
#[wasm_bindgen]
pub fn optimize_for_browser(&mut self) -> Result<(), JsValue> {
if !self.browser_info.has_web_workers {
self.config.use_web_workers = false;
self.config.num_workers = 0;
}
if !self.browser_info.has_webgl2 && !self.browser_info.has_webgpu {
self.config.use_webgl = false;
self.config.use_webgpu = false;
}
if !self.browser_info.has_simd {
self.config.enable_simd = false;
}
if self.browser_info.is_mobile {
self.config.max_memory_mb = self.config.max_memory_mb.min(512);
self.config.batch_size = self.config.batch_size.min(2);
self.config.memory_optimization = MemoryOptimization::Maximum;
}
if self.browser_info.hardware_concurrency <= 2 {
self.config.num_workers = 1;
}
console::log_1(&JsValue::from_str(&format!(
"Optimized WASM config for {} (mobile: {})",
self.browser_info.user_agent.split_whitespace().next().unwrap_or("Unknown"),
self.browser_info.is_mobile
)));
Ok(())
}
fn detect_browser_info() -> Result<BrowserInfo, Box<dyn std::error::Error>> {
let window = window().ok_or("No window object")?;
let navigator = window.navigator();
let user_agent = navigator.user_agent().unwrap_or_default();
let hardware_concurrency = navigator.hardware_concurrency() as usize;
let is_mobile = user_agent.to_lowercase().contains("mobile")
|| user_agent.to_lowercase().contains("android")
|| user_agent.to_lowercase().contains("iphone");
let has_web_workers = window.worker().is_ok();
let has_touch = window.navigator().max_touch_points() > 0;
let memory_mb = if is_mobile {
Some(2048) } else {
Some(8192) };
Ok(BrowserInfo {
user_agent,
memory_mb,
hardware_concurrency: hardware_concurrency.max(1),
has_webgl: true, has_webgl2: true, has_webgpu: false, has_simd: true, has_web_workers,
has_service_workers: true, is_mobile,
has_touch,
})
}
fn get_performance_now() -> f32 {
if let Some(window) = window() {
if let Ok(performance) = window.performance() {
return performance.now() as f32;
}
}
0.0
}
fn init_worker_pool(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let num_workers = self.config.num_workers.min(self.browser_info.hardware_concurrency);
let mut workers = Vec::with_capacity(num_workers);
for i in 0..num_workers {
console::log_1(&JsValue::from_str(&format!("Creating worker {}", i)));
}
self.worker_pool = Some(WorkerPool {
workers,
task_queue: Vec::new(),
available_workers: (0..num_workers).collect(),
});
Ok(())
}
fn parse_model_weights(
&self,
_data: &[u8],
) -> Result<HashMap<String, Tensor>, Box<dyn std::error::Error>> {
let mut weights = HashMap::new();
weights.insert("layer1".to_string(), Tensor::ones(&[10, 10])?);
Ok(weights)
}
fn parse_input_data(
&self,
_data: &[u8],
) -> Result<HashMap<String, Tensor>, Box<dyn std::error::Error>> {
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), Tensor::ones(&[1, 10])?);
Ok(inputs)
}
async fn inference_with_workers(
&mut self,
input: &HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>, Box<dyn std::error::Error>> {
self.inference_single_threaded(input)
}
fn inference_single_threaded(
&self,
input: &HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>, Box<dyn std::error::Error>> {
let mut output = HashMap::new();
if let Some(input_tensor) = input.get("input") {
let output_tensor = input_tensor.clone();
output.insert("output".to_string(), output_tensor);
}
Ok(output)
}
fn serialize_output(
&self,
_output: &HashMap<String, Tensor>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
Ok(vec![0u8; 32]) }
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
impl Default for WasmMobileConfig {
fn default() -> Self {
Self {
use_web_workers: true,
num_workers: 2,
use_webgl: true,
use_webgpu: false,
enable_simd: true,
memory_optimization: MemoryOptimization::Balanced,
max_memory_mb: 512,
enable_streaming: true,
batch_size: 1,
}
}
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
impl WasmMobileStats {
fn new() -> Self {
Self {
total_inferences: 0,
avg_inference_time_ms: 0.0,
memory_usage_mb: 0,
worker_utilization: 0.0,
gpu_utilization: 0.0,
model_load_time_ms: 0.0,
wasm_compilation_time_ms: 0.0,
}
}
fn update_inference(&mut self, inference_time_ms: f32) {
self.total_inferences += 1;
let alpha = 0.1;
if self.total_inferences == 1 {
self.avg_inference_time_ms = inference_time_ms;
} else {
self.avg_inference_time_ms =
alpha * inference_time_ms + (1.0 - alpha) * self.avg_inference_time_ms;
}
}
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
pub fn mobile_config_to_wasm(mobile_config: &MobileConfig) -> WasmMobileConfig {
WasmMobileConfig {
use_web_workers: mobile_config.enable_batching && mobile_config.num_threads > 1,
num_workers: mobile_config.num_threads.max(1).min(4),
use_webgl: true,
use_webgpu: false,
enable_simd: true,
memory_optimization: mobile_config.memory_optimization,
max_memory_mb: mobile_config.max_memory_mb,
enable_streaming: true,
batch_size: mobile_config.max_batch_size,
}
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
#[wasm_bindgen]
pub async fn create_mobile_engine(config_json: &str) -> Result<WasmMobileEngine, JsValue> {
let mut engine = WasmMobileEngine::new(config_json)?;
engine.optimize_for_browser()?;
Ok(engine)
}
#[cfg(all(target_arch = "wasm32", feature = "web"))]
#[wasm_bindgen]
pub fn get_browser_capabilities() -> String {
match WasmMobileEngine::detect_browser_info() {
Ok(info) => serde_json::to_string(&info).unwrap_or_default(),
Err(_) => "{}".to_string(),
}
}
#[cfg(not(all(target_arch = "wasm32", feature = "web")))]
pub struct WasmMobileEngine;
#[cfg(not(all(target_arch = "wasm32", feature = "web")))]
impl WasmMobileEngine {
pub fn new(_config: &str) -> Result<Self, Box<dyn std::error::Error>> {
Err("WebAssembly features only available when compiled to WASM".into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wasm_config_defaults() {
#[cfg(all(target_arch = "wasm32", feature = "web"))]
{
let config = WasmMobileConfig::default();
assert!(config.use_web_workers);
assert_eq!(config.num_workers, 2);
assert!(config.use_webgl);
assert!(!config.use_webgpu);
assert!(config.enable_simd);
assert_eq!(config.max_memory_mb, 512);
}
}
#[test]
fn test_mobile_to_wasm_config_conversion() {
#[cfg(all(target_arch = "wasm32", feature = "web"))]
{
let mobile_config = crate::MobileConfig {
memory_optimization: MemoryOptimization::Maximum,
max_memory_mb: 256,
num_threads: 4,
enable_batching: true,
max_batch_size: 2,
..Default::default()
};
let wasm_config = mobile_config_to_wasm(&mobile_config);
assert_eq!(wasm_config.memory_optimization, MemoryOptimization::Maximum);
assert_eq!(wasm_config.max_memory_mb, 256);
assert_eq!(wasm_config.num_workers, 4);
assert!(wasm_config.use_web_workers);
assert_eq!(wasm_config.batch_size, 2);
}
}
}