1use std::sync::Arc;
21
22use crate::backend::Backend;
23use crate::gguf::GgufFile;
24use crate::model::{
25 EmbeddingConfig, EmbeddingExtractor, InferenceContext, Model, ModelConfig, ModelLoader,
26 ModelSource, build_llama_model,
27};
28use crate::safetensors::SafeTensorsLoader;
29use crate::sampling::{Sampler, SamplerConfig};
30use crate::tokenizer::Tokenizer;
31
32#[derive(thiserror::Error, Debug)]
38pub enum EngineError {
39 #[error("IO error: {0}")]
40 Io(#[from] std::io::Error),
41
42 #[error("GGUF error: {0}")]
43 Gguf(#[from] crate::gguf::GgufError),
44
45 #[error("Model error: {0}")]
46 Model(#[from] crate::model::ModelError),
47
48 #[error("Tokenizer error: {0}")]
49 Tokenizer(#[from] crate::tokenizer::TokenizerError),
50
51 #[error("Embedding error: {0}")]
52 Embedding(#[from] crate::model::EmbeddingError),
53
54 #[error("Engine error: {0}")]
55 Other(String),
56}
57
58#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
67#[serde(default)]
68pub struct EngineConfig {
69 pub model_path: String,
71
72 pub tokenizer_path: Option<String>,
77
78 pub temperature: f32,
80
81 pub top_k: usize,
83
84 pub top_p: f32,
86
87 pub repeat_penalty: f32,
89
90 pub max_tokens: usize,
92
93 pub seed: Option<u64>,
95
96 pub use_gpu: bool,
98
99 pub max_context_len: Option<usize>,
107
108 #[cfg(feature = "hailo")]
111 pub hailo_config: Option<crate::backend::hailo::HailoConfig>,
112
113 pub kv_cache_type: crate::model::KVCacheType,
115}
116
117impl Default for EngineConfig {
118 fn default() -> Self {
119 Self {
120 model_path: String::new(),
121 tokenizer_path: None,
122 temperature: 0.7,
123 top_k: 40,
124 top_p: 0.95,
125 repeat_penalty: 1.1,
126 max_tokens: 512,
127 seed: None,
128 use_gpu: false,
129 max_context_len: None,
130 #[cfg(feature = "hailo")]
131 hailo_config: None,
132 kv_cache_type: crate::model::KVCacheType::F32,
133 }
134 }
135}
136
137impl EngineConfig {
138 pub fn from_config_file(
143 path: impl AsRef<std::path::Path>,
144 ) -> Result<Self, crate::config::ConfigError> {
145 let config = crate::config::Config::from_file(path)?;
146 Ok(config.to_engine_config(None))
147 }
148
149 pub fn from_config(
154 config_path: Option<impl AsRef<std::path::Path>>,
155 ) -> Result<Self, crate::config::ConfigError> {
156 let config = crate::config::Config::load(config_path)?;
157 Ok(config.to_engine_config(None))
158 }
159}
160
161#[derive(Debug, Clone, PartialEq)]
167pub enum ChatTemplate {
168 UserAssistant,
170 ChatML,
172 Llama2,
174 None,
176}
177
178impl ChatTemplate {
179 pub fn from_tokenizer_config(path: &std::path::Path) -> Option<Self> {
184 let data = std::fs::read_to_string(path).ok()?;
185 let json: serde_json::Value = serde_json::from_str(&data).ok()?;
186 let template = json.get("chat_template")?.as_str()?;
187
188 if template.contains("<|user|>") {
189 Some(ChatTemplate::UserAssistant)
190 } else if template.contains("<|im_start|>") {
191 Some(ChatTemplate::ChatML)
192 } else if template.contains("[INST]") {
193 Some(ChatTemplate::Llama2)
194 } else {
195 Some(ChatTemplate::None)
196 }
197 }
198
199 pub fn detect_from_model_type(model_type: Option<&str>) -> Self {
201 match model_type {
202 Some("qwen2" | "qwen") => ChatTemplate::ChatML,
203 Some("llama" | "codellama") => ChatTemplate::Llama2,
204 Some("mistral" | "mixtral") => ChatTemplate::Llama2,
205 _ => ChatTemplate::None,
206 }
207 }
208
209 pub fn detect(gguf: &GgufFile) -> Self {
211 if let Some(template) = gguf.data.get_string("tokenizer.chat_template") {
212 if template.contains("<|user|>") {
213 ChatTemplate::UserAssistant
214 } else if template.contains("<|im_start|>") {
215 ChatTemplate::ChatML
216 } else if template.contains("[INST]") {
217 ChatTemplate::Llama2
218 } else {
219 ChatTemplate::None
220 }
221 } else if let Some(arch) = gguf.data.get_string("general.architecture") {
222 match arch.to_lowercase().as_str() {
223 "qwen2" | "qwen" | "qwen3" | "qwen35" | "qwen3moe" | "qwen3next" => {
224 ChatTemplate::ChatML
225 }
226 _ => ChatTemplate::None,
227 }
228 } else {
229 ChatTemplate::None
230 }
231 }
232
233 pub fn wrap_prompt(&self, prompt: &str) -> String {
235 if prompt.contains("<|user|>")
237 || prompt.contains("<|im_start|>")
238 || prompt.contains("[INST]")
239 {
240 return prompt.to_string();
241 }
242
243 match self {
244 ChatTemplate::UserAssistant => {
245 format!("<|user|>\n{}<|assistant|>\n", prompt)
246 }
247 ChatTemplate::ChatML => {
248 format!(
249 "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
250 prompt
251 )
252 }
253 ChatTemplate::Llama2 => {
254 format!("[INST] {} [/INST]", prompt)
255 }
256 ChatTemplate::None => prompt.to_string(),
257 }
258 }
259
260 pub fn format_first_turn(&self, system_prompt: &str, user_message: &str) -> String {
262 match self {
263 ChatTemplate::UserAssistant => {
264 format!(
265 "<|system|>\n{}<|user|>\n{}<|assistant|>\n",
266 system_prompt, user_message
267 )
268 }
269 ChatTemplate::ChatML => {
270 format!(
271 "<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
272 system_prompt, user_message
273 )
274 }
275 ChatTemplate::Llama2 => {
276 format!(
277 "[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]",
278 system_prompt, user_message
279 )
280 }
281 ChatTemplate::None => {
282 format!(
283 "System: {}\n\nUser: {}\n\nAssistant:",
284 system_prompt, user_message
285 )
286 }
287 }
288 }
289
290 pub fn format_continuation(&self, user_message: &str) -> String {
292 match self {
293 ChatTemplate::UserAssistant => {
294 format!("<|user|>\n{}<|assistant|>\n", user_message)
295 }
296 ChatTemplate::ChatML => {
297 format!(
298 "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
299 user_message
300 )
301 }
302 ChatTemplate::Llama2 => {
303 format!(" [INST] {} [/INST]", user_message)
304 }
305 ChatTemplate::None => {
306 format!("\n\nUser: {}\n\nAssistant:", user_message)
307 }
308 }
309 }
310
311 pub fn stop_patterns(&self) -> &[&str] {
314 match self {
315 ChatTemplate::UserAssistant => &["<|user|>", "<|end|>"],
316 ChatTemplate::ChatML => &["<|im_end|>", "<|im_start|>"],
317 ChatTemplate::Llama2 => &["[INST]", "</s>"],
318 ChatTemplate::None => &["User:", "\nUser:"],
319 }
320 }
321}
322
323pub struct Engine {
332 gguf: Option<GgufFile>,
333 model: Box<dyn Model>,
334 tokenizer: Tokenizer,
335 config: ModelConfig,
336 backend: Arc<dyn Backend>,
337 sampler_config: SamplerConfig,
338 chat_template: ChatTemplate,
339 add_bos: bool,
340 engine_config: EngineConfig,
341}
342
343impl Engine {
344 pub fn load(config: EngineConfig) -> Result<Self, EngineError> {
353 if config.model_path.is_empty() {
354 return Err(EngineError::Other("model_path is required".into()));
355 }
356
357 let path = std::path::Path::new(&config.model_path);
358
359 match path.extension().and_then(|e| e.to_str()) {
361 #[cfg(feature = "onnx")]
362 Some("onnx") => Self::load_onnx(config),
363 #[cfg(not(feature = "onnx"))]
364 Some("onnx") => Err(EngineError::Other(
365 "ONNX support requires the `onnx` feature. Build with: cargo build --features onnx"
366 .into(),
367 )),
368 Some("safetensors") => Self::load_safetensors(config),
369 _ if path.is_dir() && path.join("config.json").exists() => {
370 Self::load_safetensors(config)
371 }
372 _ => Self::load_gguf(config),
373 }
374 }
375
376 fn load_gguf(config: EngineConfig) -> Result<Self, EngineError> {
378 tracing::info!("Loading GGUF model from: {}", config.model_path);
379
380 let gguf = GgufFile::open(&config.model_path)?;
382
383 tracing::info!("Loading tokenizer...");
385 let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
386 if tok_path.ends_with(".json") {
388 Tokenizer::from_hf_json(tok_path)?
389 } else {
390 let tok_gguf = GgufFile::open(tok_path)?;
391 Tokenizer::from_gguf(&tok_gguf)?
392 }
393 } else {
394 Tokenizer::from_gguf(&gguf)?
395 };
396 tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
397
398 tracing::info!("Loading model weights...");
400 let loader = ModelLoader::load(&config.model_path)?;
401 let model_config = loader.config().clone();
402 tracing::info!(
403 "Model: {} layers, {} heads, {} hidden dim, {} ctx",
404 model_config.num_layers,
405 model_config.num_heads,
406 model_config.hidden_size,
407 model_config.max_seq_len,
408 );
409
410 let arch = loader.architecture();
411
412 let (backend, model): (Arc<dyn Backend>, Box<dyn Model>) = if arch.is_encoder_only() {
413 tracing::info!("Detected encoder-only architecture: {:?}", arch);
414 let bert_model = loader.build_bert_model()?;
415 (
416 Arc::new(crate::backend::cpu::CpuBackend::new()),
417 Box::new(bert_model),
418 )
419 } else {
420 let concrete_model = loader.build_model()?;
421
422 if config.use_gpu {
428 Self::select_gpu_model(concrete_model, &model_config, &config)
429 } else {
430 (
431 Arc::new(crate::backend::cpu::CpuBackend::new()),
432 Box::new(concrete_model),
433 )
434 }
435 };
436
437 let chat_template = ChatTemplate::detect(&gguf);
439 tracing::info!("Chat template: {:?}", chat_template);
440
441 let add_bos = gguf
445 .data
446 .get_bool("tokenizer.ggml.add_bos_token")
447 .unwrap_or(tokenizer.has_explicit_bos);
448
449 let sampler_config = SamplerConfig {
451 temperature: config.temperature,
452 top_k: config.top_k,
453 top_p: config.top_p,
454 repeat_penalty: config.repeat_penalty,
455 seed: config.seed,
456 ..Default::default()
457 };
458
459 tracing::info!("Engine ready");
460
461 Ok(Self {
462 gguf: Some(gguf),
463 model,
464 tokenizer,
465 config: model_config,
466 backend,
467 sampler_config,
468 chat_template,
469 add_bos,
470 engine_config: config,
471 })
472 }
473
474 fn load_safetensors(config: EngineConfig) -> Result<Self, EngineError> {
480 tracing::info!("Loading SafeTensors model from: {}", config.model_path);
481
482 let path = std::path::Path::new(&config.model_path);
483 let dir = if path.is_dir() {
484 path
485 } else {
486 path.parent().unwrap_or(std::path::Path::new("."))
487 };
488
489 tracing::info!("Loading tokenizer...");
491 let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
492 if tok_path.ends_with(".json") {
494 Tokenizer::from_hf_json(tok_path)?
495 } else {
496 let tok_gguf = GgufFile::open(tok_path)?;
497 Tokenizer::from_gguf(&tok_gguf)?
498 }
499 } else {
500 let tok_path = dir.join("tokenizer.json");
502 if tok_path.exists() {
503 tracing::info!("Using tokenizer.json from: {}", tok_path.display());
504 Tokenizer::from_hf_json(&tok_path)?
505 } else {
506 return Err(EngineError::Other(format!(
507 "No tokenizer.json found in {}. Use --tokenizer to specify one.",
508 dir.display()
509 )));
510 }
511 };
512 tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
513
514 tracing::info!("Loading model weights...");
516 let mut loader = SafeTensorsLoader::load(path)?;
517
518 if let Some(cap) = config.max_context_len {
520 if cap > 0 && cap < loader.config().max_seq_len {
521 tracing::info!(
522 "Capping context length from {} to {}",
523 loader.config().max_seq_len,
524 cap
525 );
526 loader.config_mut().max_seq_len = cap;
527 }
528 }
529
530 let model_config = loader.config().clone();
531 let architecture = loader.architecture();
532
533 tracing::info!(
534 "Model: {} layers, {} heads, {} hidden dim, {} ctx, arch={:?}",
535 model_config.num_layers,
536 model_config.num_heads,
537 model_config.hidden_size,
538 model_config.max_seq_len,
539 architecture,
540 );
541
542 let concrete_model = build_llama_model(&loader)?;
544
545 let (backend, model): (Arc<dyn Backend>, Box<dyn Model>) = if config.use_gpu {
546 Self::select_gpu_model(concrete_model, &model_config, &config)
547 } else {
548 (
549 Arc::new(crate::backend::cpu::CpuBackend::new()),
550 Box::new(concrete_model),
551 )
552 };
553
554 let chat_template = {
557 let tc_path = dir.join("tokenizer_config.json");
558 ChatTemplate::from_tokenizer_config(&tc_path)
559 }
560 .unwrap_or_else(|| {
561 let config_path = dir.join("config.json");
563 let model_type = std::fs::read_to_string(&config_path)
564 .ok()
565 .and_then(|s| {
566 let v: serde_json::Value = serde_json::from_str(&s).ok()?;
567 v.get("model_type")?.as_str().map(|s| s.to_string())
568 });
569 ChatTemplate::detect_from_model_type(model_type.as_deref())
570 });
571 tracing::info!("Chat template: {:?}", chat_template);
572
573 let add_bos = true;
575
576 let sampler_config = SamplerConfig {
577 temperature: config.temperature,
578 top_k: config.top_k,
579 top_p: config.top_p,
580 repeat_penalty: config.repeat_penalty,
581 seed: config.seed,
582 ..Default::default()
583 };
584
585 tracing::info!("Engine ready (SafeTensors)");
586
587 Ok(Self {
588 gguf: None,
589 model,
590 tokenizer,
591 config: model_config,
592 backend,
593 sampler_config,
594 chat_template,
595 add_bos,
596 engine_config: config,
597 })
598 }
599
600 #[cfg(feature = "onnx")]
602 fn load_onnx(config: EngineConfig) -> Result<Self, EngineError> {
603 use crate::onnx::OnnxModelLoader;
604
605 tracing::info!("Loading ONNX model from: {}", config.model_path);
606
607 let model_dir = std::path::Path::new(&config.model_path)
608 .parent()
609 .unwrap_or(std::path::Path::new("."));
610
611 let loader = OnnxModelLoader::load(&config.model_path)
613 .map_err(|e| EngineError::Other(format!("ONNX load error: {}", e)))?;
614 let model_config = loader.config().clone();
615 let hf_config = loader.hf_config().clone();
616
617 tracing::info!(
618 "Model: {} layers, {} heads, {} hidden dim, {} ctx",
619 model_config.num_layers,
620 model_config.num_heads,
621 model_config.hidden_size,
622 model_config.max_seq_len,
623 );
624
625 let concrete_model = loader
626 .build_model()
627 .map_err(|e| EngineError::Other(format!("ONNX model build error: {}", e)))?;
628
629 tracing::info!("Loading tokenizer...");
631 let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
632 if tok_path.ends_with(".json") {
633 Tokenizer::from_hf_json(tok_path)?
634 } else {
635 let tok_gguf = GgufFile::open(tok_path)?;
636 Tokenizer::from_gguf(&tok_gguf)?
637 }
638 } else {
639 let tokenizer_path = model_dir.join("tokenizer.json");
641 if tokenizer_path.exists() {
642 tracing::info!("Using tokenizer.json from: {}", tokenizer_path.display());
643 Tokenizer::from_hf_json(&tokenizer_path)?
644 } else {
645 return Err(EngineError::Other(format!(
646 "No tokenizer found. ONNX models require a tokenizer.json file \
647 in the same directory as the model, or specify --tokenizer <path>. \
648 Looked for: {}",
649 tokenizer_path.display()
650 )));
651 }
652 };
653 tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
654
655 let backend: Arc<dyn Backend> = if config.use_gpu {
657 Self::select_gpu_backend(&concrete_model)
658 } else {
659 Arc::new(crate::backend::cpu::CpuBackend::new())
660 };
661
662 let model: Box<dyn Model> = Box::new(concrete_model);
663
664 let chat_template = ChatTemplate::detect_from_model_type(hf_config.model_type.as_deref());
666 tracing::info!("Chat template: {:?}", chat_template);
667
668 let add_bos = true;
670
671 let sampler_config = SamplerConfig {
672 temperature: config.temperature,
673 top_k: config.top_k,
674 top_p: config.top_p,
675 repeat_penalty: config.repeat_penalty,
676 seed: config.seed,
677 ..Default::default()
678 };
679
680 tracing::info!("Engine ready (ONNX)");
681
682 Ok(Self {
683 gguf: None,
684 model,
685 tokenizer,
686 config: model_config,
687 backend,
688 sampler_config,
689 chat_template,
690 add_bos,
691 engine_config: config,
692 })
693 }
694
695 #[allow(unused_variables)]
701 fn select_gpu_model(
702 model: crate::model::LlamaModel,
703 config: &ModelConfig,
704 engine_config: &EngineConfig,
705 ) -> (Arc<dyn Backend>, Box<dyn Model>) {
706 let gpu_seq_len = match engine_config.max_context_len {
707 Some(cap) if cap > 0 && cap < config.max_seq_len => {
708 tracing::info!(
709 "Capping GPU context length from {} to {} (max_context_len)",
710 config.max_seq_len,
711 cap
712 );
713 cap
714 }
715 _ => config.max_seq_len,
716 };
717
718 #[cfg(feature = "cuda")]
722 {
723 if cudarc::driver::CudaContext::new(0).is_ok() {
724 let architecture = model.architecture();
725 match crate::backend::cuda::gpu_only::GpuOnlyInference::from_model(
726 model,
727 gpu_seq_len,
728 ) {
729 Ok(gpu) => {
730 tracing::info!(
731 "Using full GPU inference (attention + DeltaNet + MoE all on CUDA)"
732 );
733 let wrapper = crate::backend::GpuModelWrapper::new(
734 gpu,
735 config.clone(),
736 architecture,
737 );
738 return (
739 Arc::new(crate::backend::cpu::CpuBackend::new()),
740 Box::new(wrapper),
741 );
742 }
743 Err(e) => {
744 eprintln!("Error: CUDA GPU inference init failed: {}", e);
745 eprintln!("The model was consumed during init. Please restart without --gpu.");
746 std::process::exit(1);
747 }
748 }
749 } else {
750 tracing::info!("No CUDA device available, trying other GPU backends...");
751 }
752 }
753
754 #[cfg(feature = "vulkan")]
755 {
756 if crate::backend::vulkan::VulkanBackend::new().is_ok() {
757 let architecture = model.architecture();
758 match crate::backend::vulkan::gpu_only::VulkanGpuInference::from_model(
759 model,
760 gpu_seq_len,
761 ) {
762 Ok(gpu) => {
763 tracing::info!("Using full GPU inference on Vulkan");
764 let wrapper = crate::backend::GpuModelWrapper::new(
765 gpu,
766 config.clone(),
767 architecture,
768 );
769 return (
770 Arc::new(crate::backend::cpu::CpuBackend::new()),
771 Box::new(wrapper),
772 );
773 }
774 Err(e) => {
775 eprintln!("Error: Vulkan GPU inference init failed: {}", e);
776 eprintln!("The model was consumed during init. Please restart without --gpu.");
777 std::process::exit(1);
778 }
779 }
780 } else {
781 tracing::info!("No Vulkan device available, trying other GPU backends...");
782 }
783 }
784
785 #[cfg(all(feature = "metal", target_os = "macos"))]
786 {
787 if crate::backend::metal::MetalBackend::new().is_ok() {
788 let architecture = model.architecture();
789 match crate::backend::metal::gpu_only::MetalGpuInference::from_model(
790 model,
791 gpu_seq_len,
792 ) {
793 Ok(gpu) => {
794 tracing::info!("Using full GPU inference on Metal");
795 let wrapper = crate::backend::GpuModelWrapper::new(
796 gpu,
797 config.clone(),
798 architecture,
799 );
800 return (
801 Arc::new(crate::backend::cpu::CpuBackend::new()),
802 Box::new(wrapper),
803 );
804 }
805 Err(e) => {
806 eprintln!("Error: Metal GPU inference init failed: {}", e);
807 eprintln!("The model was consumed during init. Please restart without --gpu.");
808 std::process::exit(1);
809 }
810 }
811 } else {
812 tracing::info!("No Metal device available, trying other GPU backends...");
813 }
814 }
815
816 #[cfg(all(feature = "dx12", target_os = "windows"))]
817 {
818 if crate::backend::dx12::Dx12Backend::new().is_ok() {
819 let architecture = model.architecture();
820 match crate::backend::dx12::gpu_only::Dx12GpuInference::from_model(
821 model,
822 gpu_seq_len,
823 ) {
824 Ok(gpu) => {
825 tracing::info!("Using full GPU inference on DX12");
826 let wrapper = crate::backend::GpuModelWrapper::new(
827 gpu,
828 config.clone(),
829 architecture,
830 );
831 return (
832 Arc::new(crate::backend::cpu::CpuBackend::new()),
833 Box::new(wrapper),
834 );
835 }
836 Err(e) => {
837 eprintln!("Error: DX12 GPU inference init failed: {}", e);
838 eprintln!("The model was consumed during init. Please restart without --gpu.");
839 std::process::exit(1);
840 }
841 }
842 } else {
843 tracing::info!("No DX12 device available");
844 }
845 }
846
847 #[cfg(feature = "hailo")]
848 {
849 if let Some(ref hailo_config) = engine_config.hailo_config {
850 if crate::backend::hailo::context::check_device_available().is_ok() {
851 let architecture = model.architecture();
852 match crate::backend::hailo::gpu_only::HailoGpuInference::from_model(
853 model,
854 gpu_seq_len,
855 hailo_config.clone(),
856 ) {
857 Ok(gpu) => {
858 tracing::info!("Using hybrid CPU+Hailo inference");
859 let wrapper = crate::backend::GpuModelWrapper::new(
860 gpu,
861 config.clone(),
862 architecture,
863 );
864 return (
865 Arc::new(crate::backend::cpu::CpuBackend::new()),
866 Box::new(wrapper),
867 );
868 }
869 Err(e) => {
870 eprintln!("Error: Hailo inference init failed: {}", e);
871 eprintln!("The model was consumed during init. Please restart without --hailo.");
872 std::process::exit(1);
873 }
874 }
875 } else {
876 tracing::info!("No Hailo device available, falling back to CPU...");
877 }
878 }
879 }
880
881 let backend = Self::select_gpu_backend(&model);
883 (backend, Box::new(model))
884 }
885
886 #[allow(unused_variables)]
890 pub fn select_gpu_backend(model: &crate::model::LlamaModel) -> Arc<dyn Backend> {
891 #[cfg(feature = "cuda")]
893 {
894 match crate::backend::cuda::CudaBackend::new() {
895 Ok(mut cuda) => {
896 tracing::info!("Using CUDA backend: {}", cuda.device_name());
897 if let Err(e) = cuda.load_model_weights(model) {
898 tracing::warn!("Failed to load GPU weights ({}), using quantized ops", e);
899 }
900 return Arc::new(cuda);
901 }
902 Err(e) => {
903 tracing::info!("CUDA not available ({}), trying Metal...", e);
904 }
905 }
906 }
907
908 #[cfg(all(feature = "metal", target_os = "macos"))]
910 {
911 match crate::backend::metal::MetalBackend::new() {
912 Ok(metal) => {
913 tracing::info!("Using Metal backend: {}", metal.device_name());
914 return Arc::new(metal);
915 }
916 Err(e) => {
917 tracing::info!("Metal not available ({}), trying DX12...", e);
918 }
919 }
920 }
921
922 #[cfg(all(feature = "dx12", target_os = "windows"))]
924 {
925 match crate::backend::dx12::Dx12Backend::new() {
926 Ok(dx12) => {
927 tracing::info!("Using DX12 backend: {}", dx12.device_name());
928 return Arc::new(dx12);
929 }
930 Err(e) => {
931 tracing::info!("DX12 not available ({}), trying Vulkan...", e);
932 }
933 }
934 }
935
936 #[cfg(feature = "vulkan")]
938 {
939 match crate::backend::vulkan::VulkanBackend::new() {
940 Ok(vk) => {
941 tracing::info!("Using Vulkan backend: {}", vk.device_name());
942 return Arc::new(vk);
943 }
944 Err(e) => {
945 tracing::warn!("Vulkan not available ({}), falling back to CPU", e);
946 }
947 }
948 }
949
950 #[cfg(not(any(
952 feature = "cuda",
953 feature = "vulkan",
954 all(feature = "metal", target_os = "macos"),
955 all(feature = "dx12", target_os = "windows")
956 )))]
957 {
958 tracing::warn!(
959 "No GPU backend compiled. Build with --features cuda, --features metal, --features dx12, or --features vulkan"
960 );
961 }
962
963 Arc::new(crate::backend::cpu::CpuBackend::new())
964 }
965
966 pub fn model_config(&self) -> &ModelConfig {
968 &self.config
969 }
970
971 pub fn chat_template(&self) -> &ChatTemplate {
973 &self.chat_template
974 }
975
976 pub fn gguf(&self) -> Option<&GgufFile> {
978 self.gguf.as_ref()
979 }
980
981 pub fn tokenizer(&self) -> &Tokenizer {
983 &self.tokenizer
984 }
985
986 pub fn engine_config(&self) -> &EngineConfig {
988 &self.engine_config
989 }
990
991 pub fn model(&self) -> &dyn Model {
993 &*self.model
994 }
995
996 pub fn backend(&self) -> &Arc<dyn Backend> {
998 &self.backend
999 }
1000
1001 pub fn add_bos(&self) -> bool {
1003 self.add_bos
1004 }
1005
1006 pub fn create_inference_context(&self) -> InferenceContext {
1013 if self.engine_config.kv_cache_type.is_turboquant() {
1014 InferenceContext::new_with_cache_type(
1015 &self.config,
1016 self.backend.clone(),
1017 self.engine_config.kv_cache_type,
1018 )
1019 } else {
1020 self.model.create_context(self.backend.clone())
1021 }
1022 }
1023
1024 pub fn generate(&self, prompt: &str, max_tokens: usize) -> Result<String, EngineError> {
1026 let mut ctx = self.create_inference_context();
1027 let mut sampler = Sampler::new(self.sampler_config.clone(), self.config.vocab_size);
1028
1029 let formatted = self.chat_template.wrap_prompt(prompt);
1031 let mut tokens = self.tokenizer.encode(&formatted, self.add_bos)?;
1032
1033 let mut output = String::new();
1034
1035 for _ in 0..max_tokens {
1036 if let Some(&last) = tokens.last()
1038 && last == self.tokenizer.special_tokens.eos_token_id
1039 {
1040 break;
1041 }
1042
1043 let input_tokens = if ctx.position == 0 {
1045 &tokens[..]
1046 } else {
1047 &tokens[tokens.len() - 1..]
1048 };
1049
1050 let logits = self.model.forward(input_tokens, &mut ctx)?;
1051 let next_token = sampler.sample(&logits, &tokens);
1052
1053 if next_token == self.tokenizer.special_tokens.eos_token_id {
1055 break;
1056 }
1057
1058 if let Ok(text) = self.tokenizer.decode(&[next_token]) {
1060 let combined = format!("{}{}", output, text);
1062 let stop = self
1063 .chat_template
1064 .stop_patterns()
1065 .iter()
1066 .any(|p| combined.contains(p));
1067
1068 if stop {
1069 for pattern in self.chat_template.stop_patterns() {
1071 if let Some(idx) = combined.find(pattern) {
1072 output = combined[..idx].to_string();
1073 return Ok(output.trim().to_string());
1074 }
1075 }
1076 break;
1077 }
1078
1079 output.push_str(&text);
1080 }
1081
1082 tokens.push(next_token);
1083 }
1084
1085 Ok(output.trim().to_string())
1086 }
1087
1088 pub fn generate_streaming(&self, prompt: &str, max_tokens: usize) -> GenerationStream<'_> {
1093 GenerationStream::new(self, prompt, max_tokens)
1094 }
1095
1096 pub fn embed(&self, text: &str) -> Result<Vec<f32>, EngineError> {
1098 let mut ctx = self.create_inference_context();
1099 let embed_config = EmbeddingConfig::default();
1100 let extractor = EmbeddingExtractor::new(embed_config, &self.config);
1101 let embedding =
1102 extractor.embed_text(self.model.as_ref(), &self.tokenizer, &mut ctx, text)?;
1103 Ok(embedding)
1104 }
1105}
1106
1107pub struct GenerationStream<'a> {
1115 engine: &'a Engine,
1116 ctx: InferenceContext,
1117 sampler: Sampler,
1118 tokens: Vec<u32>,
1119 remaining: usize,
1120 done: bool,
1121 accumulated: String,
1122 pending_bytes: Vec<u8>,
1124}
1125
1126impl<'a> GenerationStream<'a> {
1127 fn new(engine: &'a Engine, prompt: &str, max_tokens: usize) -> Self {
1128 let ctx = engine.create_inference_context();
1129 let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
1130
1131 let formatted = engine.chat_template.wrap_prompt(prompt);
1132 if std::env::var("LLAMA_DEBUG").is_ok() {
1133 eprintln!("[DEBUG] formatted prompt: {:?}", formatted);
1134 eprintln!("[DEBUG] add_bos: {}", engine.add_bos);
1135 }
1136 let tokens = engine
1137 .tokenizer
1138 .encode(&formatted, engine.add_bos)
1139 .unwrap_or_default();
1140 if std::env::var("LLAMA_DEBUG").is_ok() {
1141 eprintln!("[DEBUG] encoded {} tokens: {:?}", tokens.len(), &tokens[..tokens.len().min(50)]);
1142 for (i, &tid) in tokens.iter().enumerate() {
1143 if let Some(s) = engine.tokenizer.get_token(tid) {
1144 eprintln!("[DEBUG] token[{}] = {} -> {:?}", i, tid, s);
1145 }
1146 }
1147 }
1148
1149 Self {
1150 engine,
1151 ctx,
1152 sampler,
1153 tokens,
1154 remaining: max_tokens,
1155 done: false,
1156 accumulated: String::new(),
1157 pending_bytes: Vec::new(),
1158 }
1159 }
1160}
1161
1162impl<'a> Iterator for GenerationStream<'a> {
1163 type Item = Result<String, EngineError>;
1164
1165 fn next(&mut self) -> Option<Self::Item> {
1166 if self.done || self.remaining == 0 {
1167 return None;
1168 }
1169
1170 if let Some(&last) = self.tokens.last()
1172 && last == self.engine.tokenizer.special_tokens.eos_token_id
1173 {
1174 self.done = true;
1175 return None;
1176 }
1177
1178 let input_tokens = if self.ctx.position == 0 {
1180 &self.tokens[..]
1181 } else {
1182 &self.tokens[self.tokens.len() - 1..]
1183 };
1184
1185 let logits = match self.engine.model.forward(input_tokens, &mut self.ctx) {
1186 Ok(l) => l,
1187 Err(e) => {
1188 self.done = true;
1189 return Some(Err(EngineError::Model(e)));
1190 }
1191 };
1192
1193 let next_token = self.sampler.sample(&logits, &self.tokens);
1194
1195 if std::env::var("LLAMA_DEBUG_LOGITS").is_ok() {
1196 let logit_data = logits.as_f32().unwrap();
1197 let mut indexed: Vec<(usize, f32)> = logit_data.iter().copied().enumerate().collect();
1198 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1199 let step = self.tokens.len();
1200 eprint!("[LOGIT] step={} top5:", step);
1201 for (id, score) in indexed.iter().take(5) {
1202 let tok_str = self.engine.tokenizer.get_token(*id as u32).unwrap_or_default();
1203 eprint!(" {}({:.2})={:?}", id, score, tok_str);
1204 }
1205 let chosen_str = self.engine.tokenizer.get_token(next_token).unwrap_or_default();
1206 eprintln!(" → chosen={}({:?})", next_token, chosen_str);
1207 }
1208
1209 if next_token == self.engine.tokenizer.special_tokens.eos_token_id {
1211 self.done = true;
1212 return None;
1213 }
1214
1215 match self
1217 .engine
1218 .tokenizer
1219 .decode_token_streaming(next_token, &mut self.pending_bytes)
1220 {
1221 Ok(text) => {
1222 self.tokens.push(next_token);
1223 self.remaining -= 1;
1224
1225 if text.is_empty() {
1226 return self.next();
1228 }
1229
1230 let combined = format!("{}{}", self.accumulated, text);
1232 for pattern in self.engine.chat_template.stop_patterns() {
1233 if combined.contains(pattern) {
1234 self.done = true;
1235 if let Some(idx) = combined.find(pattern) {
1236 if idx > self.accumulated.len() {
1237 let before = &combined[self.accumulated.len()..idx];
1238 return Some(Ok(before.to_string()));
1239 }
1240 }
1241 return None;
1242 }
1243 }
1244
1245 self.accumulated.push_str(&text);
1246 Some(Ok(text))
1247 }
1248 Err(e) => {
1249 self.tokens.push(next_token);
1250 self.remaining -= 1;
1251 Some(Err(EngineError::Tokenizer(e)))
1252 }
1253 }
1254 }
1255}
1256
1257pub struct ChatEngine {
1266 engine: Engine,
1267 system_prompt: String,
1268 conversation_tokens: Vec<u32>,
1269 ctx: InferenceContext,
1270 sampler: Sampler,
1271 is_first_turn: bool,
1272}
1273
1274impl ChatEngine {
1275 pub fn new(engine: Engine, system_prompt: Option<String>) -> Self {
1277 let ctx = engine.create_inference_context();
1278 let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
1279
1280 Self {
1281 system_prompt: system_prompt
1282 .unwrap_or_else(|| "You are a helpful AI assistant.".to_string()),
1283 conversation_tokens: Vec::new(),
1284 ctx,
1285 sampler,
1286 is_first_turn: true,
1287 engine,
1288 }
1289 }
1290
1291 pub fn engine(&self) -> &Engine {
1293 &self.engine
1294 }
1295
1296 pub fn system_prompt(&self) -> &str {
1298 &self.system_prompt
1299 }
1300
1301 pub fn context_len(&self) -> usize {
1303 self.conversation_tokens.len()
1304 }
1305
1306 pub fn chat(&mut self, message: &str) -> Result<String, EngineError> {
1308 let max_tokens = self.engine.engine_config.max_tokens;
1309
1310 let formatted = if self.is_first_turn {
1312 self.engine
1313 .chat_template
1314 .format_first_turn(&self.system_prompt, message)
1315 } else {
1316 self.engine.chat_template.format_continuation(message)
1317 };
1318
1319 let new_tokens = self
1321 .engine
1322 .tokenizer
1323 .encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
1324
1325 self.ensure_context_space(new_tokens.len(), max_tokens);
1327
1328 self.conversation_tokens.extend(&new_tokens);
1330
1331 let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
1335 let mut response_text = String::new();
1336
1337 if new_tokens.is_empty() {
1338 self.is_first_turn = false;
1339 return Ok(response_text);
1340 }
1341
1342 let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
1343 let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
1344
1345 if first_token == eos_id {
1346 self.is_first_turn = false;
1347 return Ok(response_text);
1348 }
1349
1350 if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
1351 response_text.push_str(&text);
1352 }
1353 self.conversation_tokens.push(first_token);
1354
1355 for _ in 1..max_tokens {
1357 let should_stop = self
1359 .engine
1360 .chat_template
1361 .stop_patterns()
1362 .iter()
1363 .any(|p| response_text.contains(p));
1364 if should_stop {
1365 for pattern in self.engine.chat_template.stop_patterns() {
1366 if let Some(idx) = response_text.find(pattern) {
1367 response_text.truncate(idx);
1368 break;
1369 }
1370 }
1371 break;
1372 }
1373
1374 let last_token = *self
1375 .conversation_tokens
1376 .last()
1377 .unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
1378
1379 let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
1380 let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
1381
1382 if next_token == eos_id {
1383 break;
1384 }
1385
1386 if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
1387 response_text.push_str(&text);
1388 }
1389
1390 self.conversation_tokens.push(next_token);
1391 }
1392
1393 self.is_first_turn = false;
1394 Ok(response_text.trim().to_string())
1395 }
1396
1397 pub fn chat_with_prefix(
1405 &mut self,
1406 message: &str,
1407 prefix: &str,
1408 ) -> Result<String, EngineError> {
1409 let max_tokens = self.engine.engine_config.max_tokens;
1410
1411 let formatted = if self.is_first_turn {
1412 self.engine
1413 .chat_template
1414 .format_first_turn(&self.system_prompt, message)
1415 } else {
1416 self.engine.chat_template.format_continuation(message)
1417 };
1418
1419 let formatted_with_prefix = format!("{}{}", formatted, prefix);
1421
1422 let new_tokens = self
1423 .engine
1424 .tokenizer
1425 .encode(&formatted_with_prefix, self.is_first_turn && self.engine.add_bos)?;
1426
1427 self.ensure_context_space(new_tokens.len(), max_tokens);
1428 self.conversation_tokens.extend(&new_tokens);
1429
1430 let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
1431 let mut response_text = prefix.to_string();
1432
1433 if new_tokens.is_empty() {
1434 self.is_first_turn = false;
1435 return Ok(response_text);
1436 }
1437
1438 let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
1439 let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
1440
1441 if first_token == eos_id {
1442 self.is_first_turn = false;
1443 return Ok(response_text);
1444 }
1445
1446 if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
1447 response_text.push_str(&text);
1448 }
1449 self.conversation_tokens.push(first_token);
1450
1451 for _ in 1..max_tokens {
1452 let should_stop = self
1453 .engine
1454 .chat_template
1455 .stop_patterns()
1456 .iter()
1457 .any(|p| response_text.contains(p));
1458 if should_stop {
1459 for pattern in self.engine.chat_template.stop_patterns() {
1460 if let Some(idx) = response_text.find(pattern) {
1461 response_text.truncate(idx);
1462 break;
1463 }
1464 }
1465 break;
1466 }
1467
1468 let last_token = *self
1469 .conversation_tokens
1470 .last()
1471 .unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
1472
1473 let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
1474 let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
1475
1476 if next_token == eos_id {
1477 break;
1478 }
1479
1480 if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
1481 response_text.push_str(&text);
1482 }
1483
1484 self.conversation_tokens.push(next_token);
1485 }
1486
1487 self.is_first_turn = false;
1488 Ok(response_text.trim().to_string())
1489 }
1490
1491 pub fn chat_streaming(&mut self, message: &str) -> Result<ChatStream<'_>, EngineError> {
1496 let max_tokens = self.engine.engine_config.max_tokens;
1497
1498 let formatted = if self.is_first_turn {
1500 self.engine
1501 .chat_template
1502 .format_first_turn(&self.system_prompt, message)
1503 } else {
1504 self.engine.chat_template.format_continuation(message)
1505 };
1506
1507 let new_tokens = self
1509 .engine
1510 .tokenizer
1511 .encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
1512
1513 self.ensure_context_space(new_tokens.len(), max_tokens);
1515
1516 self.conversation_tokens.extend(&new_tokens);
1518
1519 let prefill_logits = if !new_tokens.is_empty() {
1521 Some(self.engine.model.forward(&new_tokens, &mut self.ctx)?)
1522 } else {
1523 None
1524 };
1525
1526 self.is_first_turn = false;
1527
1528 Ok(ChatStream {
1529 chat_engine: self,
1530 remaining: max_tokens,
1531 done: false,
1532 accumulated: String::new(),
1533 prefill_logits,
1534 })
1535 }
1536
1537 pub fn clear_history(&mut self) {
1539 self.conversation_tokens.clear();
1540 self.ctx.reset();
1541 self.sampler.reset();
1542 self.is_first_turn = true;
1543 }
1544
1545 fn ensure_context_space(&mut self, new_token_count: usize, max_gen_tokens: usize) {
1547 let total_len = self.conversation_tokens.len() + new_token_count + max_gen_tokens;
1548
1549 if total_len > self.engine.config.max_seq_len {
1550 let excess = total_len - self.engine.config.max_seq_len + 100;
1551
1552 if excess >= self.conversation_tokens.len() {
1553 tracing::warn!("Context full, resetting conversation");
1554 self.conversation_tokens.clear();
1555 self.ctx.reset();
1556 } else {
1557 tracing::info!("Trimming {} tokens from context", excess);
1558 self.conversation_tokens = self.conversation_tokens[excess..].to_vec();
1559 self.ctx.kv_cache.shift_left(excess);
1560 self.ctx.position = self.ctx.position.saturating_sub(excess);
1561 }
1562 }
1563 }
1564}
1565
1566pub struct ChatStream<'a> {
1574 chat_engine: &'a mut ChatEngine,
1575 remaining: usize,
1576 done: bool,
1577 accumulated: String,
1578 prefill_logits: Option<crate::tensor::Tensor>,
1580}
1581
1582impl<'a> Iterator for ChatStream<'a> {
1583 type Item = Result<String, EngineError>;
1584
1585 fn next(&mut self) -> Option<Self::Item> {
1586 if self.done || self.remaining == 0 {
1587 return None;
1588 }
1589
1590 for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
1592 if self.accumulated.contains(pattern) {
1593 self.done = true;
1594 return None;
1595 }
1596 }
1597
1598 let logits = if let Some(prefill) = self.prefill_logits.take() {
1601 prefill
1602 } else {
1603 let last_token = *self.chat_engine.conversation_tokens.last().unwrap_or(
1604 &self
1605 .chat_engine
1606 .engine
1607 .tokenizer
1608 .special_tokens
1609 .bos_token_id,
1610 );
1611
1612 match self
1613 .chat_engine
1614 .engine
1615 .model
1616 .forward(&[last_token], &mut self.chat_engine.ctx)
1617 {
1618 Ok(l) => l,
1619 Err(e) => {
1620 self.done = true;
1621 return Some(Err(EngineError::Model(e)));
1622 }
1623 }
1624 };
1625
1626 let next_token = self
1627 .chat_engine
1628 .sampler
1629 .sample(&logits, &self.chat_engine.conversation_tokens);
1630
1631 if next_token
1633 == self
1634 .chat_engine
1635 .engine
1636 .tokenizer
1637 .special_tokens
1638 .eos_token_id
1639 {
1640 self.done = true;
1641 return None;
1642 }
1643
1644 match self.chat_engine.engine.tokenizer.decode(&[next_token]) {
1645 Ok(text) => {
1646 let combined = format!("{}{}", self.accumulated, text);
1648 for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
1649 if combined.contains(pattern) {
1650 self.done = true;
1651 if let Some(idx) = combined.find(pattern) {
1652 let before = &combined[self.accumulated.len()..idx];
1653 self.chat_engine.conversation_tokens.push(next_token);
1654 if !before.is_empty() {
1655 return Some(Ok(before.to_string()));
1656 }
1657 }
1658 return None;
1659 }
1660 }
1661
1662 self.accumulated.push_str(&text);
1663 self.chat_engine.conversation_tokens.push(next_token);
1664 self.remaining -= 1;
1665 Some(Ok(text))
1666 }
1667 Err(e) => {
1668 self.chat_engine.conversation_tokens.push(next_token);
1669 self.remaining -= 1;
1670 Some(Err(EngineError::Tokenizer(e)))
1671 }
1672 }
1673 }
1674}