#[cfg(target_arch = "wasm32")]
use js_sys::Float32Array;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
use std::collections::HashMap;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
pub struct WasmSession {
inner: crate::session::Session,
}
#[cfg(target_arch = "wasm32")]
type WasmResult<T> = Result<T, JsValue>;
#[cfg(not(target_arch = "wasm32"))]
type WasmResult<T> = Result<T, String>;
#[cfg(target_arch = "wasm32")]
fn to_err(msg: String) -> JsValue {
JsValue::from_str(&msg)
}
#[cfg(not(target_arch = "wasm32"))]
fn to_err(msg: String) -> String {
msg
}
fn parse_shapes(flat: &[i32]) -> Vec<Vec<usize>> {
let mut shapes: Vec<Vec<usize>> = Vec::new();
let mut current: Vec<usize> = Vec::new();
for &dim in flat {
if dim < 0 {
shapes.push(std::mem::take(&mut current));
} else {
current.push(dim as usize);
}
}
if !current.is_empty() {
shapes.push(current);
}
shapes
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
impl WasmSession {
#[cfg_attr(target_arch = "wasm32", wasm_bindgen(constructor))]
pub fn new(model_bytes: &[u8]) -> WasmResult<WasmSession> {
let session =
crate::session::Session::from_bytes(model_bytes).map_err(|e| to_err(format!("{e}")))?;
Ok(Self { inner: session })
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen(getter))]
pub fn input_names(&self) -> String {
self.inner.input_names().join(",")
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen(getter))]
pub fn output_names(&self) -> String {
self.inner.output_names().join(",")
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen(getter))]
pub fn node_count(&self) -> usize {
self.inner.model_info().node_count
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen(getter))]
pub fn parameter_count(&self) -> usize {
self.inner.model_info().parameter_count
}
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
impl WasmSession {
#[wasm_bindgen]
pub fn run_one(
&self,
input_name: &str,
data: &[f32],
shape: &[usize],
) -> Result<Float32Array, JsValue> {
let tensor = oxionnx_core::Tensor::new(data.to_vec(), shape.to_vec());
let mut inputs = HashMap::new();
inputs.insert(input_name, tensor);
let outputs = self
.inner
.run(&inputs)
.map_err(|e| JsValue::from_str(&format!("{e}")))?;
let first_output = outputs
.values()
.next()
.ok_or_else(|| JsValue::from_str("No outputs produced"))?;
let result = Float32Array::new_with_length(first_output.data.len() as u32);
result.copy_from(&first_output.data);
Ok(result)
}
#[wasm_bindgen]
pub fn run_multi(
&self,
input_names: &str,
input_data: &[f32],
input_shapes_flat: &[i32],
) -> Result<Float32Array, JsValue> {
let names: Vec<&str> = input_names.split(',').collect();
let shapes = parse_shapes(input_shapes_flat);
if names.len() != shapes.len() {
return Err(JsValue::from_str(&format!(
"Name count {} != shape count {}",
names.len(),
shapes.len()
)));
}
let mut inputs = HashMap::new();
let mut data_offset = 0;
let mut tensors = Vec::with_capacity(names.len());
for shape in &shapes {
let size: usize = shape.iter().product();
if data_offset + size > input_data.len() {
return Err(JsValue::from_str(
"Input data too short for declared shapes",
));
}
let tensor_data = input_data[data_offset..data_offset + size].to_vec();
tensors.push(oxionnx_core::Tensor::new(tensor_data, shape.clone()));
data_offset += size;
}
for (i, name) in names.iter().enumerate() {
inputs.insert(*name, tensors[i].clone());
}
let outputs = self
.inner
.run(&inputs)
.map_err(|e| JsValue::from_str(&format!("{e}")))?;
let mut all_data: Vec<f32> = Vec::new();
for name in self.inner.output_names() {
if let Some(t) = outputs.get(name.as_str()) {
all_data.extend_from_slice(&t.data);
}
}
let result = Float32Array::new_with_length(all_data.len() as u32);
result.copy_from(&all_data);
Ok(result)
}
}
#[cfg(not(target_arch = "wasm32"))]
impl WasmSession {
pub fn run_one_native(
&self,
input_name: &str,
data: &[f32],
shape: &[usize],
) -> Result<Vec<f32>, String> {
let tensor = oxionnx_core::Tensor::new(data.to_vec(), shape.to_vec());
let mut inputs = HashMap::new();
inputs.insert(input_name, tensor);
let outputs = self.inner.run(&inputs).map_err(|e| format!("{e}"))?;
let first_output = outputs
.values()
.next()
.ok_or_else(|| "No outputs produced".to_string())?;
Ok(first_output.data.clone())
}
pub fn run_multi_native(
&self,
input_names: &str,
input_data: &[f32],
input_shapes_flat: &[i32],
) -> Result<Vec<f32>, String> {
let names: Vec<&str> = input_names.split(',').collect();
let shapes = parse_shapes(input_shapes_flat);
if names.len() != shapes.len() {
return Err(format!(
"Name count {} != shape count {}",
names.len(),
shapes.len()
));
}
let mut inputs = HashMap::new();
let mut data_offset = 0;
let mut tensors = Vec::with_capacity(names.len());
for shape in &shapes {
let size: usize = shape.iter().product();
if data_offset + size > input_data.len() {
return Err("Input data too short for declared shapes".to_string());
}
let tensor_data = input_data[data_offset..data_offset + size].to_vec();
tensors.push(oxionnx_core::Tensor::new(tensor_data, shape.clone()));
data_offset += size;
}
for (i, name) in names.iter().enumerate() {
inputs.insert(*name, tensors[i].clone());
}
let outputs = self.inner.run(&inputs).map_err(|e| format!("{e}"))?;
let mut all_data: Vec<f32> = Vec::new();
for name in self.inner.output_names() {
if let Some(t) = outputs.get(name.as_str()) {
all_data.extend_from_slice(&t.data);
}
}
Ok(all_data)
}
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen(start)]
pub fn wasm_init() {
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_shapes_basic() {
let flat: Vec<i32> = vec![1, 3, -1, 1, 5, -1, 2, 2];
let shapes = parse_shapes(&flat);
assert_eq!(shapes, vec![vec![1, 3], vec![1, 5], vec![2, 2]]);
}
#[test]
fn test_parse_shapes_single() {
let flat: Vec<i32> = vec![2, 3, 4];
let shapes = parse_shapes(&flat);
assert_eq!(shapes, vec![vec![2, 3, 4]]);
}
#[test]
fn test_parse_shapes_empty() {
let flat: Vec<i32> = vec![];
let shapes = parse_shapes(&flat);
assert!(shapes.is_empty());
}
#[test]
fn test_parse_shapes_leading_separator() {
let flat: Vec<i32> = vec![-1, 3, 4];
let shapes = parse_shapes(&flat);
assert_eq!(shapes, vec![vec![], vec![3, 4]]);
}
#[test]
fn test_parse_shapes_trailing_separator() {
let flat: Vec<i32> = vec![1, 2, -1];
let shapes = parse_shapes(&flat);
assert_eq!(shapes, vec![vec![1, 2]]);
}
}