use std::fmt::Write as FmtWrite;
use std::fs;
use std::path::Path;
use forgellm_frontend::ir::*;
#[derive(Debug, thiserror::Error)]
pub enum WasmCodegenError {
#[error("graph has no model config")]
MissingConfig,
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("format error: {0}")]
Fmt(#[from] std::fmt::Error),
}
pub fn generate_wasm_project(
graph: &Graph,
output_dir: &Path,
model_name: &str,
) -> Result<(), WasmCodegenError> {
let config = graph
.config
.as_ref()
.ok_or(WasmCodegenError::MissingConfig)?;
let src_dir = output_dir.join("src");
let pkg_dir = output_dir.join("pkg");
fs::create_dir_all(&src_dir)?;
fs::create_dir_all(&pkg_dir)?;
fs::write(
output_dir.join("Cargo.toml"),
generate_cargo_toml(model_name),
)?;
let lib_code = generate_lib_rs(graph, config)?;
fs::write(src_dir.join("lib.rs"), lib_code)?;
fs::write(pkg_dir.join("model.js"), generate_model_js())?;
Ok(())
}
fn sanitize_name(name: &str) -> String {
name.to_lowercase()
.replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
.trim_matches('-')
.to_string()
}
fn generate_cargo_toml(model_name: &str) -> String {
let sanitized = sanitize_name(model_name);
format!(
r#"[package]
name = "{sanitized}"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
wasm-bindgen = "0.2"
js-sys = "0.3"
getrandom = {{ version = "0.2", features = ["js"] }}
console_error_panic_hook = "0.1"
[profile.release]
opt-level = 3
lto = "fat"
codegen-units = 1
panic = "abort"
"#
)
}
fn generate_model_js() -> String {
r#"// model.js - JS glue for ForgeLLM WASM model
export async function loadModel(wasmUrl, weightsUrl) {
const { default: init, WasmModel } = await import(wasmUrl);
await init();
const weightsResp = await fetch(weightsUrl);
const weightsBytes = new Uint8Array(await weightsResp.arrayBuffer());
return new WasmModel(weightsBytes);
}
"#
.to_string()
}
fn generate_lib_rs(graph: &Graph, config: &ModelConfig) -> Result<String, WasmCodegenError> {
let mut code = String::with_capacity(32 * 1024);
emit_lib_header(&mut code, config)?;
emit_wasm_kernels(&mut code)?;
emit_wasm_specialized_matmul_functions(&mut code, config)?;
emit_wasm_forward_function(&mut code, graph, config)?;
emit_wasm_bindgen_exports(&mut code, config)?;
Ok(code)
}
fn emit_lib_header(code: &mut String, config: &ModelConfig) -> Result<(), WasmCodegenError> {
writeln!(code, "//! Auto-generated by ForgeLLM WASM codegen.")?;
writeln!(
code,
"//! Model: {} ({} layers, hidden={})",
config.architecture, config.num_layers, config.hidden_size
)?;
writeln!(code, "//!")?;
writeln!(
code,
"//! Targets wasm32-unknown-unknown with optional SIMD128 acceleration."
)?;
writeln!(code)?;
writeln!(code, "#![allow(clippy::excessive_precision)]")?;
writeln!(
code,
"#![allow(dead_code, unused_imports, unused_assignments)]"
)?;
writeln!(code)?;
writeln!(code, "use wasm_bindgen::prelude::*;")?;
writeln!(code)?;
writeln!(code, "// Model constants")?;
writeln!(
code,
"pub const HIDDEN_SIZE: usize = {};",
config.hidden_size
)?;
writeln!(
code,
"pub const INTERMEDIATE_SIZE: usize = {};",
config.intermediate_size
)?;
writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
writeln!(
code,
"pub const NUM_HEADS: usize = {};",
config.num_attention_heads
)?;
writeln!(
code,
"pub const NUM_KV_HEADS: usize = {};",
config.num_kv_heads
)?;
writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
let effective_seq_len = config.max_seq_len.min(4096);
writeln!(
code,
"pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
effective_seq_len, config.max_seq_len
)?;
writeln!(
code,
"pub const RMS_NORM_EPS: f32 = {:e};",
config.rms_norm_eps
)?;
writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
writeln!(code)?;
Ok(())
}
fn emit_wasm_kernels(code: &mut String) -> Result<(), WasmCodegenError> {
code.push_str(
r#"
// --- WASM SIMD128 dot product ---
#[cfg(target_feature = "simd128")]
#[inline]
fn dot_f32(a: &[f32], b: &[f32], len: usize) -> f32 {
use std::arch::wasm32::*;
unsafe {
let mut acc = f32x4_splat(0.0);
let chunks = len / 4;
for i in 0..chunks {
let base = i * 4;
let va = v128_load(a.as_ptr().add(base) as *const v128);
let vb = v128_load(b.as_ptr().add(base) as *const v128);
acc = f32x4_add(acc, f32x4_mul(va, vb));
}
let s = f32x4_extract_lane::<0>(acc) + f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc) + f32x4_extract_lane::<3>(acc);
let mut r = s;
for i in (chunks * 4)..len { r += *a.get_unchecked(i) * *b.get_unchecked(i); }
r
}
}
#[cfg(not(target_feature = "simd128"))]
#[inline]
fn dot_f32(a: &[f32], b: &[f32], len: usize) -> f32 {
let mut sum = 0.0f32;
for i in 0..len { sum += a[i] * b[i]; }
sum
}
#[inline]
pub fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
let n = input.len();
let sum_sq = dot_f32(input, input, n);
let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
for i in 0..n { output[i] = input[i] * inv_rms * weight[i]; }
}
#[inline]
pub fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
for i in 0..m {
let row = &input[i*k..(i+1)*k];
for j in 0..n {
output[i*n+j] = dot_f32(row, &weight[j*k..(j+1)*k], k);
}
}
}
#[inline]
pub fn silu(output: &mut [f32], input: &[f32]) {
for (o, &x) in output.iter_mut().zip(input.iter()) { *o = x / (1.0 + (-x).exp()); }
}
#[inline]
pub fn silu_mul(output: &mut [f32], gate: &[f32], up: &[f32]) {
for i in 0..gate.len() {
let x = gate[i];
output[i] = (x / (1.0 + (-x).exp())) * up[i];
}
}
#[inline]
pub fn residual_add(a: &mut [f32], b: &[f32]) {
for i in 0..a.len() { a[i] += b[i]; }
}
#[inline]
pub fn softmax(values: &mut [f32]) {
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in values.iter_mut() { *v = (*v - max_val).exp(); sum += *v; }
let inv = if sum > 0.0 { 1.0 / sum } else { 0.0 };
for v in values.iter_mut() { *v *= inv; }
}
#[inline]
pub fn rope_freqs(head_dim: usize, theta: f32) -> Vec<f32> {
(0..head_dim / 2).map(|i| 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32)).collect()
}
#[inline]
pub fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, freqs: &[f32]) {
let half = head_dim / 2;
let mut cos_table = vec![0.0f32; half];
let mut sin_table = vec![0.0f32; half];
for i in 0..half {
let angle = pos as f32 * freqs[i];
let (s, c) = angle.sin_cos();
cos_table[i] = c;
sin_table[i] = s;
}
for h in 0..num_heads {
let off = h * head_dim;
for i in 0..half {
let (x0, x1) = (data[off + 2*i], data[off + 2*i + 1]);
data[off + 2*i] = x0 * cos_table[i] - x1 * sin_table[i];
data[off + 2*i + 1] = x0 * sin_table[i] + x1 * cos_table[i];
}
}
}
#[inline]
pub fn attention(
output: &mut [f32], q: &[f32], k_cache: &[f32], v_cache: &[f32],
seq_len: usize, num_heads: usize, num_kv_heads: usize, head_dim: usize,
) {
let gsize = num_heads / num_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let kv_stride = num_kv_heads * head_dim;
let mut scores = vec![0.0f32; seq_len];
for h in 0..num_heads {
let kv_h = h / gsize;
let qo = h * head_dim;
for t in 0..seq_len {
let ko = t * kv_stride + kv_h * head_dim;
scores[t] = dot_f32(&q[qo..qo+head_dim], &k_cache[ko..ko+head_dim], head_dim) * scale;
}
softmax(&mut scores[..seq_len]);
for d in 0..head_dim {
let mut sum = 0.0f32;
for t in 0..seq_len {
sum += scores[t] * v_cache[t * kv_stride + kv_h * head_dim + d];
}
output[qo+d] = sum;
}
}
}
#[inline]
pub fn embedding(output: &mut [f32], token_id: u32, weight: &[f32], embed_dim: usize) {
let off = token_id as usize * embed_dim;
output.copy_from_slice(&weight[off..off + embed_dim]);
}
"#,
);
Ok(())
}
fn matmul_shapes(config: &ModelConfig) -> Vec<(usize, usize)> {
let hidden = config.hidden_size;
let intermediate = config.intermediate_size;
let num_heads = config.num_attention_heads;
let num_kv_heads = config.num_kv_heads;
let head_dim = config.head_dim;
let vocab = config.vocab_size;
let qk_size = num_heads * head_dim;
let kv_size = num_kv_heads * head_dim;
let mut shapes = vec![
(hidden, qk_size), (hidden, kv_size), (qk_size, hidden), (hidden, intermediate), (intermediate, hidden), (hidden, vocab), ];
shapes.sort();
shapes.dedup();
shapes
}
fn emit_wasm_specialized_matmul_functions(
code: &mut String,
config: &ModelConfig,
) -> Result<(), WasmCodegenError> {
writeln!(
code,
"// --- Shape-specialized matmul functions (m=1, single-threaded) ---"
)?;
writeln!(
code,
"// All dimensions baked in at compile time — no runtime size parameters."
)?;
writeln!(code)?;
for &(k, n) in &matmul_shapes(config) {
writeln!(
code,
"/// Specialized matmul: [1, {k}] x [{n}, {k}]^T -> [1, {n}]"
)?;
writeln!(code, "#[inline]")?;
writeln!(
code,
"fn matmul_vec_{k}x{n}(output: &mut [f32; {n}], input: &[f32; {k}], weight: &[f32]) {{"
)?;
let n_chunks = n / 4;
let n_remainder = n % 4;
if n_chunks > 0 {
writeln!(
code,
" // Process 4 output rows at a time for instruction-level parallelism"
)?;
writeln!(code, " for chunk in 0..{n_chunks} {{")?;
writeln!(code, " let j0 = chunk * 4;")?;
writeln!(
code,
" output[j0] = dot_f32(&input[..], &weight[j0*{k}..(j0+1)*{k}], {k});"
)?;
writeln!(
code,
" output[j0+1] = dot_f32(&input[..], &weight[(j0+1)*{k}..(j0+2)*{k}], {k});"
)?;
writeln!(
code,
" output[j0+2] = dot_f32(&input[..], &weight[(j0+2)*{k}..(j0+3)*{k}], {k});"
)?;
writeln!(
code,
" output[j0+3] = dot_f32(&input[..], &weight[(j0+3)*{k}..(j0+4)*{k}], {k});"
)?;
writeln!(code, " }}")?;
}
if n_remainder > 0 {
writeln!(code, " // Handle remaining {n_remainder} output rows")?;
writeln!(code, " let base = {n_chunks} * 4;")?;
for r in 0..n_remainder {
writeln!(code, " output[base+{r}] = dot_f32(&input[..], &weight[(base+{r})*{k}..(base+{r}+1)*{k}], {k});")?;
}
}
writeln!(code, "}}")?;
writeln!(code)?;
}
Ok(())
}
fn emit_wasm_forward_function(
code: &mut String,
_graph: &Graph,
config: &ModelConfig,
) -> Result<(), WasmCodegenError> {
let hidden = config.hidden_size;
let intermediate = config.intermediate_size;
let num_heads = config.num_attention_heads;
let num_kv_heads = config.num_kv_heads;
let head_dim = config.head_dim;
let vocab = config.vocab_size;
let qk_size = num_heads * head_dim;
let kv_size = num_kv_heads * head_dim;
writeln!(
code,
"/// Model weights — loaded once, passed to forward()."
)?;
writeln!(code, "pub struct Weights {{")?;
writeln!(
code,
" pub embed_tokens: Vec<f32>, // [{vocab} * {hidden}]"
)?;
writeln!(code, " pub layers: Vec<LayerWeights>,")?;
writeln!(code, " pub final_norm: Vec<f32>, // [{hidden}]")?;
writeln!(
code,
" pub lm_head: Vec<f32>, // [{vocab} * {hidden}]"
)?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "pub struct LayerWeights {{")?;
writeln!(code, " pub attn_norm: Vec<f32>, // [{hidden}]")?;
writeln!(
code,
" pub q_proj: Vec<f32>, // [{} * {hidden}]",
num_heads * head_dim
)?;
writeln!(
code,
" pub k_proj: Vec<f32>, // [{} * {hidden}]",
num_kv_heads * head_dim
)?;
writeln!(
code,
" pub v_proj: Vec<f32>, // [{} * {hidden}]",
num_kv_heads * head_dim
)?;
writeln!(
code,
" pub o_proj: Vec<f32>, // [{hidden} * {}]",
num_heads * head_dim
)?;
writeln!(code, " pub ffn_norm: Vec<f32>, // [{hidden}]")?;
writeln!(
code,
" pub gate_proj: Vec<f32>, // [{intermediate} * {hidden}]"
)?;
writeln!(
code,
" pub up_proj: Vec<f32>, // [{intermediate} * {hidden}]"
)?;
writeln!(
code,
" pub down_proj: Vec<f32>, // [{hidden} * {intermediate}]"
)?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "/// KV cache for autoregressive generation.")?;
writeln!(code, "pub struct KVCache {{")?;
writeln!(
code,
" pub k: Vec<Vec<f32>>, // [num_layers][MAX_SEQ_LEN * {kv_size}]"
)?;
writeln!(
code,
" pub v: Vec<Vec<f32>>, // [num_layers][MAX_SEQ_LEN * {kv_size}]"
)?;
writeln!(code, " pub len: usize,")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "impl KVCache {{")?;
writeln!(code, " pub fn new() -> Self {{")?;
writeln!(code, " Self {{")?;
writeln!(
code,
" k: (0..NUM_LAYERS).map(|_| vec![0.0f32; MAX_SEQ_LEN * {kv_size}]).collect(),"
)?;
writeln!(
code,
" v: (0..NUM_LAYERS).map(|_| vec![0.0f32; MAX_SEQ_LEN * {kv_size}]).collect(),"
)?;
writeln!(code, " len: 0,")?;
writeln!(code, " }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " pub fn reset(&mut self) {{")?;
writeln!(code, " self.len = 0;")?;
writeln!(code, " }}")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "impl Default for KVCache {{")?;
writeln!(code, " fn default() -> Self {{ Self::new() }}")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(
code,
"/// Run forward pass for a single token. Returns logits [{vocab}]."
)?;
writeln!(
code,
"pub fn forward(token_id: u32, weights: &Weights, cache: &mut KVCache) -> Vec<f32> {{"
)?;
writeln!(code, " let pos = cache.len;")?;
writeln!(code)?;
writeln!(code, " // Embedding lookup")?;
writeln!(code, " let mut hidden_state = [0.0f32; HIDDEN_SIZE];")?;
writeln!(
code,
" embedding(&mut hidden_state, token_id, &weights.embed_tokens, HIDDEN_SIZE);"
)?;
writeln!(code)?;
writeln!(code, " // Fixed-size buffers")?;
writeln!(code, " let mut normed = [0.0f32; {hidden}];")?;
writeln!(code, " let mut q = [0.0f32; {qk_size}];")?;
writeln!(code, " let mut k = [0.0f32; {kv_size}];")?;
writeln!(code, " let mut v = [0.0f32; {kv_size}];")?;
writeln!(code, " let mut attn_out = [0.0f32; {qk_size}];")?;
writeln!(code, " let mut attn_proj = [0.0f32; {hidden}];")?;
writeln!(code, " let mut gate = [0.0f32; {intermediate}];")?;
writeln!(code, " let mut up = [0.0f32; {intermediate}];")?;
writeln!(code, " let mut ffn_hidden = [0.0f32; {intermediate}];")?;
writeln!(code, " let mut ffn_out = [0.0f32; {hidden}];")?;
writeln!(code)?;
writeln!(
code,
" let rope_freqs = rope_freqs(HEAD_DIM, ROPE_THETA);"
)?;
writeln!(code)?;
writeln!(code, " // Transformer layers")?;
writeln!(code, " for layer_idx in 0..NUM_LAYERS {{")?;
writeln!(code, " let lw = &weights.layers[layer_idx];")?;
writeln!(code)?;
writeln!(code, " // Attention norm")?;
writeln!(
code,
" rms_norm(&mut normed, &hidden_state, &lw.attn_norm, RMS_NORM_EPS);"
)?;
writeln!(code)?;
writeln!(code, " // QKV projections")?;
writeln!(
code,
" matmul_vec_{hidden}x{qk_size}(&mut q, &normed, &lw.q_proj);"
)?;
writeln!(
code,
" matmul_vec_{hidden}x{kv_size}(&mut k, &normed, &lw.k_proj);"
)?;
writeln!(
code,
" matmul_vec_{hidden}x{kv_size}(&mut v, &normed, &lw.v_proj);"
)?;
writeln!(code)?;
writeln!(code, " // RoPE")?;
writeln!(
code,
" rope(&mut q, pos, HEAD_DIM, NUM_HEADS, &rope_freqs);"
)?;
writeln!(
code,
" rope(&mut k, pos, HEAD_DIM, NUM_KV_HEADS, &rope_freqs);"
)?;
writeln!(code)?;
writeln!(code, " // Update KV cache")?;
writeln!(
code,
" cache.k[layer_idx][pos*{kv_size}..(pos+1)*{kv_size}].copy_from_slice(&k);"
)?;
writeln!(
code,
" cache.v[layer_idx][pos*{kv_size}..(pos+1)*{kv_size}].copy_from_slice(&v);"
)?;
writeln!(code)?;
writeln!(code, " // Attention")?;
writeln!(code, " attention(")?;
writeln!(code, " &mut attn_out, &q,")?;
writeln!(
code,
" &cache.k[layer_idx][..(pos+1)*{kv_size}], &cache.v[layer_idx][..(pos+1)*{kv_size}],"
)?;
writeln!(
code,
" pos + 1, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM,"
)?;
writeln!(code, " );")?;
writeln!(code)?;
writeln!(code, " // Output projection + residual")?;
writeln!(
code,
" matmul_vec_{qk_size}x{hidden}(&mut attn_proj, &attn_out, &lw.o_proj);"
)?;
writeln!(code, " residual_add(&mut hidden_state, &attn_proj);")?;
writeln!(code)?;
writeln!(code, " // FFN norm")?;
writeln!(
code,
" rms_norm(&mut normed, &hidden_state, &lw.ffn_norm, RMS_NORM_EPS);"
)?;
writeln!(code)?;
writeln!(code, " // FFN: fused silu_mul")?;
writeln!(
code,
" matmul_vec_{hidden}x{intermediate}(&mut gate, &normed, &lw.gate_proj);"
)?;
writeln!(
code,
" matmul_vec_{hidden}x{intermediate}(&mut up, &normed, &lw.up_proj);"
)?;
writeln!(code, " silu_mul(&mut ffn_hidden, &gate, &up);")?;
writeln!(
code,
" matmul_vec_{intermediate}x{hidden}(&mut ffn_out, &ffn_hidden, &lw.down_proj);"
)?;
writeln!(code)?;
writeln!(code, " residual_add(&mut hidden_state, &ffn_out);")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " // Final norm")?;
writeln!(
code,
" rms_norm(&mut normed, &hidden_state, &weights.final_norm, RMS_NORM_EPS);"
)?;
writeln!(code)?;
writeln!(code, " // Logits projection")?;
writeln!(code, " let mut logits = vec![0.0f32; VOCAB_SIZE];")?;
writeln!(code, " for j in 0..VOCAB_SIZE {{")?;
writeln!(
code,
" logits[j] = dot_f32(&normed[..], &weights.lm_head[j*{hidden}..(j+1)*{hidden}], {hidden});"
)?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " cache.len += 1;")?;
writeln!(code, " logits")?;
writeln!(code, "}}")?;
writeln!(code)?;
Ok(())
}
fn emit_wasm_bindgen_exports(
code: &mut String,
config: &ModelConfig,
) -> Result<(), WasmCodegenError> {
let hidden = config.hidden_size;
let num_layers = config.num_layers;
let num_heads = config.num_attention_heads;
let num_kv_heads = config.num_kv_heads;
let head_dim = config.head_dim;
let vocab = config.vocab_size;
let intermediate = config.intermediate_size;
let qk_size = num_heads * head_dim;
let kv_size = num_kv_heads * head_dim;
writeln!(
code,
"/// Initialize panic hook for better error messages in browser console."
)?;
writeln!(code, "#[wasm_bindgen]")?;
writeln!(code, "pub fn init_panic_hook() {{")?;
writeln!(code, " console_error_panic_hook::set_once();")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(
code,
"/// WASM-exported model handle. Holds weights + KV cache."
)?;
writeln!(code, "#[wasm_bindgen]")?;
writeln!(code, "pub struct WasmModel {{")?;
writeln!(code, " weights: Weights,")?;
writeln!(code, " cache: KVCache,")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "#[wasm_bindgen]")?;
writeln!(code, "impl WasmModel {{")?;
let embed_elems = vocab * hidden;
let final_norm_elems = hidden;
let lm_head_elems = vocab * hidden;
let attn_norm_elems = hidden;
let q_proj_elems = qk_size * hidden;
let k_proj_elems = kv_size * hidden;
let v_proj_elems = kv_size * hidden;
let o_proj_elems = hidden * qk_size;
let ffn_norm_elems = hidden;
let gate_proj_elems = intermediate * hidden;
let up_proj_elems = intermediate * hidden;
let down_proj_elems = hidden * intermediate;
let layer_elems = attn_norm_elems
+ q_proj_elems
+ k_proj_elems
+ v_proj_elems
+ o_proj_elems
+ ffn_norm_elems
+ gate_proj_elems
+ up_proj_elems
+ down_proj_elems;
writeln!(code, " /// Load model from raw f32 weight bytes.")?;
writeln!(
code,
" /// Expected layout: embed_tokens | layer0 | layer1 | ... | final_norm | lm_head"
)?;
writeln!(code, " #[wasm_bindgen(constructor)]")?;
writeln!(code, " pub fn new(weights_bytes: &[u8]) -> WasmModel {{")?;
writeln!(code, " init_panic_hook();")?;
writeln!(code, " // Parse f32 weight bytes")?;
writeln!(code, " let n = weights_bytes.len() / 4;")?;
writeln!(code, " let mut raw = vec![0.0f32; n];")?;
writeln!(code, " for i in 0..n {{")?;
writeln!(
code,
" raw[i] = f32::from_le_bytes([weights_bytes[i*4], weights_bytes[i*4+1], weights_bytes[i*4+2], weights_bytes[i*4+3]]);"
)?;
writeln!(code, " }}")?;
writeln!(code, " let mut off = 0usize;")?;
writeln!(
code,
" let embed_tokens = raw[off..off+{embed_elems}].to_vec(); off += {embed_elems};"
)?;
writeln!(
code,
" let mut layers = Vec::with_capacity({num_layers});"
)?;
writeln!(code, " for _ in 0..{num_layers} {{")?;
writeln!(
code,
" let attn_norm = raw[off..off+{attn_norm_elems}].to_vec(); off += {attn_norm_elems};"
)?;
writeln!(
code,
" let q_proj = raw[off..off+{q_proj_elems}].to_vec(); off += {q_proj_elems};"
)?;
writeln!(
code,
" let k_proj = raw[off..off+{k_proj_elems}].to_vec(); off += {k_proj_elems};"
)?;
writeln!(
code,
" let v_proj = raw[off..off+{v_proj_elems}].to_vec(); off += {v_proj_elems};"
)?;
writeln!(
code,
" let o_proj = raw[off..off+{o_proj_elems}].to_vec(); off += {o_proj_elems};"
)?;
writeln!(
code,
" let ffn_norm = raw[off..off+{ffn_norm_elems}].to_vec(); off += {ffn_norm_elems};"
)?;
writeln!(
code,
" let gate_proj = raw[off..off+{gate_proj_elems}].to_vec(); off += {gate_proj_elems};"
)?;
writeln!(
code,
" let up_proj = raw[off..off+{up_proj_elems}].to_vec(); off += {up_proj_elems};"
)?;
writeln!(
code,
" let down_proj = raw[off..off+{down_proj_elems}].to_vec(); off += {down_proj_elems};"
)?;
writeln!(code, " layers.push(LayerWeights {{ attn_norm, q_proj, k_proj, v_proj, o_proj, ffn_norm, gate_proj, up_proj, down_proj }});")?;
writeln!(code, " }}")?;
writeln!(
code,
" let final_norm = raw[off..off+{final_norm_elems}].to_vec(); off += {final_norm_elems};"
)?;
writeln!(
code,
" let lm_head = raw[off..off+{lm_head_elems}].to_vec();"
)?;
writeln!(
code,
" let _ = ({layer_elems}, {embed_elems}, {lm_head_elems}, {final_norm_elems}); // suppress unused warnings"
)?;
writeln!(
code,
" let weights = Weights {{ embed_tokens, layers, final_norm, lm_head }};"
)?;
writeln!(code, " let cache = KVCache::new();")?;
writeln!(code, " WasmModel {{ weights, cache }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Run a single forward step. Returns logit for most-likely next token."
)?;
writeln!(
code,
" pub fn forward(&mut self, token_id: u32) -> u32 {{"
)?;
writeln!(
code,
" let logits = forward(token_id, &self.weights, &mut self.cache);"
)?;
writeln!(code, " // Argmax sampling")?;
writeln!(code, " let mut best = 0usize;")?;
writeln!(code, " let mut best_val = f32::NEG_INFINITY;")?;
writeln!(code, " for (i, &v) in logits.iter().enumerate() {{")?;
writeln!(
code,
" if v > best_val {{ best_val = v; best = i; }}"
)?;
writeln!(code, " }}")?;
writeln!(code, " best as u32")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " /// Reset the KV cache (start a new generation).")?;
writeln!(code, " pub fn reset_cache(&mut self) {{")?;
writeln!(code, " self.cache.reset();")?;
writeln!(code, " }}")?;
writeln!(code, "}}")?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use forgellm_frontend::{graph_builder, ir::ModelConfig};
fn tiny_config() -> ModelConfig {
ModelConfig {
architecture: Architecture::Llama,
hidden_size: 64,
intermediate_size: 128,
num_layers: 2,
num_attention_heads: 4,
num_kv_heads: 2,
head_dim: 16,
vocab_size: 256,
max_seq_len: 64,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
dtype: DType::F16,
sliding_window_size: None,
qkv_bias: false,
hidden_activation: HiddenActivation::SiLU,
}
}
#[test]
fn generate_wasm_project_creates_all_files() {
let config = tiny_config();
let graph = graph_builder::build_graph(&config).unwrap();
let dir = tempfile::tempdir().unwrap();
generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
assert!(dir.path().join("Cargo.toml").exists());
assert!(dir.path().join("src/lib.rs").exists());
assert!(dir.path().join("pkg/model.js").exists());
}
#[test]
fn generated_lib_rs_contains_wasm_bindgen() {
let config = tiny_config();
let graph = graph_builder::build_graph(&config).unwrap();
let dir = tempfile::tempdir().unwrap();
generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
assert!(lib_rs.contains("use wasm_bindgen::prelude::*;"));
}
#[test]
fn generated_lib_rs_contains_wasm_model() {
let config = tiny_config();
let graph = graph_builder::build_graph(&config).unwrap();
let dir = tempfile::tempdir().unwrap();
generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
assert!(lib_rs.contains("pub struct WasmModel"));
}
#[test]
fn generated_lib_rs_contains_dot_f32_kernel() {
let config = tiny_config();
let graph = graph_builder::build_graph(&config).unwrap();
let dir = tempfile::tempdir().unwrap();
generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
assert!(lib_rs.contains("fn dot_f32("));
assert!(lib_rs.contains("simd128"));
}
#[test]
fn generated_cargo_toml_has_cdylib() {
let config = tiny_config();
let graph = graph_builder::build_graph(&config).unwrap();
let dir = tempfile::tempdir().unwrap();
generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
let cargo_toml = std::fs::read_to_string(dir.path().join("Cargo.toml")).unwrap();
assert!(cargo_toml.contains("cdylib"));
assert!(cargo_toml.contains("wasm-bindgen"));
}
#[test]
fn generate_placeholder() {
let graph = Graph::new("test");
assert_eq!(graph.len(), 0);
}
}