1use std::sync::Arc;
21
22use crate::backend::Backend;
23use crate::gguf::GgufFile;
24use crate::model::{
25 EmbeddingConfig, EmbeddingExtractor, InferenceContext, Model, ModelConfig, ModelLoader,
26};
27use crate::sampling::{Sampler, SamplerConfig};
28use crate::tokenizer::Tokenizer;
29
30#[derive(thiserror::Error, Debug)]
36pub enum EngineError {
37 #[error("IO error: {0}")]
38 Io(#[from] std::io::Error),
39
40 #[error("GGUF error: {0}")]
41 Gguf(#[from] crate::gguf::GgufError),
42
43 #[error("Model error: {0}")]
44 Model(#[from] crate::model::ModelError),
45
46 #[error("Tokenizer error: {0}")]
47 Tokenizer(#[from] crate::tokenizer::TokenizerError),
48
49 #[error("Embedding error: {0}")]
50 Embedding(#[from] crate::model::EmbeddingError),
51
52 #[error("Engine error: {0}")]
53 Other(String),
54}
55
56#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
65#[serde(default)]
66pub struct EngineConfig {
67 pub model_path: String,
69
70 pub tokenizer_path: Option<String>,
75
76 pub temperature: f32,
78
79 pub top_k: usize,
81
82 pub top_p: f32,
84
85 pub repeat_penalty: f32,
87
88 pub max_tokens: usize,
90
91 pub seed: Option<u64>,
93
94 pub use_gpu: bool,
96
97 pub max_context_len: Option<usize>,
105
106 #[cfg(feature = "hailo")]
109 pub hailo_config: Option<crate::backend::hailo::HailoConfig>,
110
111 pub kv_cache_type: crate::model::KVCacheType,
113}
114
115impl Default for EngineConfig {
116 fn default() -> Self {
117 Self {
118 model_path: String::new(),
119 tokenizer_path: None,
120 temperature: 0.7,
121 top_k: 40,
122 top_p: 0.95,
123 repeat_penalty: 1.1,
124 max_tokens: 512,
125 seed: None,
126 use_gpu: false,
127 max_context_len: None,
128 #[cfg(feature = "hailo")]
129 hailo_config: None,
130 kv_cache_type: crate::model::KVCacheType::F32,
131 }
132 }
133}
134
135impl EngineConfig {
136 pub fn from_config_file(
141 path: impl AsRef<std::path::Path>,
142 ) -> Result<Self, crate::config::ConfigError> {
143 let config = crate::config::Config::from_file(path)?;
144 Ok(config.to_engine_config(None))
145 }
146
147 pub fn from_config(
152 config_path: Option<impl AsRef<std::path::Path>>,
153 ) -> Result<Self, crate::config::ConfigError> {
154 let config = crate::config::Config::load(config_path)?;
155 Ok(config.to_engine_config(None))
156 }
157}
158
159#[derive(Debug, Clone, PartialEq)]
165pub enum ChatTemplate {
166 UserAssistant,
168 ChatML,
170 Llama2,
172 None,
174}
175
176impl ChatTemplate {
177 pub fn detect_from_model_type(model_type: Option<&str>) -> Self {
179 match model_type {
180 Some("qwen2" | "qwen") => ChatTemplate::ChatML,
181 Some("llama" | "codellama") => ChatTemplate::Llama2,
182 Some("mistral" | "mixtral") => ChatTemplate::Llama2,
183 _ => ChatTemplate::None,
184 }
185 }
186
187 pub fn detect(gguf: &GgufFile) -> Self {
189 if let Some(template) = gguf.data.get_string("tokenizer.chat_template") {
190 if template.contains("<|user|>") {
191 ChatTemplate::UserAssistant
192 } else if template.contains("<|im_start|>") {
193 ChatTemplate::ChatML
194 } else if template.contains("[INST]") {
195 ChatTemplate::Llama2
196 } else {
197 ChatTemplate::None
198 }
199 } else if let Some(arch) = gguf.data.get_string("general.architecture") {
200 match arch.to_lowercase().as_str() {
201 "qwen2" | "qwen" | "qwen3" | "qwen35" | "qwen3moe" | "qwen3next" => {
202 ChatTemplate::ChatML
203 }
204 _ => ChatTemplate::None,
205 }
206 } else {
207 ChatTemplate::None
208 }
209 }
210
211 pub fn wrap_prompt(&self, prompt: &str) -> String {
213 if prompt.contains("<|user|>")
215 || prompt.contains("<|im_start|>")
216 || prompt.contains("[INST]")
217 {
218 return prompt.to_string();
219 }
220
221 match self {
222 ChatTemplate::UserAssistant => {
223 format!("<|user|>\n{}<|assistant|>\n", prompt)
224 }
225 ChatTemplate::ChatML => {
226 format!(
227 "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
228 prompt
229 )
230 }
231 ChatTemplate::Llama2 => {
232 format!("[INST] {} [/INST]", prompt)
233 }
234 ChatTemplate::None => prompt.to_string(),
235 }
236 }
237
238 pub fn format_first_turn(&self, system_prompt: &str, user_message: &str) -> String {
240 match self {
241 ChatTemplate::UserAssistant => {
242 format!(
243 "<|system|>\n{}<|user|>\n{}<|assistant|>\n",
244 system_prompt, user_message
245 )
246 }
247 ChatTemplate::ChatML => {
248 format!(
249 "<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
250 system_prompt, user_message
251 )
252 }
253 ChatTemplate::Llama2 => {
254 format!(
255 "[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]",
256 system_prompt, user_message
257 )
258 }
259 ChatTemplate::None => {
260 format!(
261 "System: {}\n\nUser: {}\n\nAssistant:",
262 system_prompt, user_message
263 )
264 }
265 }
266 }
267
268 pub fn format_continuation(&self, user_message: &str) -> String {
270 match self {
271 ChatTemplate::UserAssistant => {
272 format!("<|user|>\n{}<|assistant|>\n", user_message)
273 }
274 ChatTemplate::ChatML => {
275 format!(
276 "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
277 user_message
278 )
279 }
280 ChatTemplate::Llama2 => {
281 format!(" [INST] {} [/INST]", user_message)
282 }
283 ChatTemplate::None => {
284 format!("\n\nUser: {}\n\nAssistant:", user_message)
285 }
286 }
287 }
288
289 pub fn stop_patterns(&self) -> &[&str] {
292 match self {
293 ChatTemplate::UserAssistant => &["<|user|>", "<|end|>"],
294 ChatTemplate::ChatML => &["<|im_end|>", "<|im_start|>"],
295 ChatTemplate::Llama2 => &["[INST]", "</s>"],
296 ChatTemplate::None => &["User:", "\nUser:"],
297 }
298 }
299}
300
301pub struct Engine {
310 gguf: Option<GgufFile>,
311 model: Box<dyn Model>,
312 tokenizer: Tokenizer,
313 config: ModelConfig,
314 backend: Arc<dyn Backend>,
315 sampler_config: SamplerConfig,
316 chat_template: ChatTemplate,
317 add_bos: bool,
318 engine_config: EngineConfig,
319}
320
321impl Engine {
322 pub fn load(config: EngineConfig) -> Result<Self, EngineError> {
331 if config.model_path.is_empty() {
332 return Err(EngineError::Other("model_path is required".into()));
333 }
334
335 let path = std::path::Path::new(&config.model_path);
336
337 match path.extension().and_then(|e| e.to_str()) {
339 #[cfg(feature = "onnx")]
340 Some("onnx") => Self::load_onnx(config),
341 #[cfg(not(feature = "onnx"))]
342 Some("onnx") => Err(EngineError::Other(
343 "ONNX support requires the `onnx` feature. Build with: cargo build --features onnx"
344 .into(),
345 )),
346 _ => Self::load_gguf(config),
347 }
348 }
349
350 fn load_gguf(config: EngineConfig) -> Result<Self, EngineError> {
352 tracing::info!("Loading GGUF model from: {}", config.model_path);
353
354 let gguf = GgufFile::open(&config.model_path)?;
356
357 tracing::info!("Loading tokenizer...");
359 let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
360 if tok_path.ends_with(".json") {
362 Tokenizer::from_hf_json(tok_path)?
363 } else {
364 let tok_gguf = GgufFile::open(tok_path)?;
365 Tokenizer::from_gguf(&tok_gguf)?
366 }
367 } else {
368 Tokenizer::from_gguf(&gguf)?
369 };
370 tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
371
372 tracing::info!("Loading model weights...");
374 let loader = ModelLoader::load(&config.model_path)?;
375 let model_config = loader.config().clone();
376 tracing::info!(
377 "Model: {} layers, {} heads, {} hidden dim, {} ctx",
378 model_config.num_layers,
379 model_config.num_heads,
380 model_config.hidden_size,
381 model_config.max_seq_len,
382 );
383
384 let arch = loader.architecture();
385
386 let (backend, model): (Arc<dyn Backend>, Box<dyn Model>) = if arch.is_encoder_only() {
387 tracing::info!("Detected encoder-only architecture: {:?}", arch);
388 let bert_model = loader.build_bert_model()?;
389 (
390 Arc::new(crate::backend::cpu::CpuBackend::new()),
391 Box::new(bert_model),
392 )
393 } else {
394 let concrete_model = loader.build_model()?;
395
396 if config.use_gpu {
402 Self::select_gpu_model(concrete_model, &model_config, &config)
403 } else {
404 (
405 Arc::new(crate::backend::cpu::CpuBackend::new()),
406 Box::new(concrete_model),
407 )
408 }
409 };
410
411 let chat_template = ChatTemplate::detect(&gguf);
413 tracing::info!("Chat template: {:?}", chat_template);
414
415 let add_bos = gguf
419 .data
420 .get_bool("tokenizer.ggml.add_bos_token")
421 .unwrap_or(tokenizer.has_explicit_bos);
422
423 let sampler_config = SamplerConfig {
425 temperature: config.temperature,
426 top_k: config.top_k,
427 top_p: config.top_p,
428 repeat_penalty: config.repeat_penalty,
429 seed: config.seed,
430 ..Default::default()
431 };
432
433 tracing::info!("Engine ready");
434
435 Ok(Self {
436 gguf: Some(gguf),
437 model,
438 tokenizer,
439 config: model_config,
440 backend,
441 sampler_config,
442 chat_template,
443 add_bos,
444 engine_config: config,
445 })
446 }
447
448 #[cfg(feature = "onnx")]
450 fn load_onnx(config: EngineConfig) -> Result<Self, EngineError> {
451 use crate::onnx::OnnxModelLoader;
452
453 tracing::info!("Loading ONNX model from: {}", config.model_path);
454
455 let model_dir = std::path::Path::new(&config.model_path)
456 .parent()
457 .unwrap_or(std::path::Path::new("."));
458
459 let loader = OnnxModelLoader::load(&config.model_path)
461 .map_err(|e| EngineError::Other(format!("ONNX load error: {}", e)))?;
462 let model_config = loader.config().clone();
463 let hf_config = loader.hf_config().clone();
464
465 tracing::info!(
466 "Model: {} layers, {} heads, {} hidden dim, {} ctx",
467 model_config.num_layers,
468 model_config.num_heads,
469 model_config.hidden_size,
470 model_config.max_seq_len,
471 );
472
473 let concrete_model = loader
474 .build_model()
475 .map_err(|e| EngineError::Other(format!("ONNX model build error: {}", e)))?;
476
477 tracing::info!("Loading tokenizer...");
479 let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
480 if tok_path.ends_with(".json") {
481 Tokenizer::from_hf_json(tok_path)?
482 } else {
483 let tok_gguf = GgufFile::open(tok_path)?;
484 Tokenizer::from_gguf(&tok_gguf)?
485 }
486 } else {
487 let tokenizer_path = model_dir.join("tokenizer.json");
489 if tokenizer_path.exists() {
490 tracing::info!("Using tokenizer.json from: {}", tokenizer_path.display());
491 Tokenizer::from_hf_json(&tokenizer_path)?
492 } else {
493 return Err(EngineError::Other(format!(
494 "No tokenizer found. ONNX models require a tokenizer.json file \
495 in the same directory as the model, or specify --tokenizer <path>. \
496 Looked for: {}",
497 tokenizer_path.display()
498 )));
499 }
500 };
501 tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
502
503 let backend: Arc<dyn Backend> = if config.use_gpu {
505 Self::select_gpu_backend(&concrete_model)
506 } else {
507 Arc::new(crate::backend::cpu::CpuBackend::new())
508 };
509
510 let model: Box<dyn Model> = Box::new(concrete_model);
511
512 let chat_template = ChatTemplate::detect_from_model_type(hf_config.model_type.as_deref());
514 tracing::info!("Chat template: {:?}", chat_template);
515
516 let add_bos = true;
518
519 let sampler_config = SamplerConfig {
520 temperature: config.temperature,
521 top_k: config.top_k,
522 top_p: config.top_p,
523 repeat_penalty: config.repeat_penalty,
524 seed: config.seed,
525 ..Default::default()
526 };
527
528 tracing::info!("Engine ready (ONNX)");
529
530 Ok(Self {
531 gguf: None,
532 model,
533 tokenizer,
534 config: model_config,
535 backend,
536 sampler_config,
537 chat_template,
538 add_bos,
539 engine_config: config,
540 })
541 }
542
543 #[allow(unused_variables)]
549 fn select_gpu_model(
550 model: crate::model::LlamaModel,
551 config: &ModelConfig,
552 engine_config: &EngineConfig,
553 ) -> (Arc<dyn Backend>, Box<dyn Model>) {
554 let gpu_seq_len = match engine_config.max_context_len {
555 Some(cap) if cap > 0 && cap < config.max_seq_len => {
556 tracing::info!(
557 "Capping GPU context length from {} to {} (max_context_len)",
558 config.max_seq_len,
559 cap
560 );
561 cap
562 }
563 _ => config.max_seq_len,
564 };
565
566 #[cfg(feature = "cuda")]
570 {
571 if cudarc::driver::CudaDevice::new(0).is_ok() {
572 let architecture = model.architecture();
573 match crate::backend::cuda::gpu_only::GpuOnlyInference::from_model(
574 model,
575 gpu_seq_len,
576 ) {
577 Ok(gpu) => {
578 tracing::info!(
579 "Using full GPU inference (attention + DeltaNet + MoE all on CUDA)"
580 );
581 let wrapper = crate::backend::GpuModelWrapper::new(
582 gpu,
583 config.clone(),
584 architecture,
585 );
586 return (
587 Arc::new(crate::backend::cpu::CpuBackend::new()),
588 Box::new(wrapper),
589 );
590 }
591 Err(e) => {
592 eprintln!("Error: CUDA GPU inference init failed: {}", e);
593 eprintln!("The model was consumed during init. Please restart without --gpu.");
594 std::process::exit(1);
595 }
596 }
597 } else {
598 tracing::info!("No CUDA device available, trying other GPU backends...");
599 }
600 }
601
602 #[cfg(feature = "vulkan")]
603 {
604 if crate::backend::vulkan::VulkanBackend::new().is_ok() {
605 let architecture = model.architecture();
606 match crate::backend::vulkan::gpu_only::VulkanGpuInference::from_model(
607 model,
608 gpu_seq_len,
609 ) {
610 Ok(gpu) => {
611 tracing::info!("Using full GPU inference on Vulkan");
612 let wrapper = crate::backend::GpuModelWrapper::new(
613 gpu,
614 config.clone(),
615 architecture,
616 );
617 return (
618 Arc::new(crate::backend::cpu::CpuBackend::new()),
619 Box::new(wrapper),
620 );
621 }
622 Err(e) => {
623 eprintln!("Error: Vulkan GPU inference init failed: {}", e);
624 eprintln!("The model was consumed during init. Please restart without --gpu.");
625 std::process::exit(1);
626 }
627 }
628 } else {
629 tracing::info!("No Vulkan device available, trying other GPU backends...");
630 }
631 }
632
633 #[cfg(all(feature = "metal", target_os = "macos"))]
634 {
635 if crate::backend::metal::MetalBackend::new().is_ok() {
636 let architecture = model.architecture();
637 match crate::backend::metal::gpu_only::MetalGpuInference::from_model(
638 model,
639 gpu_seq_len,
640 ) {
641 Ok(gpu) => {
642 tracing::info!("Using full GPU inference on Metal");
643 let wrapper = crate::backend::GpuModelWrapper::new(
644 gpu,
645 config.clone(),
646 architecture,
647 );
648 return (
649 Arc::new(crate::backend::cpu::CpuBackend::new()),
650 Box::new(wrapper),
651 );
652 }
653 Err(e) => {
654 eprintln!("Error: Metal GPU inference init failed: {}", e);
655 eprintln!("The model was consumed during init. Please restart without --gpu.");
656 std::process::exit(1);
657 }
658 }
659 } else {
660 tracing::info!("No Metal device available, trying other GPU backends...");
661 }
662 }
663
664 #[cfg(all(feature = "dx12", target_os = "windows"))]
665 {
666 if crate::backend::dx12::Dx12Backend::new().is_ok() {
667 let architecture = model.architecture();
668 match crate::backend::dx12::gpu_only::Dx12GpuInference::from_model(
669 model,
670 gpu_seq_len,
671 ) {
672 Ok(gpu) => {
673 tracing::info!("Using full GPU inference on DX12");
674 let wrapper = crate::backend::GpuModelWrapper::new(
675 gpu,
676 config.clone(),
677 architecture,
678 );
679 return (
680 Arc::new(crate::backend::cpu::CpuBackend::new()),
681 Box::new(wrapper),
682 );
683 }
684 Err(e) => {
685 eprintln!("Error: DX12 GPU inference init failed: {}", e);
686 eprintln!("The model was consumed during init. Please restart without --gpu.");
687 std::process::exit(1);
688 }
689 }
690 } else {
691 tracing::info!("No DX12 device available");
692 }
693 }
694
695 #[cfg(feature = "hailo")]
696 {
697 if let Some(ref hailo_config) = engine_config.hailo_config {
698 if crate::backend::hailo::context::check_device_available().is_ok() {
699 let architecture = model.architecture();
700 match crate::backend::hailo::gpu_only::HailoGpuInference::from_model(
701 model,
702 gpu_seq_len,
703 hailo_config.clone(),
704 ) {
705 Ok(gpu) => {
706 tracing::info!("Using hybrid CPU+Hailo inference");
707 let wrapper = crate::backend::GpuModelWrapper::new(
708 gpu,
709 config.clone(),
710 architecture,
711 );
712 return (
713 Arc::new(crate::backend::cpu::CpuBackend::new()),
714 Box::new(wrapper),
715 );
716 }
717 Err(e) => {
718 eprintln!("Error: Hailo inference init failed: {}", e);
719 eprintln!("The model was consumed during init. Please restart without --hailo.");
720 std::process::exit(1);
721 }
722 }
723 } else {
724 tracing::info!("No Hailo device available, falling back to CPU...");
725 }
726 }
727 }
728
729 let backend = Self::select_gpu_backend(&model);
731 (backend, Box::new(model))
732 }
733
734 #[allow(unused_variables)]
738 pub fn select_gpu_backend(model: &crate::model::LlamaModel) -> Arc<dyn Backend> {
739 #[cfg(feature = "cuda")]
741 {
742 match crate::backend::cuda::CudaBackend::new() {
743 Ok(mut cuda) => {
744 tracing::info!("Using CUDA backend: {}", cuda.device_name());
745 if let Err(e) = cuda.load_model_weights(model) {
746 tracing::warn!("Failed to load GPU weights ({}), using quantized ops", e);
747 }
748 return Arc::new(cuda);
749 }
750 Err(e) => {
751 tracing::info!("CUDA not available ({}), trying Metal...", e);
752 }
753 }
754 }
755
756 #[cfg(all(feature = "metal", target_os = "macos"))]
758 {
759 match crate::backend::metal::MetalBackend::new() {
760 Ok(metal) => {
761 tracing::info!("Using Metal backend: {}", metal.device_name());
762 return Arc::new(metal);
763 }
764 Err(e) => {
765 tracing::info!("Metal not available ({}), trying DX12...", e);
766 }
767 }
768 }
769
770 #[cfg(all(feature = "dx12", target_os = "windows"))]
772 {
773 match crate::backend::dx12::Dx12Backend::new() {
774 Ok(dx12) => {
775 tracing::info!("Using DX12 backend: {}", dx12.device_name());
776 return Arc::new(dx12);
777 }
778 Err(e) => {
779 tracing::info!("DX12 not available ({}), trying Vulkan...", e);
780 }
781 }
782 }
783
784 #[cfg(feature = "vulkan")]
786 {
787 match crate::backend::vulkan::VulkanBackend::new() {
788 Ok(vk) => {
789 tracing::info!("Using Vulkan backend: {}", vk.device_name());
790 return Arc::new(vk);
791 }
792 Err(e) => {
793 tracing::warn!("Vulkan not available ({}), falling back to CPU", e);
794 }
795 }
796 }
797
798 #[cfg(not(any(
800 feature = "cuda",
801 feature = "vulkan",
802 all(feature = "metal", target_os = "macos"),
803 all(feature = "dx12", target_os = "windows")
804 )))]
805 {
806 tracing::warn!(
807 "No GPU backend compiled. Build with --features cuda, --features metal, --features dx12, or --features vulkan"
808 );
809 }
810
811 Arc::new(crate::backend::cpu::CpuBackend::new())
812 }
813
814 pub fn model_config(&self) -> &ModelConfig {
816 &self.config
817 }
818
819 pub fn chat_template(&self) -> &ChatTemplate {
821 &self.chat_template
822 }
823
824 pub fn gguf(&self) -> Option<&GgufFile> {
826 self.gguf.as_ref()
827 }
828
829 pub fn tokenizer(&self) -> &Tokenizer {
831 &self.tokenizer
832 }
833
834 pub fn engine_config(&self) -> &EngineConfig {
836 &self.engine_config
837 }
838
839 pub fn model(&self) -> &dyn Model {
841 &*self.model
842 }
843
844 pub fn backend(&self) -> &Arc<dyn Backend> {
846 &self.backend
847 }
848
849 pub fn add_bos(&self) -> bool {
851 self.add_bos
852 }
853
854 pub fn create_inference_context(&self) -> InferenceContext {
861 if self.engine_config.kv_cache_type.is_turboquant() {
862 InferenceContext::new_with_cache_type(
863 &self.config,
864 self.backend.clone(),
865 self.engine_config.kv_cache_type,
866 )
867 } else {
868 self.model.create_context(self.backend.clone())
869 }
870 }
871
872 pub fn generate(&self, prompt: &str, max_tokens: usize) -> Result<String, EngineError> {
874 let mut ctx = self.create_inference_context();
875 let mut sampler = Sampler::new(self.sampler_config.clone(), self.config.vocab_size);
876
877 let formatted = self.chat_template.wrap_prompt(prompt);
879 let mut tokens = self.tokenizer.encode(&formatted, self.add_bos)?;
880
881 let mut output = String::new();
882
883 for _ in 0..max_tokens {
884 if let Some(&last) = tokens.last()
886 && last == self.tokenizer.special_tokens.eos_token_id
887 {
888 break;
889 }
890
891 let input_tokens = if ctx.position == 0 {
893 &tokens[..]
894 } else {
895 &tokens[tokens.len() - 1..]
896 };
897
898 let logits = self.model.forward(input_tokens, &mut ctx)?;
899 let next_token = sampler.sample(&logits, &tokens);
900
901 if next_token == self.tokenizer.special_tokens.eos_token_id {
903 break;
904 }
905
906 if let Ok(text) = self.tokenizer.decode(&[next_token]) {
908 let combined = format!("{}{}", output, text);
910 let stop = self
911 .chat_template
912 .stop_patterns()
913 .iter()
914 .any(|p| combined.contains(p));
915
916 if stop {
917 for pattern in self.chat_template.stop_patterns() {
919 if let Some(idx) = combined.find(pattern) {
920 output = combined[..idx].to_string();
921 return Ok(output.trim().to_string());
922 }
923 }
924 break;
925 }
926
927 output.push_str(&text);
928 }
929
930 tokens.push(next_token);
931 }
932
933 Ok(output.trim().to_string())
934 }
935
936 pub fn generate_streaming(&self, prompt: &str, max_tokens: usize) -> GenerationStream<'_> {
941 GenerationStream::new(self, prompt, max_tokens)
942 }
943
944 pub fn embed(&self, text: &str) -> Result<Vec<f32>, EngineError> {
946 let mut ctx = self.create_inference_context();
947 let embed_config = EmbeddingConfig::default();
948 let extractor = EmbeddingExtractor::new(embed_config, &self.config);
949 let embedding =
950 extractor.embed_text(self.model.as_ref(), &self.tokenizer, &mut ctx, text)?;
951 Ok(embedding)
952 }
953}
954
955pub struct GenerationStream<'a> {
963 engine: &'a Engine,
964 ctx: InferenceContext,
965 sampler: Sampler,
966 tokens: Vec<u32>,
967 remaining: usize,
968 done: bool,
969 accumulated: String,
970 pending_bytes: Vec<u8>,
972}
973
974impl<'a> GenerationStream<'a> {
975 fn new(engine: &'a Engine, prompt: &str, max_tokens: usize) -> Self {
976 let ctx = engine.create_inference_context();
977 let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
978
979 let formatted = engine.chat_template.wrap_prompt(prompt);
980 if std::env::var("LLAMA_DEBUG").is_ok() {
981 eprintln!("[DEBUG] formatted prompt: {:?}", formatted);
982 eprintln!("[DEBUG] add_bos: {}", engine.add_bos);
983 }
984 let tokens = engine
985 .tokenizer
986 .encode(&formatted, engine.add_bos)
987 .unwrap_or_default();
988 if std::env::var("LLAMA_DEBUG").is_ok() {
989 eprintln!("[DEBUG] encoded {} tokens: {:?}", tokens.len(), &tokens[..tokens.len().min(50)]);
990 for (i, &tid) in tokens.iter().enumerate() {
991 if let Some(s) = engine.tokenizer.get_token(tid) {
992 eprintln!("[DEBUG] token[{}] = {} -> {:?}", i, tid, s);
993 }
994 }
995 }
996
997 Self {
998 engine,
999 ctx,
1000 sampler,
1001 tokens,
1002 remaining: max_tokens,
1003 done: false,
1004 accumulated: String::new(),
1005 pending_bytes: Vec::new(),
1006 }
1007 }
1008}
1009
1010impl<'a> Iterator for GenerationStream<'a> {
1011 type Item = Result<String, EngineError>;
1012
1013 fn next(&mut self) -> Option<Self::Item> {
1014 if self.done || self.remaining == 0 {
1015 return None;
1016 }
1017
1018 if let Some(&last) = self.tokens.last()
1020 && last == self.engine.tokenizer.special_tokens.eos_token_id
1021 {
1022 self.done = true;
1023 return None;
1024 }
1025
1026 let input_tokens = if self.ctx.position == 0 {
1028 &self.tokens[..]
1029 } else {
1030 &self.tokens[self.tokens.len() - 1..]
1031 };
1032
1033 let logits = match self.engine.model.forward(input_tokens, &mut self.ctx) {
1034 Ok(l) => l,
1035 Err(e) => {
1036 self.done = true;
1037 return Some(Err(EngineError::Model(e)));
1038 }
1039 };
1040
1041 let next_token = self.sampler.sample(&logits, &self.tokens);
1042
1043 if std::env::var("LLAMA_DEBUG_LOGITS").is_ok() {
1044 let logit_data = logits.as_f32().unwrap();
1045 let mut indexed: Vec<(usize, f32)> = logit_data.iter().copied().enumerate().collect();
1046 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1047 let step = self.tokens.len();
1048 eprint!("[LOGIT] step={} top5:", step);
1049 for (id, score) in indexed.iter().take(5) {
1050 let tok_str = self.engine.tokenizer.get_token(*id as u32).unwrap_or_default();
1051 eprint!(" {}({:.2})={:?}", id, score, tok_str);
1052 }
1053 let chosen_str = self.engine.tokenizer.get_token(next_token).unwrap_or_default();
1054 eprintln!(" → chosen={}({:?})", next_token, chosen_str);
1055 }
1056
1057 if next_token == self.engine.tokenizer.special_tokens.eos_token_id {
1059 self.done = true;
1060 return None;
1061 }
1062
1063 match self
1065 .engine
1066 .tokenizer
1067 .decode_token_streaming(next_token, &mut self.pending_bytes)
1068 {
1069 Ok(text) => {
1070 self.tokens.push(next_token);
1071 self.remaining -= 1;
1072
1073 if text.is_empty() {
1074 return self.next();
1076 }
1077
1078 let combined = format!("{}{}", self.accumulated, text);
1080 for pattern in self.engine.chat_template.stop_patterns() {
1081 if combined.contains(pattern) {
1082 self.done = true;
1083 if let Some(idx) = combined.find(pattern) {
1084 if idx > self.accumulated.len() {
1085 let before = &combined[self.accumulated.len()..idx];
1086 return Some(Ok(before.to_string()));
1087 }
1088 }
1089 return None;
1090 }
1091 }
1092
1093 self.accumulated.push_str(&text);
1094 Some(Ok(text))
1095 }
1096 Err(e) => {
1097 self.tokens.push(next_token);
1098 self.remaining -= 1;
1099 Some(Err(EngineError::Tokenizer(e)))
1100 }
1101 }
1102 }
1103}
1104
1105pub struct ChatEngine {
1114 engine: Engine,
1115 system_prompt: String,
1116 conversation_tokens: Vec<u32>,
1117 ctx: InferenceContext,
1118 sampler: Sampler,
1119 is_first_turn: bool,
1120}
1121
1122impl ChatEngine {
1123 pub fn new(engine: Engine, system_prompt: Option<String>) -> Self {
1125 let ctx = engine.create_inference_context();
1126 let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
1127
1128 Self {
1129 system_prompt: system_prompt
1130 .unwrap_or_else(|| "You are a helpful AI assistant.".to_string()),
1131 conversation_tokens: Vec::new(),
1132 ctx,
1133 sampler,
1134 is_first_turn: true,
1135 engine,
1136 }
1137 }
1138
1139 pub fn engine(&self) -> &Engine {
1141 &self.engine
1142 }
1143
1144 pub fn system_prompt(&self) -> &str {
1146 &self.system_prompt
1147 }
1148
1149 pub fn context_len(&self) -> usize {
1151 self.conversation_tokens.len()
1152 }
1153
1154 pub fn chat(&mut self, message: &str) -> Result<String, EngineError> {
1156 let max_tokens = self.engine.engine_config.max_tokens;
1157
1158 let formatted = if self.is_first_turn {
1160 self.engine
1161 .chat_template
1162 .format_first_turn(&self.system_prompt, message)
1163 } else {
1164 self.engine.chat_template.format_continuation(message)
1165 };
1166
1167 let new_tokens = self
1169 .engine
1170 .tokenizer
1171 .encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
1172
1173 self.ensure_context_space(new_tokens.len(), max_tokens);
1175
1176 self.conversation_tokens.extend(&new_tokens);
1178
1179 let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
1183 let mut response_text = String::new();
1184
1185 if new_tokens.is_empty() {
1186 self.is_first_turn = false;
1187 return Ok(response_text);
1188 }
1189
1190 let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
1191 let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
1192
1193 if first_token == eos_id {
1194 self.is_first_turn = false;
1195 return Ok(response_text);
1196 }
1197
1198 if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
1199 response_text.push_str(&text);
1200 }
1201 self.conversation_tokens.push(first_token);
1202
1203 for _ in 1..max_tokens {
1205 let should_stop = self
1207 .engine
1208 .chat_template
1209 .stop_patterns()
1210 .iter()
1211 .any(|p| response_text.contains(p));
1212 if should_stop {
1213 for pattern in self.engine.chat_template.stop_patterns() {
1214 if let Some(idx) = response_text.find(pattern) {
1215 response_text.truncate(idx);
1216 break;
1217 }
1218 }
1219 break;
1220 }
1221
1222 let last_token = *self
1223 .conversation_tokens
1224 .last()
1225 .unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
1226
1227 let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
1228 let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
1229
1230 if next_token == eos_id {
1231 break;
1232 }
1233
1234 if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
1235 response_text.push_str(&text);
1236 }
1237
1238 self.conversation_tokens.push(next_token);
1239 }
1240
1241 self.is_first_turn = false;
1242 Ok(response_text.trim().to_string())
1243 }
1244
1245 pub fn chat_with_prefix(
1253 &mut self,
1254 message: &str,
1255 prefix: &str,
1256 ) -> Result<String, EngineError> {
1257 let max_tokens = self.engine.engine_config.max_tokens;
1258
1259 let formatted = if self.is_first_turn {
1260 self.engine
1261 .chat_template
1262 .format_first_turn(&self.system_prompt, message)
1263 } else {
1264 self.engine.chat_template.format_continuation(message)
1265 };
1266
1267 let formatted_with_prefix = format!("{}{}", formatted, prefix);
1269
1270 let new_tokens = self
1271 .engine
1272 .tokenizer
1273 .encode(&formatted_with_prefix, self.is_first_turn && self.engine.add_bos)?;
1274
1275 self.ensure_context_space(new_tokens.len(), max_tokens);
1276 self.conversation_tokens.extend(&new_tokens);
1277
1278 let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
1279 let mut response_text = prefix.to_string();
1280
1281 if new_tokens.is_empty() {
1282 self.is_first_turn = false;
1283 return Ok(response_text);
1284 }
1285
1286 let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
1287 let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
1288
1289 if first_token == eos_id {
1290 self.is_first_turn = false;
1291 return Ok(response_text);
1292 }
1293
1294 if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
1295 response_text.push_str(&text);
1296 }
1297 self.conversation_tokens.push(first_token);
1298
1299 for _ in 1..max_tokens {
1300 let should_stop = self
1301 .engine
1302 .chat_template
1303 .stop_patterns()
1304 .iter()
1305 .any(|p| response_text.contains(p));
1306 if should_stop {
1307 for pattern in self.engine.chat_template.stop_patterns() {
1308 if let Some(idx) = response_text.find(pattern) {
1309 response_text.truncate(idx);
1310 break;
1311 }
1312 }
1313 break;
1314 }
1315
1316 let last_token = *self
1317 .conversation_tokens
1318 .last()
1319 .unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
1320
1321 let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
1322 let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
1323
1324 if next_token == eos_id {
1325 break;
1326 }
1327
1328 if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
1329 response_text.push_str(&text);
1330 }
1331
1332 self.conversation_tokens.push(next_token);
1333 }
1334
1335 self.is_first_turn = false;
1336 Ok(response_text.trim().to_string())
1337 }
1338
1339 pub fn chat_streaming(&mut self, message: &str) -> Result<ChatStream<'_>, EngineError> {
1344 let max_tokens = self.engine.engine_config.max_tokens;
1345
1346 let formatted = if self.is_first_turn {
1348 self.engine
1349 .chat_template
1350 .format_first_turn(&self.system_prompt, message)
1351 } else {
1352 self.engine.chat_template.format_continuation(message)
1353 };
1354
1355 let new_tokens = self
1357 .engine
1358 .tokenizer
1359 .encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
1360
1361 self.ensure_context_space(new_tokens.len(), max_tokens);
1363
1364 self.conversation_tokens.extend(&new_tokens);
1366
1367 let prefill_logits = if !new_tokens.is_empty() {
1369 Some(self.engine.model.forward(&new_tokens, &mut self.ctx)?)
1370 } else {
1371 None
1372 };
1373
1374 self.is_first_turn = false;
1375
1376 Ok(ChatStream {
1377 chat_engine: self,
1378 remaining: max_tokens,
1379 done: false,
1380 accumulated: String::new(),
1381 prefill_logits,
1382 })
1383 }
1384
1385 pub fn clear_history(&mut self) {
1387 self.conversation_tokens.clear();
1388 self.ctx.reset();
1389 self.sampler.reset();
1390 self.is_first_turn = true;
1391 }
1392
1393 fn ensure_context_space(&mut self, new_token_count: usize, max_gen_tokens: usize) {
1395 let total_len = self.conversation_tokens.len() + new_token_count + max_gen_tokens;
1396
1397 if total_len > self.engine.config.max_seq_len {
1398 let excess = total_len - self.engine.config.max_seq_len + 100;
1399
1400 if excess >= self.conversation_tokens.len() {
1401 tracing::warn!("Context full, resetting conversation");
1402 self.conversation_tokens.clear();
1403 self.ctx.reset();
1404 } else {
1405 tracing::info!("Trimming {} tokens from context", excess);
1406 self.conversation_tokens = self.conversation_tokens[excess..].to_vec();
1407 self.ctx.kv_cache.shift_left(excess);
1408 self.ctx.position = self.ctx.position.saturating_sub(excess);
1409 }
1410 }
1411 }
1412}
1413
1414pub struct ChatStream<'a> {
1422 chat_engine: &'a mut ChatEngine,
1423 remaining: usize,
1424 done: bool,
1425 accumulated: String,
1426 prefill_logits: Option<crate::tensor::Tensor>,
1428}
1429
1430impl<'a> Iterator for ChatStream<'a> {
1431 type Item = Result<String, EngineError>;
1432
1433 fn next(&mut self) -> Option<Self::Item> {
1434 if self.done || self.remaining == 0 {
1435 return None;
1436 }
1437
1438 for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
1440 if self.accumulated.contains(pattern) {
1441 self.done = true;
1442 return None;
1443 }
1444 }
1445
1446 let logits = if let Some(prefill) = self.prefill_logits.take() {
1449 prefill
1450 } else {
1451 let last_token = *self.chat_engine.conversation_tokens.last().unwrap_or(
1452 &self
1453 .chat_engine
1454 .engine
1455 .tokenizer
1456 .special_tokens
1457 .bos_token_id,
1458 );
1459
1460 match self
1461 .chat_engine
1462 .engine
1463 .model
1464 .forward(&[last_token], &mut self.chat_engine.ctx)
1465 {
1466 Ok(l) => l,
1467 Err(e) => {
1468 self.done = true;
1469 return Some(Err(EngineError::Model(e)));
1470 }
1471 }
1472 };
1473
1474 let next_token = self
1475 .chat_engine
1476 .sampler
1477 .sample(&logits, &self.chat_engine.conversation_tokens);
1478
1479 if next_token
1481 == self
1482 .chat_engine
1483 .engine
1484 .tokenizer
1485 .special_tokens
1486 .eos_token_id
1487 {
1488 self.done = true;
1489 return None;
1490 }
1491
1492 match self.chat_engine.engine.tokenizer.decode(&[next_token]) {
1493 Ok(text) => {
1494 let combined = format!("{}{}", self.accumulated, text);
1496 for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
1497 if combined.contains(pattern) {
1498 self.done = true;
1499 if let Some(idx) = combined.find(pattern) {
1500 let before = &combined[self.accumulated.len()..idx];
1501 self.chat_engine.conversation_tokens.push(next_token);
1502 if !before.is_empty() {
1503 return Some(Ok(before.to_string()));
1504 }
1505 }
1506 return None;
1507 }
1508 }
1509
1510 self.accumulated.push_str(&text);
1511 self.chat_engine.conversation_tokens.push(next_token);
1512 self.remaining -= 1;
1513 Some(Ok(text))
1514 }
1515 Err(e) => {
1516 self.chat_engine.conversation_tokens.push(next_token);
1517 self.remaining -= 1;
1518 Some(Err(EngineError::Tokenizer(e)))
1519 }
1520 }
1521 }
1522}