Skip to main content

llama_gguf/
engine.rs

1//! High-level inference engine for llama-gguf.
2//!
3//! Provides [`Engine`] and [`ChatEngine`] for easy model loading and text generation
4//! without needing to manually wire together GGUF files, tokenizers, backends, and samplers.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use llama_gguf::engine::{Engine, EngineConfig};
10//!
11//! let engine = Engine::load(EngineConfig {
12//!     model_path: "model.gguf".into(),
13//!     ..Default::default()
14//! }).unwrap();
15//!
16//! let response = engine.generate("What is a tort?", 256).unwrap();
17//! println!("{}", response);
18//! ```
19
20use std::sync::Arc;
21
22use crate::backend::Backend;
23use crate::gguf::GgufFile;
24use crate::model::{
25    EmbeddingConfig, EmbeddingExtractor, InferenceContext, Model, ModelConfig, ModelLoader,
26};
27use crate::sampling::{Sampler, SamplerConfig};
28use crate::tokenizer::Tokenizer;
29
30// ============================================================================
31// Error type
32// ============================================================================
33
34/// Errors that can occur during engine operations.
35#[derive(thiserror::Error, Debug)]
36pub enum EngineError {
37    #[error("IO error: {0}")]
38    Io(#[from] std::io::Error),
39
40    #[error("GGUF error: {0}")]
41    Gguf(#[from] crate::gguf::GgufError),
42
43    #[error("Model error: {0}")]
44    Model(#[from] crate::model::ModelError),
45
46    #[error("Tokenizer error: {0}")]
47    Tokenizer(#[from] crate::tokenizer::TokenizerError),
48
49    #[error("Embedding error: {0}")]
50    Embedding(#[from] crate::model::EmbeddingError),
51
52    #[error("Engine error: {0}")]
53    Other(String),
54}
55
56// ============================================================================
57// Configuration
58// ============================================================================
59
60/// Configuration for creating an [`Engine`].
61///
62/// Can be constructed manually, from a [`Config`](crate::config::Config) TOML file,
63/// or with [`Default::default()`].
64#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
65#[serde(default)]
66pub struct EngineConfig {
67    /// Path to the model file (GGUF or ONNX).
68    pub model_path: String,
69
70    /// Optional path to a tokenizer file.
71    ///
72    /// For ONNX models, this defaults to `tokenizer.json` in the same directory.
73    /// Can also point to a GGUF file to extract just the tokenizer.
74    pub tokenizer_path: Option<String>,
75
76    /// Temperature for sampling (0.0 = greedy, higher = more random).
77    pub temperature: f32,
78
79    /// Top-K sampling: only consider the K most likely tokens (0 = disabled).
80    pub top_k: usize,
81
82    /// Top-P (nucleus) sampling: only consider tokens with cumulative probability <= p.
83    pub top_p: f32,
84
85    /// Repetition penalty (1.0 = no penalty).
86    pub repeat_penalty: f32,
87
88    /// Default maximum tokens to generate.
89    pub max_tokens: usize,
90
91    /// Random seed for reproducible generation (None = random).
92    pub seed: Option<u64>,
93
94    /// Use GPU acceleration (requires `cuda` feature).
95    pub use_gpu: bool,
96
97    /// Maximum context length override.
98    ///
99    /// If set, caps the model's context length (and thus KV cache size).
100    /// The model's native `max_seq_len` from GGUF metadata can be very large
101    /// (e.g. 32768) which may exhaust GPU memory for the KV cache alone.
102    /// Set this to a smaller value (e.g. 2048) to reduce VRAM usage.
103    /// If `None` or `0`, uses the model's native `max_seq_len`.
104    pub max_context_len: Option<usize>,
105
106    /// Optional Hailo accelerator configuration.
107    /// When set, enables Hailo NPU inference in the GPU selection chain.
108    #[cfg(feature = "hailo")]
109    pub hailo_config: Option<crate::backend::hailo::HailoConfig>,
110
111    /// KV cache type for memory-efficient inference.
112    pub kv_cache_type: crate::model::KVCacheType,
113}
114
115impl Default for EngineConfig {
116    fn default() -> Self {
117        Self {
118            model_path: String::new(),
119            tokenizer_path: None,
120            temperature: 0.7,
121            top_k: 40,
122            top_p: 0.95,
123            repeat_penalty: 1.1,
124            max_tokens: 512,
125            seed: None,
126            use_gpu: false,
127            max_context_len: None,
128            #[cfg(feature = "hailo")]
129            hailo_config: None,
130            kv_cache_type: crate::model::KVCacheType::F32,
131        }
132    }
133}
134
135impl EngineConfig {
136    /// Load an `EngineConfig` from a [`Config`](crate::config::Config) TOML file.
137    ///
138    /// This is a convenience method that loads the full config and extracts
139    /// the engine-relevant sections.
140    pub fn from_config_file(
141        path: impl AsRef<std::path::Path>,
142    ) -> Result<Self, crate::config::ConfigError> {
143        let config = crate::config::Config::from_file(path)?;
144        Ok(config.to_engine_config(None))
145    }
146
147    /// Load an `EngineConfig` using the full config precedence chain.
148    ///
149    /// Searches default config file locations, then applies environment
150    /// variable overrides.
151    pub fn from_config(
152        config_path: Option<impl AsRef<std::path::Path>>,
153    ) -> Result<Self, crate::config::ConfigError> {
154        let config = crate::config::Config::load(config_path)?;
155        Ok(config.to_engine_config(None))
156    }
157}
158
159// ============================================================================
160// Chat template detection
161// ============================================================================
162
163/// Detected chat template format from the GGUF metadata.
164#[derive(Debug, Clone, PartialEq)]
165pub enum ChatTemplate {
166    /// `<|user|>\n...<|assistant|>\n` format
167    UserAssistant,
168    /// ChatML: `<|im_start|>user\n...<|im_end|>\n<|im_start|>assistant\n`
169    ChatML,
170    /// Llama-2: `[INST] ... [/INST]`
171    Llama2,
172    /// No template detected, use raw text.
173    None,
174}
175
176impl ChatTemplate {
177    /// Detect chat template from model type string (for ONNX models without GGUF metadata).
178    pub fn detect_from_model_type(model_type: Option<&str>) -> Self {
179        match model_type {
180            Some("qwen2" | "qwen") => ChatTemplate::ChatML,
181            Some("llama" | "codellama") => ChatTemplate::Llama2,
182            Some("mistral" | "mixtral") => ChatTemplate::Llama2,
183            _ => ChatTemplate::None,
184        }
185    }
186
187    /// Detect chat template from GGUF metadata.
188    pub fn detect(gguf: &GgufFile) -> Self {
189        if let Some(template) = gguf.data.get_string("tokenizer.chat_template") {
190            if template.contains("<|user|>") {
191                ChatTemplate::UserAssistant
192            } else if template.contains("<|im_start|>") {
193                ChatTemplate::ChatML
194            } else if template.contains("[INST]") {
195                ChatTemplate::Llama2
196            } else {
197                ChatTemplate::None
198            }
199        } else if let Some(arch) = gguf.data.get_string("general.architecture") {
200            match arch.to_lowercase().as_str() {
201                "qwen2" | "qwen" | "qwen3" | "qwen35" | "qwen3moe" | "qwen3next" => {
202                    ChatTemplate::ChatML
203                }
204                _ => ChatTemplate::None,
205            }
206        } else {
207            ChatTemplate::None
208        }
209    }
210
211    /// Wrap a raw prompt in the appropriate chat format.
212    pub fn wrap_prompt(&self, prompt: &str) -> String {
213        // If prompt already contains chat tokens, return as-is
214        if prompt.contains("<|user|>")
215            || prompt.contains("<|im_start|>")
216            || prompt.contains("[INST]")
217        {
218            return prompt.to_string();
219        }
220
221        match self {
222            ChatTemplate::UserAssistant => {
223                format!("<|user|>\n{}<|assistant|>\n", prompt)
224            }
225            ChatTemplate::ChatML => {
226                format!(
227                    "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
228                    prompt
229                )
230            }
231            ChatTemplate::Llama2 => {
232                format!("[INST] {} [/INST]", prompt)
233            }
234            ChatTemplate::None => prompt.to_string(),
235        }
236    }
237
238    /// Format a chat message with system prompt for the first turn.
239    pub fn format_first_turn(&self, system_prompt: &str, user_message: &str) -> String {
240        match self {
241            ChatTemplate::UserAssistant => {
242                format!(
243                    "<|system|>\n{}<|user|>\n{}<|assistant|>\n",
244                    system_prompt, user_message
245                )
246            }
247            ChatTemplate::ChatML => {
248                format!(
249                    "<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
250                    system_prompt, user_message
251                )
252            }
253            ChatTemplate::Llama2 => {
254                format!(
255                    "[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]",
256                    system_prompt, user_message
257                )
258            }
259            ChatTemplate::None => {
260                format!(
261                    "System: {}\n\nUser: {}\n\nAssistant:",
262                    system_prompt, user_message
263                )
264            }
265        }
266    }
267
268    /// Format a continuation turn (not the first message).
269    pub fn format_continuation(&self, user_message: &str) -> String {
270        match self {
271            ChatTemplate::UserAssistant => {
272                format!("<|user|>\n{}<|assistant|>\n", user_message)
273            }
274            ChatTemplate::ChatML => {
275                format!(
276                    "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
277                    user_message
278                )
279            }
280            ChatTemplate::Llama2 => {
281                format!(" [INST] {} [/INST]", user_message)
282            }
283            ChatTemplate::None => {
284                format!("\n\nUser: {}\n\nAssistant:", user_message)
285            }
286        }
287    }
288
289    /// Patterns that indicate the model is trying to generate a new user turn
290    /// (i.e., the response is complete).
291    pub fn stop_patterns(&self) -> &[&str] {
292        match self {
293            ChatTemplate::UserAssistant => &["<|user|>", "<|end|>"],
294            ChatTemplate::ChatML => &["<|im_end|>", "<|im_start|>"],
295            ChatTemplate::Llama2 => &["[INST]", "</s>"],
296            ChatTemplate::None => &["User:", "\nUser:"],
297        }
298    }
299}
300
301// ============================================================================
302// Engine
303// ============================================================================
304
305/// High-level inference engine that wraps model loading, tokenization, and generation.
306///
307/// `Engine` is `Send + Sync` safe for the immutable model and config, but the
308/// mutable inference context and sampler are created per-call via internal state.
309pub struct Engine {
310    gguf: Option<GgufFile>,
311    model: Box<dyn Model>,
312    tokenizer: Tokenizer,
313    config: ModelConfig,
314    backend: Arc<dyn Backend>,
315    sampler_config: SamplerConfig,
316    chat_template: ChatTemplate,
317    add_bos: bool,
318    engine_config: EngineConfig,
319}
320
321impl Engine {
322    /// Load a model and create an inference engine.
323    ///
324    /// This opens the model file (GGUF or ONNX), loads the tokenizer and model weights,
325    /// and selects the appropriate backend (CPU or GPU).
326    ///
327    /// Format is auto-detected by file extension:
328    /// - `.gguf` -- GGUF format (default)
329    /// - `.onnx` -- ONNX format (requires `onnx` feature, companion config.json + tokenizer.json)
330    pub fn load(config: EngineConfig) -> Result<Self, EngineError> {
331        if config.model_path.is_empty() {
332            return Err(EngineError::Other("model_path is required".into()));
333        }
334
335        let path = std::path::Path::new(&config.model_path);
336
337        // Detect format by extension
338        match path.extension().and_then(|e| e.to_str()) {
339            #[cfg(feature = "onnx")]
340            Some("onnx") => Self::load_onnx(config),
341            #[cfg(not(feature = "onnx"))]
342            Some("onnx") => Err(EngineError::Other(
343                "ONNX support requires the `onnx` feature. Build with: cargo build --features onnx"
344                    .into(),
345            )),
346            _ => Self::load_gguf(config),
347        }
348    }
349
350    /// Load from a GGUF model file (existing path).
351    fn load_gguf(config: EngineConfig) -> Result<Self, EngineError> {
352        tracing::info!("Loading GGUF model from: {}", config.model_path);
353
354        // Load GGUF file
355        let gguf = GgufFile::open(&config.model_path)?;
356
357        // Load tokenizer
358        tracing::info!("Loading tokenizer...");
359        let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
360            // User-specified tokenizer (could be a tokenizer.json or a GGUF file)
361            if tok_path.ends_with(".json") {
362                Tokenizer::from_hf_json(tok_path)?
363            } else {
364                let tok_gguf = GgufFile::open(tok_path)?;
365                Tokenizer::from_gguf(&tok_gguf)?
366            }
367        } else {
368            Tokenizer::from_gguf(&gguf)?
369        };
370        tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
371
372        // Load model weights
373        tracing::info!("Loading model weights...");
374        let loader = ModelLoader::load(&config.model_path)?;
375        let model_config = loader.config().clone();
376        tracing::info!(
377            "Model: {} layers, {} heads, {} hidden dim, {} ctx",
378            model_config.num_layers,
379            model_config.num_heads,
380            model_config.hidden_size,
381            model_config.max_seq_len,
382        );
383
384        let arch = loader.architecture();
385
386        let (backend, model): (Arc<dyn Backend>, Box<dyn Model>) = if arch.is_encoder_only() {
387            tracing::info!("Detected encoder-only architecture: {:?}", arch);
388            let bert_model = loader.build_bert_model()?;
389            (
390                Arc::new(crate::backend::cpu::CpuBackend::new()),
391                Box::new(bert_model),
392            )
393        } else {
394            let concrete_model = loader.build_model()?;
395
396            // When CUDA is available we try to create a GpuModelWrapper first.
397            // This runs the entire forward pass on GPU with pre-allocated scratch
398            // buffers, eliminating ~770 host↔device transfers per token that the
399            // standard Backend-trait path incurs.  If GPU-only init fails we
400            // fall back to the regular CudaBackend (per-op transfers) or CPU.
401            if config.use_gpu {
402                Self::select_gpu_model(concrete_model, &model_config, &config)
403            } else {
404                (
405                    Arc::new(crate::backend::cpu::CpuBackend::new()),
406                    Box::new(concrete_model),
407                )
408            }
409        };
410
411        // Detect chat template
412        let chat_template = ChatTemplate::detect(&gguf);
413        tracing::info!("Chat template: {:?}", chat_template);
414
415        // Check BOS token preference.
416        // If the GGUF explicitly says add_bos_token, use that. Otherwise only
417        // add BOS when the model actually defines a bos_token_id.
418        let add_bos = gguf
419            .data
420            .get_bool("tokenizer.ggml.add_bos_token")
421            .unwrap_or(tokenizer.has_explicit_bos);
422
423        // Build sampler config
424        let sampler_config = SamplerConfig {
425            temperature: config.temperature,
426            top_k: config.top_k,
427            top_p: config.top_p,
428            repeat_penalty: config.repeat_penalty,
429            seed: config.seed,
430            ..Default::default()
431        };
432
433        tracing::info!("Engine ready");
434
435        Ok(Self {
436            gguf: Some(gguf),
437            model,
438            tokenizer,
439            config: model_config,
440            backend,
441            sampler_config,
442            chat_template,
443            add_bos,
444            engine_config: config,
445        })
446    }
447
448    /// Load from an ONNX model file with companion config.json and tokenizer.json.
449    #[cfg(feature = "onnx")]
450    fn load_onnx(config: EngineConfig) -> Result<Self, EngineError> {
451        use crate::onnx::OnnxModelLoader;
452
453        tracing::info!("Loading ONNX model from: {}", config.model_path);
454
455        let model_dir = std::path::Path::new(&config.model_path)
456            .parent()
457            .unwrap_or(std::path::Path::new("."));
458
459        // Load model via ONNX loader
460        let loader = OnnxModelLoader::load(&config.model_path)
461            .map_err(|e| EngineError::Other(format!("ONNX load error: {}", e)))?;
462        let model_config = loader.config().clone();
463        let hf_config = loader.hf_config().clone();
464
465        tracing::info!(
466            "Model: {} layers, {} heads, {} hidden dim, {} ctx",
467            model_config.num_layers,
468            model_config.num_heads,
469            model_config.hidden_size,
470            model_config.max_seq_len,
471        );
472
473        let concrete_model = loader
474            .build_model()
475            .map_err(|e| EngineError::Other(format!("ONNX model build error: {}", e)))?;
476
477        // Load tokenizer
478        tracing::info!("Loading tokenizer...");
479        let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
480            if tok_path.ends_with(".json") {
481                Tokenizer::from_hf_json(tok_path)?
482            } else {
483                let tok_gguf = GgufFile::open(tok_path)?;
484                Tokenizer::from_gguf(&tok_gguf)?
485            }
486        } else {
487            // Look for tokenizer.json in the same directory
488            let tokenizer_path = model_dir.join("tokenizer.json");
489            if tokenizer_path.exists() {
490                tracing::info!("Using tokenizer.json from: {}", tokenizer_path.display());
491                Tokenizer::from_hf_json(&tokenizer_path)?
492            } else {
493                return Err(EngineError::Other(format!(
494                    "No tokenizer found. ONNX models require a tokenizer.json file \
495                     in the same directory as the model, or specify --tokenizer <path>. \
496                     Looked for: {}",
497                    tokenizer_path.display()
498                )));
499            }
500        };
501        tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
502
503        // Select backend
504        let backend: Arc<dyn Backend> = if config.use_gpu {
505            Self::select_gpu_backend(&concrete_model)
506        } else {
507            Arc::new(crate::backend::cpu::CpuBackend::new())
508        };
509
510        let model: Box<dyn Model> = Box::new(concrete_model);
511
512        // Infer chat template from model type
513        let chat_template = ChatTemplate::detect_from_model_type(hf_config.model_type.as_deref());
514        tracing::info!("Chat template: {:?}", chat_template);
515
516        // For ONNX models, default to adding BOS
517        let add_bos = true;
518
519        let sampler_config = SamplerConfig {
520            temperature: config.temperature,
521            top_k: config.top_k,
522            top_p: config.top_p,
523            repeat_penalty: config.repeat_penalty,
524            seed: config.seed,
525            ..Default::default()
526        };
527
528        tracing::info!("Engine ready (ONNX)");
529
530        Ok(Self {
531            gguf: None,
532            model,
533            tokenizer,
534            config: model_config,
535            backend,
536            sampler_config,
537            chat_template,
538            add_bos,
539            engine_config: config,
540        })
541    }
542
543    /// Select the best GPU model + backend combination.
544    ///
545    /// Tries GPU-only inference first (all computation on GPU, ~386× fewer
546    /// host↔device transfers), then falls back to the per-op CudaBackend,
547    /// then to CPU.
548    #[allow(unused_variables)]
549    fn select_gpu_model(
550        model: crate::model::LlamaModel,
551        config: &ModelConfig,
552        engine_config: &EngineConfig,
553    ) -> (Arc<dyn Backend>, Box<dyn Model>) {
554        let gpu_seq_len = match engine_config.max_context_len {
555            Some(cap) if cap > 0 && cap < config.max_seq_len => {
556                tracing::info!(
557                    "Capping GPU context length from {} to {} (max_context_len)",
558                    config.max_seq_len,
559                    cap
560                );
561                cap
562            }
563            _ => config.max_seq_len,
564        };
565
566        // Priority: CUDA > Vulkan > Metal > DX12 > per-op fallback
567        // Each gpu_only engine consumes the model, so only one can be tried.
568
569        #[cfg(feature = "cuda")]
570        {
571            if cudarc::driver::CudaDevice::new(0).is_ok() {
572                let architecture = model.architecture();
573                match crate::backend::cuda::gpu_only::GpuOnlyInference::from_model(
574                    model,
575                    gpu_seq_len,
576                ) {
577                    Ok(gpu) => {
578                        tracing::info!(
579                            "Using full GPU inference (attention + DeltaNet + MoE all on CUDA)"
580                        );
581                        let wrapper = crate::backend::GpuModelWrapper::new(
582                            gpu,
583                            config.clone(),
584                            architecture,
585                        );
586                        return (
587                            Arc::new(crate::backend::cpu::CpuBackend::new()),
588                            Box::new(wrapper),
589                        );
590                    }
591                    Err(e) => {
592                        eprintln!("Error: CUDA GPU inference init failed: {}", e);
593                        eprintln!("The model was consumed during init. Please restart without --gpu.");
594                        std::process::exit(1);
595                    }
596                }
597            } else {
598                tracing::info!("No CUDA device available, trying other GPU backends...");
599            }
600        }
601
602        #[cfg(feature = "vulkan")]
603        {
604            if crate::backend::vulkan::VulkanBackend::new().is_ok() {
605                let architecture = model.architecture();
606                match crate::backend::vulkan::gpu_only::VulkanGpuInference::from_model(
607                    model,
608                    gpu_seq_len,
609                ) {
610                    Ok(gpu) => {
611                        tracing::info!("Using full GPU inference on Vulkan");
612                        let wrapper = crate::backend::GpuModelWrapper::new(
613                            gpu,
614                            config.clone(),
615                            architecture,
616                        );
617                        return (
618                            Arc::new(crate::backend::cpu::CpuBackend::new()),
619                            Box::new(wrapper),
620                        );
621                    }
622                    Err(e) => {
623                        eprintln!("Error: Vulkan GPU inference init failed: {}", e);
624                        eprintln!("The model was consumed during init. Please restart without --gpu.");
625                        std::process::exit(1);
626                    }
627                }
628            } else {
629                tracing::info!("No Vulkan device available, trying other GPU backends...");
630            }
631        }
632
633        #[cfg(all(feature = "metal", target_os = "macos"))]
634        {
635            if crate::backend::metal::MetalBackend::new().is_ok() {
636                let architecture = model.architecture();
637                match crate::backend::metal::gpu_only::MetalGpuInference::from_model(
638                    model,
639                    gpu_seq_len,
640                ) {
641                    Ok(gpu) => {
642                        tracing::info!("Using full GPU inference on Metal");
643                        let wrapper = crate::backend::GpuModelWrapper::new(
644                            gpu,
645                            config.clone(),
646                            architecture,
647                        );
648                        return (
649                            Arc::new(crate::backend::cpu::CpuBackend::new()),
650                            Box::new(wrapper),
651                        );
652                    }
653                    Err(e) => {
654                        eprintln!("Error: Metal GPU inference init failed: {}", e);
655                        eprintln!("The model was consumed during init. Please restart without --gpu.");
656                        std::process::exit(1);
657                    }
658                }
659            } else {
660                tracing::info!("No Metal device available, trying other GPU backends...");
661            }
662        }
663
664        #[cfg(all(feature = "dx12", target_os = "windows"))]
665        {
666            if crate::backend::dx12::Dx12Backend::new().is_ok() {
667                let architecture = model.architecture();
668                match crate::backend::dx12::gpu_only::Dx12GpuInference::from_model(
669                    model,
670                    gpu_seq_len,
671                ) {
672                    Ok(gpu) => {
673                        tracing::info!("Using full GPU inference on DX12");
674                        let wrapper = crate::backend::GpuModelWrapper::new(
675                            gpu,
676                            config.clone(),
677                            architecture,
678                        );
679                        return (
680                            Arc::new(crate::backend::cpu::CpuBackend::new()),
681                            Box::new(wrapper),
682                        );
683                    }
684                    Err(e) => {
685                        eprintln!("Error: DX12 GPU inference init failed: {}", e);
686                        eprintln!("The model was consumed during init. Please restart without --gpu.");
687                        std::process::exit(1);
688                    }
689                }
690            } else {
691                tracing::info!("No DX12 device available");
692            }
693        }
694
695        #[cfg(feature = "hailo")]
696        {
697            if let Some(ref hailo_config) = engine_config.hailo_config {
698                if crate::backend::hailo::context::check_device_available().is_ok() {
699                    let architecture = model.architecture();
700                    match crate::backend::hailo::gpu_only::HailoGpuInference::from_model(
701                        model,
702                        gpu_seq_len,
703                        hailo_config.clone(),
704                    ) {
705                        Ok(gpu) => {
706                            tracing::info!("Using hybrid CPU+Hailo inference");
707                            let wrapper = crate::backend::GpuModelWrapper::new(
708                                gpu,
709                                config.clone(),
710                                architecture,
711                            );
712                            return (
713                                Arc::new(crate::backend::cpu::CpuBackend::new()),
714                                Box::new(wrapper),
715                            );
716                        }
717                        Err(e) => {
718                            eprintln!("Error: Hailo inference init failed: {}", e);
719                            eprintln!("The model was consumed during init. Please restart without --hailo.");
720                            std::process::exit(1);
721                        }
722                    }
723                } else {
724                    tracing::info!("No Hailo device available, falling back to CPU...");
725                }
726            }
727        }
728
729        // No GPU-only engine available: fall back to per-op backend
730        let backend = Self::select_gpu_backend(&model);
731        (backend, Box::new(model))
732    }
733
734    /// Select the best available GPU backend.
735    ///
736    /// Priority: CUDA > Metal > DX12 > Vulkan > CPU fallback.
737    #[allow(unused_variables)]
738    pub fn select_gpu_backend(model: &crate::model::LlamaModel) -> Arc<dyn Backend> {
739        // Try CUDA first (NVIDIA GPUs)
740        #[cfg(feature = "cuda")]
741        {
742            match crate::backend::cuda::CudaBackend::new() {
743                Ok(mut cuda) => {
744                    tracing::info!("Using CUDA backend: {}", cuda.device_name());
745                    if let Err(e) = cuda.load_model_weights(model) {
746                        tracing::warn!("Failed to load GPU weights ({}), using quantized ops", e);
747                    }
748                    return Arc::new(cuda);
749                }
750                Err(e) => {
751                    tracing::info!("CUDA not available ({}), trying Metal...", e);
752                }
753            }
754        }
755
756        // Try Metal (native macOS / Apple Silicon)
757        #[cfg(all(feature = "metal", target_os = "macos"))]
758        {
759            match crate::backend::metal::MetalBackend::new() {
760                Ok(metal) => {
761                    tracing::info!("Using Metal backend: {}", metal.device_name());
762                    return Arc::new(metal);
763                }
764                Err(e) => {
765                    tracing::info!("Metal not available ({}), trying DX12...", e);
766                }
767            }
768        }
769
770        // Try DX12 (native Windows GPU compute)
771        #[cfg(all(feature = "dx12", target_os = "windows"))]
772        {
773            match crate::backend::dx12::Dx12Backend::new() {
774                Ok(dx12) => {
775                    tracing::info!("Using DX12 backend: {}", dx12.device_name());
776                    return Arc::new(dx12);
777                }
778                Err(e) => {
779                    tracing::info!("DX12 not available ({}), trying Vulkan...", e);
780                }
781            }
782        }
783
784        // Try Vulkan (cross-platform: AMD, Intel, NVIDIA, etc.)
785        #[cfg(feature = "vulkan")]
786        {
787            match crate::backend::vulkan::VulkanBackend::new() {
788                Ok(vk) => {
789                    tracing::info!("Using Vulkan backend: {}", vk.device_name());
790                    return Arc::new(vk);
791                }
792                Err(e) => {
793                    tracing::warn!("Vulkan not available ({}), falling back to CPU", e);
794                }
795            }
796        }
797
798        // Fallback message when no GPU backend is compiled
799        #[cfg(not(any(
800            feature = "cuda",
801            feature = "vulkan",
802            all(feature = "metal", target_os = "macos"),
803            all(feature = "dx12", target_os = "windows")
804        )))]
805        {
806            tracing::warn!(
807                "No GPU backend compiled. Build with --features cuda, --features metal, --features dx12, or --features vulkan"
808            );
809        }
810
811        Arc::new(crate::backend::cpu::CpuBackend::new())
812    }
813
814    /// Get the model configuration.
815    pub fn model_config(&self) -> &ModelConfig {
816        &self.config
817    }
818
819    /// Get the detected chat template.
820    pub fn chat_template(&self) -> &ChatTemplate {
821        &self.chat_template
822    }
823
824    /// Get the GGUF file metadata (None for ONNX-loaded models).
825    pub fn gguf(&self) -> Option<&GgufFile> {
826        self.gguf.as_ref()
827    }
828
829    /// Get the tokenizer.
830    pub fn tokenizer(&self) -> &Tokenizer {
831        &self.tokenizer
832    }
833
834    /// Get the engine configuration.
835    pub fn engine_config(&self) -> &EngineConfig {
836        &self.engine_config
837    }
838
839    /// Get the underlying model (for advanced usage like perplexity computation).
840    pub fn model(&self) -> &dyn Model {
841        &*self.model
842    }
843
844    /// Get the backend.
845    pub fn backend(&self) -> &Arc<dyn Backend> {
846        &self.backend
847    }
848
849    /// Whether to add a BOS token when encoding prompts.
850    pub fn add_bos(&self) -> bool {
851        self.add_bos
852    }
853
854    /// Generate text from a prompt.
855    ///
856    /// The prompt is automatically wrapped with the detected chat template
857    /// unless it already contains chat formatting tokens.
858    ///
859    /// Create an InferenceContext respecting the configured KV cache type.
860    pub fn create_inference_context(&self) -> InferenceContext {
861        if self.engine_config.kv_cache_type.is_turboquant() {
862            InferenceContext::new_with_cache_type(
863                &self.config,
864                self.backend.clone(),
865                self.engine_config.kv_cache_type,
866            )
867        } else {
868            self.model.create_context(self.backend.clone())
869        }
870    }
871
872    /// Returns the generated text (not including the prompt).
873    pub fn generate(&self, prompt: &str, max_tokens: usize) -> Result<String, EngineError> {
874        let mut ctx = self.create_inference_context();
875        let mut sampler = Sampler::new(self.sampler_config.clone(), self.config.vocab_size);
876
877        // Wrap prompt with chat template
878        let formatted = self.chat_template.wrap_prompt(prompt);
879        let mut tokens = self.tokenizer.encode(&formatted, self.add_bos)?;
880
881        let mut output = String::new();
882
883        for _ in 0..max_tokens {
884            // Check if we hit EOS
885            if let Some(&last) = tokens.last()
886                && last == self.tokenizer.special_tokens.eos_token_id
887            {
888                break;
889            }
890
891            // Run forward pass
892            let input_tokens = if ctx.position == 0 {
893                &tokens[..]
894            } else {
895                &tokens[tokens.len() - 1..]
896            };
897
898            let logits = self.model.forward(input_tokens, &mut ctx)?;
899            let next_token = sampler.sample(&logits, &tokens);
900
901            // Check for EOS
902            if next_token == self.tokenizer.special_tokens.eos_token_id {
903                break;
904            }
905
906            // Decode token
907            if let Ok(text) = self.tokenizer.decode(&[next_token]) {
908                // Check for stop patterns
909                let combined = format!("{}{}", output, text);
910                let stop = self
911                    .chat_template
912                    .stop_patterns()
913                    .iter()
914                    .any(|p| combined.contains(p));
915
916                if stop {
917                    // Add only the text before the stop pattern
918                    for pattern in self.chat_template.stop_patterns() {
919                        if let Some(idx) = combined.find(pattern) {
920                            output = combined[..idx].to_string();
921                            return Ok(output.trim().to_string());
922                        }
923                    }
924                    break;
925                }
926
927                output.push_str(&text);
928            }
929
930            tokens.push(next_token);
931        }
932
933        Ok(output.trim().to_string())
934    }
935
936    /// Generate text from a prompt, yielding tokens as they are produced.
937    ///
938    /// Each item in the returned iterator is a `Result<String, EngineError>` containing
939    /// the decoded text of one or more tokens.
940    pub fn generate_streaming(&self, prompt: &str, max_tokens: usize) -> GenerationStream<'_> {
941        GenerationStream::new(self, prompt, max_tokens)
942    }
943
944    /// Extract embeddings from text using the model.
945    pub fn embed(&self, text: &str) -> Result<Vec<f32>, EngineError> {
946        let mut ctx = self.create_inference_context();
947        let embed_config = EmbeddingConfig::default();
948        let extractor = EmbeddingExtractor::new(embed_config, &self.config);
949        let embedding =
950            extractor.embed_text(self.model.as_ref(), &self.tokenizer, &mut ctx, text)?;
951        Ok(embedding)
952    }
953}
954
955// ============================================================================
956// Streaming generation
957// ============================================================================
958
959/// Iterator that yields generated tokens as strings.
960///
961/// Created by [`Engine::generate_streaming`].
962pub struct GenerationStream<'a> {
963    engine: &'a Engine,
964    ctx: InferenceContext,
965    sampler: Sampler,
966    tokens: Vec<u32>,
967    remaining: usize,
968    done: bool,
969    accumulated: String,
970    /// Pending bytes for incomplete UTF-8 sequences across token boundaries
971    pending_bytes: Vec<u8>,
972}
973
974impl<'a> GenerationStream<'a> {
975    fn new(engine: &'a Engine, prompt: &str, max_tokens: usize) -> Self {
976        let ctx = engine.create_inference_context();
977        let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
978
979        let formatted = engine.chat_template.wrap_prompt(prompt);
980        if std::env::var("LLAMA_DEBUG").is_ok() {
981            eprintln!("[DEBUG] formatted prompt: {:?}", formatted);
982            eprintln!("[DEBUG] add_bos: {}", engine.add_bos);
983        }
984        let tokens = engine
985            .tokenizer
986            .encode(&formatted, engine.add_bos)
987            .unwrap_or_default();
988        if std::env::var("LLAMA_DEBUG").is_ok() {
989            eprintln!("[DEBUG] encoded {} tokens: {:?}", tokens.len(), &tokens[..tokens.len().min(50)]);
990            for (i, &tid) in tokens.iter().enumerate() {
991                if let Some(s) = engine.tokenizer.get_token(tid) {
992                    eprintln!("[DEBUG]   token[{}] = {} -> {:?}", i, tid, s);
993                }
994            }
995        }
996
997        Self {
998            engine,
999            ctx,
1000            sampler,
1001            tokens,
1002            remaining: max_tokens,
1003            done: false,
1004            accumulated: String::new(),
1005            pending_bytes: Vec::new(),
1006        }
1007    }
1008}
1009
1010impl<'a> Iterator for GenerationStream<'a> {
1011    type Item = Result<String, EngineError>;
1012
1013    fn next(&mut self) -> Option<Self::Item> {
1014        if self.done || self.remaining == 0 {
1015            return None;
1016        }
1017
1018        // Check EOS from last token
1019        if let Some(&last) = self.tokens.last()
1020            && last == self.engine.tokenizer.special_tokens.eos_token_id
1021        {
1022            self.done = true;
1023            return None;
1024        }
1025
1026        // Forward pass
1027        let input_tokens = if self.ctx.position == 0 {
1028            &self.tokens[..]
1029        } else {
1030            &self.tokens[self.tokens.len() - 1..]
1031        };
1032
1033        let logits = match self.engine.model.forward(input_tokens, &mut self.ctx) {
1034            Ok(l) => l,
1035            Err(e) => {
1036                self.done = true;
1037                return Some(Err(EngineError::Model(e)));
1038            }
1039        };
1040
1041        let next_token = self.sampler.sample(&logits, &self.tokens);
1042
1043        if std::env::var("LLAMA_DEBUG_LOGITS").is_ok() {
1044            let logit_data = logits.as_f32().unwrap();
1045            let mut indexed: Vec<(usize, f32)> = logit_data.iter().copied().enumerate().collect();
1046            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1047            let step = self.tokens.len();
1048            eprint!("[LOGIT] step={} top5:", step);
1049            for (id, score) in indexed.iter().take(5) {
1050                let tok_str = self.engine.tokenizer.get_token(*id as u32).unwrap_or_default();
1051                eprint!(" {}({:.2})={:?}", id, score, tok_str);
1052            }
1053            let chosen_str = self.engine.tokenizer.get_token(next_token).unwrap_or_default();
1054            eprintln!(" → chosen={}({:?})", next_token, chosen_str);
1055        }
1056
1057        // Check EOS
1058        if next_token == self.engine.tokenizer.special_tokens.eos_token_id {
1059            self.done = true;
1060            return None;
1061        }
1062
1063        // Decode (streaming-aware: accumulates incomplete UTF-8 across tokens)
1064        match self
1065            .engine
1066            .tokenizer
1067            .decode_token_streaming(next_token, &mut self.pending_bytes)
1068        {
1069            Ok(text) => {
1070                self.tokens.push(next_token);
1071                self.remaining -= 1;
1072
1073                if text.is_empty() {
1074                    // Pending bytes not yet forming valid UTF-8; recurse to next token
1075                    return self.next();
1076                }
1077
1078                // Check stop patterns
1079                let combined = format!("{}{}", self.accumulated, text);
1080                for pattern in self.engine.chat_template.stop_patterns() {
1081                    if combined.contains(pattern) {
1082                        self.done = true;
1083                        if let Some(idx) = combined.find(pattern) {
1084                            if idx > self.accumulated.len() {
1085                                let before = &combined[self.accumulated.len()..idx];
1086                                return Some(Ok(before.to_string()));
1087                            }
1088                        }
1089                        return None;
1090                    }
1091                }
1092
1093                self.accumulated.push_str(&text);
1094                Some(Ok(text))
1095            }
1096            Err(e) => {
1097                self.tokens.push(next_token);
1098                self.remaining -= 1;
1099                Some(Err(EngineError::Tokenizer(e)))
1100            }
1101        }
1102    }
1103}
1104
1105// ============================================================================
1106// ChatEngine
1107// ============================================================================
1108
1109/// High-level chat engine that maintains conversation state.
1110///
1111/// Wraps an [`Engine`] with conversation history, context management,
1112/// and automatic chat template formatting.
1113pub struct ChatEngine {
1114    engine: Engine,
1115    system_prompt: String,
1116    conversation_tokens: Vec<u32>,
1117    ctx: InferenceContext,
1118    sampler: Sampler,
1119    is_first_turn: bool,
1120}
1121
1122impl ChatEngine {
1123    /// Create a new chat engine from a loaded [`Engine`].
1124    pub fn new(engine: Engine, system_prompt: Option<String>) -> Self {
1125        let ctx = engine.create_inference_context();
1126        let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
1127
1128        Self {
1129            system_prompt: system_prompt
1130                .unwrap_or_else(|| "You are a helpful AI assistant.".to_string()),
1131            conversation_tokens: Vec::new(),
1132            ctx,
1133            sampler,
1134            is_first_turn: true,
1135            engine,
1136        }
1137    }
1138
1139    /// Get a reference to the underlying engine.
1140    pub fn engine(&self) -> &Engine {
1141        &self.engine
1142    }
1143
1144    /// Get the current system prompt.
1145    pub fn system_prompt(&self) -> &str {
1146        &self.system_prompt
1147    }
1148
1149    /// Get the number of tokens in the current conversation context.
1150    pub fn context_len(&self) -> usize {
1151        self.conversation_tokens.len()
1152    }
1153
1154    /// Send a message and get the full response.
1155    pub fn chat(&mut self, message: &str) -> Result<String, EngineError> {
1156        let max_tokens = self.engine.engine_config.max_tokens;
1157
1158        // Format the message using the chat template
1159        let formatted = if self.is_first_turn {
1160            self.engine
1161                .chat_template
1162                .format_first_turn(&self.system_prompt, message)
1163        } else {
1164            self.engine.chat_template.format_continuation(message)
1165        };
1166
1167        // Encode new tokens
1168        let new_tokens = self
1169            .engine
1170            .tokenizer
1171            .encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
1172
1173        // Check context length and trim if needed
1174        self.ensure_context_space(new_tokens.len(), max_tokens);
1175
1176        // Add new tokens to conversation
1177        self.conversation_tokens.extend(&new_tokens);
1178
1179        // Batch-prefill: process ALL prompt tokens in a single forward pass.
1180        // This is dramatically faster than one-at-a-time, especially on GPU
1181        // backends where kernel launch overhead dominates for tiny batches.
1182        let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
1183        let mut response_text = String::new();
1184
1185        if new_tokens.is_empty() {
1186            self.is_first_turn = false;
1187            return Ok(response_text);
1188        }
1189
1190        let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
1191        let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
1192
1193        if first_token == eos_id {
1194            self.is_first_turn = false;
1195            return Ok(response_text);
1196        }
1197
1198        if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
1199            response_text.push_str(&text);
1200        }
1201        self.conversation_tokens.push(first_token);
1202
1203        // Autoregressive decode for the remaining tokens
1204        for _ in 1..max_tokens {
1205            // Check stop patterns
1206            let should_stop = self
1207                .engine
1208                .chat_template
1209                .stop_patterns()
1210                .iter()
1211                .any(|p| response_text.contains(p));
1212            if should_stop {
1213                for pattern in self.engine.chat_template.stop_patterns() {
1214                    if let Some(idx) = response_text.find(pattern) {
1215                        response_text.truncate(idx);
1216                        break;
1217                    }
1218                }
1219                break;
1220            }
1221
1222            let last_token = *self
1223                .conversation_tokens
1224                .last()
1225                .unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
1226
1227            let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
1228            let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
1229
1230            if next_token == eos_id {
1231                break;
1232            }
1233
1234            if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
1235                response_text.push_str(&text);
1236            }
1237
1238            self.conversation_tokens.push(next_token);
1239        }
1240
1241        self.is_first_turn = false;
1242        Ok(response_text.trim().to_string())
1243    }
1244
1245    /// Send a message and get the full response, with a prefix injected as the
1246    /// start of the assistant's reply. The prefix tokens are prefilled alongside
1247    /// the prompt tokens so the model continues from the prefix text. The prefix
1248    /// is prepended to the returned string.
1249    ///
1250    /// This is useful for forcing the model to start with a particular token
1251    /// sequence (e.g. `{` for JSON output).
1252    pub fn chat_with_prefix(
1253        &mut self,
1254        message: &str,
1255        prefix: &str,
1256    ) -> Result<String, EngineError> {
1257        let max_tokens = self.engine.engine_config.max_tokens;
1258
1259        let formatted = if self.is_first_turn {
1260            self.engine
1261                .chat_template
1262                .format_first_turn(&self.system_prompt, message)
1263        } else {
1264            self.engine.chat_template.format_continuation(message)
1265        };
1266
1267        // Append prefix to the formatted prompt so it becomes part of the prefill
1268        let formatted_with_prefix = format!("{}{}", formatted, prefix);
1269
1270        let new_tokens = self
1271            .engine
1272            .tokenizer
1273            .encode(&formatted_with_prefix, self.is_first_turn && self.engine.add_bos)?;
1274
1275        self.ensure_context_space(new_tokens.len(), max_tokens);
1276        self.conversation_tokens.extend(&new_tokens);
1277
1278        let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
1279        let mut response_text = prefix.to_string();
1280
1281        if new_tokens.is_empty() {
1282            self.is_first_turn = false;
1283            return Ok(response_text);
1284        }
1285
1286        let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
1287        let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
1288
1289        if first_token == eos_id {
1290            self.is_first_turn = false;
1291            return Ok(response_text);
1292        }
1293
1294        if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
1295            response_text.push_str(&text);
1296        }
1297        self.conversation_tokens.push(first_token);
1298
1299        for _ in 1..max_tokens {
1300            let should_stop = self
1301                .engine
1302                .chat_template
1303                .stop_patterns()
1304                .iter()
1305                .any(|p| response_text.contains(p));
1306            if should_stop {
1307                for pattern in self.engine.chat_template.stop_patterns() {
1308                    if let Some(idx) = response_text.find(pattern) {
1309                        response_text.truncate(idx);
1310                        break;
1311                    }
1312                }
1313                break;
1314            }
1315
1316            let last_token = *self
1317                .conversation_tokens
1318                .last()
1319                .unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
1320
1321            let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
1322            let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
1323
1324            if next_token == eos_id {
1325                break;
1326            }
1327
1328            if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
1329                response_text.push_str(&text);
1330            }
1331
1332            self.conversation_tokens.push(next_token);
1333        }
1334
1335        self.is_first_turn = false;
1336        Ok(response_text.trim().to_string())
1337    }
1338
1339    /// Send a message and stream the response token by token.
1340    ///
1341    /// Returns an iterator of `Result<String, EngineError>` where each item is
1342    /// a decoded token chunk.
1343    pub fn chat_streaming(&mut self, message: &str) -> Result<ChatStream<'_>, EngineError> {
1344        let max_tokens = self.engine.engine_config.max_tokens;
1345
1346        // Format the message
1347        let formatted = if self.is_first_turn {
1348            self.engine
1349                .chat_template
1350                .format_first_turn(&self.system_prompt, message)
1351        } else {
1352            self.engine.chat_template.format_continuation(message)
1353        };
1354
1355        // Encode new tokens
1356        let new_tokens = self
1357            .engine
1358            .tokenizer
1359            .encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
1360
1361        // Ensure context space
1362        self.ensure_context_space(new_tokens.len(), max_tokens);
1363
1364        // Add new tokens to conversation
1365        self.conversation_tokens.extend(&new_tokens);
1366
1367        // Batch-prefill all prompt tokens in a single forward pass.
1368        let prefill_logits = if !new_tokens.is_empty() {
1369            Some(self.engine.model.forward(&new_tokens, &mut self.ctx)?)
1370        } else {
1371            None
1372        };
1373
1374        self.is_first_turn = false;
1375
1376        Ok(ChatStream {
1377            chat_engine: self,
1378            remaining: max_tokens,
1379            done: false,
1380            accumulated: String::new(),
1381            prefill_logits,
1382        })
1383    }
1384
1385    /// Clear conversation history and reset context.
1386    pub fn clear_history(&mut self) {
1387        self.conversation_tokens.clear();
1388        self.ctx.reset();
1389        self.sampler.reset();
1390        self.is_first_turn = true;
1391    }
1392
1393    /// Ensure there's enough space in the context for new tokens + generation.
1394    fn ensure_context_space(&mut self, new_token_count: usize, max_gen_tokens: usize) {
1395        let total_len = self.conversation_tokens.len() + new_token_count + max_gen_tokens;
1396
1397        if total_len > self.engine.config.max_seq_len {
1398            let excess = total_len - self.engine.config.max_seq_len + 100;
1399
1400            if excess >= self.conversation_tokens.len() {
1401                tracing::warn!("Context full, resetting conversation");
1402                self.conversation_tokens.clear();
1403                self.ctx.reset();
1404            } else {
1405                tracing::info!("Trimming {} tokens from context", excess);
1406                self.conversation_tokens = self.conversation_tokens[excess..].to_vec();
1407                self.ctx.kv_cache.shift_left(excess);
1408                self.ctx.position = self.ctx.position.saturating_sub(excess);
1409            }
1410        }
1411    }
1412}
1413
1414// ============================================================================
1415// Chat streaming
1416// ============================================================================
1417
1418/// Iterator that yields chat response tokens as strings.
1419///
1420/// Created by [`ChatEngine::chat_streaming`].
1421pub struct ChatStream<'a> {
1422    chat_engine: &'a mut ChatEngine,
1423    remaining: usize,
1424    done: bool,
1425    accumulated: String,
1426    /// Logits from the batched prefill pass; consumed on the first `next()` call.
1427    prefill_logits: Option<crate::tensor::Tensor>,
1428}
1429
1430impl<'a> Iterator for ChatStream<'a> {
1431    type Item = Result<String, EngineError>;
1432
1433    fn next(&mut self) -> Option<Self::Item> {
1434        if self.done || self.remaining == 0 {
1435            return None;
1436        }
1437
1438        // Check stop patterns on accumulated text
1439        for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
1440            if self.accumulated.contains(pattern) {
1441                self.done = true;
1442                return None;
1443            }
1444        }
1445
1446        // On the first call, use the prefill logits (no extra forward pass).
1447        // On subsequent calls, run the standard single-token decode.
1448        let logits = if let Some(prefill) = self.prefill_logits.take() {
1449            prefill
1450        } else {
1451            let last_token = *self.chat_engine.conversation_tokens.last().unwrap_or(
1452                &self
1453                    .chat_engine
1454                    .engine
1455                    .tokenizer
1456                    .special_tokens
1457                    .bos_token_id,
1458            );
1459
1460            match self
1461                .chat_engine
1462                .engine
1463                .model
1464                .forward(&[last_token], &mut self.chat_engine.ctx)
1465            {
1466                Ok(l) => l,
1467                Err(e) => {
1468                    self.done = true;
1469                    return Some(Err(EngineError::Model(e)));
1470                }
1471            }
1472        };
1473
1474        let next_token = self
1475            .chat_engine
1476            .sampler
1477            .sample(&logits, &self.chat_engine.conversation_tokens);
1478
1479        // Check for EOS
1480        if next_token
1481            == self
1482                .chat_engine
1483                .engine
1484                .tokenizer
1485                .special_tokens
1486                .eos_token_id
1487        {
1488            self.done = true;
1489            return None;
1490        }
1491
1492        match self.chat_engine.engine.tokenizer.decode(&[next_token]) {
1493            Ok(text) => {
1494                // Check stop patterns in accumulated + new text
1495                let combined = format!("{}{}", self.accumulated, text);
1496                for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
1497                    if combined.contains(pattern) {
1498                        self.done = true;
1499                        if let Some(idx) = combined.find(pattern) {
1500                            let before = &combined[self.accumulated.len()..idx];
1501                            self.chat_engine.conversation_tokens.push(next_token);
1502                            if !before.is_empty() {
1503                                return Some(Ok(before.to_string()));
1504                            }
1505                        }
1506                        return None;
1507                    }
1508                }
1509
1510                self.accumulated.push_str(&text);
1511                self.chat_engine.conversation_tokens.push(next_token);
1512                self.remaining -= 1;
1513                Some(Ok(text))
1514            }
1515            Err(e) => {
1516                self.chat_engine.conversation_tokens.push(next_token);
1517                self.remaining -= 1;
1518                Some(Err(EngineError::Tokenizer(e)))
1519            }
1520        }
1521    }
1522}