Skip to main content

llama_rs/
engine.rs

1//! High-level inference engine for llama-rs.
2//!
3//! Provides [`Engine`] and [`ChatEngine`] for easy model loading and text generation
4//! without needing to manually wire together GGUF files, tokenizers, backends, and samplers.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use llama_rs::engine::{Engine, EngineConfig};
10//!
11//! let engine = Engine::load(EngineConfig {
12//!     model_path: "model.gguf".into(),
13//!     ..Default::default()
14//! }).unwrap();
15//!
16//! let response = engine.generate("What is a tort?", 256).unwrap();
17//! println!("{}", response);
18//! ```
19
20use std::sync::Arc;
21
22use crate::backend::Backend;
23use crate::gguf::GgufFile;
24use crate::model::{
25    EmbeddingConfig, EmbeddingExtractor, InferenceContext, Model, ModelConfig, ModelLoader,
26    ModelSource, build_llama_model,
27};
28use crate::safetensors::SafeTensorsLoader;
29use crate::sampling::{Sampler, SamplerConfig};
30use crate::tokenizer::Tokenizer;
31
32// ============================================================================
33// Error type
34// ============================================================================
35
36/// Errors that can occur during engine operations.
37#[derive(thiserror::Error, Debug)]
38pub enum EngineError {
39    #[error("IO error: {0}")]
40    Io(#[from] std::io::Error),
41
42    #[error("GGUF error: {0}")]
43    Gguf(#[from] crate::gguf::GgufError),
44
45    #[error("Model error: {0}")]
46    Model(#[from] crate::model::ModelError),
47
48    #[error("Tokenizer error: {0}")]
49    Tokenizer(#[from] crate::tokenizer::TokenizerError),
50
51    #[error("Embedding error: {0}")]
52    Embedding(#[from] crate::model::EmbeddingError),
53
54    #[error("Engine error: {0}")]
55    Other(String),
56}
57
58// ============================================================================
59// Configuration
60// ============================================================================
61
62/// Configuration for creating an [`Engine`].
63///
64/// Can be constructed manually, from a [`Config`](crate::config::Config) TOML file,
65/// or with [`Default::default()`].
66#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
67#[serde(default)]
68pub struct EngineConfig {
69    /// Path to the model file (GGUF or ONNX).
70    pub model_path: String,
71
72    /// Optional path to a tokenizer file.
73    ///
74    /// For ONNX models, this defaults to `tokenizer.json` in the same directory.
75    /// Can also point to a GGUF file to extract just the tokenizer.
76    pub tokenizer_path: Option<String>,
77
78    /// Temperature for sampling (0.0 = greedy, higher = more random).
79    pub temperature: f32,
80
81    /// Top-K sampling: only consider the K most likely tokens (0 = disabled).
82    pub top_k: usize,
83
84    /// Top-P (nucleus) sampling: only consider tokens with cumulative probability <= p.
85    pub top_p: f32,
86
87    /// Repetition penalty (1.0 = no penalty).
88    pub repeat_penalty: f32,
89
90    /// Default maximum tokens to generate.
91    pub max_tokens: usize,
92
93    /// Random seed for reproducible generation (None = random).
94    pub seed: Option<u64>,
95
96    /// Use GPU acceleration (requires `cuda` feature).
97    pub use_gpu: bool,
98
99    /// Maximum context length override.
100    ///
101    /// If set, caps the model's context length (and thus KV cache size).
102    /// The model's native `max_seq_len` from GGUF metadata can be very large
103    /// (e.g. 32768) which may exhaust GPU memory for the KV cache alone.
104    /// Set this to a smaller value (e.g. 2048) to reduce VRAM usage.
105    /// If `None` or `0`, uses the model's native `max_seq_len`.
106    pub max_context_len: Option<usize>,
107
108    /// Optional Hailo accelerator configuration.
109    /// When set, enables Hailo NPU inference in the GPU selection chain.
110    #[cfg(feature = "hailo")]
111    pub hailo_config: Option<crate::backend::hailo::HailoConfig>,
112
113    /// KV cache type for memory-efficient inference.
114    pub kv_cache_type: crate::model::KVCacheType,
115}
116
117impl Default for EngineConfig {
118    fn default() -> Self {
119        Self {
120            model_path: String::new(),
121            tokenizer_path: None,
122            temperature: 0.7,
123            top_k: 40,
124            top_p: 0.95,
125            repeat_penalty: 1.1,
126            max_tokens: 512,
127            seed: None,
128            use_gpu: false,
129            max_context_len: None,
130            #[cfg(feature = "hailo")]
131            hailo_config: None,
132            kv_cache_type: crate::model::KVCacheType::F32,
133        }
134    }
135}
136
137impl EngineConfig {
138    /// Load an `EngineConfig` from a [`Config`](crate::config::Config) TOML file.
139    ///
140    /// This is a convenience method that loads the full config and extracts
141    /// the engine-relevant sections.
142    pub fn from_config_file(
143        path: impl AsRef<std::path::Path>,
144    ) -> Result<Self, crate::config::ConfigError> {
145        let config = crate::config::Config::from_file(path)?;
146        Ok(config.to_engine_config(None))
147    }
148
149    /// Load an `EngineConfig` using the full config precedence chain.
150    ///
151    /// Searches default config file locations, then applies environment
152    /// variable overrides.
153    pub fn from_config(
154        config_path: Option<impl AsRef<std::path::Path>>,
155    ) -> Result<Self, crate::config::ConfigError> {
156        let config = crate::config::Config::load(config_path)?;
157        Ok(config.to_engine_config(None))
158    }
159}
160
161// ============================================================================
162// Chat template detection
163// ============================================================================
164
165/// Detected chat template format from the GGUF metadata.
166#[derive(Debug, Clone, PartialEq)]
167pub enum ChatTemplate {
168    /// `<|user|>\n...<|assistant|>\n` format
169    UserAssistant,
170    /// ChatML: `<|im_start|>user\n...<|im_end|>\n<|im_start|>assistant\n`
171    ChatML,
172    /// Llama-2: `[INST] ... [/INST]`
173    Llama2,
174    /// No template detected, use raw text.
175    None,
176}
177
178impl ChatTemplate {
179    /// Detect chat template from a HuggingFace `tokenizer_config.json` file.
180    ///
181    /// Reads the `chat_template` field (a Jinja2 string) and pattern-matches
182    /// against known token markers.  Returns `None` on any IO or parse error.
183    pub fn from_tokenizer_config(path: &std::path::Path) -> Option<Self> {
184        let data = std::fs::read_to_string(path).ok()?;
185        let json: serde_json::Value = serde_json::from_str(&data).ok()?;
186        let template = json.get("chat_template")?.as_str()?;
187
188        if template.contains("<|user|>") {
189            Some(ChatTemplate::UserAssistant)
190        } else if template.contains("<|im_start|>") {
191            Some(ChatTemplate::ChatML)
192        } else if template.contains("[INST]") {
193            Some(ChatTemplate::Llama2)
194        } else {
195            Some(ChatTemplate::None)
196        }
197    }
198
199    /// Detect chat template from model type string (for ONNX models without GGUF metadata).
200    pub fn detect_from_model_type(model_type: Option<&str>) -> Self {
201        match model_type {
202            Some("qwen2" | "qwen") => ChatTemplate::ChatML,
203            Some("llama" | "codellama") => ChatTemplate::Llama2,
204            Some("mistral" | "mixtral") => ChatTemplate::Llama2,
205            _ => ChatTemplate::None,
206        }
207    }
208
209    /// Detect chat template from GGUF metadata.
210    pub fn detect(gguf: &GgufFile) -> Self {
211        if let Some(template) = gguf.data.get_string("tokenizer.chat_template") {
212            if template.contains("<|user|>") {
213                ChatTemplate::UserAssistant
214            } else if template.contains("<|im_start|>") {
215                ChatTemplate::ChatML
216            } else if template.contains("[INST]") {
217                ChatTemplate::Llama2
218            } else {
219                ChatTemplate::None
220            }
221        } else if let Some(arch) = gguf.data.get_string("general.architecture") {
222            match arch.to_lowercase().as_str() {
223                "qwen2" | "qwen" | "qwen3" | "qwen35" | "qwen3moe" | "qwen3next" => {
224                    ChatTemplate::ChatML
225                }
226                _ => ChatTemplate::None,
227            }
228        } else {
229            ChatTemplate::None
230        }
231    }
232
233    /// Wrap a raw prompt in the appropriate chat format.
234    pub fn wrap_prompt(&self, prompt: &str) -> String {
235        // If prompt already contains chat tokens, return as-is
236        if prompt.contains("<|user|>")
237            || prompt.contains("<|im_start|>")
238            || prompt.contains("[INST]")
239        {
240            return prompt.to_string();
241        }
242
243        match self {
244            ChatTemplate::UserAssistant => {
245                format!("<|user|>\n{}<|assistant|>\n", prompt)
246            }
247            ChatTemplate::ChatML => {
248                format!(
249                    "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
250                    prompt
251                )
252            }
253            ChatTemplate::Llama2 => {
254                format!("[INST] {} [/INST]", prompt)
255            }
256            ChatTemplate::None => prompt.to_string(),
257        }
258    }
259
260    /// Format a chat message with system prompt for the first turn.
261    pub fn format_first_turn(&self, system_prompt: &str, user_message: &str) -> String {
262        match self {
263            ChatTemplate::UserAssistant => {
264                format!(
265                    "<|system|>\n{}<|user|>\n{}<|assistant|>\n",
266                    system_prompt, user_message
267                )
268            }
269            ChatTemplate::ChatML => {
270                format!(
271                    "<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
272                    system_prompt, user_message
273                )
274            }
275            ChatTemplate::Llama2 => {
276                format!(
277                    "[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]",
278                    system_prompt, user_message
279                )
280            }
281            ChatTemplate::None => {
282                format!(
283                    "System: {}\n\nUser: {}\n\nAssistant:",
284                    system_prompt, user_message
285                )
286            }
287        }
288    }
289
290    /// Format a continuation turn (not the first message).
291    pub fn format_continuation(&self, user_message: &str) -> String {
292        match self {
293            ChatTemplate::UserAssistant => {
294                format!("<|user|>\n{}<|assistant|>\n", user_message)
295            }
296            ChatTemplate::ChatML => {
297                format!(
298                    "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
299                    user_message
300                )
301            }
302            ChatTemplate::Llama2 => {
303                format!(" [INST] {} [/INST]", user_message)
304            }
305            ChatTemplate::None => {
306                format!("\n\nUser: {}\n\nAssistant:", user_message)
307            }
308        }
309    }
310
311    /// Patterns that indicate the model is trying to generate a new user turn
312    /// (i.e., the response is complete).
313    pub fn stop_patterns(&self) -> &[&str] {
314        match self {
315            ChatTemplate::UserAssistant => &["<|user|>", "<|end|>"],
316            ChatTemplate::ChatML => &["<|im_end|>", "<|im_start|>"],
317            ChatTemplate::Llama2 => &["[INST]", "</s>"],
318            ChatTemplate::None => &["User:", "\nUser:"],
319        }
320    }
321}
322
323// ============================================================================
324// Engine
325// ============================================================================
326
327/// High-level inference engine that wraps model loading, tokenization, and generation.
328///
329/// `Engine` is `Send + Sync` safe for the immutable model and config, but the
330/// mutable inference context and sampler are created per-call via internal state.
331pub struct Engine {
332    gguf: Option<GgufFile>,
333    model: Box<dyn Model>,
334    tokenizer: Tokenizer,
335    config: ModelConfig,
336    backend: Arc<dyn Backend>,
337    sampler_config: SamplerConfig,
338    chat_template: ChatTemplate,
339    add_bos: bool,
340    engine_config: EngineConfig,
341}
342
343impl Engine {
344    /// Load a model and create an inference engine.
345    ///
346    /// This opens the model file (GGUF or ONNX), loads the tokenizer and model weights,
347    /// and selects the appropriate backend (CPU or GPU).
348    ///
349    /// Format is auto-detected by file extension:
350    /// - `.gguf` -- GGUF format (default)
351    /// - `.onnx` -- ONNX format (requires `onnx` feature, companion config.json + tokenizer.json)
352    pub fn load(config: EngineConfig) -> Result<Self, EngineError> {
353        if config.model_path.is_empty() {
354            return Err(EngineError::Other("model_path is required".into()));
355        }
356
357        let path = std::path::Path::new(&config.model_path);
358
359        // Detect format by extension or directory contents
360        match path.extension().and_then(|e| e.to_str()) {
361            #[cfg(feature = "onnx")]
362            Some("onnx") => Self::load_onnx(config),
363            #[cfg(not(feature = "onnx"))]
364            Some("onnx") => Err(EngineError::Other(
365                "ONNX support requires the `onnx` feature. Build with: cargo build --features onnx"
366                    .into(),
367            )),
368            Some("safetensors") => Self::load_safetensors(config),
369            _ if path.is_dir() && path.join("config.json").exists() => {
370                Self::load_safetensors(config)
371            }
372            _ => Self::load_gguf(config),
373        }
374    }
375
376    /// Load from a GGUF model file (existing path).
377    fn load_gguf(config: EngineConfig) -> Result<Self, EngineError> {
378        tracing::info!("Loading GGUF model from: {}", config.model_path);
379
380        // Load GGUF file
381        let gguf = GgufFile::open(&config.model_path)?;
382
383        // Load tokenizer
384        tracing::info!("Loading tokenizer...");
385        let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
386            // User-specified tokenizer (could be a tokenizer.json or a GGUF file)
387            if tok_path.ends_with(".json") {
388                Tokenizer::from_hf_json(tok_path)?
389            } else {
390                let tok_gguf = GgufFile::open(tok_path)?;
391                Tokenizer::from_gguf(&tok_gguf)?
392            }
393        } else {
394            Tokenizer::from_gguf(&gguf)?
395        };
396        tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
397
398        // Load model weights
399        tracing::info!("Loading model weights...");
400        let loader = ModelLoader::load(&config.model_path)?;
401        let model_config = loader.config().clone();
402        tracing::info!(
403            "Model: {} layers, {} heads, {} hidden dim, {} ctx",
404            model_config.num_layers,
405            model_config.num_heads,
406            model_config.hidden_size,
407            model_config.max_seq_len,
408        );
409
410        let arch = loader.architecture();
411
412        let (backend, model): (Arc<dyn Backend>, Box<dyn Model>) = if arch.is_encoder_only() {
413            tracing::info!("Detected encoder-only architecture: {:?}", arch);
414            let bert_model = loader.build_bert_model()?;
415            (
416                Arc::new(crate::backend::cpu::CpuBackend::new()),
417                Box::new(bert_model),
418            )
419        } else {
420            let concrete_model = loader.build_model()?;
421
422            // When CUDA is available we try to create a GpuModelWrapper first.
423            // This runs the entire forward pass on GPU with pre-allocated scratch
424            // buffers, eliminating ~770 host↔device transfers per token that the
425            // standard Backend-trait path incurs.  If GPU-only init fails we
426            // fall back to the regular CudaBackend (per-op transfers) or CPU.
427            if config.use_gpu {
428                Self::select_gpu_model(concrete_model, &model_config, &config)
429            } else {
430                (
431                    Arc::new(crate::backend::cpu::CpuBackend::new()),
432                    Box::new(concrete_model),
433                )
434            }
435        };
436
437        // Detect chat template
438        let chat_template = ChatTemplate::detect(&gguf);
439        tracing::info!("Chat template: {:?}", chat_template);
440
441        // Check BOS token preference.
442        // If the GGUF explicitly says add_bos_token, use that. Otherwise only
443        // add BOS when the model actually defines a bos_token_id.
444        let add_bos = gguf
445            .data
446            .get_bool("tokenizer.ggml.add_bos_token")
447            .unwrap_or(tokenizer.has_explicit_bos);
448
449        // Build sampler config
450        let sampler_config = SamplerConfig {
451            temperature: config.temperature,
452            top_k: config.top_k,
453            top_p: config.top_p,
454            repeat_penalty: config.repeat_penalty,
455            seed: config.seed,
456            ..Default::default()
457        };
458
459        tracing::info!("Engine ready");
460
461        Ok(Self {
462            gguf: Some(gguf),
463            model,
464            tokenizer,
465            config: model_config,
466            backend,
467            sampler_config,
468            chat_template,
469            add_bos,
470            engine_config: config,
471        })
472    }
473
474    /// Load from a HuggingFace SafeTensors model directory.
475    ///
476    /// Expects a directory containing `config.json`, `tokenizer.json`, and one
477    /// or more `.safetensors` weight files.  Also accepts a path pointing to a
478    /// specific `.safetensors` file, in which case the parent directory is used.
479    fn load_safetensors(config: EngineConfig) -> Result<Self, EngineError> {
480        tracing::info!("Loading SafeTensors model from: {}", config.model_path);
481
482        let path = std::path::Path::new(&config.model_path);
483        let dir = if path.is_dir() {
484            path
485        } else {
486            path.parent().unwrap_or(std::path::Path::new("."))
487        };
488
489        // Load tokenizer
490        tracing::info!("Loading tokenizer...");
491        let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
492            // User-specified tokenizer
493            if tok_path.ends_with(".json") {
494                Tokenizer::from_hf_json(tok_path)?
495            } else {
496                let tok_gguf = GgufFile::open(tok_path)?;
497                Tokenizer::from_gguf(&tok_gguf)?
498            }
499        } else {
500            // Look for tokenizer.json in the model directory
501            let tok_path = dir.join("tokenizer.json");
502            if tok_path.exists() {
503                tracing::info!("Using tokenizer.json from: {}", tok_path.display());
504                Tokenizer::from_hf_json(&tok_path)?
505            } else {
506                return Err(EngineError::Other(format!(
507                    "No tokenizer.json found in {}. Use --tokenizer to specify one.",
508                    dir.display()
509                )));
510            }
511        };
512        tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
513
514        // Load model weights from SafeTensors
515        tracing::info!("Loading model weights...");
516        let mut loader = SafeTensorsLoader::load(path)?;
517
518        // Clamp context length if configured
519        if let Some(cap) = config.max_context_len {
520            if cap > 0 && cap < loader.config().max_seq_len {
521                tracing::info!(
522                    "Capping context length from {} to {}",
523                    loader.config().max_seq_len,
524                    cap
525                );
526                loader.config_mut().max_seq_len = cap;
527            }
528        }
529
530        let model_config = loader.config().clone();
531        let architecture = loader.architecture();
532
533        tracing::info!(
534            "Model: {} layers, {} heads, {} hidden dim, {} ctx, arch={:?}",
535            model_config.num_layers,
536            model_config.num_heads,
537            model_config.hidden_size,
538            model_config.max_seq_len,
539            architecture,
540        );
541
542        // Build model using the shared format-independent builder
543        let concrete_model = build_llama_model(&loader)?;
544
545        let (backend, model): (Arc<dyn Backend>, Box<dyn Model>) = if config.use_gpu {
546            Self::select_gpu_model(concrete_model, &model_config, &config)
547        } else {
548            (
549                Arc::new(crate::backend::cpu::CpuBackend::new()),
550                Box::new(concrete_model),
551            )
552        };
553
554        // Detect chat template from tokenizer_config.json, then fall back
555        // to model-type heuristic.
556        let chat_template = {
557            let tc_path = dir.join("tokenizer_config.json");
558            ChatTemplate::from_tokenizer_config(&tc_path)
559        }
560        .unwrap_or_else(|| {
561            // Read model_type from config.json for fallback detection
562            let config_path = dir.join("config.json");
563            let model_type = std::fs::read_to_string(&config_path)
564                .ok()
565                .and_then(|s| {
566                    let v: serde_json::Value = serde_json::from_str(&s).ok()?;
567                    v.get("model_type")?.as_str().map(|s| s.to_string())
568                });
569            ChatTemplate::detect_from_model_type(model_type.as_deref())
570        });
571        tracing::info!("Chat template: {:?}", chat_template);
572
573        // For HF models, default to adding BOS
574        let add_bos = true;
575
576        let sampler_config = SamplerConfig {
577            temperature: config.temperature,
578            top_k: config.top_k,
579            top_p: config.top_p,
580            repeat_penalty: config.repeat_penalty,
581            seed: config.seed,
582            ..Default::default()
583        };
584
585        tracing::info!("Engine ready (SafeTensors)");
586
587        Ok(Self {
588            gguf: None,
589            model,
590            tokenizer,
591            config: model_config,
592            backend,
593            sampler_config,
594            chat_template,
595            add_bos,
596            engine_config: config,
597        })
598    }
599
600    /// Load from an ONNX model file with companion config.json and tokenizer.json.
601    #[cfg(feature = "onnx")]
602    fn load_onnx(config: EngineConfig) -> Result<Self, EngineError> {
603        use crate::onnx::OnnxModelLoader;
604
605        tracing::info!("Loading ONNX model from: {}", config.model_path);
606
607        let model_dir = std::path::Path::new(&config.model_path)
608            .parent()
609            .unwrap_or(std::path::Path::new("."));
610
611        // Load model via ONNX loader
612        let loader = OnnxModelLoader::load(&config.model_path)
613            .map_err(|e| EngineError::Other(format!("ONNX load error: {}", e)))?;
614        let model_config = loader.config().clone();
615        let hf_config = loader.hf_config().clone();
616
617        tracing::info!(
618            "Model: {} layers, {} heads, {} hidden dim, {} ctx",
619            model_config.num_layers,
620            model_config.num_heads,
621            model_config.hidden_size,
622            model_config.max_seq_len,
623        );
624
625        let concrete_model = loader
626            .build_model()
627            .map_err(|e| EngineError::Other(format!("ONNX model build error: {}", e)))?;
628
629        // Load tokenizer
630        tracing::info!("Loading tokenizer...");
631        let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
632            if tok_path.ends_with(".json") {
633                Tokenizer::from_hf_json(tok_path)?
634            } else {
635                let tok_gguf = GgufFile::open(tok_path)?;
636                Tokenizer::from_gguf(&tok_gguf)?
637            }
638        } else {
639            // Look for tokenizer.json in the same directory
640            let tokenizer_path = model_dir.join("tokenizer.json");
641            if tokenizer_path.exists() {
642                tracing::info!("Using tokenizer.json from: {}", tokenizer_path.display());
643                Tokenizer::from_hf_json(&tokenizer_path)?
644            } else {
645                return Err(EngineError::Other(format!(
646                    "No tokenizer found. ONNX models require a tokenizer.json file \
647                     in the same directory as the model, or specify --tokenizer <path>. \
648                     Looked for: {}",
649                    tokenizer_path.display()
650                )));
651            }
652        };
653        tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
654
655        // Select backend
656        let backend: Arc<dyn Backend> = if config.use_gpu {
657            Self::select_gpu_backend(&concrete_model)
658        } else {
659            Arc::new(crate::backend::cpu::CpuBackend::new())
660        };
661
662        let model: Box<dyn Model> = Box::new(concrete_model);
663
664        // Infer chat template from model type
665        let chat_template = ChatTemplate::detect_from_model_type(hf_config.model_type.as_deref());
666        tracing::info!("Chat template: {:?}", chat_template);
667
668        // For ONNX models, default to adding BOS
669        let add_bos = true;
670
671        let sampler_config = SamplerConfig {
672            temperature: config.temperature,
673            top_k: config.top_k,
674            top_p: config.top_p,
675            repeat_penalty: config.repeat_penalty,
676            seed: config.seed,
677            ..Default::default()
678        };
679
680        tracing::info!("Engine ready (ONNX)");
681
682        Ok(Self {
683            gguf: None,
684            model,
685            tokenizer,
686            config: model_config,
687            backend,
688            sampler_config,
689            chat_template,
690            add_bos,
691            engine_config: config,
692        })
693    }
694
695    /// Select the best GPU model + backend combination.
696    ///
697    /// Tries GPU-only inference first (all computation on GPU, ~386× fewer
698    /// host↔device transfers), then falls back to the per-op CudaBackend,
699    /// then to CPU.
700    #[allow(unused_variables)]
701    fn select_gpu_model(
702        model: crate::model::LlamaModel,
703        config: &ModelConfig,
704        engine_config: &EngineConfig,
705    ) -> (Arc<dyn Backend>, Box<dyn Model>) {
706        let gpu_seq_len = match engine_config.max_context_len {
707            Some(cap) if cap > 0 && cap < config.max_seq_len => {
708                tracing::info!(
709                    "Capping GPU context length from {} to {} (max_context_len)",
710                    config.max_seq_len,
711                    cap
712                );
713                cap
714            }
715            _ => config.max_seq_len,
716        };
717
718        // Priority: CUDA > Vulkan > Metal > DX12 > per-op fallback
719        // Each gpu_only engine consumes the model, so only one can be tried.
720
721        #[cfg(feature = "cuda")]
722        {
723            if cudarc::driver::CudaContext::new(0).is_ok() {
724                let architecture = model.architecture();
725                match crate::backend::cuda::gpu_only::GpuOnlyInference::from_model(
726                    model,
727                    gpu_seq_len,
728                ) {
729                    Ok(gpu) => {
730                        tracing::info!(
731                            "Using full GPU inference (attention + DeltaNet + MoE all on CUDA)"
732                        );
733                        let wrapper = crate::backend::GpuModelWrapper::new(
734                            gpu,
735                            config.clone(),
736                            architecture,
737                        );
738                        return (
739                            Arc::new(crate::backend::cpu::CpuBackend::new()),
740                            Box::new(wrapper),
741                        );
742                    }
743                    Err(e) => {
744                        eprintln!("Error: CUDA GPU inference init failed: {}", e);
745                        eprintln!("The model was consumed during init. Please restart without --gpu.");
746                        std::process::exit(1);
747                    }
748                }
749            } else {
750                tracing::info!("No CUDA device available, trying other GPU backends...");
751            }
752        }
753
754        #[cfg(feature = "vulkan")]
755        {
756            if crate::backend::vulkan::VulkanBackend::new().is_ok() {
757                let architecture = model.architecture();
758                match crate::backend::vulkan::gpu_only::VulkanGpuInference::from_model(
759                    model,
760                    gpu_seq_len,
761                ) {
762                    Ok(gpu) => {
763                        tracing::info!("Using full GPU inference on Vulkan");
764                        let wrapper = crate::backend::GpuModelWrapper::new(
765                            gpu,
766                            config.clone(),
767                            architecture,
768                        );
769                        return (
770                            Arc::new(crate::backend::cpu::CpuBackend::new()),
771                            Box::new(wrapper),
772                        );
773                    }
774                    Err(e) => {
775                        eprintln!("Error: Vulkan GPU inference init failed: {}", e);
776                        eprintln!("The model was consumed during init. Please restart without --gpu.");
777                        std::process::exit(1);
778                    }
779                }
780            } else {
781                tracing::info!("No Vulkan device available, trying other GPU backends...");
782            }
783        }
784
785        #[cfg(all(feature = "metal", target_os = "macos"))]
786        {
787            if crate::backend::metal::MetalBackend::new().is_ok() {
788                let architecture = model.architecture();
789                match crate::backend::metal::gpu_only::MetalGpuInference::from_model(
790                    model,
791                    gpu_seq_len,
792                ) {
793                    Ok(gpu) => {
794                        tracing::info!("Using full GPU inference on Metal");
795                        let wrapper = crate::backend::GpuModelWrapper::new(
796                            gpu,
797                            config.clone(),
798                            architecture,
799                        );
800                        return (
801                            Arc::new(crate::backend::cpu::CpuBackend::new()),
802                            Box::new(wrapper),
803                        );
804                    }
805                    Err(e) => {
806                        eprintln!("Error: Metal GPU inference init failed: {}", e);
807                        eprintln!("The model was consumed during init. Please restart without --gpu.");
808                        std::process::exit(1);
809                    }
810                }
811            } else {
812                tracing::info!("No Metal device available, trying other GPU backends...");
813            }
814        }
815
816        #[cfg(all(feature = "dx12", target_os = "windows"))]
817        {
818            if crate::backend::dx12::Dx12Backend::new().is_ok() {
819                let architecture = model.architecture();
820                match crate::backend::dx12::gpu_only::Dx12GpuInference::from_model(
821                    model,
822                    gpu_seq_len,
823                ) {
824                    Ok(gpu) => {
825                        tracing::info!("Using full GPU inference on DX12");
826                        let wrapper = crate::backend::GpuModelWrapper::new(
827                            gpu,
828                            config.clone(),
829                            architecture,
830                        );
831                        return (
832                            Arc::new(crate::backend::cpu::CpuBackend::new()),
833                            Box::new(wrapper),
834                        );
835                    }
836                    Err(e) => {
837                        eprintln!("Error: DX12 GPU inference init failed: {}", e);
838                        eprintln!("The model was consumed during init. Please restart without --gpu.");
839                        std::process::exit(1);
840                    }
841                }
842            } else {
843                tracing::info!("No DX12 device available");
844            }
845        }
846
847        #[cfg(feature = "hailo")]
848        {
849            if let Some(ref hailo_config) = engine_config.hailo_config {
850                if crate::backend::hailo::context::check_device_available().is_ok() {
851                    let architecture = model.architecture();
852                    match crate::backend::hailo::gpu_only::HailoGpuInference::from_model(
853                        model,
854                        gpu_seq_len,
855                        hailo_config.clone(),
856                    ) {
857                        Ok(gpu) => {
858                            tracing::info!("Using hybrid CPU+Hailo inference");
859                            let wrapper = crate::backend::GpuModelWrapper::new(
860                                gpu,
861                                config.clone(),
862                                architecture,
863                            );
864                            return (
865                                Arc::new(crate::backend::cpu::CpuBackend::new()),
866                                Box::new(wrapper),
867                            );
868                        }
869                        Err(e) => {
870                            eprintln!("Error: Hailo inference init failed: {}", e);
871                            eprintln!("The model was consumed during init. Please restart without --hailo.");
872                            std::process::exit(1);
873                        }
874                    }
875                } else {
876                    tracing::info!("No Hailo device available, falling back to CPU...");
877                }
878            }
879        }
880
881        // No GPU-only engine available: fall back to per-op backend
882        let backend = Self::select_gpu_backend(&model);
883        (backend, Box::new(model))
884    }
885
886    /// Select the best available GPU backend.
887    ///
888    /// Priority: CUDA > Metal > DX12 > Vulkan > CPU fallback.
889    #[allow(unused_variables)]
890    pub fn select_gpu_backend(model: &crate::model::LlamaModel) -> Arc<dyn Backend> {
891        // Try CUDA first (NVIDIA GPUs)
892        #[cfg(feature = "cuda")]
893        {
894            match crate::backend::cuda::CudaBackend::new() {
895                Ok(mut cuda) => {
896                    tracing::info!("Using CUDA backend: {}", cuda.device_name());
897                    if let Err(e) = cuda.load_model_weights(model) {
898                        tracing::warn!("Failed to load GPU weights ({}), using quantized ops", e);
899                    }
900                    return Arc::new(cuda);
901                }
902                Err(e) => {
903                    tracing::info!("CUDA not available ({}), trying Metal...", e);
904                }
905            }
906        }
907
908        // Try Metal (native macOS / Apple Silicon)
909        #[cfg(all(feature = "metal", target_os = "macos"))]
910        {
911            match crate::backend::metal::MetalBackend::new() {
912                Ok(metal) => {
913                    tracing::info!("Using Metal backend: {}", metal.device_name());
914                    return Arc::new(metal);
915                }
916                Err(e) => {
917                    tracing::info!("Metal not available ({}), trying DX12...", e);
918                }
919            }
920        }
921
922        // Try DX12 (native Windows GPU compute)
923        #[cfg(all(feature = "dx12", target_os = "windows"))]
924        {
925            match crate::backend::dx12::Dx12Backend::new() {
926                Ok(dx12) => {
927                    tracing::info!("Using DX12 backend: {}", dx12.device_name());
928                    return Arc::new(dx12);
929                }
930                Err(e) => {
931                    tracing::info!("DX12 not available ({}), trying Vulkan...", e);
932                }
933            }
934        }
935
936        // Try Vulkan (cross-platform: AMD, Intel, NVIDIA, etc.)
937        #[cfg(feature = "vulkan")]
938        {
939            match crate::backend::vulkan::VulkanBackend::new() {
940                Ok(vk) => {
941                    tracing::info!("Using Vulkan backend: {}", vk.device_name());
942                    return Arc::new(vk);
943                }
944                Err(e) => {
945                    tracing::warn!("Vulkan not available ({}), falling back to CPU", e);
946                }
947            }
948        }
949
950        // Fallback message when no GPU backend is compiled
951        #[cfg(not(any(
952            feature = "cuda",
953            feature = "vulkan",
954            all(feature = "metal", target_os = "macos"),
955            all(feature = "dx12", target_os = "windows")
956        )))]
957        {
958            tracing::warn!(
959                "No GPU backend compiled. Build with --features cuda, --features metal, --features dx12, or --features vulkan"
960            );
961        }
962
963        Arc::new(crate::backend::cpu::CpuBackend::new())
964    }
965
966    /// Get the model configuration.
967    pub fn model_config(&self) -> &ModelConfig {
968        &self.config
969    }
970
971    /// Get the detected chat template.
972    pub fn chat_template(&self) -> &ChatTemplate {
973        &self.chat_template
974    }
975
976    /// Get the GGUF file metadata (None for ONNX-loaded models).
977    pub fn gguf(&self) -> Option<&GgufFile> {
978        self.gguf.as_ref()
979    }
980
981    /// Get the tokenizer.
982    pub fn tokenizer(&self) -> &Tokenizer {
983        &self.tokenizer
984    }
985
986    /// Get the engine configuration.
987    pub fn engine_config(&self) -> &EngineConfig {
988        &self.engine_config
989    }
990
991    /// Get the underlying model (for advanced usage like perplexity computation).
992    pub fn model(&self) -> &dyn Model {
993        &*self.model
994    }
995
996    /// Get the backend.
997    pub fn backend(&self) -> &Arc<dyn Backend> {
998        &self.backend
999    }
1000
1001    /// Whether to add a BOS token when encoding prompts.
1002    pub fn add_bos(&self) -> bool {
1003        self.add_bos
1004    }
1005
1006    /// Generate text from a prompt.
1007    ///
1008    /// The prompt is automatically wrapped with the detected chat template
1009    /// unless it already contains chat formatting tokens.
1010    ///
1011    /// Create an InferenceContext respecting the configured KV cache type.
1012    pub fn create_inference_context(&self) -> InferenceContext {
1013        if self.engine_config.kv_cache_type.is_turboquant() {
1014            InferenceContext::new_with_cache_type(
1015                &self.config,
1016                self.backend.clone(),
1017                self.engine_config.kv_cache_type,
1018            )
1019        } else {
1020            self.model.create_context(self.backend.clone())
1021        }
1022    }
1023
1024    /// Returns the generated text (not including the prompt).
1025    pub fn generate(&self, prompt: &str, max_tokens: usize) -> Result<String, EngineError> {
1026        let mut ctx = self.create_inference_context();
1027        let mut sampler = Sampler::new(self.sampler_config.clone(), self.config.vocab_size);
1028
1029        // Wrap prompt with chat template
1030        let formatted = self.chat_template.wrap_prompt(prompt);
1031        let mut tokens = self.tokenizer.encode(&formatted, self.add_bos)?;
1032
1033        let mut output = String::new();
1034
1035        for _ in 0..max_tokens {
1036            // Check if we hit EOS
1037            if let Some(&last) = tokens.last()
1038                && last == self.tokenizer.special_tokens.eos_token_id
1039            {
1040                break;
1041            }
1042
1043            // Run forward pass
1044            let input_tokens = if ctx.position == 0 {
1045                &tokens[..]
1046            } else {
1047                &tokens[tokens.len() - 1..]
1048            };
1049
1050            let logits = self.model.forward(input_tokens, &mut ctx)?;
1051            let next_token = sampler.sample(&logits, &tokens);
1052
1053            // Check for EOS
1054            if next_token == self.tokenizer.special_tokens.eos_token_id {
1055                break;
1056            }
1057
1058            // Decode token
1059            if let Ok(text) = self.tokenizer.decode(&[next_token]) {
1060                // Check for stop patterns
1061                let combined = format!("{}{}", output, text);
1062                let stop = self
1063                    .chat_template
1064                    .stop_patterns()
1065                    .iter()
1066                    .any(|p| combined.contains(p));
1067
1068                if stop {
1069                    // Add only the text before the stop pattern
1070                    for pattern in self.chat_template.stop_patterns() {
1071                        if let Some(idx) = combined.find(pattern) {
1072                            output = combined[..idx].to_string();
1073                            return Ok(output.trim().to_string());
1074                        }
1075                    }
1076                    break;
1077                }
1078
1079                output.push_str(&text);
1080            }
1081
1082            tokens.push(next_token);
1083        }
1084
1085        Ok(output.trim().to_string())
1086    }
1087
1088    /// Generate text from a prompt, yielding tokens as they are produced.
1089    ///
1090    /// Each item in the returned iterator is a `Result<String, EngineError>` containing
1091    /// the decoded text of one or more tokens.
1092    pub fn generate_streaming(&self, prompt: &str, max_tokens: usize) -> GenerationStream<'_> {
1093        GenerationStream::new(self, prompt, max_tokens)
1094    }
1095
1096    /// Extract embeddings from text using the model.
1097    pub fn embed(&self, text: &str) -> Result<Vec<f32>, EngineError> {
1098        let mut ctx = self.create_inference_context();
1099        let embed_config = EmbeddingConfig::default();
1100        let extractor = EmbeddingExtractor::new(embed_config, &self.config);
1101        let embedding =
1102            extractor.embed_text(self.model.as_ref(), &self.tokenizer, &mut ctx, text)?;
1103        Ok(embedding)
1104    }
1105}
1106
1107// ============================================================================
1108// Streaming generation
1109// ============================================================================
1110
1111/// Iterator that yields generated tokens as strings.
1112///
1113/// Created by [`Engine::generate_streaming`].
1114pub struct GenerationStream<'a> {
1115    engine: &'a Engine,
1116    ctx: InferenceContext,
1117    sampler: Sampler,
1118    tokens: Vec<u32>,
1119    remaining: usize,
1120    done: bool,
1121    accumulated: String,
1122    /// Pending bytes for incomplete UTF-8 sequences across token boundaries
1123    pending_bytes: Vec<u8>,
1124}
1125
1126impl<'a> GenerationStream<'a> {
1127    fn new(engine: &'a Engine, prompt: &str, max_tokens: usize) -> Self {
1128        let ctx = engine.create_inference_context();
1129        let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
1130
1131        let formatted = engine.chat_template.wrap_prompt(prompt);
1132        if std::env::var("LLAMA_DEBUG").is_ok() {
1133            eprintln!("[DEBUG] formatted prompt: {:?}", formatted);
1134            eprintln!("[DEBUG] add_bos: {}", engine.add_bos);
1135        }
1136        let tokens = engine
1137            .tokenizer
1138            .encode(&formatted, engine.add_bos)
1139            .unwrap_or_default();
1140        if std::env::var("LLAMA_DEBUG").is_ok() {
1141            eprintln!("[DEBUG] encoded {} tokens: {:?}", tokens.len(), &tokens[..tokens.len().min(50)]);
1142            for (i, &tid) in tokens.iter().enumerate() {
1143                if let Some(s) = engine.tokenizer.get_token(tid) {
1144                    eprintln!("[DEBUG]   token[{}] = {} -> {:?}", i, tid, s);
1145                }
1146            }
1147        }
1148
1149        Self {
1150            engine,
1151            ctx,
1152            sampler,
1153            tokens,
1154            remaining: max_tokens,
1155            done: false,
1156            accumulated: String::new(),
1157            pending_bytes: Vec::new(),
1158        }
1159    }
1160}
1161
1162impl<'a> Iterator for GenerationStream<'a> {
1163    type Item = Result<String, EngineError>;
1164
1165    fn next(&mut self) -> Option<Self::Item> {
1166        if self.done || self.remaining == 0 {
1167            return None;
1168        }
1169
1170        // Check EOS from last token
1171        if let Some(&last) = self.tokens.last()
1172            && last == self.engine.tokenizer.special_tokens.eos_token_id
1173        {
1174            self.done = true;
1175            return None;
1176        }
1177
1178        // Forward pass
1179        let input_tokens = if self.ctx.position == 0 {
1180            &self.tokens[..]
1181        } else {
1182            &self.tokens[self.tokens.len() - 1..]
1183        };
1184
1185        let logits = match self.engine.model.forward(input_tokens, &mut self.ctx) {
1186            Ok(l) => l,
1187            Err(e) => {
1188                self.done = true;
1189                return Some(Err(EngineError::Model(e)));
1190            }
1191        };
1192
1193        let next_token = self.sampler.sample(&logits, &self.tokens);
1194
1195        if std::env::var("LLAMA_DEBUG_LOGITS").is_ok() {
1196            let logit_data = logits.as_f32().unwrap();
1197            let mut indexed: Vec<(usize, f32)> = logit_data.iter().copied().enumerate().collect();
1198            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1199            let step = self.tokens.len();
1200            eprint!("[LOGIT] step={} top5:", step);
1201            for (id, score) in indexed.iter().take(5) {
1202                let tok_str = self.engine.tokenizer.get_token(*id as u32).unwrap_or_default();
1203                eprint!(" {}({:.2})={:?}", id, score, tok_str);
1204            }
1205            let chosen_str = self.engine.tokenizer.get_token(next_token).unwrap_or_default();
1206            eprintln!(" → chosen={}({:?})", next_token, chosen_str);
1207        }
1208
1209        // Check EOS
1210        if next_token == self.engine.tokenizer.special_tokens.eos_token_id {
1211            self.done = true;
1212            return None;
1213        }
1214
1215        // Decode (streaming-aware: accumulates incomplete UTF-8 across tokens)
1216        match self
1217            .engine
1218            .tokenizer
1219            .decode_token_streaming(next_token, &mut self.pending_bytes)
1220        {
1221            Ok(text) => {
1222                self.tokens.push(next_token);
1223                self.remaining -= 1;
1224
1225                if text.is_empty() {
1226                    // Pending bytes not yet forming valid UTF-8; recurse to next token
1227                    return self.next();
1228                }
1229
1230                // Check stop patterns
1231                let combined = format!("{}{}", self.accumulated, text);
1232                for pattern in self.engine.chat_template.stop_patterns() {
1233                    if combined.contains(pattern) {
1234                        self.done = true;
1235                        if let Some(idx) = combined.find(pattern) {
1236                            if idx > self.accumulated.len() {
1237                                let before = &combined[self.accumulated.len()..idx];
1238                                return Some(Ok(before.to_string()));
1239                            }
1240                        }
1241                        return None;
1242                    }
1243                }
1244
1245                self.accumulated.push_str(&text);
1246                Some(Ok(text))
1247            }
1248            Err(e) => {
1249                self.tokens.push(next_token);
1250                self.remaining -= 1;
1251                Some(Err(EngineError::Tokenizer(e)))
1252            }
1253        }
1254    }
1255}
1256
1257// ============================================================================
1258// ChatEngine
1259// ============================================================================
1260
1261/// High-level chat engine that maintains conversation state.
1262///
1263/// Wraps an [`Engine`] with conversation history, context management,
1264/// and automatic chat template formatting.
1265pub struct ChatEngine {
1266    engine: Engine,
1267    system_prompt: String,
1268    conversation_tokens: Vec<u32>,
1269    ctx: InferenceContext,
1270    sampler: Sampler,
1271    is_first_turn: bool,
1272}
1273
1274impl ChatEngine {
1275    /// Create a new chat engine from a loaded [`Engine`].
1276    pub fn new(engine: Engine, system_prompt: Option<String>) -> Self {
1277        let ctx = engine.create_inference_context();
1278        let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
1279
1280        Self {
1281            system_prompt: system_prompt
1282                .unwrap_or_else(|| "You are a helpful AI assistant.".to_string()),
1283            conversation_tokens: Vec::new(),
1284            ctx,
1285            sampler,
1286            is_first_turn: true,
1287            engine,
1288        }
1289    }
1290
1291    /// Get a reference to the underlying engine.
1292    pub fn engine(&self) -> &Engine {
1293        &self.engine
1294    }
1295
1296    /// Get the current system prompt.
1297    pub fn system_prompt(&self) -> &str {
1298        &self.system_prompt
1299    }
1300
1301    /// Get the number of tokens in the current conversation context.
1302    pub fn context_len(&self) -> usize {
1303        self.conversation_tokens.len()
1304    }
1305
1306    /// Send a message and get the full response.
1307    pub fn chat(&mut self, message: &str) -> Result<String, EngineError> {
1308        let max_tokens = self.engine.engine_config.max_tokens;
1309
1310        // Format the message using the chat template
1311        let formatted = if self.is_first_turn {
1312            self.engine
1313                .chat_template
1314                .format_first_turn(&self.system_prompt, message)
1315        } else {
1316            self.engine.chat_template.format_continuation(message)
1317        };
1318
1319        // Encode new tokens
1320        let new_tokens = self
1321            .engine
1322            .tokenizer
1323            .encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
1324
1325        // Check context length and trim if needed
1326        self.ensure_context_space(new_tokens.len(), max_tokens);
1327
1328        // Add new tokens to conversation
1329        self.conversation_tokens.extend(&new_tokens);
1330
1331        // Batch-prefill: process ALL prompt tokens in a single forward pass.
1332        // This is dramatically faster than one-at-a-time, especially on GPU
1333        // backends where kernel launch overhead dominates for tiny batches.
1334        let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
1335        let mut response_text = String::new();
1336
1337        if new_tokens.is_empty() {
1338            self.is_first_turn = false;
1339            return Ok(response_text);
1340        }
1341
1342        let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
1343        let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
1344
1345        if first_token == eos_id {
1346            self.is_first_turn = false;
1347            return Ok(response_text);
1348        }
1349
1350        if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
1351            response_text.push_str(&text);
1352        }
1353        self.conversation_tokens.push(first_token);
1354
1355        // Autoregressive decode for the remaining tokens
1356        for _ in 1..max_tokens {
1357            // Check stop patterns
1358            let should_stop = self
1359                .engine
1360                .chat_template
1361                .stop_patterns()
1362                .iter()
1363                .any(|p| response_text.contains(p));
1364            if should_stop {
1365                for pattern in self.engine.chat_template.stop_patterns() {
1366                    if let Some(idx) = response_text.find(pattern) {
1367                        response_text.truncate(idx);
1368                        break;
1369                    }
1370                }
1371                break;
1372            }
1373
1374            let last_token = *self
1375                .conversation_tokens
1376                .last()
1377                .unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
1378
1379            let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
1380            let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
1381
1382            if next_token == eos_id {
1383                break;
1384            }
1385
1386            if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
1387                response_text.push_str(&text);
1388            }
1389
1390            self.conversation_tokens.push(next_token);
1391        }
1392
1393        self.is_first_turn = false;
1394        Ok(response_text.trim().to_string())
1395    }
1396
1397    /// Send a message and get the full response, with a prefix injected as the
1398    /// start of the assistant's reply. The prefix tokens are prefilled alongside
1399    /// the prompt tokens so the model continues from the prefix text. The prefix
1400    /// is prepended to the returned string.
1401    ///
1402    /// This is useful for forcing the model to start with a particular token
1403    /// sequence (e.g. `{` for JSON output).
1404    pub fn chat_with_prefix(
1405        &mut self,
1406        message: &str,
1407        prefix: &str,
1408    ) -> Result<String, EngineError> {
1409        let max_tokens = self.engine.engine_config.max_tokens;
1410
1411        let formatted = if self.is_first_turn {
1412            self.engine
1413                .chat_template
1414                .format_first_turn(&self.system_prompt, message)
1415        } else {
1416            self.engine.chat_template.format_continuation(message)
1417        };
1418
1419        // Append prefix to the formatted prompt so it becomes part of the prefill
1420        let formatted_with_prefix = format!("{}{}", formatted, prefix);
1421
1422        let new_tokens = self
1423            .engine
1424            .tokenizer
1425            .encode(&formatted_with_prefix, self.is_first_turn && self.engine.add_bos)?;
1426
1427        self.ensure_context_space(new_tokens.len(), max_tokens);
1428        self.conversation_tokens.extend(&new_tokens);
1429
1430        let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
1431        let mut response_text = prefix.to_string();
1432
1433        if new_tokens.is_empty() {
1434            self.is_first_turn = false;
1435            return Ok(response_text);
1436        }
1437
1438        let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
1439        let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
1440
1441        if first_token == eos_id {
1442            self.is_first_turn = false;
1443            return Ok(response_text);
1444        }
1445
1446        if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
1447            response_text.push_str(&text);
1448        }
1449        self.conversation_tokens.push(first_token);
1450
1451        for _ in 1..max_tokens {
1452            let should_stop = self
1453                .engine
1454                .chat_template
1455                .stop_patterns()
1456                .iter()
1457                .any(|p| response_text.contains(p));
1458            if should_stop {
1459                for pattern in self.engine.chat_template.stop_patterns() {
1460                    if let Some(idx) = response_text.find(pattern) {
1461                        response_text.truncate(idx);
1462                        break;
1463                    }
1464                }
1465                break;
1466            }
1467
1468            let last_token = *self
1469                .conversation_tokens
1470                .last()
1471                .unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
1472
1473            let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
1474            let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
1475
1476            if next_token == eos_id {
1477                break;
1478            }
1479
1480            if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
1481                response_text.push_str(&text);
1482            }
1483
1484            self.conversation_tokens.push(next_token);
1485        }
1486
1487        self.is_first_turn = false;
1488        Ok(response_text.trim().to_string())
1489    }
1490
1491    /// Send a message and stream the response token by token.
1492    ///
1493    /// Returns an iterator of `Result<String, EngineError>` where each item is
1494    /// a decoded token chunk.
1495    pub fn chat_streaming(&mut self, message: &str) -> Result<ChatStream<'_>, EngineError> {
1496        let max_tokens = self.engine.engine_config.max_tokens;
1497
1498        // Format the message
1499        let formatted = if self.is_first_turn {
1500            self.engine
1501                .chat_template
1502                .format_first_turn(&self.system_prompt, message)
1503        } else {
1504            self.engine.chat_template.format_continuation(message)
1505        };
1506
1507        // Encode new tokens
1508        let new_tokens = self
1509            .engine
1510            .tokenizer
1511            .encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
1512
1513        // Ensure context space
1514        self.ensure_context_space(new_tokens.len(), max_tokens);
1515
1516        // Add new tokens to conversation
1517        self.conversation_tokens.extend(&new_tokens);
1518
1519        // Batch-prefill all prompt tokens in a single forward pass.
1520        let prefill_logits = if !new_tokens.is_empty() {
1521            Some(self.engine.model.forward(&new_tokens, &mut self.ctx)?)
1522        } else {
1523            None
1524        };
1525
1526        self.is_first_turn = false;
1527
1528        Ok(ChatStream {
1529            chat_engine: self,
1530            remaining: max_tokens,
1531            done: false,
1532            accumulated: String::new(),
1533            prefill_logits,
1534        })
1535    }
1536
1537    /// Clear conversation history and reset context.
1538    pub fn clear_history(&mut self) {
1539        self.conversation_tokens.clear();
1540        self.ctx.reset();
1541        self.sampler.reset();
1542        self.is_first_turn = true;
1543    }
1544
1545    /// Ensure there's enough space in the context for new tokens + generation.
1546    fn ensure_context_space(&mut self, new_token_count: usize, max_gen_tokens: usize) {
1547        let total_len = self.conversation_tokens.len() + new_token_count + max_gen_tokens;
1548
1549        if total_len > self.engine.config.max_seq_len {
1550            let excess = total_len - self.engine.config.max_seq_len + 100;
1551
1552            if excess >= self.conversation_tokens.len() {
1553                tracing::warn!("Context full, resetting conversation");
1554                self.conversation_tokens.clear();
1555                self.ctx.reset();
1556            } else {
1557                tracing::info!("Trimming {} tokens from context", excess);
1558                self.conversation_tokens = self.conversation_tokens[excess..].to_vec();
1559                self.ctx.kv_cache.shift_left(excess);
1560                self.ctx.position = self.ctx.position.saturating_sub(excess);
1561            }
1562        }
1563    }
1564}
1565
1566// ============================================================================
1567// Chat streaming
1568// ============================================================================
1569
1570/// Iterator that yields chat response tokens as strings.
1571///
1572/// Created by [`ChatEngine::chat_streaming`].
1573pub struct ChatStream<'a> {
1574    chat_engine: &'a mut ChatEngine,
1575    remaining: usize,
1576    done: bool,
1577    accumulated: String,
1578    /// Logits from the batched prefill pass; consumed on the first `next()` call.
1579    prefill_logits: Option<crate::tensor::Tensor>,
1580}
1581
1582impl<'a> Iterator for ChatStream<'a> {
1583    type Item = Result<String, EngineError>;
1584
1585    fn next(&mut self) -> Option<Self::Item> {
1586        if self.done || self.remaining == 0 {
1587            return None;
1588        }
1589
1590        // Check stop patterns on accumulated text
1591        for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
1592            if self.accumulated.contains(pattern) {
1593                self.done = true;
1594                return None;
1595            }
1596        }
1597
1598        // On the first call, use the prefill logits (no extra forward pass).
1599        // On subsequent calls, run the standard single-token decode.
1600        let logits = if let Some(prefill) = self.prefill_logits.take() {
1601            prefill
1602        } else {
1603            let last_token = *self.chat_engine.conversation_tokens.last().unwrap_or(
1604                &self
1605                    .chat_engine
1606                    .engine
1607                    .tokenizer
1608                    .special_tokens
1609                    .bos_token_id,
1610            );
1611
1612            match self
1613                .chat_engine
1614                .engine
1615                .model
1616                .forward(&[last_token], &mut self.chat_engine.ctx)
1617            {
1618                Ok(l) => l,
1619                Err(e) => {
1620                    self.done = true;
1621                    return Some(Err(EngineError::Model(e)));
1622                }
1623            }
1624        };
1625
1626        let next_token = self
1627            .chat_engine
1628            .sampler
1629            .sample(&logits, &self.chat_engine.conversation_tokens);
1630
1631        // Check for EOS
1632        if next_token
1633            == self
1634                .chat_engine
1635                .engine
1636                .tokenizer
1637                .special_tokens
1638                .eos_token_id
1639        {
1640            self.done = true;
1641            return None;
1642        }
1643
1644        match self.chat_engine.engine.tokenizer.decode(&[next_token]) {
1645            Ok(text) => {
1646                // Check stop patterns in accumulated + new text
1647                let combined = format!("{}{}", self.accumulated, text);
1648                for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
1649                    if combined.contains(pattern) {
1650                        self.done = true;
1651                        if let Some(idx) = combined.find(pattern) {
1652                            let before = &combined[self.accumulated.len()..idx];
1653                            self.chat_engine.conversation_tokens.push(next_token);
1654                            if !before.is_empty() {
1655                                return Some(Ok(before.to_string()));
1656                            }
1657                        }
1658                        return None;
1659                    }
1660                }
1661
1662                self.accumulated.push_str(&text);
1663                self.chat_engine.conversation_tokens.push(next_token);
1664                self.remaining -= 1;
1665                Some(Ok(text))
1666            }
1667            Err(e) => {
1668                self.chat_engine.conversation_tokens.push(next_token);
1669                self.remaining -= 1;
1670                Some(Err(EngineError::Tokenizer(e)))
1671            }
1672        }
1673    }
1674}