Skip to main content

oxillama_runtime/
engine.rs

1//! Main inference engine — orchestrates model loading and text generation.
2
3use std::path::Path;
4
5/// Sequence-length threshold above which the engine routes attention through
6/// the memory-efficient tiled flash-attention kernel rather than the naïve
7/// full-score-matrix path.
8///
9/// At and above this threshold the O(N²) memory cost of materialising the
10/// full attention matrix becomes a bottleneck; the tiled kernel keeps memory
11/// at O(BQ × BK) per tile instead.
12///
13/// The actual dispatch lives inside `oxillama-arch`'s `ForwardPass::forward`
14/// implementations.  This constant is exported so that arch crates and
15/// callers can apply the same policy without hard-coding the threshold.
16pub const FLASH_ATTN_THRESHOLD: usize = 512;
17
18use oxillama_arch::config::ModelConfig;
19use oxillama_arch::traits::{ForwardPass, KvCacheAccess};
20use oxillama_gguf::GgufModel;
21
22use crate::embedding::{pool_hidden_states, PoolingMode};
23use crate::error::{RuntimeError, RuntimeResult};
24use crate::kv_cache::{KvCache, KvCacheSnapshot};
25use crate::metrics::{EngineMetrics, MetricsSnapshot};
26use crate::offload::{LayerPager, OffloadPolicy};
27use crate::sampling::{Sampler, SamplerConfig};
28use crate::tokenizer_bridge::TokenizerBridge;
29use std::sync::Arc;
30use std::time::Instant;
31
32/// Configuration for the inference engine.
33#[derive(Debug, Clone)]
34pub struct EngineConfig {
35    /// Path to the GGUF model file.
36    pub model_path: String,
37    /// Path to the tokenizer JSON file (if not embedded in GGUF).
38    pub tokenizer_path: Option<String>,
39    /// Context size override (None = use model default).
40    pub context_size: Option<usize>,
41    /// Number of threads for parallel computation.
42    pub num_threads: usize,
43    /// Sampling configuration.
44    pub sampler: SamplerConfig,
45    /// Prefill chunk size: how many prompt tokens to process per forward call.
46    ///
47    /// Set to 0 or `usize::MAX` to process the entire prompt in one batch.
48    /// Smaller values reduce peak memory usage for long prompts at the cost
49    /// of slightly higher overhead from multiple forward calls.
50    /// Default: 512.
51    pub prefill_chunk_size: usize,
52
53    /// CPU/disk offload policy.
54    ///
55    /// Controls which model weights are kept resident in RAM and which are
56    /// evicted to disk and reloaded on demand.
57    ///
58    /// Default: [`OffloadPolicy::None`] — all weights remain in RAM, matching
59    /// classic llama.cpp behaviour.
60    pub offload_policy: OffloadPolicy,
61}
62
63impl Default for EngineConfig {
64    fn default() -> Self {
65        Self {
66            model_path: String::new(),
67            tokenizer_path: None,
68            context_size: None,
69            num_threads: 4,
70            sampler: SamplerConfig::default(),
71            prefill_chunk_size: 512,
72            offload_policy: OffloadPolicy::None,
73        }
74    }
75}
76
77impl EngineConfig {
78    /// Set the CPU/disk offload policy, consuming self and returning the
79    /// updated config (builder pattern).
80    pub fn with_offload(mut self, policy: OffloadPolicy) -> Self {
81        self.offload_policy = policy;
82        self
83    }
84}
85
86/// The main inference engine.
87///
88/// Manages model loading, forward pass execution, and token generation.
89/// The full pipeline: load GGUF → parse metadata → build architecture → generate.
90pub struct InferenceEngine {
91    config: EngineConfig,
92    /// Loaded GGUF model (None until load_model is called).
93    gguf_model: Option<GgufModel>,
94    /// Parsed model configuration from GGUF metadata.
95    model_config: Option<ModelConfig>,
96    /// Forward pass implementation (architecture-specific).
97    forward_pass: Option<Box<dyn ForwardPass>>,
98    /// Key-value cache.
99    kv_cache: Option<KvCache>,
100    /// Tokenizer bridge.
101    tokenizer: Option<TokenizerBridge>,
102    /// EOS token ID for stopping generation.
103    eos_token_id: Option<u32>,
104    /// Live metrics counters.
105    metrics: Arc<EngineMetrics>,
106    /// Stack of active LoRA adapters (in insertion order).
107    lora_stack: oxillama_arch::LoraStack,
108    /// Optional CPU/disk layer pager (None when offload_policy is None).
109    ///
110    /// When present, linear-layer forward passes can call
111    /// `layer_pager.acquire(&tensor_id)` to get (or load from disk) the
112    /// raw quantized bytes for a given weight tensor.  This is the
113    /// graceful-fallback path: if `layer_pager` is `None`, existing in-RAM
114    /// weight references are used unchanged.
115    ///
116    /// Full integration with the arch-layer forward kernels (wiring through
117    /// `acquire` at each GEMM site) requires changes in `oxillama-arch` and
118    /// is deferred to a follow-up subtask (R1-arch integration).
119    layer_pager: Option<Arc<LayerPager>>,
120}
121
122impl InferenceEngine {
123    /// Create a new inference engine with the given configuration.
124    pub fn new(config: EngineConfig) -> Self {
125        Self {
126            config,
127            gguf_model: None,
128            model_config: None,
129            forward_pass: None,
130            kv_cache: None,
131            tokenizer: None,
132            eos_token_id: None,
133            metrics: EngineMetrics::new(),
134            lora_stack: oxillama_arch::LoraStack::new(),
135            layer_pager: None,
136        }
137    }
138
139    /// Return a reference to the active layer pager, if offloading is enabled.
140    ///
141    /// This is the inspection / integration hook that arch-layer code (or
142    /// higher-level callers) can use to acquire tensors on demand.  When
143    /// the pager is `None`, the engine is running in the default fully-in-RAM
144    /// mode.
145    pub fn layer_pager(&self) -> Option<&Arc<LayerPager>> {
146        self.layer_pager.as_ref()
147    }
148
149    /// Attach a pre-built [`LayerPager`] to this engine.
150    ///
151    /// This is the integration point for callers that construct their own
152    /// pager (e.g. from a custom [`PagerSource`][crate::offload::PagerSource])
153    /// and want to inject it rather than relying on the engine to build one
154    /// automatically from the GGUF file.
155    pub fn set_layer_pager(&mut self, pager: Arc<LayerPager>) {
156        self.layer_pager = Some(pager);
157    }
158
159    /// Load the model from an in-memory GGUF byte buffer.
160    ///
161    /// This is the preferred entry point for environments that cannot access
162    /// the filesystem, such as `wasm32-unknown-unknown`.  The tokenizer must be
163    /// provided separately as a JSON string because GGUF metadata rarely
164    /// contains the full HuggingFace `tokenizer.json`.
165    ///
166    /// The loading pipeline is identical to `load_model` except:
167    /// - The GGUF data comes from the supplied `model_bytes` slice (copied into
168    ///   owned storage inside [`GgufModel::from_bytes`]).
169    /// - The tokenizer is loaded from `tokenizer_json` rather than a file path.
170    ///
171    /// Any `context_size` override from [`EngineConfig`] is still applied.
172    pub fn load_model_from_bytes(
173        &mut self,
174        model_bytes: &[u8],
175        tokenizer_json: &str,
176    ) -> RuntimeResult<()> {
177        // ── Step 1: Parse GGUF from owned bytes ──────────────────────────────
178        let gguf = GgufModel::from_bytes(model_bytes.to_vec())?;
179        tracing::info!(
180            arch = gguf.architecture().unwrap_or("unknown"),
181            tensors = gguf.file.header.tensor_count,
182            "GGUF file parsed from bytes"
183        );
184
185        // ── Step 2: Extract model configuration ──────────────────────────────
186        let mut model_config = ModelConfig::from_metadata(&gguf.file.metadata)?;
187        if let Some(ctx) = self.config.context_size {
188            model_config.max_context_length = ctx;
189        }
190
191        tracing::info!(
192            arch = %model_config.architecture,
193            layers = model_config.num_layers,
194            hidden = model_config.hidden_size,
195            heads = model_config.num_attention_heads,
196            kv_heads = model_config.num_kv_heads,
197            vocab = model_config.vocab_size,
198            ctx = model_config.max_context_length,
199            "model config loaded from bytes"
200        );
201
202        // ── Step 3: Build forward pass ────────────────────────────────────────
203        let forward_pass = build_forward_pass(&gguf, &model_config)?;
204
205        // ── Step 4: KV cache ──────────────────────────────────────────────────
206        let kv_dim = model_config.num_kv_heads * model_config.head_dim;
207        let kv_cache = KvCache::new(
208            model_config.num_layers,
209            model_config.max_context_length,
210            kv_dim,
211        );
212        tracing::info!(
213            layers = model_config.num_layers,
214            max_ctx = model_config.max_context_length,
215            kv_dim = kv_dim,
216            "KV cache initialized (from-bytes path)"
217        );
218
219        // ── Step 5: Tokenizer from JSON string ────────────────────────────────
220        let tokenizer = TokenizerBridge::from_bytes(tokenizer_json.as_bytes())?;
221        let eos_token_id = tokenizer.eos_token_id();
222        tracing::info!(
223            vocab_size = tokenizer.vocab_size(),
224            eos = ?eos_token_id,
225            "tokenizer loaded from JSON string"
226        );
227
228        self.model_config = Some(model_config);
229        self.forward_pass = Some(forward_pass);
230        self.kv_cache = Some(kv_cache);
231        self.tokenizer = Some(tokenizer);
232        self.eos_token_id = eos_token_id;
233        self.gguf_model = Some(gguf);
234
235        Ok(())
236    }
237
238    /// Load the model from the configured path.
239    ///
240    /// This performs the full loading pipeline:
241    /// 1. Parse GGUF file (header, metadata, tensor info)
242    /// 2. Extract model configuration from metadata
243    /// 3. Build the architecture-specific forward pass
244    /// 4. Initialize KV cache
245    /// 5. Load tokenizer
246    pub fn load_model(&mut self) -> RuntimeResult<()> {
247        let path = Path::new(&self.config.model_path);
248        if !path.exists() {
249            return Err(RuntimeError::ModelLoadError {
250                message: format!("model file not found: {}", self.config.model_path),
251            });
252        }
253
254        tracing::info!(path = %self.config.model_path, "loading GGUF model");
255
256        // Step 1: Load and parse GGUF
257        let gguf = GgufModel::load(&self.config.model_path)?;
258        tracing::info!(
259            arch = gguf.architecture().unwrap_or("unknown"),
260            tensors = gguf.file.header.tensor_count,
261            "GGUF file parsed"
262        );
263
264        // Step 2: Extract model config from metadata
265        let mut model_config = ModelConfig::from_metadata(&gguf.file.metadata)?;
266
267        // Apply context size override
268        if let Some(ctx) = self.config.context_size {
269            model_config.max_context_length = ctx;
270        }
271
272        tracing::info!(
273            arch = %model_config.architecture,
274            layers = model_config.num_layers,
275            hidden = model_config.hidden_size,
276            heads = model_config.num_attention_heads,
277            kv_heads = model_config.num_kv_heads,
278            vocab = model_config.vocab_size,
279            ctx = model_config.max_context_length,
280            "model config loaded"
281        );
282
283        // Step 3: Build architecture-specific forward pass
284        let forward_pass = build_forward_pass(&gguf, &model_config)?;
285
286        // Step 4: Initialize KV cache
287        let kv_dim = model_config.num_kv_heads * model_config.head_dim;
288        let kv_cache = KvCache::new(
289            model_config.num_layers,
290            model_config.max_context_length,
291            kv_dim,
292        );
293        tracing::info!(
294            layers = model_config.num_layers,
295            max_ctx = model_config.max_context_length,
296            kv_dim = kv_dim,
297            "KV cache initialized"
298        );
299
300        // Step 5: Load tokenizer
301        let tokenizer = load_tokenizer(&self.config, &gguf)?;
302        let eos_token_id = tokenizer.eos_token_id();
303        tracing::info!(
304            vocab_size = tokenizer.vocab_size(),
305            eos = ?eos_token_id,
306            "tokenizer loaded"
307        );
308
309        self.model_config = Some(model_config);
310        self.forward_pass = Some(forward_pass);
311        self.kv_cache = Some(kv_cache);
312        self.tokenizer = Some(tokenizer);
313        self.eos_token_id = eos_token_id;
314        self.gguf_model = Some(gguf);
315
316        Ok(())
317    }
318
319    /// Generate tokens from a prompt.
320    ///
321    /// Runs the full generation pipeline:
322    /// 1. Tokenize the prompt
323    /// 2. Prefill: process all prompt tokens through the model
324    /// 3. Decode: autoregressive generation until EOS or max_tokens
325    ///
326    /// The callback is invoked with each decoded token's text as it's generated.
327    pub fn generate(
328        &mut self,
329        prompt: &str,
330        max_tokens: usize,
331        mut callback: impl FnMut(&str),
332    ) -> RuntimeResult<String> {
333        let tokenizer = self
334            .tokenizer
335            .as_ref()
336            .ok_or(RuntimeError::ModelNotLoaded)?;
337        let forward_pass = self
338            .forward_pass
339            .as_mut()
340            .ok_or(RuntimeError::ModelNotLoaded)?;
341        let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
342
343        // Step 1: Tokenize prompt
344        let prompt_tokens = tokenizer.encode(prompt)?;
345        if prompt_tokens.is_empty() {
346            return Ok(String::new());
347        }
348
349        tracing::debug!(n_tokens = prompt_tokens.len(), "prompt tokenized");
350
351        // Track recent tokens for repetition penalty
352        let mut recent_tokens = prompt_tokens.clone();
353        let mut generated_tokens: Vec<u32> = Vec::new();
354        let mut output_text = String::new();
355
356        // Step 2: Batch prefill — process all prompt tokens through the model.
357        //
358        // Instead of processing tokens one-at-a-time (N separate forward calls
359        // each computing and discarding full logits), we batch them into chunks.
360        // The architecture's forward() handles multi-token input: it iterates
361        // internally and only computes logits for the last hidden state.
362        //
363        // For very long prompts, we chunk into `prefill_chunk_size` pieces to
364        // bound peak memory usage in the attention computation.
365        let chunk_size = if self.config.prefill_chunk_size == 0 {
366            prompt_tokens.len()
367        } else {
368            self.config.prefill_chunk_size
369        };
370
371        let mut logits = if prompt_tokens.len() <= chunk_size {
372            // Short prompt: single batch forward
373            tracing::debug!(
374                chunk = 1,
375                tokens = prompt_tokens.len(),
376                "prefill: single batch"
377            );
378            let prefill_start = Instant::now();
379            let result = forward_pass.forward(&prompt_tokens, kv_cache)?;
380            self.metrics
381                .record_prefill(prompt_tokens.len() as u64, prefill_start.elapsed());
382            result
383        } else {
384            // Long prompt: chunked prefill
385            let n_chunks = prompt_tokens.len().div_ceil(chunk_size);
386            tracing::debug!(
387                n_chunks = n_chunks,
388                chunk_size = chunk_size,
389                total = prompt_tokens.len(),
390                "prefill: chunked"
391            );
392
393            let prefill_start = Instant::now();
394            let mut last_logits = Vec::new();
395            for (i, chunk) in prompt_tokens.chunks(chunk_size).enumerate() {
396                tracing::trace!(
397                    chunk_idx = i,
398                    chunk_len = chunk.len(),
399                    kv_pos = kv_cache.seq_len(),
400                    "prefill chunk"
401                );
402                last_logits = forward_pass.forward(chunk, kv_cache)?;
403            }
404            self.metrics
405                .record_prefill(prompt_tokens.len() as u64, prefill_start.elapsed());
406            last_logits
407        };
408
409        // Create a stateful sampler so grammar state is maintained across tokens.
410        let mut sampler = Sampler::new(self.config.sampler.clone());
411
412        // Step 3: Autoregressive decode loop
413        self.metrics.record_request_start();
414        for _step in 0..max_tokens {
415            // Sample next token (grammar masking happens inside the sampler)
416            let next_token = sampler.sample(&logits, &recent_tokens);
417
418            // Check for EOS
419            if Some(next_token) == self.eos_token_id {
420                tracing::debug!("EOS token generated, stopping");
421                break;
422            }
423
424            // Check context length
425            if kv_cache.seq_len() >= forward_pass.max_context_length() {
426                tracing::warn!("context length reached, stopping generation");
427                break;
428            }
429
430            // Decode token to text
431            let token_text = tokenizer.decode(&[next_token])?;
432            callback(&token_text);
433            output_text.push_str(&token_text);
434
435            // Track for repetition penalty
436            recent_tokens.push(next_token);
437            generated_tokens.push(next_token);
438
439            // Forward pass for next token
440            let decode_start = Instant::now();
441            logits = forward_pass.forward(&[next_token], kv_cache)?;
442            self.metrics.record_decode_token(decode_start.elapsed());
443        }
444        self.metrics.record_request_complete();
445
446        tracing::info!(
447            prompt_tokens = prompt_tokens.len(),
448            generated_tokens = generated_tokens.len(),
449            "generation complete"
450        );
451
452        Ok(output_text)
453    }
454
455    /// Generate tokens using an explicit sampler config instead of the engine default.
456    ///
457    /// This is the preferred entry point for per-request sampler customization
458    /// (e.g., grammar-constrained sampling from the API server).
459    pub fn generate_with_config(
460        &mut self,
461        prompt: &str,
462        max_tokens: usize,
463        sampler_config: SamplerConfig,
464        mut callback: impl FnMut(&str),
465    ) -> RuntimeResult<String> {
466        let tokenizer = self
467            .tokenizer
468            .as_ref()
469            .ok_or(RuntimeError::ModelNotLoaded)?;
470        let forward_pass = self
471            .forward_pass
472            .as_mut()
473            .ok_or(RuntimeError::ModelNotLoaded)?;
474        let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
475
476        let prompt_tokens = tokenizer.encode(prompt)?;
477        if prompt_tokens.is_empty() {
478            return Ok(String::new());
479        }
480
481        let mut recent_tokens = prompt_tokens.clone();
482        let mut generated_tokens: Vec<u32> = Vec::new();
483        let mut output_text = String::new();
484
485        for &token in &prompt_tokens[..prompt_tokens.len() - 1] {
486            forward_pass.forward(&[token], kv_cache)?;
487        }
488
489        let last = *prompt_tokens.last().ok_or(RuntimeError::ModelNotLoaded)?;
490        let mut logits = forward_pass.forward(&[last], kv_cache)?;
491
492        let mut sampler = Sampler::new(sampler_config);
493        self.metrics.record_request_start();
494        for _step in 0..max_tokens {
495            let next_token = sampler.sample(&logits, &recent_tokens);
496
497            if Some(next_token) == self.eos_token_id {
498                tracing::debug!("EOS token generated, stopping");
499                break;
500            }
501
502            if kv_cache.seq_len() >= forward_pass.max_context_length() {
503                tracing::warn!("context length reached, stopping generation");
504                break;
505            }
506
507            let token_text = tokenizer.decode(&[next_token])?;
508            callback(&token_text);
509            output_text.push_str(&token_text);
510
511            recent_tokens.push(next_token);
512            generated_tokens.push(next_token);
513
514            let decode_start = Instant::now();
515            logits = forward_pass.forward(&[next_token], kv_cache)?;
516            self.metrics.record_decode_token(decode_start.elapsed());
517        }
518        self.metrics.record_request_complete();
519
520        tracing::info!(
521            prompt_tokens = prompt_tokens.len(),
522            generated_tokens = generated_tokens.len(),
523            "generation (with custom config) complete"
524        );
525
526        Ok(output_text)
527    }
528
529    /// Build the vocabulary byte table, used for grammar-constrained sampling.
530    ///
531    /// Returns `None` if no tokenizer is loaded.
532    pub fn vocab_bytes(&self) -> Option<Vec<(u32, Vec<u8>)>> {
533        self.tokenizer.as_ref().map(|t| t.vocab_bytes())
534    }
535
536    /// Apply a loaded LoRA adapter to the model's linear layers.
537    ///
538    /// Delegates to the architecture-specific [`ForwardPass::apply_lora`]
539    /// implementation, which walks the model's layers and attaches
540    /// [`LoraAdapter`](oxillama_quant::LoraAdapter) instances to each
541    /// matching `QuantLinear` field.
542    ///
543    /// # Errors
544    ///
545    /// Returns [`RuntimeError::ModelNotLoaded`] if no model has been loaded.
546    pub fn apply_lora_adapters(
547        &mut self,
548        lora: &oxillama_arch::lora::LoadedLora,
549    ) -> RuntimeResult<()> {
550        let fp = self
551            .forward_pass
552            .as_mut()
553            .ok_or(RuntimeError::ModelNotLoaded)?;
554        fp.apply_lora(lora).map_err(RuntimeError::Arch)?;
555        Ok(())
556    }
557
558    // -------------------------------------------------------------------------
559    // Multi-LoRA hot-swap
560    // -------------------------------------------------------------------------
561
562    /// Push a LoRA adapter onto the stack with a per-entry scale multiplier.
563    ///
564    /// The adapter is applied additively during inference:
565    /// `output += scale · (alpha/rank) · B @ A @ input`
566    pub fn push_lora(&mut self, lora: std::sync::Arc<oxillama_arch::lora::LoadedLora>, scale: f32) {
567        self.lora_stack.push(lora, scale);
568    }
569
570    /// Remove the last adapter pushed onto the stack.
571    ///
572    /// Returns `None` if the stack is empty.
573    pub fn pop_lora(&mut self) -> Option<(std::sync::Arc<oxillama_arch::lora::LoadedLora>, f32)> {
574        self.lora_stack.pop()
575    }
576
577    /// Remove all LoRA adapters from the stack.
578    pub fn clear_loras(&mut self) {
579        self.lora_stack.clear();
580    }
581
582    /// Inspect the current LoRA stack.
583    pub fn lora_stack(&self) -> &oxillama_arch::LoraStack {
584        &self.lora_stack
585    }
586
587    /// Apply the stacked LoRA adapters to the loaded model's linear layers.
588    ///
589    /// This is a hot-swap operation: it can be called at any time without
590    /// reloading the model.  If the stack is empty this is a no-op.
591    ///
592    /// Returns [`RuntimeError::ModelNotLoaded`] if no model has been loaded.
593    pub fn apply_lora_stack(&mut self) -> RuntimeResult<()> {
594        if self.lora_stack.is_empty() {
595            return Ok(());
596        }
597        let fp = self
598            .forward_pass
599            .as_mut()
600            .ok_or(RuntimeError::ModelNotLoaded)?;
601        fp.apply_lora_stack(&self.lora_stack)
602            .map_err(RuntimeError::Arch)?;
603        Ok(())
604    }
605
606    /// Remove all LoRA adapters from the loaded model's linear layers.
607    ///
608    /// Clears the `lora_stack` and calls `unapply_all_loras` on the forward
609    /// pass so every `QuantLinear.lora` field is set back to `None`.
610    ///
611    /// This is the necessary counterpart to `apply_lora_stack` for per-request
612    /// LoRA hot-swap: push adapters, apply, generate, then unapply.
613    ///
614    /// Does nothing when no model is loaded.
615    pub fn unapply_all_loras(&mut self) {
616        self.lora_stack.clear();
617        if let Some(fp) = self.forward_pass.as_mut() {
618            fp.unapply_all_loras();
619        }
620    }
621
622    /// Restore the KV cache from a cached prefix snapshot and run prefill for
623    /// the suffix tokens that follow the cached prefix.
624    ///
625    /// This is the prefix-KV-cache fast path: instead of re-prefilling the
626    /// entire prompt from scratch, the engine restores the KV state for the
627    /// longest matching cached prefix and only runs the forward pass for the
628    /// remaining suffix tokens.
629    ///
630    /// Restriction: the cached prefix must start at position 0 (i.e. it was
631    /// stored from the beginning of a sequence).  This matches how
632    /// `PrefixKvCache::store` snapshots KV state.
633    ///
634    /// # Errors
635    ///
636    /// Returns [`RuntimeError::ModelNotLoaded`] if no model is loaded.
637    pub fn prime_with_prefix(
638        &mut self,
639        cached: &crate::kv_cache::prefix::CachedKvState,
640        restore_to: usize,
641        suffix_tokens: &[u32],
642    ) -> RuntimeResult<Vec<f32>> {
643        if suffix_tokens.is_empty() {
644            return Err(RuntimeError::ModelLoadError {
645                message: "prime_with_prefix: suffix_tokens must contain at least one token"
646                    .to_string(),
647            });
648        }
649        // Restore the KV cache to the requested number of prefix positions
650        // (which may be less than cached.seq_len() to handle the edge case
651        // where the caller wants to re-process the last cached token).
652        {
653            let kv = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
654            kv.restore_from_snapshot(cached.keys(), cached.values(), restore_to);
655        }
656        // Process suffix tokens from the restored position and return logits
657        // of the last suffix token (used to seed the autoregressive decode loop).
658        let forward_pass = self
659            .forward_pass
660            .as_mut()
661            .ok_or(RuntimeError::ModelNotLoaded)?;
662        let kv = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
663        let logits = forward_pass
664            .forward(suffix_tokens, kv)
665            .map_err(RuntimeError::Arch)?;
666        Ok(logits)
667    }
668
669    /// Run the autoregressive decode loop starting from pre-computed logits.
670    ///
671    /// Unlike [`Self::generate_with_config`], this does **not** run prefill — the
672    /// caller must have already primed the KV cache (via [`Self::prime_with_prefix`]
673    /// or a full prefill) and obtained the initial logits from that step.
674    ///
675    /// # Errors
676    ///
677    /// Returns [`RuntimeError::ModelNotLoaded`] if no model is loaded.
678    pub fn generate_with_logits(
679        &mut self,
680        prompt_tokens: &[u32],
681        initial_logits: Vec<f32>,
682        max_tokens: usize,
683        sampler_config: SamplerConfig,
684        mut callback: impl FnMut(&str),
685    ) -> RuntimeResult<String> {
686        let tokenizer = self
687            .tokenizer
688            .as_ref()
689            .ok_or(RuntimeError::ModelNotLoaded)?;
690        let forward_pass = self
691            .forward_pass
692            .as_mut()
693            .ok_or(RuntimeError::ModelNotLoaded)?;
694        let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
695        let max_ctx = forward_pass.max_context_length();
696        let eos_token_id = self.eos_token_id;
697
698        let mut recent_tokens: Vec<u32> = prompt_tokens.to_vec();
699        let mut output_text = String::new();
700        let mut logits = initial_logits;
701
702        let mut sampler = Sampler::new(sampler_config);
703        self.metrics.record_request_start();
704
705        for _step in 0..max_tokens {
706            let next_token = sampler.sample(&logits, &recent_tokens);
707
708            if Some(next_token) == eos_token_id {
709                tracing::debug!("EOS token generated, stopping (primed path)");
710                break;
711            }
712
713            if kv_cache.seq_len() >= max_ctx {
714                tracing::warn!("context length reached, stopping generation (primed path)");
715                break;
716            }
717
718            let token_text = tokenizer.decode(&[next_token])?;
719            callback(&token_text);
720            output_text.push_str(&token_text);
721            recent_tokens.push(next_token);
722
723            let decode_start = Instant::now();
724            logits = forward_pass
725                .forward(&[next_token], kv_cache)
726                .map_err(RuntimeError::Arch)?;
727            self.metrics.record_decode_token(decode_start.elapsed());
728        }
729
730        self.metrics.record_request_complete();
731        Ok(output_text)
732    }
733
734    /// Returns whether a model is currently loaded.
735    pub fn is_loaded(&self) -> bool {
736        self.forward_pass.is_some()
737    }
738
739    /// Returns the engine configuration.
740    pub fn config(&self) -> &EngineConfig {
741        &self.config
742    }
743
744    /// Returns the model configuration, if loaded.
745    pub fn model_config(&self) -> Option<&ModelConfig> {
746        self.model_config.as_ref()
747    }
748
749    /// Returns a shared reference to the KV cache, if a model is loaded.
750    pub(crate) fn kv_cache_ref(&self) -> Option<&KvCache> {
751        self.kv_cache.as_ref()
752    }
753
754    /// Returns a mutable reference to the KV cache, if a model is loaded.
755    pub(crate) fn kv_cache_mut(&mut self) -> Option<&mut KvCache> {
756        self.kv_cache.as_mut()
757    }
758
759    /// Store the current KV cache state into a `PrefixKvCache` under `tokens`.
760    ///
761    /// This is the public integration point for server-side prefix caching:
762    /// after a successful generation pass the worker calls this to persist the
763    /// KV state so future requests sharing the same prefix can skip prefill.
764    ///
765    /// If no model is loaded (KV cache absent) the call is a silent no-op.
766    pub fn store_kv_in_prefix_cache(
767        &mut self,
768        tokens: &[u32],
769        prefix_cache: &mut crate::kv_cache::prefix::PrefixKvCache,
770    ) {
771        if let Some(kv) = self.kv_cache.as_mut() {
772            let seq_len = kv.seq_len();
773            let kv_dim = kv.kv_dim();
774            let num_layers = kv.num_layers();
775            prefix_cache.store(tokens, kv, seq_len, kv_dim, num_layers);
776        }
777    }
778
779    /// Reset the KV cache (for starting a new conversation).
780    pub fn reset(&mut self) {
781        if let Some(ref mut cache) = self.kv_cache {
782            cache.clear();
783        }
784    }
785
786    // -------------------------------------------------------------------------
787    // Speculative-decoding primitives
788    // -------------------------------------------------------------------------
789
790    /// Tokenize text and return token IDs.
791    ///
792    /// Requires that a model (and thus a tokenizer) has been loaded.
793    pub fn tokenize(&self, text: &str) -> RuntimeResult<Vec<u32>> {
794        let tokenizer = self
795            .tokenizer
796            .as_ref()
797            .ok_or(RuntimeError::ModelNotLoaded)?;
798        tokenizer.encode(text)
799    }
800
801    /// Prefill the KV cache with the given token sequence without returning logits.
802    ///
803    /// Processes all tokens in order, updating the KV cache at each position.
804    /// The last token's logits are discarded; callers typically follow up with
805    /// `forward_one` to begin autoregressive generation.
806    pub fn prefill(&mut self, tokens: &[u32]) -> RuntimeResult<()> {
807        if tokens.is_empty() {
808            return Ok(());
809        }
810        let forward_pass = self
811            .forward_pass
812            .as_mut()
813            .ok_or(RuntimeError::ModelNotLoaded)?;
814        let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
815        for &token in tokens {
816            forward_pass.forward(&[token], kv_cache)?;
817        }
818        Ok(())
819    }
820
821    /// Run a batched prefill forward pass for the given chunk of tokens.
822    ///
823    /// This is the per-chunk entry point for the chunked-prefill scheduler
824    /// fairness path (A3).  It differs from `prefill` in two ways:
825    ///
826    /// 1. It accepts a multi-token slice and dispatches a *single* batched
827    ///    forward call, matching the `generate` path's chunked prefill logic.
828    /// 2. It returns the logits of the last token in the chunk so that the
829    ///    caller can immediately begin decode sampling if `pos_end` equals the
830    ///    full prompt length.
831    ///
832    /// `pos_start` is the KV-cache position at which this chunk begins.  It
833    /// must equal the current `kv_cache.seq_len()` on entry; the parameter is
834    /// provided explicitly so that callers (e.g. the scheduler) can assert the
835    /// invariant in debug builds.
836    ///
837    /// # Errors
838    ///
839    /// Returns [`RuntimeError::ModelNotLoaded`] if no model is loaded, or
840    /// any arch-level error from the forward pass.
841    pub fn forward_prefill(&mut self, tokens: &[u32], pos_start: usize) -> RuntimeResult<Vec<f32>> {
842        if tokens.is_empty() {
843            return Err(RuntimeError::ModelLoadError {
844                message: "forward_prefill called with empty token slice".to_string(),
845            });
846        }
847        let forward_pass = self
848            .forward_pass
849            .as_mut()
850            .ok_or(RuntimeError::ModelNotLoaded)?;
851        let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
852
853        debug_assert_eq!(
854            kv_cache.seq_len(),
855            pos_start,
856            "forward_prefill: pos_start ({pos_start}) must equal kv_cache.seq_len() ({})",
857            kv_cache.seq_len(),
858        );
859
860        let logits = forward_pass.forward(tokens, kv_cache)?;
861        Ok(logits)
862    }
863
864    /// Run a single autoregressive decode step for `token` and return logits.
865    ///
866    /// This is the per-step entry point for the chunked-prefill scheduler
867    /// fairness path (A3).  It is semantically equivalent to `forward_one`
868    /// but named differently to make the prefill/decode distinction explicit
869    /// in call sites inside the engine and scheduler integration layer.
870    ///
871    /// `pos` is the current sequence position (= `kv_cache.seq_len()`).  It
872    /// is accepted as a parameter so that callers can assert the invariant.
873    ///
874    /// # Errors
875    ///
876    /// Returns [`RuntimeError::ModelNotLoaded`] if no model is loaded.
877    pub fn forward_decode(&mut self, token: u32, pos: usize) -> RuntimeResult<Vec<f32>> {
878        let forward_pass = self
879            .forward_pass
880            .as_mut()
881            .ok_or(RuntimeError::ModelNotLoaded)?;
882        let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
883
884        debug_assert_eq!(
885            kv_cache.seq_len(),
886            pos,
887            "forward_decode: pos ({pos}) must equal kv_cache.seq_len() ({})",
888            kv_cache.seq_len(),
889        );
890
891        let logits = forward_pass.forward(&[token], kv_cache)?;
892        Ok(logits)
893    }
894
895    /// Run a single forward pass for `token` and return raw logits.
896    ///
897    /// The KV cache is updated (one position advanced).
898    pub fn forward_one(&mut self, token: u32) -> RuntimeResult<Vec<f32>> {
899        let forward_pass = self
900            .forward_pass
901            .as_mut()
902            .ok_or(RuntimeError::ModelNotLoaded)?;
903        let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
904        let logits = forward_pass.forward(&[token], kv_cache)?;
905        Ok(logits)
906    }
907
908    /// Returns `true` if `token` is the EOS token for this model.
909    pub fn is_eos(&self, token: u32) -> bool {
910        self.eos_token_id == Some(token)
911    }
912
913    /// Decode a single token ID to its string representation.
914    pub fn decode_token(&self, token: u32) -> RuntimeResult<String> {
915        let tokenizer = self
916            .tokenizer
917            .as_ref()
918            .ok_or(RuntimeError::ModelNotLoaded)?;
919        tokenizer.decode(&[token])
920    }
921
922    /// Returns a shared reference to the engine's live metrics counters.
923    pub fn metrics(&self) -> Arc<EngineMetrics> {
924        Arc::clone(&self.metrics)
925    }
926
927    /// Returns a point-in-time [`MetricsSnapshot`] of the engine's counters.
928    pub fn metrics_snapshot(&self) -> MetricsSnapshot {
929        self.metrics.snapshot()
930    }
931
932    /// Capture a [`KvCacheSnapshot`] from the current KV cache state.
933    ///
934    /// Returns `None` if no model (and thus no KV cache) is loaded.
935    pub fn kv_snapshot(&self) -> Option<KvCacheSnapshot> {
936        self.kv_cache.as_ref().map(|c| c.snapshot())
937    }
938
939    /// Restore the KV cache state from a previously captured [`KvCacheSnapshot`].
940    ///
941    /// Returns [`RuntimeError::ModelNotLoaded`] if no model is loaded.
942    pub fn kv_restore(&mut self, snapshot: &KvCacheSnapshot) -> RuntimeResult<()> {
943        let kv = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
944        kv.restore_from_snapshot(&snapshot.keys, &snapshot.values, snapshot.seq_len);
945        Ok(())
946    }
947
948    /// Truncate the KV cache to `n` tokens.
949    ///
950    /// After this call the engine behaves as if only `n` tokens have been
951    /// processed.  This is the low-level primitive used by speculative
952    /// decoding on divergence rollback.
953    ///
954    /// # Errors
955    ///
956    /// Returns [`RuntimeError::ModelNotLoaded`] if no model is loaded.
957    pub fn truncate(&mut self, n: usize) -> RuntimeResult<()> {
958        let kv = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
959        kv.truncate(n);
960        Ok(())
961    }
962
963    /// Return the current KV cache sequence length.
964    ///
965    /// Returns 0 if no model is loaded.
966    pub fn kv_cache_seq_len(&self) -> usize {
967        self.kv_cache.as_ref().map(|c| c.seq_len()).unwrap_or(0)
968    }
969
970    /// Returns the model's hidden state dimension, if a model is loaded.
971    pub fn hidden_size(&self) -> Option<usize> {
972        self.model_config.as_ref().map(|c| c.hidden_size)
973    }
974
975    /// Compute a semantic embedding vector for the given text using `PoolingMode::Last`.
976    ///
977    /// This is a convenience wrapper around [`Self::embed_with`]. Runs tokenization →
978    /// full transformer layers → final RMSNorm, then L2-normalises the resulting
979    /// `hidden_size`-dimensional vector. The KV cache is reset before the pass
980    /// so that embeddings for different inputs are independent of each other.
981    ///
982    /// Returns `RuntimeError::ModelNotLoaded` if no model has been loaded.
983    pub fn embed(&mut self, text: &str) -> RuntimeResult<Vec<f32>> {
984        self.embed_with(text, PoolingMode::Last)
985    }
986
987    /// Compute a semantic embedding vector for the given text using the specified
988    /// pooling strategy.
989    ///
990    /// Runs tokenization → full transformer layers → final RMSNorm → pooling,
991    /// then L2-normalises the resulting `hidden_size`-dimensional vector.
992    /// The KV cache is reset before the pass so that embeddings for different
993    /// inputs are independent of each other.
994    ///
995    /// # Pooling modes
996    ///
997    /// * [`PoolingMode::Last`] — last token hidden state (causal / decoder models).
998    /// * [`PoolingMode::Mean`] — mean across all token positions.
999    /// * [`PoolingMode::Max`]  — elementwise max across all token positions.
1000    /// * [`PoolingMode::Cls`]  — first token hidden state (BERT / encoder models).
1001    ///
1002    /// Returns `RuntimeError::ModelNotLoaded` if no model has been loaded.
1003    pub fn embed_with(&mut self, text: &str, mode: PoolingMode) -> RuntimeResult<Vec<f32>> {
1004        // Step 1: Reset the KV cache so this embedding is independent.
1005        // Must happen before we take any partial borrows below.
1006        self.reset();
1007
1008        // Step 2: Validate that all components are loaded.
1009        let forward_pass = self
1010            .forward_pass
1011            .as_mut()
1012            .ok_or(RuntimeError::ModelNotLoaded)?;
1013        let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
1014
1015        // Step 3: Tokenize. We need the tokenizer reference independently,
1016        // so read it before we borrow forward_pass/kv_cache mutably.
1017        let tokens = {
1018            let tok = self
1019                .tokenizer
1020                .as_ref()
1021                .ok_or(RuntimeError::ModelNotLoaded)?;
1022            tok.encode(text)?
1023        };
1024
1025        if tokens.is_empty() {
1026            // Return a zero vector of the appropriate dimension if available,
1027            // otherwise an empty vector. An empty input has no well-defined embedding.
1028            let dim = self
1029                .model_config
1030                .as_ref()
1031                .map(|c| c.hidden_size)
1032                .unwrap_or(0);
1033            return Ok(vec![0.0f32; dim]);
1034        }
1035
1036        // Step 4: Run the embed forward pass (all layers + output_norm, no LM head).
1037        // The forward pass returns the hidden states for all token positions as a
1038        // flat [seq_len × hidden_size] vector (or just hidden_size for single-token
1039        // models). We retrieve all-states where available; otherwise fall back to
1040        // the last-state path for backward compatibility.
1041        let seq_len = tokens.len();
1042        let hidden_size = forward_pass.hidden_size();
1043
1044        let all_hidden = forward_pass.embed_all(&tokens, kv_cache);
1045        let hidden = match all_hidden {
1046            Ok(states) if states.len() == seq_len * hidden_size && seq_len > 0 => {
1047                // embed_all is supported and returned a correctly-shaped matrix.
1048                pool_hidden_states(&states, seq_len, hidden_size, mode)?
1049            }
1050            _ => {
1051                // Fall back to embed() (last-token path).
1052                forward_pass.embed(&tokens, kv_cache)?
1053            }
1054        };
1055
1056        // Step 5: L2-normalise the hidden vector for cosine-similarity compatibility.
1057        let norm: f32 = hidden.iter().map(|x| x * x).sum::<f32>().sqrt();
1058        if norm > 1e-9 {
1059            Ok(hidden.into_iter().map(|x| x / norm).collect())
1060        } else {
1061            Ok(hidden)
1062        }
1063    }
1064
1065    /// Extract embedding vectors for multiple input texts using `PoolingMode::Last`.
1066    ///
1067    /// This is a convenience wrapper around [`Self::embed_batch_with`].
1068    /// Each text is processed independently with a fresh KV cache.
1069    pub fn embed_batch(&mut self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
1070        let str_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1071        self.embed_batch_with(&str_refs, PoolingMode::Last)
1072    }
1073
1074    /// Extract embedding vectors for multiple input texts using the specified
1075    /// pooling strategy.
1076    ///
1077    /// Each text is processed independently with a fresh KV cache.
1078    /// The output order matches the input order.
1079    ///
1080    /// Returns `RuntimeError::ModelNotLoaded` if no model has been loaded.
1081    pub fn embed_batch_with(
1082        &mut self,
1083        texts: &[&str],
1084        mode: PoolingMode,
1085    ) -> RuntimeResult<Vec<Vec<f32>>> {
1086        // Validate that the engine is loaded by borrowing forward_pass and
1087        // tokenizer — this also gives us the hidden_size for the zero-vector path.
1088        {
1089            let _fp = self
1090                .forward_pass
1091                .as_ref()
1092                .ok_or(RuntimeError::ModelNotLoaded)?;
1093            let _tok = self
1094                .tokenizer
1095                .as_ref()
1096                .ok_or(RuntimeError::ModelNotLoaded)?;
1097        }
1098
1099        let mut embeddings = Vec::with_capacity(texts.len());
1100        for &text in texts {
1101            embeddings.push(self.embed_with(text, mode)?);
1102        }
1103        Ok(embeddings)
1104    }
1105}
1106
1107/// Build the forward pass from a loaded GGUF model.
1108fn build_forward_pass(
1109    gguf: &GgufModel,
1110    config: &ModelConfig,
1111) -> RuntimeResult<Box<dyn ForwardPass>> {
1112    match config.architecture.as_str() {
1113        #[cfg(feature = "llama")]
1114        "llama" => {
1115            let model = oxillama_arch::llama::load_llama_from_gguf(gguf, config)?;
1116            Ok(Box::new(model))
1117        }
1118        #[cfg(feature = "qwen3")]
1119        "qwen3" => {
1120            let model = oxillama_arch::qwen3::load_qwen3_from_gguf(gguf, config)?;
1121            Ok(Box::new(model))
1122        }
1123        #[cfg(feature = "mistral")]
1124        "mistral" => {
1125            let model = oxillama_arch::mistral::load_mistral_from_gguf(gguf, config)?;
1126            Ok(Box::new(model))
1127        }
1128        #[cfg(feature = "gemma")]
1129        "gemma" | "gemma2" | "gemma3" => {
1130            let model = oxillama_arch::gemma::load_gemma_from_gguf(gguf, config)?;
1131            Ok(Box::new(model))
1132        }
1133        #[cfg(feature = "phi")]
1134        "phi3" | "phi" => {
1135            let model = oxillama_arch::phi::load_phi_from_gguf(gguf, config)?;
1136            Ok(Box::new(model))
1137        }
1138        #[cfg(feature = "command-r")]
1139        "command-r" => {
1140            let model = oxillama_arch::command_r::load_command_r_from_gguf(gguf, config)?;
1141            Ok(Box::new(model))
1142        }
1143        #[cfg(feature = "starcoder")]
1144        "starcoder" => {
1145            let model = oxillama_arch::starcoder::load_starcoder_from_gguf(gguf, config)?;
1146            Ok(Box::new(model))
1147        }
1148        arch => Err(RuntimeError::ModelLoadError {
1149            message: format!("unsupported architecture: '{arch}'"),
1150        }),
1151    }
1152}
1153
1154/// Load the tokenizer, trying multiple sources.
1155fn load_tokenizer(config: &EngineConfig, gguf: &GgufModel) -> RuntimeResult<TokenizerBridge> {
1156    // Try explicit tokenizer path first
1157    if let Some(ref path) = config.tokenizer_path {
1158        return TokenizerBridge::from_file(path);
1159    }
1160
1161    // Try to extract tokenizer from GGUF metadata
1162    if let Some(tokenizer_json) = gguf
1163        .file
1164        .metadata
1165        .get("tokenizer.ggml.tokens")
1166        .and_then(|_| {
1167            // If there's a full tokenizer JSON in metadata, use it
1168            gguf.file
1169                .metadata
1170                .get("tokenizer.huggingface.json")
1171                .and_then(|v| v.as_str())
1172        })
1173    {
1174        return TokenizerBridge::from_bytes(tokenizer_json.as_bytes());
1175    }
1176
1177    // Try to find tokenizer.json next to the model file
1178    let model_dir = Path::new(&config.model_path)
1179        .parent()
1180        .unwrap_or(Path::new("."));
1181    let tokenizer_path = model_dir.join("tokenizer.json");
1182    if tokenizer_path.exists() {
1183        return TokenizerBridge::from_file(tokenizer_path.to_str().unwrap_or("tokenizer.json"));
1184    }
1185
1186    Err(RuntimeError::TokenizerError {
1187        message: "no tokenizer found: provide --tokenizer path or place tokenizer.json next to the model file".to_string(),
1188    })
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193    use super::*;
1194
1195    // ── A3: forward_prefill / forward_decode tests ───────────────────────────
1196
1197    /// `forward_prefill` must return `ModelNotLoaded` when no model is loaded.
1198    #[test]
1199    fn test_forward_prefill_errors_when_not_loaded() {
1200        let mut engine = InferenceEngine::new(EngineConfig::default());
1201        let result = engine.forward_prefill(&[1, 2, 3], 0);
1202        assert!(
1203            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1204            "expected ModelNotLoaded from forward_prefill, got {result:?}"
1205        );
1206    }
1207
1208    /// `forward_prefill` with empty token slice must return an error
1209    /// (even if a model were loaded — callers must supply at least one token).
1210    #[test]
1211    fn test_forward_prefill_empty_slice_errors() {
1212        let mut engine = InferenceEngine::new(EngineConfig::default());
1213        // No model loaded; empty slice should error with ModelNotLoaded because
1214        // the empty-slice guard fires first (returns ModelLoadError), but any
1215        // error is acceptable — the point is that it never returns Ok.
1216        let result = engine.forward_prefill(&[], 0);
1217        assert!(
1218            result.is_err(),
1219            "forward_prefill with empty slice must return Err, got Ok"
1220        );
1221    }
1222
1223    /// `forward_decode` must return `ModelNotLoaded` when no model is loaded.
1224    #[test]
1225    fn test_forward_decode_errors_when_not_loaded() {
1226        let mut engine = InferenceEngine::new(EngineConfig::default());
1227        let result = engine.forward_decode(42, 0);
1228        assert!(
1229            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1230            "expected ModelNotLoaded from forward_decode, got {result:?}"
1231        );
1232    }
1233
1234    // ── A3: forward_prefill / forward_decode with loaded model ────────────────
1235
1236    /// `forward_prefill` with a loaded model must return a logits vector whose
1237    /// length equals the model's vocab size (32 in the synthetic fixture).
1238    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1239    #[test]
1240    fn test_forward_prefill_returns_logits_after_load() {
1241        let mut engine = make_loaded_engine();
1242        // Fresh engine: KV cache is empty so pos_start = 0.
1243        let result = engine.forward_prefill(&[3, 4, 5], 0);
1244        assert!(
1245            result.is_ok(),
1246            "forward_prefill must return Ok when model is loaded, got {result:?}"
1247        );
1248        let logits = result.expect("forward_prefill Ok");
1249        assert_eq!(
1250            logits.len(),
1251            32,
1252            "logits length must equal vocab_size=32, got {}",
1253            logits.len()
1254        );
1255    }
1256
1257    /// `forward_decode` with a loaded model must return a logits vector of
1258    /// the correct vocab-size length.
1259    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1260    #[test]
1261    fn test_forward_decode_returns_logits_after_load() {
1262        let mut engine = make_loaded_engine();
1263        // Prefill one token to prime the KV cache (pos becomes 1).
1264        engine
1265            .forward_prefill(&[3], 0)
1266            .expect("prefill must succeed");
1267        // Now KV cache seq_len == 1.
1268        let result = engine.forward_decode(4, 1);
1269        assert!(
1270            result.is_ok(),
1271            "forward_decode must return Ok when model is loaded, got {result:?}"
1272        );
1273        let logits = result.expect("forward_decode Ok");
1274        assert_eq!(
1275            logits.len(),
1276            32,
1277            "logits length must equal vocab_size=32, got {}",
1278            logits.len()
1279        );
1280    }
1281
1282    /// Verify that chunked-prefill produces the same final logits as
1283    /// single-shot prefill (the core KV-state invariant from A3).
1284    ///
1285    /// Both paths must agree on the logit vector produced after processing
1286    /// the same prompt tokens, within floating-point tolerance.
1287    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1288    #[test]
1289    fn chunked_prefill_kv_matches_singleshot() {
1290        let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
1291        let tokenizer_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
1292        let prompt_tokens = vec![3u32, 4, 5, 6];
1293
1294        // ── Single-shot path ──────────────────────────────────────────────────
1295        let mut engine_single = InferenceEngine::new(EngineConfig::default());
1296        engine_single
1297            .load_model_from_bytes(&model_bytes, tokenizer_json)
1298            .expect("single-shot load");
1299        // Fresh engine: pos_start = 0.
1300        let logits_single = engine_single
1301            .forward_prefill(&prompt_tokens, 0)
1302            .expect("single-shot prefill");
1303
1304        // ── Chunked path (chunk = 2) ──────────────────────────────────────────
1305        let mut engine_chunked = InferenceEngine::new(EngineConfig::default());
1306        engine_chunked
1307            .load_model_from_bytes(&model_bytes, tokenizer_json)
1308            .expect("chunked load");
1309
1310        let mut logits_chunked = Vec::new();
1311        let chunk_size = 2usize;
1312        let mut pos = 0usize;
1313        for slice in prompt_tokens.chunks(chunk_size) {
1314            logits_chunked = engine_chunked
1315                .forward_prefill(slice, pos)
1316                .expect("chunked prefill");
1317            pos += slice.len();
1318        }
1319
1320        // ── Compare logits ────────────────────────────────────────────────────
1321        assert_eq!(
1322            logits_single.len(),
1323            logits_chunked.len(),
1324            "logit vector lengths must match"
1325        );
1326        let tol = 1e-4f32;
1327        let max_diff = logits_single
1328            .iter()
1329            .zip(logits_chunked.iter())
1330            .map(|(a, b)| (a - b).abs())
1331            .fold(0.0f32, f32::max);
1332        assert!(
1333            max_diff < tol,
1334            "chunked and single-shot prefill logits differ by {max_diff} > tolerance {tol}"
1335        );
1336    }
1337
1338    // ── End A3 ────────────────────────────────────────────────────────────────
1339
1340    /// embed() must return an error when no model has been loaded,
1341    /// rather than panicking or producing a garbage vector.
1342    #[test]
1343    fn test_embed_returns_err_when_not_loaded() {
1344        let mut engine = InferenceEngine::new(EngineConfig::default());
1345        let result = engine.embed("hello world");
1346        assert!(
1347            result.is_err(),
1348            "embed() should return Err when no model is loaded"
1349        );
1350    }
1351
1352    /// hidden_size() returns None when no model is loaded.
1353    #[test]
1354    fn test_hidden_size_none_when_not_loaded() {
1355        let engine = InferenceEngine::new(EngineConfig::default());
1356        assert!(
1357            engine.hidden_size().is_none(),
1358            "hidden_size() should be None before load_model()"
1359        );
1360    }
1361
1362    /// is_loaded() must be false for a freshly created engine.
1363    #[test]
1364    fn test_is_loaded_false_initially() {
1365        let engine = InferenceEngine::new(EngineConfig::default());
1366        assert!(!engine.is_loaded());
1367    }
1368
1369    #[test]
1370    fn test_model_config_none_when_not_loaded() {
1371        let engine = InferenceEngine::new(EngineConfig::default());
1372        assert!(engine.model_config().is_none());
1373    }
1374
1375    #[test]
1376    fn test_config_roundtrip() {
1377        let cfg = EngineConfig {
1378            model_path: "test.gguf".to_string(),
1379            num_threads: 8,
1380            ..EngineConfig::default()
1381        };
1382        let engine = InferenceEngine::new(cfg);
1383        assert_eq!(engine.config().model_path, "test.gguf");
1384        assert_eq!(engine.config().num_threads, 8);
1385    }
1386
1387    #[test]
1388    fn test_generate_errors_when_not_loaded() {
1389        let mut engine = InferenceEngine::new(EngineConfig::default());
1390        let result = engine.generate("hello", 10, |_| {});
1391        assert!(
1392            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1393            "expected ModelNotLoaded, got {result:?}"
1394        );
1395    }
1396
1397    #[test]
1398    fn test_generate_with_config_errors_when_not_loaded() {
1399        let mut engine = InferenceEngine::new(EngineConfig::default());
1400        let result = engine.generate_with_config("hello", 5, SamplerConfig::greedy(), |_| {});
1401        assert!(
1402            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1403            "expected ModelNotLoaded, got {result:?}"
1404        );
1405    }
1406
1407    #[test]
1408    fn test_tokenize_errors_when_not_loaded() {
1409        let engine = InferenceEngine::new(EngineConfig::default());
1410        let result = engine.tokenize("hello world");
1411        assert!(
1412            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1413            "expected ModelNotLoaded, got {result:?}"
1414        );
1415    }
1416
1417    #[test]
1418    fn test_prefill_errors_when_not_loaded() {
1419        let mut engine = InferenceEngine::new(EngineConfig::default());
1420        let result = engine.prefill(&[1, 2, 3]);
1421        assert!(
1422            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1423            "expected ModelNotLoaded, got {result:?}"
1424        );
1425    }
1426
1427    #[test]
1428    fn test_prefill_empty_slice_ok_when_no_model() {
1429        let mut engine = InferenceEngine::new(EngineConfig::default());
1430        // Empty slice is a no-op and returns Ok regardless of model state.
1431        let result = engine.prefill(&[]);
1432        assert!(result.is_ok(), "empty prefill should be Ok, got {result:?}");
1433    }
1434
1435    #[test]
1436    fn test_forward_one_errors_when_not_loaded() {
1437        let mut engine = InferenceEngine::new(EngineConfig::default());
1438        let result = engine.forward_one(42);
1439        assert!(
1440            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1441            "expected ModelNotLoaded, got {result:?}"
1442        );
1443    }
1444
1445    #[test]
1446    fn test_decode_token_errors_when_not_loaded() {
1447        let engine = InferenceEngine::new(EngineConfig::default());
1448        let result = engine.decode_token(1);
1449        assert!(
1450            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1451            "expected ModelNotLoaded, got {result:?}"
1452        );
1453    }
1454
1455    #[test]
1456    fn test_is_eos_false_when_not_loaded() {
1457        let engine = InferenceEngine::new(EngineConfig::default());
1458        assert!(!engine.is_eos(0));
1459        assert!(!engine.is_eos(u32::MAX));
1460    }
1461
1462    #[test]
1463    fn test_vocab_bytes_none_when_not_loaded() {
1464        let engine = InferenceEngine::new(EngineConfig::default());
1465        assert!(engine.vocab_bytes().is_none());
1466    }
1467
1468    #[test]
1469    fn test_reset_does_not_panic_when_no_kv_cache() {
1470        let mut engine = InferenceEngine::new(EngineConfig::default());
1471        engine.reset(); // should be a no-op, not a panic
1472    }
1473
1474    #[test]
1475    fn test_apply_lora_adapters_errors_when_not_loaded() {
1476        use oxillama_arch::lora::LoadedLora;
1477        let mut engine = InferenceEngine::new(EngineConfig::default());
1478        let lora = LoadedLora {
1479            rank: 8,
1480            alpha: 1.0,
1481            adapters: std::collections::HashMap::new(),
1482        };
1483        let result = engine.apply_lora_adapters(&lora);
1484        assert!(
1485            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1486            "expected ModelNotLoaded, got {result:?}"
1487        );
1488    }
1489
1490    #[test]
1491    fn test_load_model_missing_file_errors() {
1492        let cfg = EngineConfig {
1493            model_path: "/nonexistent/path/model_abc_xyz.gguf".to_string(),
1494            ..EngineConfig::default()
1495        };
1496        let mut engine = InferenceEngine::new(cfg);
1497        let result = engine.load_model();
1498        assert!(
1499            matches!(result, Err(RuntimeError::ModelLoadError { .. })),
1500            "expected ModelLoadError for missing file, got {result:?}"
1501        );
1502    }
1503
1504    #[test]
1505    fn test_load_model_from_bytes_bad_magic_errors() {
1506        let cfg = EngineConfig::default();
1507        let mut engine = InferenceEngine::new(cfg);
1508        // Bytes that look nothing like a GGUF file (wrong magic)
1509        let bad_bytes = b"THIS IS NOT A GGUF FILE AT ALL";
1510        let result = engine.load_model_from_bytes(bad_bytes, "{}");
1511        assert!(
1512            result.is_err(),
1513            "load_model_from_bytes with garbage bytes should error, got Ok(())"
1514        );
1515    }
1516
1517    #[test]
1518    fn test_load_model_from_bytes_empty_errors() {
1519        let cfg = EngineConfig::default();
1520        let mut engine = InferenceEngine::new(cfg);
1521        let result = engine.load_model_from_bytes(&[], "{}");
1522        assert!(
1523            result.is_err(),
1524            "load_model_from_bytes with empty bytes should error"
1525        );
1526    }
1527
1528    #[test]
1529    fn test_engine_config_default_fields() {
1530        let cfg = EngineConfig::default();
1531        assert!(
1532            cfg.model_path.is_empty(),
1533            "default model_path should be empty"
1534        );
1535        assert!(
1536            cfg.tokenizer_path.is_none(),
1537            "default tokenizer_path should be None"
1538        );
1539        assert!(
1540            cfg.context_size.is_none(),
1541            "default context_size should be None"
1542        );
1543        assert_eq!(cfg.num_threads, 4, "default num_threads should be 4");
1544    }
1545
1546    #[test]
1547    fn test_engine_config_context_override() {
1548        let cfg = EngineConfig {
1549            context_size: Some(2048),
1550            ..EngineConfig::default()
1551        };
1552        assert_eq!(cfg.context_size, Some(2048));
1553    }
1554
1555    #[test]
1556    fn test_generate_with_config_errors_when_not_loaded_variant() {
1557        // Additional variant with explicit SamplerConfig fields
1558        let mut engine = InferenceEngine::new(EngineConfig::default());
1559        let sc = SamplerConfig {
1560            temperature: 0.7,
1561            top_k: 40,
1562            ..SamplerConfig::default()
1563        };
1564        let result = engine.generate_with_config("test prompt", 5, sc, |_| {});
1565        assert!(
1566            matches!(result, Err(RuntimeError::ModelNotLoaded)),
1567            "expected ModelNotLoaded, got {result:?}"
1568        );
1569    }
1570
1571    /// load_model() with a file that *exists* but contains garbage (not valid GGUF)
1572    /// must return an error without panicking, exercising the GgufModel::load parse path.
1573    #[test]
1574    fn test_load_model_existing_invalid_file_errors() {
1575        let mut tmp = std::env::temp_dir();
1576        tmp.push("oxillama_engine_bad_magic_test.gguf");
1577        // Write garbage bytes that will fail GGUF magic-byte check.
1578        std::fs::write(&tmp, b"NOT A GGUF FILE AT ALL - GARBAGE BYTES 0123456789")
1579            .expect("write temp file");
1580        let cfg = EngineConfig {
1581            model_path: tmp
1582                .to_str()
1583                .expect("temp path must be valid UTF-8")
1584                .to_string(),
1585            ..EngineConfig::default()
1586        };
1587        let mut engine = InferenceEngine::new(cfg);
1588        let result = engine.load_model();
1589        // Clean up before asserting so the file is always removed.
1590        let _ = std::fs::remove_file(&tmp);
1591        assert!(
1592            result.is_err(),
1593            "load_model with invalid GGUF content should return Err"
1594        );
1595    }
1596
1597    /// is_loaded() must remain false after a failed load_model() call.
1598    #[test]
1599    fn test_is_loaded_remains_false_after_failed_load() {
1600        let cfg = EngineConfig {
1601            model_path: "/nonexistent/guaranteed_missing_model.gguf".to_string(),
1602            ..EngineConfig::default()
1603        };
1604        let mut engine = InferenceEngine::new(cfg);
1605        // This must fail (file doesn't exist).
1606        let _ = engine.load_model();
1607        assert!(
1608            !engine.is_loaded(),
1609            "is_loaded() must be false after a failed load_model()"
1610        );
1611    }
1612
1613    /// EngineConfig implements Clone; verify the clone is independent.
1614    #[test]
1615    fn test_engine_config_clone_is_independent() {
1616        let original = EngineConfig {
1617            model_path: "original.gguf".to_string(),
1618            num_threads: 16,
1619            context_size: Some(4096),
1620            ..EngineConfig::default()
1621        };
1622        let mut cloned = original.clone();
1623        cloned.model_path = "cloned.gguf".to_string();
1624        cloned.num_threads = 1;
1625        // Original must be unaffected.
1626        assert_eq!(original.model_path, "original.gguf");
1627        assert_eq!(original.num_threads, 16);
1628        assert_eq!(original.context_size, Some(4096));
1629    }
1630
1631    // -------------------------------------------------------------------------
1632    // Tests backed by the synthetic GGUF fixture (requires tokenizer feature)
1633    // -------------------------------------------------------------------------
1634
1635    /// Return a loaded engine from the synthetic GGUF + tokenizer.
1636    /// These tests are only meaningful when a tokenizer backend is active.
1637    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1638    fn make_loaded_engine() -> InferenceEngine {
1639        let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
1640        let tokenizer_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
1641        let mut engine = InferenceEngine::new(EngineConfig::default());
1642        engine
1643            .load_model_from_bytes(&model_bytes, tokenizer_json)
1644            .expect("synthetic GGUF must load successfully");
1645        engine
1646    }
1647
1648    /// load_model_from_bytes with the synthetic fixture must succeed and
1649    /// set is_loaded() to true.
1650    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1651    #[test]
1652    fn test_load_model_from_bytes_succeeds() {
1653        let engine = make_loaded_engine();
1654        assert!(
1655            engine.is_loaded(),
1656            "is_loaded() must be true after a successful load_model_from_bytes()"
1657        );
1658    }
1659
1660    /// model_config() must be Some with the expected hidden size after loading.
1661    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1662    #[test]
1663    fn test_hidden_size_after_load() {
1664        let engine = make_loaded_engine();
1665        let hs = engine.hidden_size();
1666        assert_eq!(
1667            hs,
1668            Some(32),
1669            "hidden_size() must be Some(32) after loading the synthetic model, got {hs:?}"
1670        );
1671    }
1672
1673    /// tokenize() must return Ok with at least one token after loading.
1674    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1675    #[test]
1676    fn test_tokenize_after_load() {
1677        let engine = make_loaded_engine();
1678        let result = engine.tokenize("abc");
1679        assert!(
1680            result.is_ok(),
1681            "tokenize() must return Ok after model is loaded, got {result:?}"
1682        );
1683        let tokens = result.expect("tokenize succeeded");
1684        assert!(
1685            !tokens.is_empty(),
1686            "tokenize('abc') must produce at least one token"
1687        );
1688    }
1689
1690    /// is_eos() must return true for token id 2 (</s> = EOS in the synthetic tokenizer).
1691    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1692    #[test]
1693    fn test_is_eos_after_load() {
1694        let engine = make_loaded_engine();
1695        assert!(
1696            engine.is_eos(2),
1697            "is_eos(2) must be true — </s> is the EOS token in the synthetic tokenizer"
1698        );
1699        assert!(
1700            !engine.is_eos(3),
1701            "is_eos(3) must be false — token 3 ('a') is not EOS"
1702        );
1703    }
1704
1705    /// decode_token(3) should decode successfully (token 3 = 'a' in the synthetic vocab).
1706    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1707    #[test]
1708    fn test_decode_token_after_load() {
1709        let engine = make_loaded_engine();
1710        let result = engine.decode_token(3);
1711        assert!(
1712            result.is_ok(),
1713            "decode_token(3) must return Ok, got {result:?}"
1714        );
1715    }
1716
1717    /// generate() must return Ok after loading; the returned string may be empty
1718    /// if the EOS token is sampled immediately, but it must not panic or error.
1719    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1720    #[test]
1721    fn test_generate_after_load() {
1722        let mut engine = make_loaded_engine();
1723        let result = engine.generate("a", 3, |_| {});
1724        assert!(
1725            result.is_ok(),
1726            "generate() must return Ok after model is loaded, got {result:?}"
1727        );
1728    }
1729
1730    /// generate() with max_tokens=5 must produce at most 5 tokens of output.
1731    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1732    #[test]
1733    fn test_generate_respects_max_tokens() {
1734        let mut engine = make_loaded_engine();
1735        let max = 5usize;
1736        // Count callback invocations as a proxy for generated tokens.
1737        let mut count = 0usize;
1738        let result = engine.generate("a", max, |_tok| {
1739            count += 1;
1740        });
1741        assert!(result.is_ok(), "generate() must return Ok, got {result:?}");
1742        assert!(
1743            count <= max,
1744            "callback was invoked {count} times but max_tokens={max}"
1745        );
1746    }
1747
1748    /// generate_streaming — count callback invocations to verify the streaming path fires.
1749    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1750    #[test]
1751    fn test_generate_streaming_calls_callback() {
1752        let mut engine = make_loaded_engine();
1753        let mut invocations = 0usize;
1754        let max_tokens = 4;
1755        let result = engine.generate("a", max_tokens, |_piece| {
1756            invocations += 1;
1757        });
1758        assert!(
1759            result.is_ok(),
1760            "generate() streaming path must return Ok, got {result:?}"
1761        );
1762        // invocations may be 0 if EOS is sampled on the first step; just assert <= max.
1763        assert!(
1764            invocations <= max_tokens,
1765            "streaming callback fired {invocations} > max_tokens={max_tokens}"
1766        );
1767    }
1768
1769    /// embed() must return Ok with a non-empty vector after loading.
1770    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1771    #[test]
1772    fn test_embed_after_load() {
1773        let mut engine = make_loaded_engine();
1774        let result = engine.embed("a");
1775        assert!(
1776            result.is_ok(),
1777            "embed() must return Ok after model is loaded, got {result:?}"
1778        );
1779        let vec = result.expect("embed succeeded");
1780        assert!(!vec.is_empty(), "embed() must return a non-empty vector");
1781    }
1782
1783    /// embed() must return a vector of length == hidden_size (32 for the synthetic model).
1784    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1785    #[test]
1786    fn test_embed_returns_hidden_size_vector() {
1787        let mut engine = make_loaded_engine();
1788        let vec = engine
1789            .embed("a")
1790            .expect("embed() must succeed after loading");
1791        assert_eq!(
1792            vec.len(),
1793            32,
1794            "embed() vector length must equal hidden_size=32, got {}",
1795            vec.len()
1796        );
1797    }
1798
1799    /// Reload: loading the model a second time must succeed and leave is_loaded() true.
1800    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1801    #[test]
1802    fn test_reload_model_succeeds() {
1803        let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
1804        let tokenizer_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
1805        let mut engine = InferenceEngine::new(EngineConfig::default());
1806
1807        // First load.
1808        engine
1809            .load_model_from_bytes(&model_bytes, tokenizer_json)
1810            .expect("first load must succeed");
1811        assert!(engine.is_loaded(), "is_loaded() after first load");
1812
1813        // Second load (reload).
1814        engine
1815            .load_model_from_bytes(&model_bytes, tokenizer_json)
1816            .expect("second (re)load must succeed");
1817        assert!(
1818            engine.is_loaded(),
1819            "is_loaded() after reload must still be true"
1820        );
1821    }
1822
1823    /// vocab_bytes() must return Some with non-empty entries after loading.
1824    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1825    #[test]
1826    fn test_vocab_bytes_some_after_load() {
1827        let engine = make_loaded_engine();
1828        let vb = engine.vocab_bytes();
1829        assert!(
1830            vb.is_some(),
1831            "vocab_bytes() must be Some after model is loaded"
1832        );
1833        let entries = vb.expect("vocab_bytes is Some");
1834        assert!(
1835            !entries.is_empty(),
1836            "vocab_bytes() must contain at least one entry"
1837        );
1838    }
1839
1840    /// model_config() must return Some with correct metadata after loading.
1841    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1842    #[test]
1843    fn test_model_config_some_after_load() {
1844        let engine = make_loaded_engine();
1845        let cfg = engine.model_config();
1846        assert!(cfg.is_some(), "model_config() must be Some after loading");
1847        let mc = cfg.expect("model_config is Some");
1848        assert_eq!(mc.architecture, "llama", "architecture must be 'llama'");
1849        assert_eq!(
1850            mc.num_layers, 1,
1851            "num_layers must be 1 for the synthetic model"
1852        );
1853        assert_eq!(mc.vocab_size, 32, "vocab_size must be 32");
1854    }
1855
1856    /// reset() must not panic when a model is loaded, and the engine remains usable.
1857    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1858    #[test]
1859    fn test_reset_when_loaded_does_not_panic() {
1860        let mut engine = make_loaded_engine();
1861        engine.reset(); // must not panic
1862                        // After reset, basic queries still work.
1863        assert!(
1864            engine.is_loaded(),
1865            "is_loaded() must still be true after reset()"
1866        );
1867        assert_eq!(engine.hidden_size(), Some(32));
1868    }
1869
1870    // -------------------------------------------------------------------------
1871    // Architecture forward-pass integration tests
1872    // -------------------------------------------------------------------------
1873    // Each test loads the synthetic GGUF for a specific architecture, verifies
1874    // is_loaded(), runs generate() for 2 tokens, and verifies the call
1875    // succeeds.  Two architectures additionally run embed() to cover the
1876    // embedding endpoint code path.
1877
1878    /// Qwen3 forward pass: load, generate, assert ok.
1879    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1880    #[test]
1881    fn test_generate_qwen3_arch() {
1882        use oxillama_gguf::test_utils::{build_minimal_qwen3_gguf, minimal_tokenizer_json};
1883
1884        let bytes = build_minimal_qwen3_gguf();
1885        let json = minimal_tokenizer_json();
1886        let mut engine = InferenceEngine::new(EngineConfig::default());
1887        engine
1888            .load_model_from_bytes(&bytes, json)
1889            .expect("test: load qwen3");
1890        assert!(engine.is_loaded(), "qwen3: is_loaded() must be true");
1891        let _out = engine
1892            .generate("abc", 2, |_| {})
1893            .expect("test: generate qwen3");
1894    }
1895
1896    /// Qwen3 embed: load and call embed().
1897    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1898    #[test]
1899    fn test_embed_qwen3_arch() {
1900        use oxillama_gguf::test_utils::{build_minimal_qwen3_gguf, minimal_tokenizer_json};
1901
1902        let bytes = build_minimal_qwen3_gguf();
1903        let json = minimal_tokenizer_json();
1904        let mut engine = InferenceEngine::new(EngineConfig::default());
1905        engine
1906            .load_model_from_bytes(&bytes, json)
1907            .expect("test: load qwen3 for embed");
1908        let vec = engine.embed("abc").expect("test: embed qwen3");
1909        assert_eq!(
1910            vec.len(),
1911            32,
1912            "qwen3 embed must return hidden_size=32 vector"
1913        );
1914    }
1915
1916    /// Mistral forward pass: load, generate with sliding window, assert ok.
1917    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1918    #[test]
1919    fn test_generate_mistral_arch() {
1920        use oxillama_gguf::test_utils::{build_minimal_mistral_gguf, minimal_tokenizer_json};
1921
1922        let bytes = build_minimal_mistral_gguf();
1923        let json = minimal_tokenizer_json();
1924        let mut engine = InferenceEngine::new(EngineConfig::default());
1925        engine
1926            .load_model_from_bytes(&bytes, json)
1927            .expect("test: load mistral");
1928        assert!(engine.is_loaded(), "mistral: is_loaded() must be true");
1929        let _out = engine
1930            .generate("abc", 2, |_| {})
1931            .expect("test: generate mistral");
1932    }
1933
1934    /// Gemma forward pass: load with soft-capping metadata, generate, assert ok.
1935    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1936    #[test]
1937    fn test_generate_gemma_arch() {
1938        use oxillama_gguf::test_utils::{build_minimal_gemma_gguf, minimal_tokenizer_json};
1939
1940        let bytes = build_minimal_gemma_gguf();
1941        let json = minimal_tokenizer_json();
1942        let mut engine = InferenceEngine::new(EngineConfig::default());
1943        engine
1944            .load_model_from_bytes(&bytes, json)
1945            .expect("test: load gemma");
1946        assert!(engine.is_loaded(), "gemma: is_loaded() must be true");
1947        let _out = engine
1948            .generate("abc", 2, |_| {})
1949            .expect("test: generate gemma");
1950    }
1951
1952    /// Gemma embed: load and call embed() to cover the Gemma embedding path.
1953    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1954    #[test]
1955    fn test_embed_gemma_arch() {
1956        use oxillama_gguf::test_utils::{build_minimal_gemma_gguf, minimal_tokenizer_json};
1957
1958        let bytes = build_minimal_gemma_gguf();
1959        let json = minimal_tokenizer_json();
1960        let mut engine = InferenceEngine::new(EngineConfig::default());
1961        engine
1962            .load_model_from_bytes(&bytes, json)
1963            .expect("test: load gemma for embed");
1964        let vec = engine.embed("abc").expect("test: embed gemma");
1965        assert_eq!(
1966            vec.len(),
1967            32,
1968            "gemma embed must return hidden_size=32 vector"
1969        );
1970    }
1971
1972    /// Phi-3 forward pass: merged QKV + partial RoPE, generate, assert ok.
1973    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1974    #[test]
1975    fn test_generate_phi3_arch() {
1976        use oxillama_gguf::test_utils::{build_minimal_phi3_gguf, minimal_tokenizer_json};
1977
1978        let bytes = build_minimal_phi3_gguf();
1979        let json = minimal_tokenizer_json();
1980        let mut engine = InferenceEngine::new(EngineConfig::default());
1981        engine
1982            .load_model_from_bytes(&bytes, json)
1983            .expect("test: load phi3");
1984        assert!(engine.is_loaded(), "phi3: is_loaded() must be true");
1985        let _out = engine
1986            .generate("abc", 2, |_| {})
1987            .expect("test: generate phi3");
1988    }
1989
1990    /// Command-R forward pass: logit scaling, generate, assert ok.
1991    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1992    #[test]
1993    fn test_generate_command_r_arch() {
1994        use oxillama_gguf::test_utils::{build_minimal_command_r_gguf, minimal_tokenizer_json};
1995
1996        let bytes = build_minimal_command_r_gguf();
1997        let json = minimal_tokenizer_json();
1998        let mut engine = InferenceEngine::new(EngineConfig::default());
1999        engine
2000            .load_model_from_bytes(&bytes, json)
2001            .expect("test: load command-r");
2002        assert!(engine.is_loaded(), "command-r: is_loaded() must be true");
2003        let _out = engine
2004            .generate("abc", 2, |_| {})
2005            .expect("test: generate command-r");
2006    }
2007
2008    /// StarCoder forward pass: MQA + absolute position embeddings, generate, assert ok.
2009    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
2010    #[test]
2011    fn test_generate_starcoder_arch() {
2012        use oxillama_gguf::test_utils::{build_minimal_starcoder_gguf, minimal_tokenizer_json};
2013
2014        let bytes = build_minimal_starcoder_gguf();
2015        let json = minimal_tokenizer_json();
2016        let mut engine = InferenceEngine::new(EngineConfig::default());
2017        engine
2018            .load_model_from_bytes(&bytes, json)
2019            .expect("test: load starcoder");
2020        assert!(engine.is_loaded(), "starcoder: is_loaded() must be true");
2021        let _out = engine
2022            .generate("abc", 2, |_| {})
2023            .expect("test: generate starcoder");
2024    }
2025
2026    // -------------------------------------------------------------------------
2027    // Multi-LoRA hot-swap tests
2028    // -------------------------------------------------------------------------
2029
2030    #[test]
2031    fn lora_stack_push_pop() {
2032        use oxillama_arch::lora::LoadedLora;
2033        use oxillama_quant::LoraAdapter;
2034        use std::collections::HashMap;
2035        use std::sync::Arc;
2036
2037        fn make_lora() -> Arc<LoadedLora> {
2038            let adapter = LoraAdapter::new(vec![0.0f32; 4 * 8], vec![0.0f32; 8 * 4], 4, 1.0, 8, 8)
2039                .expect("valid lora adapter");
2040            let mut adapters = HashMap::new();
2041            adapters.insert("test.weight".to_string(), Arc::new(adapter));
2042            Arc::new(LoadedLora {
2043                adapters,
2044                rank: 4,
2045                alpha: 1.0,
2046            })
2047        }
2048
2049        let mut engine = InferenceEngine::new(EngineConfig::default());
2050
2051        // Initially empty
2052        assert!(engine.lora_stack().is_empty());
2053        assert_eq!(engine.lora_stack().len(), 0);
2054
2055        // Push two adapters
2056        engine.push_lora(make_lora(), 1.0);
2057        engine.push_lora(make_lora(), 0.5);
2058        assert_eq!(engine.lora_stack().len(), 2);
2059        assert!(!engine.lora_stack().is_empty());
2060
2061        // Pop one
2062        let popped = engine.pop_lora();
2063        assert!(popped.is_some());
2064        let (_, scale) = popped.expect("pop must return Some");
2065        assert!((scale - 0.5).abs() < 1e-6);
2066        assert_eq!(engine.lora_stack().len(), 1);
2067
2068        // Clear
2069        engine.clear_loras();
2070        assert!(engine.lora_stack().is_empty());
2071
2072        // Pop from empty returns None
2073        assert!(engine.pop_lora().is_none());
2074    }
2075
2076    #[test]
2077    fn lora_apply_stack_errors_when_not_loaded() {
2078        use oxillama_arch::lora::LoadedLora;
2079        use oxillama_quant::LoraAdapter;
2080        use std::collections::HashMap;
2081        use std::sync::Arc;
2082
2083        let adapter = LoraAdapter::new(vec![0.0f32; 4 * 8], vec![0.0f32; 8 * 4], 4, 1.0, 8, 8)
2084            .expect("valid lora adapter");
2085        let mut adapters = HashMap::new();
2086        adapters.insert("test.weight".to_string(), Arc::new(adapter));
2087        let lora = Arc::new(LoadedLora {
2088            adapters,
2089            rank: 4,
2090            alpha: 1.0,
2091        });
2092
2093        let mut engine = InferenceEngine::new(EngineConfig::default());
2094        engine.push_lora(lora, 1.0);
2095        let result = engine.apply_lora_stack();
2096        assert!(
2097            matches!(result, Err(RuntimeError::ModelNotLoaded)),
2098            "expected ModelNotLoaded, got {:?}",
2099            result
2100        );
2101    }
2102
2103    /// `unapply_all_loras` on an unloaded engine must not panic.
2104    #[test]
2105    fn unapply_all_loras_noop_when_unloaded() {
2106        let mut engine = InferenceEngine::new(EngineConfig::default());
2107        engine.unapply_all_loras(); // must not panic
2108        assert!(!engine.is_loaded());
2109    }
2110
2111    /// `prime_with_prefix` on an unloaded engine returns `ModelNotLoaded`.
2112    #[test]
2113    fn prime_with_prefix_returns_model_not_loaded() {
2114        use crate::kv_cache::prefix::{PrefixCacheConfig, PrefixKvCache};
2115        use crate::kv_cache::KvCache;
2116
2117        let mut engine = InferenceEngine::new(EngineConfig::default());
2118
2119        // Build a minimal prefix cache and store a tiny entry so lookup returns Some.
2120        let mut prefix_cache = PrefixKvCache::new(PrefixCacheConfig {
2121            max_entries: 16,
2122            max_memory_bytes: 1024 * 1024,
2123            min_prefix_len: 1, // allow tiny prefixes
2124        });
2125        let kv = KvCache::new(1, 4, 32);
2126        let tokens: Vec<u32> = vec![1, 2, 3];
2127        prefix_cache.store(&tokens, &kv, 3, 4, 1);
2128
2129        if let Some((match_len, cached)) = prefix_cache.lookup(&tokens) {
2130            // Pass one suffix token so we don't hit the empty-suffix guard.
2131            let suffix = &tokens[match_len.min(tokens.len() - 1)..];
2132            let result = engine.prime_with_prefix(cached, match_len.saturating_sub(1), suffix);
2133            assert!(
2134                matches!(result, Err(RuntimeError::ModelNotLoaded)),
2135                "unloaded engine must return ModelNotLoaded, got {:?}",
2136                result
2137            );
2138        }
2139        // If store did not cache (e.g. tokens.len() < min_prefix_len for some
2140        // reason), the assertion is vacuously satisfied — the important property
2141        // is that prime_with_prefix never panics when called without a model.
2142    }
2143}