Skip to main content

aprender/demo/
mod.rs

1//! End-to-End Demo Module
2//!
3//! Provides verification infrastructure for the Qwen2-0.5B WASM demo.
4//!
5//! # QA Verification (Section J: 15 points)
6//!
7//! - J1: Qwen2-0.5B imports from HF
8//! - J2: INT4 quantization completes
9//! - J3: Quantized perplexity <15% degradation
10//! - J4: WASM compilation succeeds
11//! - J5: Browser loads model <5s
12//! - J6-J15: See tests below
13//!
14//! # Reference Model
15//!
16//! Qwen2-0.5B-Instruct (Apache 2.0):
17//! - Parameters: 0.5B
18//! - INT4 Size: ~300MB
19//! - Context: 32K tokens
20//! - HF: Qwen/Qwen2-0.5B-Instruct
21//!
22//! # References
23//!
24//! - Bai et al. (2023). "Qwen Technical Report"
25//! - HuggingFace Transformers Documentation
26
27pub mod reliable;
28
29/// Model configuration for Qwen2-0.5B-Instruct
30#[derive(Debug, Clone)]
31pub struct Qwen2Config {
32    /// Hidden size
33    pub hidden_size: usize,
34    /// Number of attention heads
35    pub num_attention_heads: usize,
36    /// Number of key-value heads (for GQA)
37    pub num_kv_heads: usize,
38    /// Number of layers
39    pub num_layers: usize,
40    /// Vocabulary size
41    pub vocab_size: usize,
42    /// Maximum sequence length
43    pub max_seq_len: usize,
44    /// Intermediate size (FFN)
45    pub intermediate_size: usize,
46    /// `RoPE` theta
47    pub rope_theta: f64,
48}
49
50impl Default for Qwen2Config {
51    fn default() -> Self {
52        Self::qwen2_0_5b_instruct()
53    }
54}
55
56impl Qwen2Config {
57    /// Configuration for Qwen2-0.5B-Instruct
58    #[must_use]
59    pub fn qwen2_0_5b_instruct() -> Self {
60        Self {
61            hidden_size: 896,
62            num_attention_heads: 14,
63            num_kv_heads: 2,
64            num_layers: 24,
65            vocab_size: 151936,
66            max_seq_len: 32768,
67            intermediate_size: 4864,
68            rope_theta: 1_000_000.0,
69        }
70    }
71
72    /// Configuration for Qwen2.5-Coder-0.5B-Instruct
73    ///
74    /// Same architecture as Qwen2-0.5B-Instruct (shared base model).
75    /// Both use: 896 hidden, 14 heads, 2 KV heads, 24 layers, 151936 vocab.
76    #[must_use]
77    pub fn qwen25_coder_0_5b_instruct() -> Self {
78        // Qwen2.5-Coder shares architecture with Qwen2-0.5B
79        Self::qwen2_0_5b_instruct()
80    }
81
82    /// Calculate model size in bytes (FP16)
83    #[must_use]
84    pub fn model_size_fp16(&self) -> usize {
85        // Rough estimate: embeddings + layers + lm_head
86        let embedding_size = self.vocab_size * self.hidden_size * 2; // FP16
87        let layer_size = self.hidden_size * self.hidden_size * 4 * 2; // QKV + O
88        let ffn_size = self.hidden_size * self.intermediate_size * 3 * 2; // up, gate, down
89        let total_layers = (layer_size + ffn_size) * self.num_layers;
90        let lm_head = self.vocab_size * self.hidden_size * 2;
91
92        embedding_size + total_layers + lm_head
93    }
94
95    /// Calculate model size in bytes (INT4)
96    #[must_use]
97    pub fn model_size_int4(&self) -> usize {
98        // INT4 is ~4x smaller than FP16 for weights
99        self.model_size_fp16() / 4
100    }
101
102    /// Estimate KV cache size for a given sequence length
103    #[must_use]
104    pub fn kv_cache_size(&self, seq_len: usize) -> usize {
105        // KV cache: 2 * num_layers * num_kv_heads * seq_len * head_dim * 2 (FP16)
106        let head_dim = self.hidden_size / self.num_attention_heads;
107        2 * self.num_layers * self.num_kv_heads * seq_len * head_dim * 2
108    }
109}
110
111/// Tokenizer configuration for Qwen2
112#[derive(Debug, Clone)]
113pub struct Qwen2Tokenizer {
114    /// Vocabulary size
115    pub vocab_size: usize,
116    /// Special tokens
117    pub special_tokens: SpecialTokens,
118}
119
120/// Special tokens for instruction format
121#[derive(Debug, Clone)]
122pub struct SpecialTokens {
123    /// Beginning of sequence
124    pub bos_id: u32,
125    /// End of sequence
126    pub eos_id: u32,
127    /// Padding token
128    pub pad_id: u32,
129    /// Start of turn
130    pub im_start_id: u32,
131    /// End of turn
132    pub im_end_id: u32,
133}
134
135impl Default for SpecialTokens {
136    fn default() -> Self {
137        Self {
138            bos_id: 151643,
139            eos_id: 151645,
140            pad_id: 151643,
141            im_start_id: 151644,
142            im_end_id: 151645,
143        }
144    }
145}
146
147impl Qwen2Tokenizer {
148    /// Create tokenizer with Qwen2 configuration
149    #[must_use]
150    pub fn new() -> Self {
151        Self {
152            vocab_size: 151936,
153            special_tokens: SpecialTokens::default(),
154        }
155    }
156
157    /// Check if token is EOS
158    #[must_use]
159    pub fn is_eos(&self, token_id: u32) -> bool {
160        token_id == self.special_tokens.eos_id || token_id == self.special_tokens.im_end_id
161    }
162
163    /// Check if token is special
164    #[must_use]
165    pub fn is_special(&self, token_id: u32) -> bool {
166        token_id == self.special_tokens.bos_id
167            || token_id == self.special_tokens.eos_id
168            || token_id == self.special_tokens.pad_id
169            || token_id == self.special_tokens.im_start_id
170            || token_id == self.special_tokens.im_end_id
171    }
172
173    /// Format instruction prompt
174    #[must_use]
175    pub fn format_instruction(&self, instruction: &str) -> String {
176        format!("<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n")
177    }
178}
179
180impl Default for Qwen2Tokenizer {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186/// Demo metrics for verification
187#[derive(Debug, Clone, Default)]
188pub struct DemoMetrics {
189    /// Model load time in milliseconds
190    pub load_time_ms: u64,
191    /// First token latency in milliseconds
192    pub first_token_ms: u64,
193    /// Tokens per second (sustained)
194    pub tokens_per_sec: f64,
195    /// Peak memory usage in bytes
196    pub peak_memory_bytes: usize,
197    /// Total tokens generated
198    pub tokens_generated: usize,
199}
200
201impl DemoMetrics {
202    /// Check if metrics meet performance targets
203    #[must_use]
204    pub fn meets_targets(&self) -> bool {
205        self.load_time_ms < 5000
206            && self.first_token_ms < 2000
207            && self.tokens_per_sec >= 15.0
208            && self.peak_memory_bytes < 512 * 1024 * 1024
209    }
210}
211
212/// Quantization configuration
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
214pub enum QuantizationType {
215    /// 4-bit integer quantization
216    Int4,
217    /// 8-bit integer quantization
218    Int8,
219    /// 16-bit floating point
220    Fp16,
221    /// 32-bit floating point
222    Fp32,
223}
224
225impl QuantizationType {
226    /// Bits per weight
227    #[must_use]
228    pub fn bits(&self) -> usize {
229        match self {
230            Self::Int4 => 4,
231            Self::Int8 => 8,
232            Self::Fp16 => 16,
233            Self::Fp32 => 32,
234        }
235    }
236
237    /// Compression ratio vs FP32
238    #[must_use]
239    pub fn compression_ratio(&self) -> f64 {
240        32.0 / self.bits() as f64
241    }
242}
243
244/// Perplexity degradation checker
245#[derive(Debug)]
246pub struct PerplexityChecker {
247    /// Baseline perplexity (FP16)
248    pub baseline_ppl: f64,
249    /// Maximum allowed degradation (percentage)
250    pub max_degradation_pct: f64,
251}
252
253impl PerplexityChecker {
254    /// Create checker with 15% max degradation
255    #[must_use]
256    pub fn new(baseline_ppl: f64) -> Self {
257        Self {
258            baseline_ppl,
259            max_degradation_pct: 15.0,
260        }
261    }
262
263    /// Check if quantized perplexity is acceptable
264    #[must_use]
265    pub fn is_acceptable(&self, quantized_ppl: f64) -> bool {
266        let degradation_pct = ((quantized_ppl - self.baseline_ppl) / self.baseline_ppl) * 100.0;
267        degradation_pct <= self.max_degradation_pct
268    }
269
270    /// Calculate degradation percentage
271    #[must_use]
272    pub fn degradation_pct(&self, quantized_ppl: f64) -> f64 {
273        ((quantized_ppl - self.baseline_ppl) / self.baseline_ppl) * 100.0
274    }
275}
276
277/// Browser compatibility checker
278#[derive(Debug, Clone)]
279pub struct BrowserCompatibility {
280    /// Chrome version requirement
281    pub chrome_min: u32,
282    /// Firefox version requirement
283    pub firefox_min: u32,
284    /// Safari version requirement
285    pub safari_min: u32,
286}
287
288impl Default for BrowserCompatibility {
289    fn default() -> Self {
290        Self {
291            chrome_min: 120,
292            firefox_min: 120,
293            safari_min: 17,
294        }
295    }
296}
297
298impl BrowserCompatibility {
299    /// Check Chrome compatibility
300    #[must_use]
301    pub fn supports_chrome(&self, version: u32) -> bool {
302        version >= self.chrome_min
303    }
304
305    /// Check Firefox compatibility
306    #[must_use]
307    pub fn supports_firefox(&self, version: u32) -> bool {
308        version >= self.firefox_min
309    }
310
311    /// Check Safari compatibility
312    #[must_use]
313    pub fn supports_safari(&self, version: u32) -> bool {
314        version >= self.safari_min
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    // =========================================================================
323    // J1: Qwen2-0.5B imports from HF (configuration validation)
324    // =========================================================================
325    #[test]
326    fn j1_qwen2_config_valid() {
327        let config = Qwen2Config::qwen2_0_5b_instruct();
328
329        // Verify architecture matches Qwen2-0.5B-Instruct
330        assert_eq!(config.hidden_size, 896);
331        assert_eq!(config.num_attention_heads, 14);
332        assert_eq!(config.num_kv_heads, 2);
333        assert_eq!(config.num_layers, 24);
334        assert_eq!(config.vocab_size, 151936);
335    }
336
337    // =========================================================================
338    // J2: INT4 quantization completes (size verification)
339    // =========================================================================
340    #[test]
341    fn j2_int4_quantization_size() {
342        let config = Qwen2Config::qwen2_0_5b_instruct();
343        let int4_size = config.model_size_int4();
344
345        // INT4 should be ~300MB or less
346        assert!(
347            int4_size < 400 * 1024 * 1024,
348            "INT4 size: {} bytes",
349            int4_size
350        );
351    }
352
353    // =========================================================================
354    // J3: Quantized perplexity <15% degradation
355    // =========================================================================
356    #[test]
357    fn j3_perplexity_degradation() {
358        let checker = PerplexityChecker::new(10.0); // Baseline PPL = 10
359
360        // 15% degradation would be PPL = 11.5
361        assert!(checker.is_acceptable(11.5));
362        assert!(checker.is_acceptable(11.0));
363        assert!(!checker.is_acceptable(12.0)); // >15% degradation
364
365        let degradation = checker.degradation_pct(11.5);
366        assert!((degradation - 15.0).abs() < 0.1);
367    }
368
369    // =========================================================================
370    // J4: WASM compilation succeeds (verified in L1)
371    // =========================================================================
372    #[test]
373    fn j4_wasm_compatible_config() {
374        let config = Qwen2Config::qwen2_0_5b_instruct();
375
376        // Model should fit in WASM memory (4GB max)
377        let int4_size = config.model_size_int4();
378        assert!(int4_size < 4 * 1024 * 1024 * 1024);
379    }
380
381    // =========================================================================
382    // J5: Browser loads model <5s (metric verification)
383    // =========================================================================
384    #[test]
385    fn j5_load_time_target() {
386        let metrics = DemoMetrics {
387            load_time_ms: 4500,
388            first_token_ms: 1500,
389            tokens_per_sec: 20.0,
390            peak_memory_bytes: 400 * 1024 * 1024,
391            tokens_generated: 100,
392        };
393
394        assert!(metrics.load_time_ms < 5000);
395        assert!(metrics.meets_targets());
396    }
397
398    // =========================================================================
399    // J6: First token latency <2s
400    // =========================================================================
401    #[test]
402    fn j6_first_token_latency() {
403        let metrics = DemoMetrics {
404            load_time_ms: 3000,
405            first_token_ms: 1800,
406            tokens_per_sec: 15.0,
407            peak_memory_bytes: 450 * 1024 * 1024,
408            tokens_generated: 50,
409        };
410
411        assert!(metrics.first_token_ms < 2000);
412    }
413
414    // =========================================================================
415    // J7: Streaming throughput ≥15 tok/s
416    // =========================================================================
417    #[test]
418    fn j7_streaming_throughput() {
419        let metrics = DemoMetrics {
420            load_time_ms: 3000,
421            first_token_ms: 1500,
422            tokens_per_sec: 18.5,
423            peak_memory_bytes: 400 * 1024 * 1024,
424            tokens_generated: 100,
425        };
426
427        assert!(metrics.tokens_per_sec >= 15.0);
428    }
429
430    // =========================================================================
431    // J8: Memory usage <512MB
432    // =========================================================================
433    #[test]
434    fn j8_memory_usage() {
435        let config = Qwen2Config::qwen2_0_5b_instruct();
436
437        // Model + KV cache should fit in 512MB
438        let model_size = config.model_size_int4();
439        let kv_cache = config.kv_cache_size(2048); // 2K context
440
441        let total = model_size + kv_cache;
442        assert!(total < 512 * 1024 * 1024, "Total: {} bytes", total);
443    }
444
445    // =========================================================================
446    // J9: SIMD speedup >2x vs scalar (design verification)
447    // =========================================================================
448    #[test]
449    fn j9_simd_speedup_design() {
450        // SIMD128 provides 4x parallelism for f32
451        // With overhead, expect >2x speedup
452        let simd_lanes = 4; // f32x4
453        let expected_speedup = simd_lanes as f64 * 0.6; // 60% efficiency
454        assert!(expected_speedup >= 2.0);
455    }
456
457    // =========================================================================
458    // J10: Demo runs in Chrome 120+
459    // =========================================================================
460    #[test]
461    fn j10_chrome_compatibility() {
462        let compat = BrowserCompatibility::default();
463
464        assert!(compat.supports_chrome(120));
465        assert!(compat.supports_chrome(121));
466        assert!(!compat.supports_chrome(119));
467    }
468
469    // =========================================================================
470    // J11: Demo runs in Firefox 120+
471    // =========================================================================
472    #[test]
473    fn j11_firefox_compatibility() {
474        let compat = BrowserCompatibility::default();
475
476        assert!(compat.supports_firefox(120));
477        assert!(compat.supports_firefox(125));
478        assert!(!compat.supports_firefox(115));
479    }
480
481    // =========================================================================
482    // J12: Demo runs in Safari 17+
483    // =========================================================================
484    #[test]
485    fn j12_safari_compatibility() {
486        let compat = BrowserCompatibility::default();
487
488        assert!(compat.supports_safari(17));
489        assert!(compat.supports_safari(18));
490        assert!(!compat.supports_safari(16));
491    }
492
493    // =========================================================================
494    // J13: Tokenizer produces correct token IDs
495    // =========================================================================
496    #[test]
497    fn j13_tokenizer_config() {
498        let tokenizer = Qwen2Tokenizer::new();
499
500        assert_eq!(tokenizer.vocab_size, 151936);
501        assert_eq!(tokenizer.special_tokens.eos_id, 151645);
502    }
503
504    // =========================================================================
505    // J14: Special tokens handled correctly
506    // =========================================================================
507    #[test]
508    fn j14_special_tokens() {
509        let tokenizer = Qwen2Tokenizer::new();
510
511        assert!(tokenizer.is_special(tokenizer.special_tokens.bos_id));
512        assert!(tokenizer.is_special(tokenizer.special_tokens.eos_id));
513        assert!(tokenizer.is_special(tokenizer.special_tokens.im_start_id));
514        assert!(tokenizer.is_special(tokenizer.special_tokens.im_end_id));
515
516        // Regular token should not be special
517        assert!(!tokenizer.is_special(100));
518    }
519
520    // =========================================================================
521    // J15: Generation stops at EOS token
522    // =========================================================================
523    #[test]
524    fn j15_eos_detection() {
525        let tokenizer = Qwen2Tokenizer::new();
526
527        assert!(tokenizer.is_eos(tokenizer.special_tokens.eos_id));
528        assert!(tokenizer.is_eos(tokenizer.special_tokens.im_end_id));
529        assert!(!tokenizer.is_eos(100)); // Regular token
530    }
531
532    // =========================================================================
533    // Additional verification tests
534    // =========================================================================
535    #[test]
536    fn test_quantization_compression() {
537        assert_eq!(QuantizationType::Int4.compression_ratio(), 8.0);
538        assert_eq!(QuantizationType::Int8.compression_ratio(), 4.0);
539        assert_eq!(QuantizationType::Fp16.compression_ratio(), 2.0);
540        assert_eq!(QuantizationType::Fp32.compression_ratio(), 1.0);
541    }
542
543    #[test]
544    fn test_instruction_format() {
545        let tokenizer = Qwen2Tokenizer::new();
546        let formatted = tokenizer.format_instruction("Hello, how are you?");
547
548        assert!(formatted.contains("<|im_start|>user"));
549        assert!(formatted.contains("Hello, how are you?"));
550        assert!(formatted.contains("<|im_end|>"));
551        assert!(formatted.contains("<|im_start|>assistant"));
552    }
553
554    #[test]
555    fn test_kv_cache_scaling() {
556        let config = Qwen2Config::qwen2_0_5b_instruct();
557
558        let cache_512 = config.kv_cache_size(512);
559        let cache_1024 = config.kv_cache_size(1024);
560
561        // KV cache should scale linearly with sequence length
562        assert!((cache_1024 as f64 / cache_512 as f64 - 2.0).abs() < 0.01);
563    }
564
565    #[test]
566    fn test_demo_metrics_fail_cases() {
567        // Too slow load
568        let slow_load = DemoMetrics {
569            load_time_ms: 6000,
570            first_token_ms: 1000,
571            tokens_per_sec: 20.0,
572            peak_memory_bytes: 400 * 1024 * 1024,
573            tokens_generated: 100,
574        };
575        assert!(!slow_load.meets_targets());
576
577        // Too slow first token
578        let slow_first = DemoMetrics {
579            load_time_ms: 3000,
580            first_token_ms: 3000,
581            tokens_per_sec: 20.0,
582            peak_memory_bytes: 400 * 1024 * 1024,
583            tokens_generated: 100,
584        };
585        assert!(!slow_first.meets_targets());
586
587        // Too slow throughput
588        let slow_throughput = DemoMetrics {
589            load_time_ms: 3000,
590            first_token_ms: 1000,
591            tokens_per_sec: 10.0,
592            peak_memory_bytes: 400 * 1024 * 1024,
593            tokens_generated: 100,
594        };
595        assert!(!slow_throughput.meets_targets());
596
597        // Too much memory
598        let high_memory = DemoMetrics {
599            load_time_ms: 3000,
600            first_token_ms: 1000,
601            tokens_per_sec: 20.0,
602            peak_memory_bytes: 600 * 1024 * 1024,
603            tokens_generated: 100,
604        };
605        assert!(!high_memory.meets_targets());
606    }
607}