use crate::system::{HasSystemState, SystemHostFunctions, SystemState};
use crate::wasi::WasiContext;
use mielin_hal::capabilities::HardwareCapabilities;
use mielin_tensor::{Tensor, TensorRuntime};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use wasmtime::{Caller, Linker};
pub struct HostState {
tensor_runtime: Arc<TensorRuntime>,
tensors: Arc<Mutex<TensorStore>>,
system_state: SystemState,
wasi_context: Option<Arc<WasiContext>>,
}
impl HostState {
pub fn new(capabilities: HardwareCapabilities) -> Self {
Self {
tensor_runtime: Arc::new(TensorRuntime::new(capabilities)),
tensors: Arc::new(Mutex::new(TensorStore::new())),
system_state: SystemState::new(),
wasi_context: None,
}
}
pub fn with_seed(capabilities: HardwareCapabilities, seed: u64) -> Self {
Self {
tensor_runtime: Arc::new(TensorRuntime::new(capabilities)),
tensors: Arc::new(Mutex::new(TensorStore::new())),
system_state: SystemState::with_seed(seed),
wasi_context: None,
}
}
pub fn with_wasi(capabilities: HardwareCapabilities, wasi: WasiContext) -> Self {
Self {
tensor_runtime: Arc::new(TensorRuntime::new(capabilities)),
tensors: Arc::new(Mutex::new(TensorStore::new())),
system_state: SystemState::new(),
wasi_context: Some(Arc::new(wasi)),
}
}
pub fn with_wasi_and_seed(
capabilities: HardwareCapabilities,
wasi: WasiContext,
seed: u64,
) -> Self {
Self {
tensor_runtime: Arc::new(TensorRuntime::new(capabilities)),
tensors: Arc::new(Mutex::new(TensorStore::new())),
system_state: SystemState::with_seed(seed),
wasi_context: Some(Arc::new(wasi)),
}
}
pub fn tensor_runtime(&self) -> &TensorRuntime {
&self.tensor_runtime
}
pub fn tensors(&self) -> &Arc<Mutex<TensorStore>> {
&self.tensors
}
pub fn system_state(&self) -> &SystemState {
&self.system_state
}
pub fn wasi_context(&self) -> Option<&Arc<WasiContext>> {
self.wasi_context.as_ref()
}
}
impl HasSystemState for HostState {
fn system_state(&self) -> &SystemState {
&self.system_state
}
}
pub struct TensorStore {
tensors: HashMap<u32, Tensor<f32>>,
next_id: u32,
}
impl TensorStore {
pub fn new() -> Self {
Self {
tensors: HashMap::new(),
next_id: 1,
}
}
pub fn allocate(&mut self, tensor: Tensor<f32>) -> u32 {
let id = self.next_id;
self.next_id += 1;
self.tensors.insert(id, tensor);
id
}
pub fn get(&self, id: u32) -> Option<&Tensor<f32>> {
self.tensors.get(&id)
}
pub fn remove(&mut self, id: u32) -> Option<Tensor<f32>> {
self.tensors.remove(&id)
}
pub fn clear(&mut self) {
self.tensors.clear();
}
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
pub fn len(&self) -> usize {
self.tensors.len()
}
}
impl Default for TensorStore {
fn default() -> Self {
Self::new()
}
}
pub struct HostFunctions;
impl HostFunctions {
pub fn register(linker: &mut Linker<HostState>) -> anyhow::Result<()> {
SystemHostFunctions::register(linker)?;
linker.func_wrap(
"mielin",
"tensor_supports_sve2",
|caller: Caller<'_, HostState>| caller.data().tensor_runtime().supports_sve2() as i32,
)?;
linker.func_wrap(
"mielin",
"tensor_supports_neon",
|caller: Caller<'_, HostState>| caller.data().tensor_runtime().supports_neon() as i32,
)?;
linker.func_wrap(
"mielin",
"tensor_supports_avx2",
|caller: Caller<'_, HostState>| caller.data().tensor_runtime().supports_avx2() as i32,
)?;
linker.func_wrap(
"mielin",
"tensor_zeros",
|mut caller: Caller<'_, HostState>, shape_ptr: i32, shape_len: i32| -> i32 {
let memory = match caller.get_export("memory") {
Some(wasmtime::Extern::Memory(mem)) => mem,
_ => return 0,
};
let shape_offset = shape_ptr as usize;
let shape_size = (shape_len as usize) * 4;
let data = memory.data(&caller);
if shape_offset + shape_size > data.len() {
return 0;
}
let shape: Vec<usize> = data[shape_offset..shape_offset + shape_size]
.chunks_exact(4)
.map(|chunk| {
u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as usize
})
.collect();
let tensor = Tensor::zeros(shape);
let mut store = caller
.data()
.tensors()
.lock()
.expect("Tensor store lock poisoned");
store.allocate(tensor) as i32
},
)?;
linker.func_wrap(
"mielin",
"tensor_ones",
|mut caller: Caller<'_, HostState>, shape_ptr: i32, shape_len: i32| -> i32 {
let memory = match caller.get_export("memory") {
Some(wasmtime::Extern::Memory(mem)) => mem,
_ => return 0,
};
let shape_offset = shape_ptr as usize;
let shape_size = (shape_len as usize) * 4;
let data = memory.data(&caller);
if shape_offset + shape_size > data.len() {
return 0;
}
let shape: Vec<usize> = data[shape_offset..shape_offset + shape_size]
.chunks_exact(4)
.map(|chunk| {
u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as usize
})
.collect();
let tensor = Tensor::ones(shape);
let mut store = caller
.data()
.tensors()
.lock()
.expect("Tensor store lock poisoned");
store.allocate(tensor) as i32
},
)?;
linker.func_wrap(
"mielin",
"tensor_dot",
|mut caller: Caller<'_, HostState>, a_id: i32, b_id: i32, result_ptr: i32| -> i32 {
let store_guard = caller
.data()
.tensors()
.lock()
.expect("Tensor store lock poisoned");
let a = match store_guard.get(a_id as u32) {
Some(t) => t,
None => return 0,
};
let b = match store_guard.get(b_id as u32) {
Some(t) => t,
None => return 0,
};
let result = match caller.data().tensor_runtime().ops().dot(a, b) {
Some(r) => r,
None => return 0,
};
drop(store_guard);
let memory = match caller.get_export("memory") {
Some(wasmtime::Extern::Memory(mem)) => mem,
_ => return 0,
};
let offset = result_ptr as usize;
let data = memory.data_mut(&mut caller);
if offset + 4 > data.len() {
return 0;
}
data[offset..offset + 4].copy_from_slice(&result.to_le_bytes());
1
},
)?;
linker.func_wrap(
"mielin",
"tensor_add",
|caller: Caller<'_, HostState>, a_id: i32, b_id: i32| -> i32 {
let mut store_guard = caller
.data()
.tensors()
.lock()
.expect("Tensor store lock poisoned");
let a = match store_guard.get(a_id as u32) {
Some(t) => t.clone(),
None => return 0,
};
let b = match store_guard.get(b_id as u32) {
Some(t) => t.clone(),
None => return 0,
};
let result = match caller.data().tensor_runtime().ops().add(&a, &b) {
Some(r) => r,
None => return 0,
};
store_guard.allocate(result) as i32
},
)?;
linker.func_wrap(
"mielin",
"tensor_matmul",
|caller: Caller<'_, HostState>, a_id: i32, b_id: i32| -> i32 {
let mut store_guard = caller
.data()
.tensors()
.lock()
.expect("Tensor store lock poisoned");
let a = match store_guard.get(a_id as u32) {
Some(t) => t.clone(),
None => return 0,
};
let b = match store_guard.get(b_id as u32) {
Some(t) => t.clone(),
None => return 0,
};
let result = match caller.data().tensor_runtime().ops().matmul(&a, &b) {
Some(r) => r,
None => return 0,
};
store_guard.allocate(result) as i32
},
)?;
linker.func_wrap(
"mielin",
"tensor_get_shape",
|mut caller: Caller<'_, HostState>,
tensor_id: i32,
shape_ptr: i32,
max_dims: i32|
-> i32 {
let store_guard = caller
.data()
.tensors()
.lock()
.expect("Tensor store lock poisoned");
let tensor = match store_guard.get(tensor_id as u32) {
Some(t) => t,
None => return 0,
};
let shape: Vec<usize> = tensor.shape().to_vec();
let dims_to_copy = shape.len().min(max_dims as usize);
drop(store_guard);
let memory = match caller.get_export("memory") {
Some(wasmtime::Extern::Memory(mem)) => mem,
_ => return 0,
};
let offset = shape_ptr as usize;
let data = memory.data_mut(&mut caller);
if offset + dims_to_copy * 4 > data.len() {
return 0;
}
for (i, &dim) in shape.iter().take(dims_to_copy).enumerate() {
let bytes = (dim as u32).to_le_bytes();
data[offset + i * 4..offset + (i + 1) * 4].copy_from_slice(&bytes);
}
dims_to_copy as i32
},
)?;
linker.func_wrap(
"mielin",
"tensor_free",
|caller: Caller<'_, HostState>, tensor_id: i32| -> i32 {
let mut store_guard = caller
.data()
.tensors()
.lock()
.expect("Tensor store lock poisoned");
if store_guard.remove(tensor_id as u32).is_some() {
1
} else {
0
}
},
)?;
Ok(())
}
}