Skip to main content

forgellm_codegen_wasm/
lib.rs

1//! Forge WASM Code Generation — WASM + WebGPU code emission.
2//!
3//! Generates a complete Cargo project targeting `wasm32-unknown-unknown` with:
4//! - WASM SIMD128 intrinsics for matmul acceleration
5//! - `wasm_bindgen` exports for JS integration
6//! - A companion JS glue layer for browser use
7
8use std::fmt::Write as FmtWrite;
9use std::fs;
10use std::path::Path;
11
12use forgellm_frontend::ir::*;
13
14/// Errors during WASM code generation.
15#[derive(Debug, thiserror::Error)]
16pub enum WasmCodegenError {
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    #[error("I/O error: {0}")]
21    Io(#[from] std::io::Error),
22
23    #[error("format error: {0}")]
24    Fmt(#[from] std::fmt::Error),
25}
26
27/// Generate a complete WASM Cargo project from a computation graph.
28///
29/// Creates:
30/// - `Cargo.toml` — targeting `wasm32-unknown-unknown` with wasm-bindgen
31/// - `src/lib.rs` — SIMD128-accelerated kernels + `WasmModel` export
32/// - `pkg/model.js` — JS glue layer for browser integration
33pub fn generate_wasm_project(
34    graph: &Graph,
35    output_dir: &Path,
36    model_name: &str,
37) -> Result<(), WasmCodegenError> {
38    let config = graph
39        .config
40        .as_ref()
41        .ok_or(WasmCodegenError::MissingConfig)?;
42
43    let src_dir = output_dir.join("src");
44    let pkg_dir = output_dir.join("pkg");
45    fs::create_dir_all(&src_dir)?;
46    fs::create_dir_all(&pkg_dir)?;
47
48    // Generate Cargo.toml
49    fs::write(
50        output_dir.join("Cargo.toml"),
51        generate_cargo_toml(model_name),
52    )?;
53
54    // Generate src/lib.rs
55    let lib_code = generate_lib_rs(graph, config)?;
56    fs::write(src_dir.join("lib.rs"), lib_code)?;
57
58    // Generate pkg/model.js
59    fs::write(pkg_dir.join("model.js"), generate_model_js())?;
60
61    Ok(())
62}
63
64fn sanitize_name(name: &str) -> String {
65    name.to_lowercase()
66        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
67        .trim_matches('-')
68        .to_string()
69}
70
71fn generate_cargo_toml(model_name: &str) -> String {
72    let sanitized = sanitize_name(model_name);
73    format!(
74        r#"[package]
75name = "{sanitized}"
76version = "0.1.0"
77edition = "2021"
78
79[lib]
80crate-type = ["cdylib"]
81
82[dependencies]
83wasm-bindgen = "0.2"
84js-sys = "0.3"
85getrandom = {{ version = "0.2", features = ["js"] }}
86console_error_panic_hook = "0.1"
87
88[profile.release]
89opt-level = 3
90lto = "fat"
91codegen-units = 1
92panic = "abort"
93"#
94    )
95}
96
97fn generate_model_js() -> String {
98    r#"// model.js - JS glue for ForgeLLM WASM model
99export async function loadModel(wasmUrl, weightsUrl) {
100  const { default: init, WasmModel } = await import(wasmUrl);
101  await init();
102  const weightsResp = await fetch(weightsUrl);
103  const weightsBytes = new Uint8Array(await weightsResp.arrayBuffer());
104  return new WasmModel(weightsBytes);
105}
106"#
107    .to_string()
108}
109
110fn generate_lib_rs(graph: &Graph, config: &ModelConfig) -> Result<String, WasmCodegenError> {
111    let mut code = String::with_capacity(32 * 1024);
112
113    emit_lib_header(&mut code, config)?;
114    emit_wasm_kernels(&mut code)?;
115    emit_wasm_specialized_matmul_functions(&mut code, config)?;
116    emit_wasm_forward_function(&mut code, graph, config)?;
117    emit_wasm_bindgen_exports(&mut code, config)?;
118
119    Ok(code)
120}
121
122fn emit_lib_header(code: &mut String, config: &ModelConfig) -> Result<(), WasmCodegenError> {
123    writeln!(code, "//! Auto-generated by ForgeLLM WASM codegen.")?;
124    writeln!(
125        code,
126        "//! Model: {} ({} layers, hidden={})",
127        config.architecture, config.num_layers, config.hidden_size
128    )?;
129    writeln!(code, "//!")?;
130    writeln!(
131        code,
132        "//! Targets wasm32-unknown-unknown with optional SIMD128 acceleration."
133    )?;
134    writeln!(code)?;
135    writeln!(code, "#![allow(clippy::excessive_precision)]")?;
136    writeln!(
137        code,
138        "#![allow(dead_code, unused_imports, unused_assignments)]"
139    )?;
140    writeln!(code)?;
141    writeln!(code, "use wasm_bindgen::prelude::*;")?;
142    writeln!(code)?;
143    writeln!(code, "// Model constants")?;
144    writeln!(
145        code,
146        "pub const HIDDEN_SIZE: usize = {};",
147        config.hidden_size
148    )?;
149    writeln!(
150        code,
151        "pub const INTERMEDIATE_SIZE: usize = {};",
152        config.intermediate_size
153    )?;
154    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
155    writeln!(
156        code,
157        "pub const NUM_HEADS: usize = {};",
158        config.num_attention_heads
159    )?;
160    writeln!(
161        code,
162        "pub const NUM_KV_HEADS: usize = {};",
163        config.num_kv_heads
164    )?;
165    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
166    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
167    let effective_seq_len = config.max_seq_len.min(4096);
168    writeln!(
169        code,
170        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
171        effective_seq_len, config.max_seq_len
172    )?;
173    writeln!(
174        code,
175        "pub const RMS_NORM_EPS: f32 = {:e};",
176        config.rms_norm_eps
177    )?;
178    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
179    writeln!(code)?;
180
181    Ok(())
182}
183
184fn emit_wasm_kernels(code: &mut String) -> Result<(), WasmCodegenError> {
185    code.push_str(
186        r#"
187// --- WASM SIMD128 dot product ---
188#[cfg(target_feature = "simd128")]
189#[inline]
190fn dot_f32(a: &[f32], b: &[f32], len: usize) -> f32 {
191    use std::arch::wasm32::*;
192    unsafe {
193        let mut acc = f32x4_splat(0.0);
194        let chunks = len / 4;
195        for i in 0..chunks {
196            let base = i * 4;
197            let va = v128_load(a.as_ptr().add(base) as *const v128);
198            let vb = v128_load(b.as_ptr().add(base) as *const v128);
199            acc = f32x4_add(acc, f32x4_mul(va, vb));
200        }
201        let s = f32x4_extract_lane::<0>(acc) + f32x4_extract_lane::<1>(acc)
202              + f32x4_extract_lane::<2>(acc) + f32x4_extract_lane::<3>(acc);
203        let mut r = s;
204        for i in (chunks * 4)..len { r += *a.get_unchecked(i) * *b.get_unchecked(i); }
205        r
206    }
207}
208
209#[cfg(not(target_feature = "simd128"))]
210#[inline]
211fn dot_f32(a: &[f32], b: &[f32], len: usize) -> f32 {
212    let mut sum = 0.0f32;
213    for i in 0..len { sum += a[i] * b[i]; }
214    sum
215}
216
217#[inline]
218pub fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
219    let n = input.len();
220    let sum_sq = dot_f32(input, input, n);
221    let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
222    for i in 0..n { output[i] = input[i] * inv_rms * weight[i]; }
223}
224
225#[inline]
226pub fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
227    for i in 0..m {
228        let row = &input[i*k..(i+1)*k];
229        for j in 0..n {
230            output[i*n+j] = dot_f32(row, &weight[j*k..(j+1)*k], k);
231        }
232    }
233}
234
235#[inline]
236pub fn silu(output: &mut [f32], input: &[f32]) {
237    for (o, &x) in output.iter_mut().zip(input.iter()) { *o = x / (1.0 + (-x).exp()); }
238}
239
240#[inline]
241pub fn silu_mul(output: &mut [f32], gate: &[f32], up: &[f32]) {
242    for i in 0..gate.len() {
243        let x = gate[i];
244        output[i] = (x / (1.0 + (-x).exp())) * up[i];
245    }
246}
247
248#[inline]
249pub fn residual_add(a: &mut [f32], b: &[f32]) {
250    for i in 0..a.len() { a[i] += b[i]; }
251}
252
253#[inline]
254pub fn softmax(values: &mut [f32]) {
255    let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
256    let mut sum = 0.0f32;
257    for v in values.iter_mut() { *v = (*v - max_val).exp(); sum += *v; }
258    let inv = if sum > 0.0 { 1.0 / sum } else { 0.0 };
259    for v in values.iter_mut() { *v *= inv; }
260}
261
262#[inline]
263pub fn rope_freqs(head_dim: usize, theta: f32) -> Vec<f32> {
264    (0..head_dim / 2).map(|i| 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32)).collect()
265}
266
267#[inline]
268pub fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, freqs: &[f32]) {
269    let half = head_dim / 2;
270    let mut cos_table = vec![0.0f32; half];
271    let mut sin_table = vec![0.0f32; half];
272    for i in 0..half {
273        let angle = pos as f32 * freqs[i];
274        let (s, c) = angle.sin_cos();
275        cos_table[i] = c;
276        sin_table[i] = s;
277    }
278    for h in 0..num_heads {
279        let off = h * head_dim;
280        for i in 0..half {
281            let (x0, x1) = (data[off + 2*i], data[off + 2*i + 1]);
282            data[off + 2*i] = x0 * cos_table[i] - x1 * sin_table[i];
283            data[off + 2*i + 1] = x0 * sin_table[i] + x1 * cos_table[i];
284        }
285    }
286}
287
288#[inline]
289pub fn attention(
290    output: &mut [f32], q: &[f32], k_cache: &[f32], v_cache: &[f32],
291    seq_len: usize, num_heads: usize, num_kv_heads: usize, head_dim: usize,
292) {
293    let gsize = num_heads / num_kv_heads;
294    let scale = 1.0 / (head_dim as f32).sqrt();
295    let kv_stride = num_kv_heads * head_dim;
296    let mut scores = vec![0.0f32; seq_len];
297    for h in 0..num_heads {
298        let kv_h = h / gsize;
299        let qo = h * head_dim;
300        for t in 0..seq_len {
301            let ko = t * kv_stride + kv_h * head_dim;
302            scores[t] = dot_f32(&q[qo..qo+head_dim], &k_cache[ko..ko+head_dim], head_dim) * scale;
303        }
304        softmax(&mut scores[..seq_len]);
305        for d in 0..head_dim {
306            let mut sum = 0.0f32;
307            for t in 0..seq_len {
308                sum += scores[t] * v_cache[t * kv_stride + kv_h * head_dim + d];
309            }
310            output[qo+d] = sum;
311        }
312    }
313}
314
315#[inline]
316pub fn embedding(output: &mut [f32], token_id: u32, weight: &[f32], embed_dim: usize) {
317    let off = token_id as usize * embed_dim;
318    output.copy_from_slice(&weight[off..off + embed_dim]);
319}
320
321"#,
322    );
323
324    Ok(())
325}
326
327/// Collect all unique (k, n) matmul shapes used in the forward pass.
328fn matmul_shapes(config: &ModelConfig) -> Vec<(usize, usize)> {
329    let hidden = config.hidden_size;
330    let intermediate = config.intermediate_size;
331    let num_heads = config.num_attention_heads;
332    let num_kv_heads = config.num_kv_heads;
333    let head_dim = config.head_dim;
334    let vocab = config.vocab_size;
335    let qk_size = num_heads * head_dim;
336    let kv_size = num_kv_heads * head_dim;
337
338    let mut shapes = vec![
339        (hidden, qk_size),      // q_proj
340        (hidden, kv_size),      // k_proj, v_proj
341        (qk_size, hidden),      // o_proj
342        (hidden, intermediate), // gate_proj, up_proj
343        (intermediate, hidden), // down_proj
344        (hidden, vocab),        // lm_head
345    ];
346    shapes.sort();
347    shapes.dedup();
348    shapes
349}
350
351/// Emit shape-specialized matmul_vec functions (WASM, single-threaded — no rayon).
352fn emit_wasm_specialized_matmul_functions(
353    code: &mut String,
354    config: &ModelConfig,
355) -> Result<(), WasmCodegenError> {
356    writeln!(
357        code,
358        "// --- Shape-specialized matmul functions (m=1, single-threaded) ---"
359    )?;
360    writeln!(
361        code,
362        "// All dimensions baked in at compile time — no runtime size parameters."
363    )?;
364    writeln!(code)?;
365
366    for &(k, n) in &matmul_shapes(config) {
367        writeln!(
368            code,
369            "/// Specialized matmul: [1, {k}] x [{n}, {k}]^T -> [1, {n}]"
370        )?;
371        writeln!(code, "#[inline]")?;
372        writeln!(
373            code,
374            "fn matmul_vec_{k}x{n}(output: &mut [f32; {n}], input: &[f32; {k}], weight: &[f32]) {{"
375        )?;
376        let n_chunks = n / 4;
377        let n_remainder = n % 4;
378        if n_chunks > 0 {
379            writeln!(
380                code,
381                "    // Process 4 output rows at a time for instruction-level parallelism"
382            )?;
383            writeln!(code, "    for chunk in 0..{n_chunks} {{")?;
384            writeln!(code, "        let j0 = chunk * 4;")?;
385            writeln!(
386                code,
387                "        output[j0]   = dot_f32(&input[..], &weight[j0*{k}..(j0+1)*{k}], {k});"
388            )?;
389            writeln!(
390                code,
391                "        output[j0+1] = dot_f32(&input[..], &weight[(j0+1)*{k}..(j0+2)*{k}], {k});"
392            )?;
393            writeln!(
394                code,
395                "        output[j0+2] = dot_f32(&input[..], &weight[(j0+2)*{k}..(j0+3)*{k}], {k});"
396            )?;
397            writeln!(
398                code,
399                "        output[j0+3] = dot_f32(&input[..], &weight[(j0+3)*{k}..(j0+4)*{k}], {k});"
400            )?;
401            writeln!(code, "    }}")?;
402        }
403        if n_remainder > 0 {
404            writeln!(code, "    // Handle remaining {n_remainder} output rows")?;
405            writeln!(code, "    let base = {n_chunks} * 4;")?;
406            for r in 0..n_remainder {
407                writeln!(code, "    output[base+{r}] = dot_f32(&input[..], &weight[(base+{r})*{k}..(base+{r}+1)*{k}], {k});")?;
408            }
409        }
410        writeln!(code, "}}")?;
411        writeln!(code)?;
412    }
413
414    Ok(())
415}
416
417fn emit_wasm_forward_function(
418    code: &mut String,
419    _graph: &Graph,
420    config: &ModelConfig,
421) -> Result<(), WasmCodegenError> {
422    let hidden = config.hidden_size;
423    let intermediate = config.intermediate_size;
424    let num_heads = config.num_attention_heads;
425    let num_kv_heads = config.num_kv_heads;
426    let head_dim = config.head_dim;
427    let vocab = config.vocab_size;
428    let qk_size = num_heads * head_dim;
429    let kv_size = num_kv_heads * head_dim;
430
431    // Weights struct
432    writeln!(
433        code,
434        "/// Model weights — loaded once, passed to forward()."
435    )?;
436    writeln!(code, "pub struct Weights {{")?;
437    writeln!(
438        code,
439        "    pub embed_tokens: Vec<f32>,       // [{vocab} * {hidden}]"
440    )?;
441    writeln!(code, "    pub layers: Vec<LayerWeights>,")?;
442    writeln!(code, "    pub final_norm: Vec<f32>,          // [{hidden}]")?;
443    writeln!(
444        code,
445        "    pub lm_head: Vec<f32>,             // [{vocab} * {hidden}]"
446    )?;
447    writeln!(code, "}}")?;
448    writeln!(code)?;
449
450    writeln!(code, "pub struct LayerWeights {{")?;
451    writeln!(code, "    pub attn_norm: Vec<f32>,           // [{hidden}]")?;
452    writeln!(
453        code,
454        "    pub q_proj: Vec<f32>,              // [{} * {hidden}]",
455        num_heads * head_dim
456    )?;
457    writeln!(
458        code,
459        "    pub k_proj: Vec<f32>,              // [{} * {hidden}]",
460        num_kv_heads * head_dim
461    )?;
462    writeln!(
463        code,
464        "    pub v_proj: Vec<f32>,              // [{} * {hidden}]",
465        num_kv_heads * head_dim
466    )?;
467    writeln!(
468        code,
469        "    pub o_proj: Vec<f32>,              // [{hidden} * {}]",
470        num_heads * head_dim
471    )?;
472    writeln!(code, "    pub ffn_norm: Vec<f32>,            // [{hidden}]")?;
473    writeln!(
474        code,
475        "    pub gate_proj: Vec<f32>,           // [{intermediate} * {hidden}]"
476    )?;
477    writeln!(
478        code,
479        "    pub up_proj: Vec<f32>,             // [{intermediate} * {hidden}]"
480    )?;
481    writeln!(
482        code,
483        "    pub down_proj: Vec<f32>,           // [{hidden} * {intermediate}]"
484    )?;
485    writeln!(code, "}}")?;
486    writeln!(code)?;
487
488    // KVCache
489    writeln!(code, "/// KV cache for autoregressive generation.")?;
490    writeln!(code, "pub struct KVCache {{")?;
491    writeln!(
492        code,
493        "    pub k: Vec<Vec<f32>>,  // [num_layers][MAX_SEQ_LEN * {kv_size}]"
494    )?;
495    writeln!(
496        code,
497        "    pub v: Vec<Vec<f32>>,  // [num_layers][MAX_SEQ_LEN * {kv_size}]"
498    )?;
499    writeln!(code, "    pub len: usize,")?;
500    writeln!(code, "}}")?;
501    writeln!(code)?;
502
503    writeln!(code, "impl KVCache {{")?;
504    writeln!(code, "    pub fn new() -> Self {{")?;
505    writeln!(code, "        Self {{")?;
506    writeln!(
507        code,
508        "            k: (0..NUM_LAYERS).map(|_| vec![0.0f32; MAX_SEQ_LEN * {kv_size}]).collect(),"
509    )?;
510    writeln!(
511        code,
512        "            v: (0..NUM_LAYERS).map(|_| vec![0.0f32; MAX_SEQ_LEN * {kv_size}]).collect(),"
513    )?;
514    writeln!(code, "            len: 0,")?;
515    writeln!(code, "        }}")?;
516    writeln!(code, "    }}")?;
517    writeln!(code)?;
518    writeln!(code, "    pub fn reset(&mut self) {{")?;
519    writeln!(code, "        self.len = 0;")?;
520    writeln!(code, "    }}")?;
521    writeln!(code, "}}")?;
522    writeln!(code)?;
523    writeln!(code, "impl Default for KVCache {{")?;
524    writeln!(code, "    fn default() -> Self {{ Self::new() }}")?;
525    writeln!(code, "}}")?;
526    writeln!(code)?;
527
528    // forward() function
529    writeln!(
530        code,
531        "/// Run forward pass for a single token. Returns logits [{vocab}]."
532    )?;
533    writeln!(
534        code,
535        "pub fn forward(token_id: u32, weights: &Weights, cache: &mut KVCache) -> Vec<f32> {{"
536    )?;
537    writeln!(code, "    let pos = cache.len;")?;
538    writeln!(code)?;
539
540    writeln!(code, "    // Embedding lookup")?;
541    writeln!(code, "    let mut hidden_state = [0.0f32; HIDDEN_SIZE];")?;
542    writeln!(
543        code,
544        "    embedding(&mut hidden_state, token_id, &weights.embed_tokens, HIDDEN_SIZE);"
545    )?;
546    writeln!(code)?;
547
548    writeln!(code, "    // Fixed-size buffers")?;
549    writeln!(code, "    let mut normed = [0.0f32; {hidden}];")?;
550    writeln!(code, "    let mut q = [0.0f32; {qk_size}];")?;
551    writeln!(code, "    let mut k = [0.0f32; {kv_size}];")?;
552    writeln!(code, "    let mut v = [0.0f32; {kv_size}];")?;
553    writeln!(code, "    let mut attn_out = [0.0f32; {qk_size}];")?;
554    writeln!(code, "    let mut attn_proj = [0.0f32; {hidden}];")?;
555    writeln!(code, "    let mut gate = [0.0f32; {intermediate}];")?;
556    writeln!(code, "    let mut up = [0.0f32; {intermediate}];")?;
557    writeln!(code, "    let mut ffn_hidden = [0.0f32; {intermediate}];")?;
558    writeln!(code, "    let mut ffn_out = [0.0f32; {hidden}];")?;
559    writeln!(code)?;
560    writeln!(
561        code,
562        "    let rope_freqs = rope_freqs(HEAD_DIM, ROPE_THETA);"
563    )?;
564    writeln!(code)?;
565
566    writeln!(code, "    // Transformer layers")?;
567    writeln!(code, "    for layer_idx in 0..NUM_LAYERS {{")?;
568    writeln!(code, "        let lw = &weights.layers[layer_idx];")?;
569    writeln!(code)?;
570    writeln!(code, "        // Attention norm")?;
571    writeln!(
572        code,
573        "        rms_norm(&mut normed, &hidden_state, &lw.attn_norm, RMS_NORM_EPS);"
574    )?;
575    writeln!(code)?;
576    writeln!(code, "        // QKV projections")?;
577    writeln!(
578        code,
579        "        matmul_vec_{hidden}x{qk_size}(&mut q, &normed, &lw.q_proj);"
580    )?;
581    writeln!(
582        code,
583        "        matmul_vec_{hidden}x{kv_size}(&mut k, &normed, &lw.k_proj);"
584    )?;
585    writeln!(
586        code,
587        "        matmul_vec_{hidden}x{kv_size}(&mut v, &normed, &lw.v_proj);"
588    )?;
589    writeln!(code)?;
590    writeln!(code, "        // RoPE")?;
591    writeln!(
592        code,
593        "        rope(&mut q, pos, HEAD_DIM, NUM_HEADS, &rope_freqs);"
594    )?;
595    writeln!(
596        code,
597        "        rope(&mut k, pos, HEAD_DIM, NUM_KV_HEADS, &rope_freqs);"
598    )?;
599    writeln!(code)?;
600    writeln!(code, "        // Update KV cache")?;
601    writeln!(
602        code,
603        "        cache.k[layer_idx][pos*{kv_size}..(pos+1)*{kv_size}].copy_from_slice(&k);"
604    )?;
605    writeln!(
606        code,
607        "        cache.v[layer_idx][pos*{kv_size}..(pos+1)*{kv_size}].copy_from_slice(&v);"
608    )?;
609    writeln!(code)?;
610    writeln!(code, "        // Attention")?;
611    writeln!(code, "        attention(")?;
612    writeln!(code, "            &mut attn_out, &q,")?;
613    writeln!(
614        code,
615        "            &cache.k[layer_idx][..(pos+1)*{kv_size}], &cache.v[layer_idx][..(pos+1)*{kv_size}],"
616    )?;
617    writeln!(
618        code,
619        "            pos + 1, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM,"
620    )?;
621    writeln!(code, "        );")?;
622    writeln!(code)?;
623    writeln!(code, "        // Output projection + residual")?;
624    writeln!(
625        code,
626        "        matmul_vec_{qk_size}x{hidden}(&mut attn_proj, &attn_out, &lw.o_proj);"
627    )?;
628    writeln!(code, "        residual_add(&mut hidden_state, &attn_proj);")?;
629    writeln!(code)?;
630    writeln!(code, "        // FFN norm")?;
631    writeln!(
632        code,
633        "        rms_norm(&mut normed, &hidden_state, &lw.ffn_norm, RMS_NORM_EPS);"
634    )?;
635    writeln!(code)?;
636    writeln!(code, "        // FFN: fused silu_mul")?;
637    writeln!(
638        code,
639        "        matmul_vec_{hidden}x{intermediate}(&mut gate, &normed, &lw.gate_proj);"
640    )?;
641    writeln!(
642        code,
643        "        matmul_vec_{hidden}x{intermediate}(&mut up, &normed, &lw.up_proj);"
644    )?;
645    writeln!(code, "        silu_mul(&mut ffn_hidden, &gate, &up);")?;
646    writeln!(
647        code,
648        "        matmul_vec_{intermediate}x{hidden}(&mut ffn_out, &ffn_hidden, &lw.down_proj);"
649    )?;
650    writeln!(code)?;
651    writeln!(code, "        residual_add(&mut hidden_state, &ffn_out);")?;
652    writeln!(code, "    }}")?;
653    writeln!(code)?;
654
655    writeln!(code, "    // Final norm")?;
656    writeln!(
657        code,
658        "    rms_norm(&mut normed, &hidden_state, &weights.final_norm, RMS_NORM_EPS);"
659    )?;
660    writeln!(code)?;
661
662    writeln!(code, "    // Logits projection")?;
663    writeln!(code, "    let mut logits = vec![0.0f32; VOCAB_SIZE];")?;
664    writeln!(code, "    for j in 0..VOCAB_SIZE {{")?;
665    writeln!(
666        code,
667        "        logits[j] = dot_f32(&normed[..], &weights.lm_head[j*{hidden}..(j+1)*{hidden}], {hidden});"
668    )?;
669    writeln!(code, "    }}")?;
670    writeln!(code)?;
671    writeln!(code, "    cache.len += 1;")?;
672    writeln!(code, "    logits")?;
673    writeln!(code, "}}")?;
674    writeln!(code)?;
675
676    Ok(())
677}
678
679fn emit_wasm_bindgen_exports(
680    code: &mut String,
681    config: &ModelConfig,
682) -> Result<(), WasmCodegenError> {
683    let hidden = config.hidden_size;
684    let num_layers = config.num_layers;
685    let num_heads = config.num_attention_heads;
686    let num_kv_heads = config.num_kv_heads;
687    let head_dim = config.head_dim;
688    let vocab = config.vocab_size;
689    let intermediate = config.intermediate_size;
690    let qk_size = num_heads * head_dim;
691    let kv_size = num_kv_heads * head_dim;
692
693    writeln!(
694        code,
695        "/// Initialize panic hook for better error messages in browser console."
696    )?;
697    writeln!(code, "#[wasm_bindgen]")?;
698    writeln!(code, "pub fn init_panic_hook() {{")?;
699    writeln!(code, "    console_error_panic_hook::set_once();")?;
700    writeln!(code, "}}")?;
701    writeln!(code)?;
702
703    writeln!(
704        code,
705        "/// WASM-exported model handle. Holds weights + KV cache."
706    )?;
707    writeln!(code, "#[wasm_bindgen]")?;
708    writeln!(code, "pub struct WasmModel {{")?;
709    writeln!(code, "    weights: Weights,")?;
710    writeln!(code, "    cache: KVCache,")?;
711    writeln!(code, "}}")?;
712    writeln!(code)?;
713
714    writeln!(code, "#[wasm_bindgen]")?;
715    writeln!(code, "impl WasmModel {{")?;
716
717    // Emit byte-size constants for weight parsing
718    let embed_elems = vocab * hidden;
719    let final_norm_elems = hidden;
720    let lm_head_elems = vocab * hidden;
721    let attn_norm_elems = hidden;
722    let q_proj_elems = qk_size * hidden;
723    let k_proj_elems = kv_size * hidden;
724    let v_proj_elems = kv_size * hidden;
725    let o_proj_elems = hidden * qk_size;
726    let ffn_norm_elems = hidden;
727    let gate_proj_elems = intermediate * hidden;
728    let up_proj_elems = intermediate * hidden;
729    let down_proj_elems = hidden * intermediate;
730
731    let layer_elems = attn_norm_elems
732        + q_proj_elems
733        + k_proj_elems
734        + v_proj_elems
735        + o_proj_elems
736        + ffn_norm_elems
737        + gate_proj_elems
738        + up_proj_elems
739        + down_proj_elems;
740
741    writeln!(code, "    /// Load model from raw f32 weight bytes.")?;
742    writeln!(
743        code,
744        "    /// Expected layout: embed_tokens | layer0 | layer1 | ... | final_norm | lm_head"
745    )?;
746    writeln!(code, "    #[wasm_bindgen(constructor)]")?;
747    writeln!(code, "    pub fn new(weights_bytes: &[u8]) -> WasmModel {{")?;
748    writeln!(code, "        init_panic_hook();")?;
749    writeln!(code, "        // Parse f32 weight bytes")?;
750    writeln!(code, "        let n = weights_bytes.len() / 4;")?;
751    writeln!(code, "        let mut raw = vec![0.0f32; n];")?;
752    writeln!(code, "        for i in 0..n {{")?;
753    writeln!(
754        code,
755        "            raw[i] = f32::from_le_bytes([weights_bytes[i*4], weights_bytes[i*4+1], weights_bytes[i*4+2], weights_bytes[i*4+3]]);"
756    )?;
757    writeln!(code, "        }}")?;
758    writeln!(code, "        let mut off = 0usize;")?;
759    writeln!(
760        code,
761        "        let embed_tokens = raw[off..off+{embed_elems}].to_vec(); off += {embed_elems};"
762    )?;
763    writeln!(
764        code,
765        "        let mut layers = Vec::with_capacity({num_layers});"
766    )?;
767    writeln!(code, "        for _ in 0..{num_layers} {{")?;
768    writeln!(
769        code,
770        "            let attn_norm = raw[off..off+{attn_norm_elems}].to_vec(); off += {attn_norm_elems};"
771    )?;
772    writeln!(
773        code,
774        "            let q_proj = raw[off..off+{q_proj_elems}].to_vec(); off += {q_proj_elems};"
775    )?;
776    writeln!(
777        code,
778        "            let k_proj = raw[off..off+{k_proj_elems}].to_vec(); off += {k_proj_elems};"
779    )?;
780    writeln!(
781        code,
782        "            let v_proj = raw[off..off+{v_proj_elems}].to_vec(); off += {v_proj_elems};"
783    )?;
784    writeln!(
785        code,
786        "            let o_proj = raw[off..off+{o_proj_elems}].to_vec(); off += {o_proj_elems};"
787    )?;
788    writeln!(
789        code,
790        "            let ffn_norm = raw[off..off+{ffn_norm_elems}].to_vec(); off += {ffn_norm_elems};"
791    )?;
792    writeln!(
793        code,
794        "            let gate_proj = raw[off..off+{gate_proj_elems}].to_vec(); off += {gate_proj_elems};"
795    )?;
796    writeln!(
797        code,
798        "            let up_proj = raw[off..off+{up_proj_elems}].to_vec(); off += {up_proj_elems};"
799    )?;
800    writeln!(
801        code,
802        "            let down_proj = raw[off..off+{down_proj_elems}].to_vec(); off += {down_proj_elems};"
803    )?;
804    writeln!(code, "            layers.push(LayerWeights {{ attn_norm, q_proj, k_proj, v_proj, o_proj, ffn_norm, gate_proj, up_proj, down_proj }});")?;
805    writeln!(code, "        }}")?;
806    writeln!(
807        code,
808        "        let final_norm = raw[off..off+{final_norm_elems}].to_vec(); off += {final_norm_elems};"
809    )?;
810    writeln!(
811        code,
812        "        let lm_head = raw[off..off+{lm_head_elems}].to_vec();"
813    )?;
814    writeln!(
815        code,
816        "        let _ = ({layer_elems}, {embed_elems}, {lm_head_elems}, {final_norm_elems}); // suppress unused warnings"
817    )?;
818    writeln!(
819        code,
820        "        let weights = Weights {{ embed_tokens, layers, final_norm, lm_head }};"
821    )?;
822    writeln!(code, "        let cache = KVCache::new();")?;
823    writeln!(code, "        WasmModel {{ weights, cache }}")?;
824    writeln!(code, "    }}")?;
825    writeln!(code)?;
826
827    writeln!(
828        code,
829        "    /// Run a single forward step. Returns logit for most-likely next token."
830    )?;
831    writeln!(
832        code,
833        "    pub fn forward(&mut self, token_id: u32) -> u32 {{"
834    )?;
835    writeln!(
836        code,
837        "        let logits = forward(token_id, &self.weights, &mut self.cache);"
838    )?;
839    writeln!(code, "        // Argmax sampling")?;
840    writeln!(code, "        let mut best = 0usize;")?;
841    writeln!(code, "        let mut best_val = f32::NEG_INFINITY;")?;
842    writeln!(code, "        for (i, &v) in logits.iter().enumerate() {{")?;
843    writeln!(
844        code,
845        "            if v > best_val {{ best_val = v; best = i; }}"
846    )?;
847    writeln!(code, "        }}")?;
848    writeln!(code, "        best as u32")?;
849    writeln!(code, "    }}")?;
850    writeln!(code)?;
851
852    writeln!(code, "    /// Reset the KV cache (start a new generation).")?;
853    writeln!(code, "    pub fn reset_cache(&mut self) {{")?;
854    writeln!(code, "        self.cache.reset();")?;
855    writeln!(code, "    }}")?;
856    writeln!(code, "}}")?;
857
858    Ok(())
859}
860
861#[cfg(test)]
862mod tests {
863    use super::*;
864    use forgellm_frontend::{graph_builder, ir::ModelConfig};
865
866    fn tiny_config() -> ModelConfig {
867        ModelConfig {
868            architecture: Architecture::Llama,
869            hidden_size: 64,
870            intermediate_size: 128,
871            num_layers: 2,
872            num_attention_heads: 4,
873            num_kv_heads: 2,
874            head_dim: 16,
875            vocab_size: 256,
876            max_seq_len: 64,
877            rms_norm_eps: 1e-5,
878            rope_theta: 10000.0,
879            dtype: DType::F16,
880            sliding_window_size: None,
881            qkv_bias: false,
882            hidden_activation: HiddenActivation::SiLU,
883        }
884    }
885
886    #[test]
887    fn generate_wasm_project_creates_all_files() {
888        let config = tiny_config();
889        let graph = graph_builder::build_graph(&config).unwrap();
890        let dir = tempfile::tempdir().unwrap();
891        generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
892
893        assert!(dir.path().join("Cargo.toml").exists());
894        assert!(dir.path().join("src/lib.rs").exists());
895        assert!(dir.path().join("pkg/model.js").exists());
896    }
897
898    #[test]
899    fn generated_lib_rs_contains_wasm_bindgen() {
900        let config = tiny_config();
901        let graph = graph_builder::build_graph(&config).unwrap();
902        let dir = tempfile::tempdir().unwrap();
903        generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
904
905        let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
906        assert!(lib_rs.contains("use wasm_bindgen::prelude::*;"));
907    }
908
909    #[test]
910    fn generated_lib_rs_contains_wasm_model() {
911        let config = tiny_config();
912        let graph = graph_builder::build_graph(&config).unwrap();
913        let dir = tempfile::tempdir().unwrap();
914        generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
915
916        let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
917        assert!(lib_rs.contains("pub struct WasmModel"));
918    }
919
920    #[test]
921    fn generated_lib_rs_contains_dot_f32_kernel() {
922        let config = tiny_config();
923        let graph = graph_builder::build_graph(&config).unwrap();
924        let dir = tempfile::tempdir().unwrap();
925        generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
926
927        let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
928        assert!(lib_rs.contains("fn dot_f32("));
929        assert!(lib_rs.contains("simd128"));
930    }
931
932    #[test]
933    fn generated_cargo_toml_has_cdylib() {
934        let config = tiny_config();
935        let graph = graph_builder::build_graph(&config).unwrap();
936        let dir = tempfile::tempdir().unwrap();
937        generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
938
939        let cargo_toml = std::fs::read_to_string(dir.path().join("Cargo.toml")).unwrap();
940        assert!(cargo_toml.contains("cdylib"));
941        assert!(cargo_toml.contains("wasm-bindgen"));
942    }
943
944    #[test]
945    fn generate_placeholder() {
946        let graph = Graph::new("test");
947        // Old placeholder test: just ensure graph creation still works
948        assert_eq!(graph.len(), 0);
949    }
950}