1use std::fmt::Write as FmtWrite;
9use std::fs;
10use std::path::Path;
11
12use forgellm_frontend::ir::*;
13
14#[derive(Debug, thiserror::Error)]
16pub enum WasmCodegenError {
17 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
21 Io(#[from] std::io::Error),
22
23 #[error("format error: {0}")]
24 Fmt(#[from] std::fmt::Error),
25}
26
27pub fn generate_wasm_project(
34 graph: &Graph,
35 output_dir: &Path,
36 model_name: &str,
37) -> Result<(), WasmCodegenError> {
38 let config = graph
39 .config
40 .as_ref()
41 .ok_or(WasmCodegenError::MissingConfig)?;
42
43 let src_dir = output_dir.join("src");
44 let pkg_dir = output_dir.join("pkg");
45 fs::create_dir_all(&src_dir)?;
46 fs::create_dir_all(&pkg_dir)?;
47
48 fs::write(
50 output_dir.join("Cargo.toml"),
51 generate_cargo_toml(model_name),
52 )?;
53
54 let lib_code = generate_lib_rs(graph, config)?;
56 fs::write(src_dir.join("lib.rs"), lib_code)?;
57
58 fs::write(pkg_dir.join("model.js"), generate_model_js())?;
60
61 Ok(())
62}
63
64fn sanitize_name(name: &str) -> String {
65 name.to_lowercase()
66 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
67 .trim_matches('-')
68 .to_string()
69}
70
71fn generate_cargo_toml(model_name: &str) -> String {
72 let sanitized = sanitize_name(model_name);
73 format!(
74 r#"[package]
75name = "{sanitized}"
76version = "0.1.0"
77edition = "2021"
78
79[lib]
80crate-type = ["cdylib"]
81
82[dependencies]
83wasm-bindgen = "0.2"
84js-sys = "0.3"
85getrandom = {{ version = "0.2", features = ["js"] }}
86console_error_panic_hook = "0.1"
87
88[profile.release]
89opt-level = 3
90lto = "fat"
91codegen-units = 1
92panic = "abort"
93"#
94 )
95}
96
97fn generate_model_js() -> String {
98 r#"// model.js - JS glue for ForgeLLM WASM model
99export async function loadModel(wasmUrl, weightsUrl) {
100 const { default: init, WasmModel } = await import(wasmUrl);
101 await init();
102 const weightsResp = await fetch(weightsUrl);
103 const weightsBytes = new Uint8Array(await weightsResp.arrayBuffer());
104 return new WasmModel(weightsBytes);
105}
106"#
107 .to_string()
108}
109
110fn generate_lib_rs(graph: &Graph, config: &ModelConfig) -> Result<String, WasmCodegenError> {
111 let mut code = String::with_capacity(32 * 1024);
112
113 emit_lib_header(&mut code, config)?;
114 emit_wasm_kernels(&mut code)?;
115 emit_wasm_specialized_matmul_functions(&mut code, config)?;
116 emit_wasm_forward_function(&mut code, graph, config)?;
117 emit_wasm_bindgen_exports(&mut code, config)?;
118
119 Ok(code)
120}
121
122fn emit_lib_header(code: &mut String, config: &ModelConfig) -> Result<(), WasmCodegenError> {
123 writeln!(code, "//! Auto-generated by ForgeLLM WASM codegen.")?;
124 writeln!(
125 code,
126 "//! Model: {} ({} layers, hidden={})",
127 config.architecture, config.num_layers, config.hidden_size
128 )?;
129 writeln!(code, "//!")?;
130 writeln!(
131 code,
132 "//! Targets wasm32-unknown-unknown with optional SIMD128 acceleration."
133 )?;
134 writeln!(code)?;
135 writeln!(code, "#![allow(clippy::excessive_precision)]")?;
136 writeln!(
137 code,
138 "#![allow(dead_code, unused_imports, unused_assignments)]"
139 )?;
140 writeln!(code)?;
141 writeln!(code, "use wasm_bindgen::prelude::*;")?;
142 writeln!(code)?;
143 writeln!(code, "// Model constants")?;
144 writeln!(
145 code,
146 "pub const HIDDEN_SIZE: usize = {};",
147 config.hidden_size
148 )?;
149 writeln!(
150 code,
151 "pub const INTERMEDIATE_SIZE: usize = {};",
152 config.intermediate_size
153 )?;
154 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
155 writeln!(
156 code,
157 "pub const NUM_HEADS: usize = {};",
158 config.num_attention_heads
159 )?;
160 writeln!(
161 code,
162 "pub const NUM_KV_HEADS: usize = {};",
163 config.num_kv_heads
164 )?;
165 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
166 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
167 let effective_seq_len = config.max_seq_len.min(4096);
168 writeln!(
169 code,
170 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
171 effective_seq_len, config.max_seq_len
172 )?;
173 writeln!(
174 code,
175 "pub const RMS_NORM_EPS: f32 = {:e};",
176 config.rms_norm_eps
177 )?;
178 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
179 writeln!(code)?;
180
181 Ok(())
182}
183
184fn emit_wasm_kernels(code: &mut String) -> Result<(), WasmCodegenError> {
185 code.push_str(
186 r#"
187// --- WASM SIMD128 dot product ---
188#[cfg(target_feature = "simd128")]
189#[inline]
190fn dot_f32(a: &[f32], b: &[f32], len: usize) -> f32 {
191 use std::arch::wasm32::*;
192 unsafe {
193 let mut acc = f32x4_splat(0.0);
194 let chunks = len / 4;
195 for i in 0..chunks {
196 let base = i * 4;
197 let va = v128_load(a.as_ptr().add(base) as *const v128);
198 let vb = v128_load(b.as_ptr().add(base) as *const v128);
199 acc = f32x4_add(acc, f32x4_mul(va, vb));
200 }
201 let s = f32x4_extract_lane::<0>(acc) + f32x4_extract_lane::<1>(acc)
202 + f32x4_extract_lane::<2>(acc) + f32x4_extract_lane::<3>(acc);
203 let mut r = s;
204 for i in (chunks * 4)..len { r += *a.get_unchecked(i) * *b.get_unchecked(i); }
205 r
206 }
207}
208
209#[cfg(not(target_feature = "simd128"))]
210#[inline]
211fn dot_f32(a: &[f32], b: &[f32], len: usize) -> f32 {
212 let mut sum = 0.0f32;
213 for i in 0..len { sum += a[i] * b[i]; }
214 sum
215}
216
217#[inline]
218pub fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
219 let n = input.len();
220 let sum_sq = dot_f32(input, input, n);
221 let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
222 for i in 0..n { output[i] = input[i] * inv_rms * weight[i]; }
223}
224
225#[inline]
226pub fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
227 for i in 0..m {
228 let row = &input[i*k..(i+1)*k];
229 for j in 0..n {
230 output[i*n+j] = dot_f32(row, &weight[j*k..(j+1)*k], k);
231 }
232 }
233}
234
235#[inline]
236pub fn silu(output: &mut [f32], input: &[f32]) {
237 for (o, &x) in output.iter_mut().zip(input.iter()) { *o = x / (1.0 + (-x).exp()); }
238}
239
240#[inline]
241pub fn silu_mul(output: &mut [f32], gate: &[f32], up: &[f32]) {
242 for i in 0..gate.len() {
243 let x = gate[i];
244 output[i] = (x / (1.0 + (-x).exp())) * up[i];
245 }
246}
247
248#[inline]
249pub fn residual_add(a: &mut [f32], b: &[f32]) {
250 for i in 0..a.len() { a[i] += b[i]; }
251}
252
253#[inline]
254pub fn softmax(values: &mut [f32]) {
255 let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
256 let mut sum = 0.0f32;
257 for v in values.iter_mut() { *v = (*v - max_val).exp(); sum += *v; }
258 let inv = if sum > 0.0 { 1.0 / sum } else { 0.0 };
259 for v in values.iter_mut() { *v *= inv; }
260}
261
262#[inline]
263pub fn rope_freqs(head_dim: usize, theta: f32) -> Vec<f32> {
264 (0..head_dim / 2).map(|i| 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32)).collect()
265}
266
267#[inline]
268pub fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, freqs: &[f32]) {
269 let half = head_dim / 2;
270 let mut cos_table = vec![0.0f32; half];
271 let mut sin_table = vec![0.0f32; half];
272 for i in 0..half {
273 let angle = pos as f32 * freqs[i];
274 let (s, c) = angle.sin_cos();
275 cos_table[i] = c;
276 sin_table[i] = s;
277 }
278 for h in 0..num_heads {
279 let off = h * head_dim;
280 for i in 0..half {
281 let (x0, x1) = (data[off + 2*i], data[off + 2*i + 1]);
282 data[off + 2*i] = x0 * cos_table[i] - x1 * sin_table[i];
283 data[off + 2*i + 1] = x0 * sin_table[i] + x1 * cos_table[i];
284 }
285 }
286}
287
288#[inline]
289pub fn attention(
290 output: &mut [f32], q: &[f32], k_cache: &[f32], v_cache: &[f32],
291 seq_len: usize, num_heads: usize, num_kv_heads: usize, head_dim: usize,
292) {
293 let gsize = num_heads / num_kv_heads;
294 let scale = 1.0 / (head_dim as f32).sqrt();
295 let kv_stride = num_kv_heads * head_dim;
296 let mut scores = vec![0.0f32; seq_len];
297 for h in 0..num_heads {
298 let kv_h = h / gsize;
299 let qo = h * head_dim;
300 for t in 0..seq_len {
301 let ko = t * kv_stride + kv_h * head_dim;
302 scores[t] = dot_f32(&q[qo..qo+head_dim], &k_cache[ko..ko+head_dim], head_dim) * scale;
303 }
304 softmax(&mut scores[..seq_len]);
305 for d in 0..head_dim {
306 let mut sum = 0.0f32;
307 for t in 0..seq_len {
308 sum += scores[t] * v_cache[t * kv_stride + kv_h * head_dim + d];
309 }
310 output[qo+d] = sum;
311 }
312 }
313}
314
315#[inline]
316pub fn embedding(output: &mut [f32], token_id: u32, weight: &[f32], embed_dim: usize) {
317 let off = token_id as usize * embed_dim;
318 output.copy_from_slice(&weight[off..off + embed_dim]);
319}
320
321"#,
322 );
323
324 Ok(())
325}
326
327fn matmul_shapes(config: &ModelConfig) -> Vec<(usize, usize)> {
329 let hidden = config.hidden_size;
330 let intermediate = config.intermediate_size;
331 let num_heads = config.num_attention_heads;
332 let num_kv_heads = config.num_kv_heads;
333 let head_dim = config.head_dim;
334 let vocab = config.vocab_size;
335 let qk_size = num_heads * head_dim;
336 let kv_size = num_kv_heads * head_dim;
337
338 let mut shapes = vec![
339 (hidden, qk_size), (hidden, kv_size), (qk_size, hidden), (hidden, intermediate), (intermediate, hidden), (hidden, vocab), ];
346 shapes.sort();
347 shapes.dedup();
348 shapes
349}
350
351fn emit_wasm_specialized_matmul_functions(
353 code: &mut String,
354 config: &ModelConfig,
355) -> Result<(), WasmCodegenError> {
356 writeln!(
357 code,
358 "// --- Shape-specialized matmul functions (m=1, single-threaded) ---"
359 )?;
360 writeln!(
361 code,
362 "// All dimensions baked in at compile time — no runtime size parameters."
363 )?;
364 writeln!(code)?;
365
366 for &(k, n) in &matmul_shapes(config) {
367 writeln!(
368 code,
369 "/// Specialized matmul: [1, {k}] x [{n}, {k}]^T -> [1, {n}]"
370 )?;
371 writeln!(code, "#[inline]")?;
372 writeln!(
373 code,
374 "fn matmul_vec_{k}x{n}(output: &mut [f32; {n}], input: &[f32; {k}], weight: &[f32]) {{"
375 )?;
376 let n_chunks = n / 4;
377 let n_remainder = n % 4;
378 if n_chunks > 0 {
379 writeln!(
380 code,
381 " // Process 4 output rows at a time for instruction-level parallelism"
382 )?;
383 writeln!(code, " for chunk in 0..{n_chunks} {{")?;
384 writeln!(code, " let j0 = chunk * 4;")?;
385 writeln!(
386 code,
387 " output[j0] = dot_f32(&input[..], &weight[j0*{k}..(j0+1)*{k}], {k});"
388 )?;
389 writeln!(
390 code,
391 " output[j0+1] = dot_f32(&input[..], &weight[(j0+1)*{k}..(j0+2)*{k}], {k});"
392 )?;
393 writeln!(
394 code,
395 " output[j0+2] = dot_f32(&input[..], &weight[(j0+2)*{k}..(j0+3)*{k}], {k});"
396 )?;
397 writeln!(
398 code,
399 " output[j0+3] = dot_f32(&input[..], &weight[(j0+3)*{k}..(j0+4)*{k}], {k});"
400 )?;
401 writeln!(code, " }}")?;
402 }
403 if n_remainder > 0 {
404 writeln!(code, " // Handle remaining {n_remainder} output rows")?;
405 writeln!(code, " let base = {n_chunks} * 4;")?;
406 for r in 0..n_remainder {
407 writeln!(code, " output[base+{r}] = dot_f32(&input[..], &weight[(base+{r})*{k}..(base+{r}+1)*{k}], {k});")?;
408 }
409 }
410 writeln!(code, "}}")?;
411 writeln!(code)?;
412 }
413
414 Ok(())
415}
416
417fn emit_wasm_forward_function(
418 code: &mut String,
419 _graph: &Graph,
420 config: &ModelConfig,
421) -> Result<(), WasmCodegenError> {
422 let hidden = config.hidden_size;
423 let intermediate = config.intermediate_size;
424 let num_heads = config.num_attention_heads;
425 let num_kv_heads = config.num_kv_heads;
426 let head_dim = config.head_dim;
427 let vocab = config.vocab_size;
428 let qk_size = num_heads * head_dim;
429 let kv_size = num_kv_heads * head_dim;
430
431 writeln!(
433 code,
434 "/// Model weights — loaded once, passed to forward()."
435 )?;
436 writeln!(code, "pub struct Weights {{")?;
437 writeln!(
438 code,
439 " pub embed_tokens: Vec<f32>, // [{vocab} * {hidden}]"
440 )?;
441 writeln!(code, " pub layers: Vec<LayerWeights>,")?;
442 writeln!(code, " pub final_norm: Vec<f32>, // [{hidden}]")?;
443 writeln!(
444 code,
445 " pub lm_head: Vec<f32>, // [{vocab} * {hidden}]"
446 )?;
447 writeln!(code, "}}")?;
448 writeln!(code)?;
449
450 writeln!(code, "pub struct LayerWeights {{")?;
451 writeln!(code, " pub attn_norm: Vec<f32>, // [{hidden}]")?;
452 writeln!(
453 code,
454 " pub q_proj: Vec<f32>, // [{} * {hidden}]",
455 num_heads * head_dim
456 )?;
457 writeln!(
458 code,
459 " pub k_proj: Vec<f32>, // [{} * {hidden}]",
460 num_kv_heads * head_dim
461 )?;
462 writeln!(
463 code,
464 " pub v_proj: Vec<f32>, // [{} * {hidden}]",
465 num_kv_heads * head_dim
466 )?;
467 writeln!(
468 code,
469 " pub o_proj: Vec<f32>, // [{hidden} * {}]",
470 num_heads * head_dim
471 )?;
472 writeln!(code, " pub ffn_norm: Vec<f32>, // [{hidden}]")?;
473 writeln!(
474 code,
475 " pub gate_proj: Vec<f32>, // [{intermediate} * {hidden}]"
476 )?;
477 writeln!(
478 code,
479 " pub up_proj: Vec<f32>, // [{intermediate} * {hidden}]"
480 )?;
481 writeln!(
482 code,
483 " pub down_proj: Vec<f32>, // [{hidden} * {intermediate}]"
484 )?;
485 writeln!(code, "}}")?;
486 writeln!(code)?;
487
488 writeln!(code, "/// KV cache for autoregressive generation.")?;
490 writeln!(code, "pub struct KVCache {{")?;
491 writeln!(
492 code,
493 " pub k: Vec<Vec<f32>>, // [num_layers][MAX_SEQ_LEN * {kv_size}]"
494 )?;
495 writeln!(
496 code,
497 " pub v: Vec<Vec<f32>>, // [num_layers][MAX_SEQ_LEN * {kv_size}]"
498 )?;
499 writeln!(code, " pub len: usize,")?;
500 writeln!(code, "}}")?;
501 writeln!(code)?;
502
503 writeln!(code, "impl KVCache {{")?;
504 writeln!(code, " pub fn new() -> Self {{")?;
505 writeln!(code, " Self {{")?;
506 writeln!(
507 code,
508 " k: (0..NUM_LAYERS).map(|_| vec![0.0f32; MAX_SEQ_LEN * {kv_size}]).collect(),"
509 )?;
510 writeln!(
511 code,
512 " v: (0..NUM_LAYERS).map(|_| vec![0.0f32; MAX_SEQ_LEN * {kv_size}]).collect(),"
513 )?;
514 writeln!(code, " len: 0,")?;
515 writeln!(code, " }}")?;
516 writeln!(code, " }}")?;
517 writeln!(code)?;
518 writeln!(code, " pub fn reset(&mut self) {{")?;
519 writeln!(code, " self.len = 0;")?;
520 writeln!(code, " }}")?;
521 writeln!(code, "}}")?;
522 writeln!(code)?;
523 writeln!(code, "impl Default for KVCache {{")?;
524 writeln!(code, " fn default() -> Self {{ Self::new() }}")?;
525 writeln!(code, "}}")?;
526 writeln!(code)?;
527
528 writeln!(
530 code,
531 "/// Run forward pass for a single token. Returns logits [{vocab}]."
532 )?;
533 writeln!(
534 code,
535 "pub fn forward(token_id: u32, weights: &Weights, cache: &mut KVCache) -> Vec<f32> {{"
536 )?;
537 writeln!(code, " let pos = cache.len;")?;
538 writeln!(code)?;
539
540 writeln!(code, " // Embedding lookup")?;
541 writeln!(code, " let mut hidden_state = [0.0f32; HIDDEN_SIZE];")?;
542 writeln!(
543 code,
544 " embedding(&mut hidden_state, token_id, &weights.embed_tokens, HIDDEN_SIZE);"
545 )?;
546 writeln!(code)?;
547
548 writeln!(code, " // Fixed-size buffers")?;
549 writeln!(code, " let mut normed = [0.0f32; {hidden}];")?;
550 writeln!(code, " let mut q = [0.0f32; {qk_size}];")?;
551 writeln!(code, " let mut k = [0.0f32; {kv_size}];")?;
552 writeln!(code, " let mut v = [0.0f32; {kv_size}];")?;
553 writeln!(code, " let mut attn_out = [0.0f32; {qk_size}];")?;
554 writeln!(code, " let mut attn_proj = [0.0f32; {hidden}];")?;
555 writeln!(code, " let mut gate = [0.0f32; {intermediate}];")?;
556 writeln!(code, " let mut up = [0.0f32; {intermediate}];")?;
557 writeln!(code, " let mut ffn_hidden = [0.0f32; {intermediate}];")?;
558 writeln!(code, " let mut ffn_out = [0.0f32; {hidden}];")?;
559 writeln!(code)?;
560 writeln!(
561 code,
562 " let rope_freqs = rope_freqs(HEAD_DIM, ROPE_THETA);"
563 )?;
564 writeln!(code)?;
565
566 writeln!(code, " // Transformer layers")?;
567 writeln!(code, " for layer_idx in 0..NUM_LAYERS {{")?;
568 writeln!(code, " let lw = &weights.layers[layer_idx];")?;
569 writeln!(code)?;
570 writeln!(code, " // Attention norm")?;
571 writeln!(
572 code,
573 " rms_norm(&mut normed, &hidden_state, &lw.attn_norm, RMS_NORM_EPS);"
574 )?;
575 writeln!(code)?;
576 writeln!(code, " // QKV projections")?;
577 writeln!(
578 code,
579 " matmul_vec_{hidden}x{qk_size}(&mut q, &normed, &lw.q_proj);"
580 )?;
581 writeln!(
582 code,
583 " matmul_vec_{hidden}x{kv_size}(&mut k, &normed, &lw.k_proj);"
584 )?;
585 writeln!(
586 code,
587 " matmul_vec_{hidden}x{kv_size}(&mut v, &normed, &lw.v_proj);"
588 )?;
589 writeln!(code)?;
590 writeln!(code, " // RoPE")?;
591 writeln!(
592 code,
593 " rope(&mut q, pos, HEAD_DIM, NUM_HEADS, &rope_freqs);"
594 )?;
595 writeln!(
596 code,
597 " rope(&mut k, pos, HEAD_DIM, NUM_KV_HEADS, &rope_freqs);"
598 )?;
599 writeln!(code)?;
600 writeln!(code, " // Update KV cache")?;
601 writeln!(
602 code,
603 " cache.k[layer_idx][pos*{kv_size}..(pos+1)*{kv_size}].copy_from_slice(&k);"
604 )?;
605 writeln!(
606 code,
607 " cache.v[layer_idx][pos*{kv_size}..(pos+1)*{kv_size}].copy_from_slice(&v);"
608 )?;
609 writeln!(code)?;
610 writeln!(code, " // Attention")?;
611 writeln!(code, " attention(")?;
612 writeln!(code, " &mut attn_out, &q,")?;
613 writeln!(
614 code,
615 " &cache.k[layer_idx][..(pos+1)*{kv_size}], &cache.v[layer_idx][..(pos+1)*{kv_size}],"
616 )?;
617 writeln!(
618 code,
619 " pos + 1, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM,"
620 )?;
621 writeln!(code, " );")?;
622 writeln!(code)?;
623 writeln!(code, " // Output projection + residual")?;
624 writeln!(
625 code,
626 " matmul_vec_{qk_size}x{hidden}(&mut attn_proj, &attn_out, &lw.o_proj);"
627 )?;
628 writeln!(code, " residual_add(&mut hidden_state, &attn_proj);")?;
629 writeln!(code)?;
630 writeln!(code, " // FFN norm")?;
631 writeln!(
632 code,
633 " rms_norm(&mut normed, &hidden_state, &lw.ffn_norm, RMS_NORM_EPS);"
634 )?;
635 writeln!(code)?;
636 writeln!(code, " // FFN: fused silu_mul")?;
637 writeln!(
638 code,
639 " matmul_vec_{hidden}x{intermediate}(&mut gate, &normed, &lw.gate_proj);"
640 )?;
641 writeln!(
642 code,
643 " matmul_vec_{hidden}x{intermediate}(&mut up, &normed, &lw.up_proj);"
644 )?;
645 writeln!(code, " silu_mul(&mut ffn_hidden, &gate, &up);")?;
646 writeln!(
647 code,
648 " matmul_vec_{intermediate}x{hidden}(&mut ffn_out, &ffn_hidden, &lw.down_proj);"
649 )?;
650 writeln!(code)?;
651 writeln!(code, " residual_add(&mut hidden_state, &ffn_out);")?;
652 writeln!(code, " }}")?;
653 writeln!(code)?;
654
655 writeln!(code, " // Final norm")?;
656 writeln!(
657 code,
658 " rms_norm(&mut normed, &hidden_state, &weights.final_norm, RMS_NORM_EPS);"
659 )?;
660 writeln!(code)?;
661
662 writeln!(code, " // Logits projection")?;
663 writeln!(code, " let mut logits = vec![0.0f32; VOCAB_SIZE];")?;
664 writeln!(code, " for j in 0..VOCAB_SIZE {{")?;
665 writeln!(
666 code,
667 " logits[j] = dot_f32(&normed[..], &weights.lm_head[j*{hidden}..(j+1)*{hidden}], {hidden});"
668 )?;
669 writeln!(code, " }}")?;
670 writeln!(code)?;
671 writeln!(code, " cache.len += 1;")?;
672 writeln!(code, " logits")?;
673 writeln!(code, "}}")?;
674 writeln!(code)?;
675
676 Ok(())
677}
678
679fn emit_wasm_bindgen_exports(
680 code: &mut String,
681 config: &ModelConfig,
682) -> Result<(), WasmCodegenError> {
683 let hidden = config.hidden_size;
684 let num_layers = config.num_layers;
685 let num_heads = config.num_attention_heads;
686 let num_kv_heads = config.num_kv_heads;
687 let head_dim = config.head_dim;
688 let vocab = config.vocab_size;
689 let intermediate = config.intermediate_size;
690 let qk_size = num_heads * head_dim;
691 let kv_size = num_kv_heads * head_dim;
692
693 writeln!(
694 code,
695 "/// Initialize panic hook for better error messages in browser console."
696 )?;
697 writeln!(code, "#[wasm_bindgen]")?;
698 writeln!(code, "pub fn init_panic_hook() {{")?;
699 writeln!(code, " console_error_panic_hook::set_once();")?;
700 writeln!(code, "}}")?;
701 writeln!(code)?;
702
703 writeln!(
704 code,
705 "/// WASM-exported model handle. Holds weights + KV cache."
706 )?;
707 writeln!(code, "#[wasm_bindgen]")?;
708 writeln!(code, "pub struct WasmModel {{")?;
709 writeln!(code, " weights: Weights,")?;
710 writeln!(code, " cache: KVCache,")?;
711 writeln!(code, "}}")?;
712 writeln!(code)?;
713
714 writeln!(code, "#[wasm_bindgen]")?;
715 writeln!(code, "impl WasmModel {{")?;
716
717 let embed_elems = vocab * hidden;
719 let final_norm_elems = hidden;
720 let lm_head_elems = vocab * hidden;
721 let attn_norm_elems = hidden;
722 let q_proj_elems = qk_size * hidden;
723 let k_proj_elems = kv_size * hidden;
724 let v_proj_elems = kv_size * hidden;
725 let o_proj_elems = hidden * qk_size;
726 let ffn_norm_elems = hidden;
727 let gate_proj_elems = intermediate * hidden;
728 let up_proj_elems = intermediate * hidden;
729 let down_proj_elems = hidden * intermediate;
730
731 let layer_elems = attn_norm_elems
732 + q_proj_elems
733 + k_proj_elems
734 + v_proj_elems
735 + o_proj_elems
736 + ffn_norm_elems
737 + gate_proj_elems
738 + up_proj_elems
739 + down_proj_elems;
740
741 writeln!(code, " /// Load model from raw f32 weight bytes.")?;
742 writeln!(
743 code,
744 " /// Expected layout: embed_tokens | layer0 | layer1 | ... | final_norm | lm_head"
745 )?;
746 writeln!(code, " #[wasm_bindgen(constructor)]")?;
747 writeln!(code, " pub fn new(weights_bytes: &[u8]) -> WasmModel {{")?;
748 writeln!(code, " init_panic_hook();")?;
749 writeln!(code, " // Parse f32 weight bytes")?;
750 writeln!(code, " let n = weights_bytes.len() / 4;")?;
751 writeln!(code, " let mut raw = vec![0.0f32; n];")?;
752 writeln!(code, " for i in 0..n {{")?;
753 writeln!(
754 code,
755 " raw[i] = f32::from_le_bytes([weights_bytes[i*4], weights_bytes[i*4+1], weights_bytes[i*4+2], weights_bytes[i*4+3]]);"
756 )?;
757 writeln!(code, " }}")?;
758 writeln!(code, " let mut off = 0usize;")?;
759 writeln!(
760 code,
761 " let embed_tokens = raw[off..off+{embed_elems}].to_vec(); off += {embed_elems};"
762 )?;
763 writeln!(
764 code,
765 " let mut layers = Vec::with_capacity({num_layers});"
766 )?;
767 writeln!(code, " for _ in 0..{num_layers} {{")?;
768 writeln!(
769 code,
770 " let attn_norm = raw[off..off+{attn_norm_elems}].to_vec(); off += {attn_norm_elems};"
771 )?;
772 writeln!(
773 code,
774 " let q_proj = raw[off..off+{q_proj_elems}].to_vec(); off += {q_proj_elems};"
775 )?;
776 writeln!(
777 code,
778 " let k_proj = raw[off..off+{k_proj_elems}].to_vec(); off += {k_proj_elems};"
779 )?;
780 writeln!(
781 code,
782 " let v_proj = raw[off..off+{v_proj_elems}].to_vec(); off += {v_proj_elems};"
783 )?;
784 writeln!(
785 code,
786 " let o_proj = raw[off..off+{o_proj_elems}].to_vec(); off += {o_proj_elems};"
787 )?;
788 writeln!(
789 code,
790 " let ffn_norm = raw[off..off+{ffn_norm_elems}].to_vec(); off += {ffn_norm_elems};"
791 )?;
792 writeln!(
793 code,
794 " let gate_proj = raw[off..off+{gate_proj_elems}].to_vec(); off += {gate_proj_elems};"
795 )?;
796 writeln!(
797 code,
798 " let up_proj = raw[off..off+{up_proj_elems}].to_vec(); off += {up_proj_elems};"
799 )?;
800 writeln!(
801 code,
802 " let down_proj = raw[off..off+{down_proj_elems}].to_vec(); off += {down_proj_elems};"
803 )?;
804 writeln!(code, " layers.push(LayerWeights {{ attn_norm, q_proj, k_proj, v_proj, o_proj, ffn_norm, gate_proj, up_proj, down_proj }});")?;
805 writeln!(code, " }}")?;
806 writeln!(
807 code,
808 " let final_norm = raw[off..off+{final_norm_elems}].to_vec(); off += {final_norm_elems};"
809 )?;
810 writeln!(
811 code,
812 " let lm_head = raw[off..off+{lm_head_elems}].to_vec();"
813 )?;
814 writeln!(
815 code,
816 " let _ = ({layer_elems}, {embed_elems}, {lm_head_elems}, {final_norm_elems}); // suppress unused warnings"
817 )?;
818 writeln!(
819 code,
820 " let weights = Weights {{ embed_tokens, layers, final_norm, lm_head }};"
821 )?;
822 writeln!(code, " let cache = KVCache::new();")?;
823 writeln!(code, " WasmModel {{ weights, cache }}")?;
824 writeln!(code, " }}")?;
825 writeln!(code)?;
826
827 writeln!(
828 code,
829 " /// Run a single forward step. Returns logit for most-likely next token."
830 )?;
831 writeln!(
832 code,
833 " pub fn forward(&mut self, token_id: u32) -> u32 {{"
834 )?;
835 writeln!(
836 code,
837 " let logits = forward(token_id, &self.weights, &mut self.cache);"
838 )?;
839 writeln!(code, " // Argmax sampling")?;
840 writeln!(code, " let mut best = 0usize;")?;
841 writeln!(code, " let mut best_val = f32::NEG_INFINITY;")?;
842 writeln!(code, " for (i, &v) in logits.iter().enumerate() {{")?;
843 writeln!(
844 code,
845 " if v > best_val {{ best_val = v; best = i; }}"
846 )?;
847 writeln!(code, " }}")?;
848 writeln!(code, " best as u32")?;
849 writeln!(code, " }}")?;
850 writeln!(code)?;
851
852 writeln!(code, " /// Reset the KV cache (start a new generation).")?;
853 writeln!(code, " pub fn reset_cache(&mut self) {{")?;
854 writeln!(code, " self.cache.reset();")?;
855 writeln!(code, " }}")?;
856 writeln!(code, "}}")?;
857
858 Ok(())
859}
860
861#[cfg(test)]
862mod tests {
863 use super::*;
864 use forgellm_frontend::{graph_builder, ir::ModelConfig};
865
866 fn tiny_config() -> ModelConfig {
867 ModelConfig {
868 architecture: Architecture::Llama,
869 hidden_size: 64,
870 intermediate_size: 128,
871 num_layers: 2,
872 num_attention_heads: 4,
873 num_kv_heads: 2,
874 head_dim: 16,
875 vocab_size: 256,
876 max_seq_len: 64,
877 rms_norm_eps: 1e-5,
878 rope_theta: 10000.0,
879 dtype: DType::F16,
880 sliding_window_size: None,
881 qkv_bias: false,
882 hidden_activation: HiddenActivation::SiLU,
883 }
884 }
885
886 #[test]
887 fn generate_wasm_project_creates_all_files() {
888 let config = tiny_config();
889 let graph = graph_builder::build_graph(&config).unwrap();
890 let dir = tempfile::tempdir().unwrap();
891 generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
892
893 assert!(dir.path().join("Cargo.toml").exists());
894 assert!(dir.path().join("src/lib.rs").exists());
895 assert!(dir.path().join("pkg/model.js").exists());
896 }
897
898 #[test]
899 fn generated_lib_rs_contains_wasm_bindgen() {
900 let config = tiny_config();
901 let graph = graph_builder::build_graph(&config).unwrap();
902 let dir = tempfile::tempdir().unwrap();
903 generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
904
905 let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
906 assert!(lib_rs.contains("use wasm_bindgen::prelude::*;"));
907 }
908
909 #[test]
910 fn generated_lib_rs_contains_wasm_model() {
911 let config = tiny_config();
912 let graph = graph_builder::build_graph(&config).unwrap();
913 let dir = tempfile::tempdir().unwrap();
914 generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
915
916 let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
917 assert!(lib_rs.contains("pub struct WasmModel"));
918 }
919
920 #[test]
921 fn generated_lib_rs_contains_dot_f32_kernel() {
922 let config = tiny_config();
923 let graph = graph_builder::build_graph(&config).unwrap();
924 let dir = tempfile::tempdir().unwrap();
925 generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
926
927 let lib_rs = std::fs::read_to_string(dir.path().join("src/lib.rs")).unwrap();
928 assert!(lib_rs.contains("fn dot_f32("));
929 assert!(lib_rs.contains("simd128"));
930 }
931
932 #[test]
933 fn generated_cargo_toml_has_cdylib() {
934 let config = tiny_config();
935 let graph = graph_builder::build_graph(&config).unwrap();
936 let dir = tempfile::tempdir().unwrap();
937 generate_wasm_project(&graph, dir.path(), "test-model").unwrap();
938
939 let cargo_toml = std::fs::read_to_string(dir.path().join("Cargo.toml")).unwrap();
940 assert!(cargo_toml.contains("cdylib"));
941 assert!(cargo_toml.contains("wasm-bindgen"));
942 }
943
944 #[test]
945 fn generate_placeholder() {
946 let graph = Graph::new("test");
947 assert_eq!(graph.len(), 0);
949 }
950}