1use std::path::Path;
4
5pub const FLASH_ATTN_THRESHOLD: usize = 512;
17
18use oxillama_arch::config::ModelConfig;
19use oxillama_arch::traits::{ForwardPass, KvCacheAccess};
20use oxillama_gguf::GgufModel;
21
22use crate::embedding::{pool_hidden_states, PoolingMode};
23use crate::error::{RuntimeError, RuntimeResult};
24use crate::kv_cache::{KvCache, KvCacheSnapshot};
25use crate::metrics::{EngineMetrics, MetricsSnapshot};
26use crate::offload::{LayerPager, OffloadPolicy};
27use crate::sampling::{Sampler, SamplerConfig};
28use crate::tokenizer_bridge::TokenizerBridge;
29use std::sync::Arc;
30use std::time::Instant;
31
32#[derive(Debug, Clone)]
34pub struct EngineConfig {
35 pub model_path: String,
37 pub tokenizer_path: Option<String>,
39 pub context_size: Option<usize>,
41 pub num_threads: usize,
43 pub sampler: SamplerConfig,
45 pub prefill_chunk_size: usize,
52
53 pub offload_policy: OffloadPolicy,
61}
62
63impl Default for EngineConfig {
64 fn default() -> Self {
65 Self {
66 model_path: String::new(),
67 tokenizer_path: None,
68 context_size: None,
69 num_threads: 4,
70 sampler: SamplerConfig::default(),
71 prefill_chunk_size: 512,
72 offload_policy: OffloadPolicy::None,
73 }
74 }
75}
76
77impl EngineConfig {
78 pub fn with_offload(mut self, policy: OffloadPolicy) -> Self {
81 self.offload_policy = policy;
82 self
83 }
84}
85
86pub struct InferenceEngine {
91 config: EngineConfig,
92 gguf_model: Option<GgufModel>,
94 model_config: Option<ModelConfig>,
96 forward_pass: Option<Box<dyn ForwardPass>>,
98 kv_cache: Option<KvCache>,
100 tokenizer: Option<TokenizerBridge>,
102 eos_token_id: Option<u32>,
104 metrics: Arc<EngineMetrics>,
106 lora_stack: oxillama_arch::LoraStack,
108 layer_pager: Option<Arc<LayerPager>>,
120}
121
122impl InferenceEngine {
123 pub fn new(config: EngineConfig) -> Self {
125 Self {
126 config,
127 gguf_model: None,
128 model_config: None,
129 forward_pass: None,
130 kv_cache: None,
131 tokenizer: None,
132 eos_token_id: None,
133 metrics: EngineMetrics::new(),
134 lora_stack: oxillama_arch::LoraStack::new(),
135 layer_pager: None,
136 }
137 }
138
139 pub fn layer_pager(&self) -> Option<&Arc<LayerPager>> {
146 self.layer_pager.as_ref()
147 }
148
149 pub fn set_layer_pager(&mut self, pager: Arc<LayerPager>) {
156 self.layer_pager = Some(pager);
157 }
158
159 pub fn load_model_from_bytes(
173 &mut self,
174 model_bytes: &[u8],
175 tokenizer_json: &str,
176 ) -> RuntimeResult<()> {
177 let gguf = GgufModel::from_bytes(model_bytes.to_vec())?;
179 tracing::info!(
180 arch = gguf.architecture().unwrap_or("unknown"),
181 tensors = gguf.file.header.tensor_count,
182 "GGUF file parsed from bytes"
183 );
184
185 let mut model_config = ModelConfig::from_metadata(&gguf.file.metadata)?;
187 if let Some(ctx) = self.config.context_size {
188 model_config.max_context_length = ctx;
189 }
190
191 tracing::info!(
192 arch = %model_config.architecture,
193 layers = model_config.num_layers,
194 hidden = model_config.hidden_size,
195 heads = model_config.num_attention_heads,
196 kv_heads = model_config.num_kv_heads,
197 vocab = model_config.vocab_size,
198 ctx = model_config.max_context_length,
199 "model config loaded from bytes"
200 );
201
202 let forward_pass = build_forward_pass(&gguf, &model_config)?;
204
205 let kv_dim = model_config.num_kv_heads * model_config.head_dim;
207 let kv_cache = KvCache::new(
208 model_config.num_layers,
209 model_config.max_context_length,
210 kv_dim,
211 );
212 tracing::info!(
213 layers = model_config.num_layers,
214 max_ctx = model_config.max_context_length,
215 kv_dim = kv_dim,
216 "KV cache initialized (from-bytes path)"
217 );
218
219 let tokenizer = TokenizerBridge::from_bytes(tokenizer_json.as_bytes())?;
221 let eos_token_id = tokenizer.eos_token_id();
222 tracing::info!(
223 vocab_size = tokenizer.vocab_size(),
224 eos = ?eos_token_id,
225 "tokenizer loaded from JSON string"
226 );
227
228 self.model_config = Some(model_config);
229 self.forward_pass = Some(forward_pass);
230 self.kv_cache = Some(kv_cache);
231 self.tokenizer = Some(tokenizer);
232 self.eos_token_id = eos_token_id;
233 self.gguf_model = Some(gguf);
234
235 Ok(())
236 }
237
238 pub fn load_model(&mut self) -> RuntimeResult<()> {
247 let path = Path::new(&self.config.model_path);
248 if !path.exists() {
249 return Err(RuntimeError::ModelLoadError {
250 message: format!("model file not found: {}", self.config.model_path),
251 });
252 }
253
254 tracing::info!(path = %self.config.model_path, "loading GGUF model");
255
256 let gguf = GgufModel::load(&self.config.model_path)?;
258 tracing::info!(
259 arch = gguf.architecture().unwrap_or("unknown"),
260 tensors = gguf.file.header.tensor_count,
261 "GGUF file parsed"
262 );
263
264 let mut model_config = ModelConfig::from_metadata(&gguf.file.metadata)?;
266
267 if let Some(ctx) = self.config.context_size {
269 model_config.max_context_length = ctx;
270 }
271
272 tracing::info!(
273 arch = %model_config.architecture,
274 layers = model_config.num_layers,
275 hidden = model_config.hidden_size,
276 heads = model_config.num_attention_heads,
277 kv_heads = model_config.num_kv_heads,
278 vocab = model_config.vocab_size,
279 ctx = model_config.max_context_length,
280 "model config loaded"
281 );
282
283 let forward_pass = build_forward_pass(&gguf, &model_config)?;
285
286 let kv_dim = model_config.num_kv_heads * model_config.head_dim;
288 let kv_cache = KvCache::new(
289 model_config.num_layers,
290 model_config.max_context_length,
291 kv_dim,
292 );
293 tracing::info!(
294 layers = model_config.num_layers,
295 max_ctx = model_config.max_context_length,
296 kv_dim = kv_dim,
297 "KV cache initialized"
298 );
299
300 let tokenizer = load_tokenizer(&self.config, &gguf)?;
302 let eos_token_id = tokenizer.eos_token_id();
303 tracing::info!(
304 vocab_size = tokenizer.vocab_size(),
305 eos = ?eos_token_id,
306 "tokenizer loaded"
307 );
308
309 self.model_config = Some(model_config);
310 self.forward_pass = Some(forward_pass);
311 self.kv_cache = Some(kv_cache);
312 self.tokenizer = Some(tokenizer);
313 self.eos_token_id = eos_token_id;
314 self.gguf_model = Some(gguf);
315
316 Ok(())
317 }
318
319 pub fn generate(
328 &mut self,
329 prompt: &str,
330 max_tokens: usize,
331 mut callback: impl FnMut(&str),
332 ) -> RuntimeResult<String> {
333 let tokenizer = self
334 .tokenizer
335 .as_ref()
336 .ok_or(RuntimeError::ModelNotLoaded)?;
337 let forward_pass = self
338 .forward_pass
339 .as_mut()
340 .ok_or(RuntimeError::ModelNotLoaded)?;
341 let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
342
343 let prompt_tokens = tokenizer.encode(prompt)?;
345 if prompt_tokens.is_empty() {
346 return Ok(String::new());
347 }
348
349 tracing::debug!(n_tokens = prompt_tokens.len(), "prompt tokenized");
350
351 let mut recent_tokens = prompt_tokens.clone();
353 let mut generated_tokens: Vec<u32> = Vec::new();
354 let mut output_text = String::new();
355
356 let chunk_size = if self.config.prefill_chunk_size == 0 {
366 prompt_tokens.len()
367 } else {
368 self.config.prefill_chunk_size
369 };
370
371 let mut logits = if prompt_tokens.len() <= chunk_size {
372 tracing::debug!(
374 chunk = 1,
375 tokens = prompt_tokens.len(),
376 "prefill: single batch"
377 );
378 let prefill_start = Instant::now();
379 let result = forward_pass.forward(&prompt_tokens, kv_cache)?;
380 self.metrics
381 .record_prefill(prompt_tokens.len() as u64, prefill_start.elapsed());
382 result
383 } else {
384 let n_chunks = prompt_tokens.len().div_ceil(chunk_size);
386 tracing::debug!(
387 n_chunks = n_chunks,
388 chunk_size = chunk_size,
389 total = prompt_tokens.len(),
390 "prefill: chunked"
391 );
392
393 let prefill_start = Instant::now();
394 let mut last_logits = Vec::new();
395 for (i, chunk) in prompt_tokens.chunks(chunk_size).enumerate() {
396 tracing::trace!(
397 chunk_idx = i,
398 chunk_len = chunk.len(),
399 kv_pos = kv_cache.seq_len(),
400 "prefill chunk"
401 );
402 last_logits = forward_pass.forward(chunk, kv_cache)?;
403 }
404 self.metrics
405 .record_prefill(prompt_tokens.len() as u64, prefill_start.elapsed());
406 last_logits
407 };
408
409 let mut sampler = Sampler::new(self.config.sampler.clone());
411
412 self.metrics.record_request_start();
414 for _step in 0..max_tokens {
415 let next_token = sampler.sample(&logits, &recent_tokens);
417
418 if Some(next_token) == self.eos_token_id {
420 tracing::debug!("EOS token generated, stopping");
421 break;
422 }
423
424 if kv_cache.seq_len() >= forward_pass.max_context_length() {
426 tracing::warn!("context length reached, stopping generation");
427 break;
428 }
429
430 let token_text = tokenizer.decode(&[next_token])?;
432 callback(&token_text);
433 output_text.push_str(&token_text);
434
435 recent_tokens.push(next_token);
437 generated_tokens.push(next_token);
438
439 let decode_start = Instant::now();
441 logits = forward_pass.forward(&[next_token], kv_cache)?;
442 self.metrics.record_decode_token(decode_start.elapsed());
443 }
444 self.metrics.record_request_complete();
445
446 tracing::info!(
447 prompt_tokens = prompt_tokens.len(),
448 generated_tokens = generated_tokens.len(),
449 "generation complete"
450 );
451
452 Ok(output_text)
453 }
454
455 pub fn generate_with_config(
460 &mut self,
461 prompt: &str,
462 max_tokens: usize,
463 sampler_config: SamplerConfig,
464 mut callback: impl FnMut(&str),
465 ) -> RuntimeResult<String> {
466 let tokenizer = self
467 .tokenizer
468 .as_ref()
469 .ok_or(RuntimeError::ModelNotLoaded)?;
470 let forward_pass = self
471 .forward_pass
472 .as_mut()
473 .ok_or(RuntimeError::ModelNotLoaded)?;
474 let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
475
476 let prompt_tokens = tokenizer.encode(prompt)?;
477 if prompt_tokens.is_empty() {
478 return Ok(String::new());
479 }
480
481 let mut recent_tokens = prompt_tokens.clone();
482 let mut generated_tokens: Vec<u32> = Vec::new();
483 let mut output_text = String::new();
484
485 for &token in &prompt_tokens[..prompt_tokens.len() - 1] {
486 forward_pass.forward(&[token], kv_cache)?;
487 }
488
489 let last = *prompt_tokens.last().ok_or(RuntimeError::ModelNotLoaded)?;
490 let mut logits = forward_pass.forward(&[last], kv_cache)?;
491
492 let mut sampler = Sampler::new(sampler_config);
493 self.metrics.record_request_start();
494 for _step in 0..max_tokens {
495 let next_token = sampler.sample(&logits, &recent_tokens);
496
497 if Some(next_token) == self.eos_token_id {
498 tracing::debug!("EOS token generated, stopping");
499 break;
500 }
501
502 if kv_cache.seq_len() >= forward_pass.max_context_length() {
503 tracing::warn!("context length reached, stopping generation");
504 break;
505 }
506
507 let token_text = tokenizer.decode(&[next_token])?;
508 callback(&token_text);
509 output_text.push_str(&token_text);
510
511 recent_tokens.push(next_token);
512 generated_tokens.push(next_token);
513
514 let decode_start = Instant::now();
515 logits = forward_pass.forward(&[next_token], kv_cache)?;
516 self.metrics.record_decode_token(decode_start.elapsed());
517 }
518 self.metrics.record_request_complete();
519
520 tracing::info!(
521 prompt_tokens = prompt_tokens.len(),
522 generated_tokens = generated_tokens.len(),
523 "generation (with custom config) complete"
524 );
525
526 Ok(output_text)
527 }
528
529 pub fn vocab_bytes(&self) -> Option<Vec<(u32, Vec<u8>)>> {
533 self.tokenizer.as_ref().map(|t| t.vocab_bytes())
534 }
535
536 pub fn apply_lora_adapters(
547 &mut self,
548 lora: &oxillama_arch::lora::LoadedLora,
549 ) -> RuntimeResult<()> {
550 let fp = self
551 .forward_pass
552 .as_mut()
553 .ok_or(RuntimeError::ModelNotLoaded)?;
554 fp.apply_lora(lora).map_err(RuntimeError::Arch)?;
555 Ok(())
556 }
557
558 pub fn push_lora(&mut self, lora: std::sync::Arc<oxillama_arch::lora::LoadedLora>, scale: f32) {
567 self.lora_stack.push(lora, scale);
568 }
569
570 pub fn pop_lora(&mut self) -> Option<(std::sync::Arc<oxillama_arch::lora::LoadedLora>, f32)> {
574 self.lora_stack.pop()
575 }
576
577 pub fn clear_loras(&mut self) {
579 self.lora_stack.clear();
580 }
581
582 pub fn lora_stack(&self) -> &oxillama_arch::LoraStack {
584 &self.lora_stack
585 }
586
587 pub fn apply_lora_stack(&mut self) -> RuntimeResult<()> {
594 if self.lora_stack.is_empty() {
595 return Ok(());
596 }
597 let fp = self
598 .forward_pass
599 .as_mut()
600 .ok_or(RuntimeError::ModelNotLoaded)?;
601 fp.apply_lora_stack(&self.lora_stack)
602 .map_err(RuntimeError::Arch)?;
603 Ok(())
604 }
605
606 pub fn unapply_all_loras(&mut self) {
616 self.lora_stack.clear();
617 if let Some(fp) = self.forward_pass.as_mut() {
618 fp.unapply_all_loras();
619 }
620 }
621
622 pub fn prime_with_prefix(
638 &mut self,
639 cached: &crate::kv_cache::prefix::CachedKvState,
640 restore_to: usize,
641 suffix_tokens: &[u32],
642 ) -> RuntimeResult<Vec<f32>> {
643 if suffix_tokens.is_empty() {
644 return Err(RuntimeError::ModelLoadError {
645 message: "prime_with_prefix: suffix_tokens must contain at least one token"
646 .to_string(),
647 });
648 }
649 {
653 let kv = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
654 kv.restore_from_snapshot(cached.keys(), cached.values(), restore_to);
655 }
656 let forward_pass = self
659 .forward_pass
660 .as_mut()
661 .ok_or(RuntimeError::ModelNotLoaded)?;
662 let kv = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
663 let logits = forward_pass
664 .forward(suffix_tokens, kv)
665 .map_err(RuntimeError::Arch)?;
666 Ok(logits)
667 }
668
669 pub fn generate_with_logits(
679 &mut self,
680 prompt_tokens: &[u32],
681 initial_logits: Vec<f32>,
682 max_tokens: usize,
683 sampler_config: SamplerConfig,
684 mut callback: impl FnMut(&str),
685 ) -> RuntimeResult<String> {
686 let tokenizer = self
687 .tokenizer
688 .as_ref()
689 .ok_or(RuntimeError::ModelNotLoaded)?;
690 let forward_pass = self
691 .forward_pass
692 .as_mut()
693 .ok_or(RuntimeError::ModelNotLoaded)?;
694 let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
695 let max_ctx = forward_pass.max_context_length();
696 let eos_token_id = self.eos_token_id;
697
698 let mut recent_tokens: Vec<u32> = prompt_tokens.to_vec();
699 let mut output_text = String::new();
700 let mut logits = initial_logits;
701
702 let mut sampler = Sampler::new(sampler_config);
703 self.metrics.record_request_start();
704
705 for _step in 0..max_tokens {
706 let next_token = sampler.sample(&logits, &recent_tokens);
707
708 if Some(next_token) == eos_token_id {
709 tracing::debug!("EOS token generated, stopping (primed path)");
710 break;
711 }
712
713 if kv_cache.seq_len() >= max_ctx {
714 tracing::warn!("context length reached, stopping generation (primed path)");
715 break;
716 }
717
718 let token_text = tokenizer.decode(&[next_token])?;
719 callback(&token_text);
720 output_text.push_str(&token_text);
721 recent_tokens.push(next_token);
722
723 let decode_start = Instant::now();
724 logits = forward_pass
725 .forward(&[next_token], kv_cache)
726 .map_err(RuntimeError::Arch)?;
727 self.metrics.record_decode_token(decode_start.elapsed());
728 }
729
730 self.metrics.record_request_complete();
731 Ok(output_text)
732 }
733
734 pub fn is_loaded(&self) -> bool {
736 self.forward_pass.is_some()
737 }
738
739 pub fn config(&self) -> &EngineConfig {
741 &self.config
742 }
743
744 pub fn model_config(&self) -> Option<&ModelConfig> {
746 self.model_config.as_ref()
747 }
748
749 pub(crate) fn kv_cache_ref(&self) -> Option<&KvCache> {
751 self.kv_cache.as_ref()
752 }
753
754 pub(crate) fn kv_cache_mut(&mut self) -> Option<&mut KvCache> {
756 self.kv_cache.as_mut()
757 }
758
759 pub fn store_kv_in_prefix_cache(
767 &mut self,
768 tokens: &[u32],
769 prefix_cache: &mut crate::kv_cache::prefix::PrefixKvCache,
770 ) {
771 if let Some(kv) = self.kv_cache.as_mut() {
772 let seq_len = kv.seq_len();
773 let kv_dim = kv.kv_dim();
774 let num_layers = kv.num_layers();
775 prefix_cache.store(tokens, kv, seq_len, kv_dim, num_layers);
776 }
777 }
778
779 pub fn reset(&mut self) {
781 if let Some(ref mut cache) = self.kv_cache {
782 cache.clear();
783 }
784 }
785
786 pub fn tokenize(&self, text: &str) -> RuntimeResult<Vec<u32>> {
794 let tokenizer = self
795 .tokenizer
796 .as_ref()
797 .ok_or(RuntimeError::ModelNotLoaded)?;
798 tokenizer.encode(text)
799 }
800
801 pub fn prefill(&mut self, tokens: &[u32]) -> RuntimeResult<()> {
807 if tokens.is_empty() {
808 return Ok(());
809 }
810 let forward_pass = self
811 .forward_pass
812 .as_mut()
813 .ok_or(RuntimeError::ModelNotLoaded)?;
814 let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
815 for &token in tokens {
816 forward_pass.forward(&[token], kv_cache)?;
817 }
818 Ok(())
819 }
820
821 pub fn forward_prefill(&mut self, tokens: &[u32], pos_start: usize) -> RuntimeResult<Vec<f32>> {
842 if tokens.is_empty() {
843 return Err(RuntimeError::ModelLoadError {
844 message: "forward_prefill called with empty token slice".to_string(),
845 });
846 }
847 let forward_pass = self
848 .forward_pass
849 .as_mut()
850 .ok_or(RuntimeError::ModelNotLoaded)?;
851 let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
852
853 debug_assert_eq!(
854 kv_cache.seq_len(),
855 pos_start,
856 "forward_prefill: pos_start ({pos_start}) must equal kv_cache.seq_len() ({})",
857 kv_cache.seq_len(),
858 );
859
860 let logits = forward_pass.forward(tokens, kv_cache)?;
861 Ok(logits)
862 }
863
864 pub fn forward_decode(&mut self, token: u32, pos: usize) -> RuntimeResult<Vec<f32>> {
878 let forward_pass = self
879 .forward_pass
880 .as_mut()
881 .ok_or(RuntimeError::ModelNotLoaded)?;
882 let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
883
884 debug_assert_eq!(
885 kv_cache.seq_len(),
886 pos,
887 "forward_decode: pos ({pos}) must equal kv_cache.seq_len() ({})",
888 kv_cache.seq_len(),
889 );
890
891 let logits = forward_pass.forward(&[token], kv_cache)?;
892 Ok(logits)
893 }
894
895 pub fn forward_one(&mut self, token: u32) -> RuntimeResult<Vec<f32>> {
899 let forward_pass = self
900 .forward_pass
901 .as_mut()
902 .ok_or(RuntimeError::ModelNotLoaded)?;
903 let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
904 let logits = forward_pass.forward(&[token], kv_cache)?;
905 Ok(logits)
906 }
907
908 pub fn is_eos(&self, token: u32) -> bool {
910 self.eos_token_id == Some(token)
911 }
912
913 pub fn decode_token(&self, token: u32) -> RuntimeResult<String> {
915 let tokenizer = self
916 .tokenizer
917 .as_ref()
918 .ok_or(RuntimeError::ModelNotLoaded)?;
919 tokenizer.decode(&[token])
920 }
921
922 pub fn metrics(&self) -> Arc<EngineMetrics> {
924 Arc::clone(&self.metrics)
925 }
926
927 pub fn metrics_snapshot(&self) -> MetricsSnapshot {
929 self.metrics.snapshot()
930 }
931
932 pub fn kv_snapshot(&self) -> Option<KvCacheSnapshot> {
936 self.kv_cache.as_ref().map(|c| c.snapshot())
937 }
938
939 pub fn kv_restore(&mut self, snapshot: &KvCacheSnapshot) -> RuntimeResult<()> {
943 let kv = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
944 kv.restore_from_snapshot(&snapshot.keys, &snapshot.values, snapshot.seq_len);
945 Ok(())
946 }
947
948 pub fn truncate(&mut self, n: usize) -> RuntimeResult<()> {
958 let kv = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
959 kv.truncate(n);
960 Ok(())
961 }
962
963 pub fn kv_cache_seq_len(&self) -> usize {
967 self.kv_cache.as_ref().map(|c| c.seq_len()).unwrap_or(0)
968 }
969
970 pub fn hidden_size(&self) -> Option<usize> {
972 self.model_config.as_ref().map(|c| c.hidden_size)
973 }
974
975 pub fn embed(&mut self, text: &str) -> RuntimeResult<Vec<f32>> {
984 self.embed_with(text, PoolingMode::Last)
985 }
986
987 pub fn embed_with(&mut self, text: &str, mode: PoolingMode) -> RuntimeResult<Vec<f32>> {
1004 self.reset();
1007
1008 let forward_pass = self
1010 .forward_pass
1011 .as_mut()
1012 .ok_or(RuntimeError::ModelNotLoaded)?;
1013 let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
1014
1015 let tokens = {
1018 let tok = self
1019 .tokenizer
1020 .as_ref()
1021 .ok_or(RuntimeError::ModelNotLoaded)?;
1022 tok.encode(text)?
1023 };
1024
1025 if tokens.is_empty() {
1026 let dim = self
1029 .model_config
1030 .as_ref()
1031 .map(|c| c.hidden_size)
1032 .unwrap_or(0);
1033 return Ok(vec![0.0f32; dim]);
1034 }
1035
1036 let seq_len = tokens.len();
1042 let hidden_size = forward_pass.hidden_size();
1043
1044 let all_hidden = forward_pass.embed_all(&tokens, kv_cache);
1045 let hidden = match all_hidden {
1046 Ok(states) if states.len() == seq_len * hidden_size && seq_len > 0 => {
1047 pool_hidden_states(&states, seq_len, hidden_size, mode)?
1049 }
1050 _ => {
1051 forward_pass.embed(&tokens, kv_cache)?
1053 }
1054 };
1055
1056 let norm: f32 = hidden.iter().map(|x| x * x).sum::<f32>().sqrt();
1058 if norm > 1e-9 {
1059 Ok(hidden.into_iter().map(|x| x / norm).collect())
1060 } else {
1061 Ok(hidden)
1062 }
1063 }
1064
1065 pub fn embed_batch(&mut self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
1070 let str_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1071 self.embed_batch_with(&str_refs, PoolingMode::Last)
1072 }
1073
1074 pub fn embed_batch_with(
1082 &mut self,
1083 texts: &[&str],
1084 mode: PoolingMode,
1085 ) -> RuntimeResult<Vec<Vec<f32>>> {
1086 {
1089 let _fp = self
1090 .forward_pass
1091 .as_ref()
1092 .ok_or(RuntimeError::ModelNotLoaded)?;
1093 let _tok = self
1094 .tokenizer
1095 .as_ref()
1096 .ok_or(RuntimeError::ModelNotLoaded)?;
1097 }
1098
1099 let mut embeddings = Vec::with_capacity(texts.len());
1100 for &text in texts {
1101 embeddings.push(self.embed_with(text, mode)?);
1102 }
1103 Ok(embeddings)
1104 }
1105}
1106
1107fn build_forward_pass(
1109 gguf: &GgufModel,
1110 config: &ModelConfig,
1111) -> RuntimeResult<Box<dyn ForwardPass>> {
1112 match config.architecture.as_str() {
1113 #[cfg(feature = "llama")]
1114 "llama" => {
1115 let model = oxillama_arch::llama::load_llama_from_gguf(gguf, config)?;
1116 Ok(Box::new(model))
1117 }
1118 #[cfg(feature = "qwen3")]
1119 "qwen3" => {
1120 let model = oxillama_arch::qwen3::load_qwen3_from_gguf(gguf, config)?;
1121 Ok(Box::new(model))
1122 }
1123 #[cfg(feature = "mistral")]
1124 "mistral" => {
1125 let model = oxillama_arch::mistral::load_mistral_from_gguf(gguf, config)?;
1126 Ok(Box::new(model))
1127 }
1128 #[cfg(feature = "gemma")]
1129 "gemma" | "gemma2" | "gemma3" => {
1130 let model = oxillama_arch::gemma::load_gemma_from_gguf(gguf, config)?;
1131 Ok(Box::new(model))
1132 }
1133 #[cfg(feature = "phi")]
1134 "phi3" | "phi" => {
1135 let model = oxillama_arch::phi::load_phi_from_gguf(gguf, config)?;
1136 Ok(Box::new(model))
1137 }
1138 #[cfg(feature = "command-r")]
1139 "command-r" => {
1140 let model = oxillama_arch::command_r::load_command_r_from_gguf(gguf, config)?;
1141 Ok(Box::new(model))
1142 }
1143 #[cfg(feature = "starcoder")]
1144 "starcoder" => {
1145 let model = oxillama_arch::starcoder::load_starcoder_from_gguf(gguf, config)?;
1146 Ok(Box::new(model))
1147 }
1148 arch => Err(RuntimeError::ModelLoadError {
1149 message: format!("unsupported architecture: '{arch}'"),
1150 }),
1151 }
1152}
1153
1154fn load_tokenizer(config: &EngineConfig, gguf: &GgufModel) -> RuntimeResult<TokenizerBridge> {
1156 if let Some(ref path) = config.tokenizer_path {
1158 return TokenizerBridge::from_file(path);
1159 }
1160
1161 if let Some(tokenizer_json) = gguf
1163 .file
1164 .metadata
1165 .get("tokenizer.ggml.tokens")
1166 .and_then(|_| {
1167 gguf.file
1169 .metadata
1170 .get("tokenizer.huggingface.json")
1171 .and_then(|v| v.as_str())
1172 })
1173 {
1174 return TokenizerBridge::from_bytes(tokenizer_json.as_bytes());
1175 }
1176
1177 let model_dir = Path::new(&config.model_path)
1179 .parent()
1180 .unwrap_or(Path::new("."));
1181 let tokenizer_path = model_dir.join("tokenizer.json");
1182 if tokenizer_path.exists() {
1183 return TokenizerBridge::from_file(tokenizer_path.to_str().unwrap_or("tokenizer.json"));
1184 }
1185
1186 Err(RuntimeError::TokenizerError {
1187 message: "no tokenizer found: provide --tokenizer path or place tokenizer.json next to the model file".to_string(),
1188 })
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193 use super::*;
1194
1195 #[test]
1199 fn test_forward_prefill_errors_when_not_loaded() {
1200 let mut engine = InferenceEngine::new(EngineConfig::default());
1201 let result = engine.forward_prefill(&[1, 2, 3], 0);
1202 assert!(
1203 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1204 "expected ModelNotLoaded from forward_prefill, got {result:?}"
1205 );
1206 }
1207
1208 #[test]
1211 fn test_forward_prefill_empty_slice_errors() {
1212 let mut engine = InferenceEngine::new(EngineConfig::default());
1213 let result = engine.forward_prefill(&[], 0);
1217 assert!(
1218 result.is_err(),
1219 "forward_prefill with empty slice must return Err, got Ok"
1220 );
1221 }
1222
1223 #[test]
1225 fn test_forward_decode_errors_when_not_loaded() {
1226 let mut engine = InferenceEngine::new(EngineConfig::default());
1227 let result = engine.forward_decode(42, 0);
1228 assert!(
1229 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1230 "expected ModelNotLoaded from forward_decode, got {result:?}"
1231 );
1232 }
1233
1234 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1239 #[test]
1240 fn test_forward_prefill_returns_logits_after_load() {
1241 let mut engine = make_loaded_engine();
1242 let result = engine.forward_prefill(&[3, 4, 5], 0);
1244 assert!(
1245 result.is_ok(),
1246 "forward_prefill must return Ok when model is loaded, got {result:?}"
1247 );
1248 let logits = result.expect("forward_prefill Ok");
1249 assert_eq!(
1250 logits.len(),
1251 32,
1252 "logits length must equal vocab_size=32, got {}",
1253 logits.len()
1254 );
1255 }
1256
1257 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1260 #[test]
1261 fn test_forward_decode_returns_logits_after_load() {
1262 let mut engine = make_loaded_engine();
1263 engine
1265 .forward_prefill(&[3], 0)
1266 .expect("prefill must succeed");
1267 let result = engine.forward_decode(4, 1);
1269 assert!(
1270 result.is_ok(),
1271 "forward_decode must return Ok when model is loaded, got {result:?}"
1272 );
1273 let logits = result.expect("forward_decode Ok");
1274 assert_eq!(
1275 logits.len(),
1276 32,
1277 "logits length must equal vocab_size=32, got {}",
1278 logits.len()
1279 );
1280 }
1281
1282 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1288 #[test]
1289 fn chunked_prefill_kv_matches_singleshot() {
1290 let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
1291 let tokenizer_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
1292 let prompt_tokens = vec![3u32, 4, 5, 6];
1293
1294 let mut engine_single = InferenceEngine::new(EngineConfig::default());
1296 engine_single
1297 .load_model_from_bytes(&model_bytes, tokenizer_json)
1298 .expect("single-shot load");
1299 let logits_single = engine_single
1301 .forward_prefill(&prompt_tokens, 0)
1302 .expect("single-shot prefill");
1303
1304 let mut engine_chunked = InferenceEngine::new(EngineConfig::default());
1306 engine_chunked
1307 .load_model_from_bytes(&model_bytes, tokenizer_json)
1308 .expect("chunked load");
1309
1310 let mut logits_chunked = Vec::new();
1311 let chunk_size = 2usize;
1312 let mut pos = 0usize;
1313 for slice in prompt_tokens.chunks(chunk_size) {
1314 logits_chunked = engine_chunked
1315 .forward_prefill(slice, pos)
1316 .expect("chunked prefill");
1317 pos += slice.len();
1318 }
1319
1320 assert_eq!(
1322 logits_single.len(),
1323 logits_chunked.len(),
1324 "logit vector lengths must match"
1325 );
1326 let tol = 1e-4f32;
1327 let max_diff = logits_single
1328 .iter()
1329 .zip(logits_chunked.iter())
1330 .map(|(a, b)| (a - b).abs())
1331 .fold(0.0f32, f32::max);
1332 assert!(
1333 max_diff < tol,
1334 "chunked and single-shot prefill logits differ by {max_diff} > tolerance {tol}"
1335 );
1336 }
1337
1338 #[test]
1343 fn test_embed_returns_err_when_not_loaded() {
1344 let mut engine = InferenceEngine::new(EngineConfig::default());
1345 let result = engine.embed("hello world");
1346 assert!(
1347 result.is_err(),
1348 "embed() should return Err when no model is loaded"
1349 );
1350 }
1351
1352 #[test]
1354 fn test_hidden_size_none_when_not_loaded() {
1355 let engine = InferenceEngine::new(EngineConfig::default());
1356 assert!(
1357 engine.hidden_size().is_none(),
1358 "hidden_size() should be None before load_model()"
1359 );
1360 }
1361
1362 #[test]
1364 fn test_is_loaded_false_initially() {
1365 let engine = InferenceEngine::new(EngineConfig::default());
1366 assert!(!engine.is_loaded());
1367 }
1368
1369 #[test]
1370 fn test_model_config_none_when_not_loaded() {
1371 let engine = InferenceEngine::new(EngineConfig::default());
1372 assert!(engine.model_config().is_none());
1373 }
1374
1375 #[test]
1376 fn test_config_roundtrip() {
1377 let cfg = EngineConfig {
1378 model_path: "test.gguf".to_string(),
1379 num_threads: 8,
1380 ..EngineConfig::default()
1381 };
1382 let engine = InferenceEngine::new(cfg);
1383 assert_eq!(engine.config().model_path, "test.gguf");
1384 assert_eq!(engine.config().num_threads, 8);
1385 }
1386
1387 #[test]
1388 fn test_generate_errors_when_not_loaded() {
1389 let mut engine = InferenceEngine::new(EngineConfig::default());
1390 let result = engine.generate("hello", 10, |_| {});
1391 assert!(
1392 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1393 "expected ModelNotLoaded, got {result:?}"
1394 );
1395 }
1396
1397 #[test]
1398 fn test_generate_with_config_errors_when_not_loaded() {
1399 let mut engine = InferenceEngine::new(EngineConfig::default());
1400 let result = engine.generate_with_config("hello", 5, SamplerConfig::greedy(), |_| {});
1401 assert!(
1402 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1403 "expected ModelNotLoaded, got {result:?}"
1404 );
1405 }
1406
1407 #[test]
1408 fn test_tokenize_errors_when_not_loaded() {
1409 let engine = InferenceEngine::new(EngineConfig::default());
1410 let result = engine.tokenize("hello world");
1411 assert!(
1412 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1413 "expected ModelNotLoaded, got {result:?}"
1414 );
1415 }
1416
1417 #[test]
1418 fn test_prefill_errors_when_not_loaded() {
1419 let mut engine = InferenceEngine::new(EngineConfig::default());
1420 let result = engine.prefill(&[1, 2, 3]);
1421 assert!(
1422 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1423 "expected ModelNotLoaded, got {result:?}"
1424 );
1425 }
1426
1427 #[test]
1428 fn test_prefill_empty_slice_ok_when_no_model() {
1429 let mut engine = InferenceEngine::new(EngineConfig::default());
1430 let result = engine.prefill(&[]);
1432 assert!(result.is_ok(), "empty prefill should be Ok, got {result:?}");
1433 }
1434
1435 #[test]
1436 fn test_forward_one_errors_when_not_loaded() {
1437 let mut engine = InferenceEngine::new(EngineConfig::default());
1438 let result = engine.forward_one(42);
1439 assert!(
1440 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1441 "expected ModelNotLoaded, got {result:?}"
1442 );
1443 }
1444
1445 #[test]
1446 fn test_decode_token_errors_when_not_loaded() {
1447 let engine = InferenceEngine::new(EngineConfig::default());
1448 let result = engine.decode_token(1);
1449 assert!(
1450 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1451 "expected ModelNotLoaded, got {result:?}"
1452 );
1453 }
1454
1455 #[test]
1456 fn test_is_eos_false_when_not_loaded() {
1457 let engine = InferenceEngine::new(EngineConfig::default());
1458 assert!(!engine.is_eos(0));
1459 assert!(!engine.is_eos(u32::MAX));
1460 }
1461
1462 #[test]
1463 fn test_vocab_bytes_none_when_not_loaded() {
1464 let engine = InferenceEngine::new(EngineConfig::default());
1465 assert!(engine.vocab_bytes().is_none());
1466 }
1467
1468 #[test]
1469 fn test_reset_does_not_panic_when_no_kv_cache() {
1470 let mut engine = InferenceEngine::new(EngineConfig::default());
1471 engine.reset(); }
1473
1474 #[test]
1475 fn test_apply_lora_adapters_errors_when_not_loaded() {
1476 use oxillama_arch::lora::LoadedLora;
1477 let mut engine = InferenceEngine::new(EngineConfig::default());
1478 let lora = LoadedLora {
1479 rank: 8,
1480 alpha: 1.0,
1481 adapters: std::collections::HashMap::new(),
1482 };
1483 let result = engine.apply_lora_adapters(&lora);
1484 assert!(
1485 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1486 "expected ModelNotLoaded, got {result:?}"
1487 );
1488 }
1489
1490 #[test]
1491 fn test_load_model_missing_file_errors() {
1492 let cfg = EngineConfig {
1493 model_path: "/nonexistent/path/model_abc_xyz.gguf".to_string(),
1494 ..EngineConfig::default()
1495 };
1496 let mut engine = InferenceEngine::new(cfg);
1497 let result = engine.load_model();
1498 assert!(
1499 matches!(result, Err(RuntimeError::ModelLoadError { .. })),
1500 "expected ModelLoadError for missing file, got {result:?}"
1501 );
1502 }
1503
1504 #[test]
1505 fn test_load_model_from_bytes_bad_magic_errors() {
1506 let cfg = EngineConfig::default();
1507 let mut engine = InferenceEngine::new(cfg);
1508 let bad_bytes = b"THIS IS NOT A GGUF FILE AT ALL";
1510 let result = engine.load_model_from_bytes(bad_bytes, "{}");
1511 assert!(
1512 result.is_err(),
1513 "load_model_from_bytes with garbage bytes should error, got Ok(())"
1514 );
1515 }
1516
1517 #[test]
1518 fn test_load_model_from_bytes_empty_errors() {
1519 let cfg = EngineConfig::default();
1520 let mut engine = InferenceEngine::new(cfg);
1521 let result = engine.load_model_from_bytes(&[], "{}");
1522 assert!(
1523 result.is_err(),
1524 "load_model_from_bytes with empty bytes should error"
1525 );
1526 }
1527
1528 #[test]
1529 fn test_engine_config_default_fields() {
1530 let cfg = EngineConfig::default();
1531 assert!(
1532 cfg.model_path.is_empty(),
1533 "default model_path should be empty"
1534 );
1535 assert!(
1536 cfg.tokenizer_path.is_none(),
1537 "default tokenizer_path should be None"
1538 );
1539 assert!(
1540 cfg.context_size.is_none(),
1541 "default context_size should be None"
1542 );
1543 assert_eq!(cfg.num_threads, 4, "default num_threads should be 4");
1544 }
1545
1546 #[test]
1547 fn test_engine_config_context_override() {
1548 let cfg = EngineConfig {
1549 context_size: Some(2048),
1550 ..EngineConfig::default()
1551 };
1552 assert_eq!(cfg.context_size, Some(2048));
1553 }
1554
1555 #[test]
1556 fn test_generate_with_config_errors_when_not_loaded_variant() {
1557 let mut engine = InferenceEngine::new(EngineConfig::default());
1559 let sc = SamplerConfig {
1560 temperature: 0.7,
1561 top_k: 40,
1562 ..SamplerConfig::default()
1563 };
1564 let result = engine.generate_with_config("test prompt", 5, sc, |_| {});
1565 assert!(
1566 matches!(result, Err(RuntimeError::ModelNotLoaded)),
1567 "expected ModelNotLoaded, got {result:?}"
1568 );
1569 }
1570
1571 #[test]
1574 fn test_load_model_existing_invalid_file_errors() {
1575 let mut tmp = std::env::temp_dir();
1576 tmp.push("oxillama_engine_bad_magic_test.gguf");
1577 std::fs::write(&tmp, b"NOT A GGUF FILE AT ALL - GARBAGE BYTES 0123456789")
1579 .expect("write temp file");
1580 let cfg = EngineConfig {
1581 model_path: tmp
1582 .to_str()
1583 .expect("temp path must be valid UTF-8")
1584 .to_string(),
1585 ..EngineConfig::default()
1586 };
1587 let mut engine = InferenceEngine::new(cfg);
1588 let result = engine.load_model();
1589 let _ = std::fs::remove_file(&tmp);
1591 assert!(
1592 result.is_err(),
1593 "load_model with invalid GGUF content should return Err"
1594 );
1595 }
1596
1597 #[test]
1599 fn test_is_loaded_remains_false_after_failed_load() {
1600 let cfg = EngineConfig {
1601 model_path: "/nonexistent/guaranteed_missing_model.gguf".to_string(),
1602 ..EngineConfig::default()
1603 };
1604 let mut engine = InferenceEngine::new(cfg);
1605 let _ = engine.load_model();
1607 assert!(
1608 !engine.is_loaded(),
1609 "is_loaded() must be false after a failed load_model()"
1610 );
1611 }
1612
1613 #[test]
1615 fn test_engine_config_clone_is_independent() {
1616 let original = EngineConfig {
1617 model_path: "original.gguf".to_string(),
1618 num_threads: 16,
1619 context_size: Some(4096),
1620 ..EngineConfig::default()
1621 };
1622 let mut cloned = original.clone();
1623 cloned.model_path = "cloned.gguf".to_string();
1624 cloned.num_threads = 1;
1625 assert_eq!(original.model_path, "original.gguf");
1627 assert_eq!(original.num_threads, 16);
1628 assert_eq!(original.context_size, Some(4096));
1629 }
1630
1631 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1638 fn make_loaded_engine() -> InferenceEngine {
1639 let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
1640 let tokenizer_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
1641 let mut engine = InferenceEngine::new(EngineConfig::default());
1642 engine
1643 .load_model_from_bytes(&model_bytes, tokenizer_json)
1644 .expect("synthetic GGUF must load successfully");
1645 engine
1646 }
1647
1648 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1651 #[test]
1652 fn test_load_model_from_bytes_succeeds() {
1653 let engine = make_loaded_engine();
1654 assert!(
1655 engine.is_loaded(),
1656 "is_loaded() must be true after a successful load_model_from_bytes()"
1657 );
1658 }
1659
1660 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1662 #[test]
1663 fn test_hidden_size_after_load() {
1664 let engine = make_loaded_engine();
1665 let hs = engine.hidden_size();
1666 assert_eq!(
1667 hs,
1668 Some(32),
1669 "hidden_size() must be Some(32) after loading the synthetic model, got {hs:?}"
1670 );
1671 }
1672
1673 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1675 #[test]
1676 fn test_tokenize_after_load() {
1677 let engine = make_loaded_engine();
1678 let result = engine.tokenize("abc");
1679 assert!(
1680 result.is_ok(),
1681 "tokenize() must return Ok after model is loaded, got {result:?}"
1682 );
1683 let tokens = result.expect("tokenize succeeded");
1684 assert!(
1685 !tokens.is_empty(),
1686 "tokenize('abc') must produce at least one token"
1687 );
1688 }
1689
1690 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1692 #[test]
1693 fn test_is_eos_after_load() {
1694 let engine = make_loaded_engine();
1695 assert!(
1696 engine.is_eos(2),
1697 "is_eos(2) must be true — </s> is the EOS token in the synthetic tokenizer"
1698 );
1699 assert!(
1700 !engine.is_eos(3),
1701 "is_eos(3) must be false — token 3 ('a') is not EOS"
1702 );
1703 }
1704
1705 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1707 #[test]
1708 fn test_decode_token_after_load() {
1709 let engine = make_loaded_engine();
1710 let result = engine.decode_token(3);
1711 assert!(
1712 result.is_ok(),
1713 "decode_token(3) must return Ok, got {result:?}"
1714 );
1715 }
1716
1717 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1720 #[test]
1721 fn test_generate_after_load() {
1722 let mut engine = make_loaded_engine();
1723 let result = engine.generate("a", 3, |_| {});
1724 assert!(
1725 result.is_ok(),
1726 "generate() must return Ok after model is loaded, got {result:?}"
1727 );
1728 }
1729
1730 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1732 #[test]
1733 fn test_generate_respects_max_tokens() {
1734 let mut engine = make_loaded_engine();
1735 let max = 5usize;
1736 let mut count = 0usize;
1738 let result = engine.generate("a", max, |_tok| {
1739 count += 1;
1740 });
1741 assert!(result.is_ok(), "generate() must return Ok, got {result:?}");
1742 assert!(
1743 count <= max,
1744 "callback was invoked {count} times but max_tokens={max}"
1745 );
1746 }
1747
1748 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1750 #[test]
1751 fn test_generate_streaming_calls_callback() {
1752 let mut engine = make_loaded_engine();
1753 let mut invocations = 0usize;
1754 let max_tokens = 4;
1755 let result = engine.generate("a", max_tokens, |_piece| {
1756 invocations += 1;
1757 });
1758 assert!(
1759 result.is_ok(),
1760 "generate() streaming path must return Ok, got {result:?}"
1761 );
1762 assert!(
1764 invocations <= max_tokens,
1765 "streaming callback fired {invocations} > max_tokens={max_tokens}"
1766 );
1767 }
1768
1769 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1771 #[test]
1772 fn test_embed_after_load() {
1773 let mut engine = make_loaded_engine();
1774 let result = engine.embed("a");
1775 assert!(
1776 result.is_ok(),
1777 "embed() must return Ok after model is loaded, got {result:?}"
1778 );
1779 let vec = result.expect("embed succeeded");
1780 assert!(!vec.is_empty(), "embed() must return a non-empty vector");
1781 }
1782
1783 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1785 #[test]
1786 fn test_embed_returns_hidden_size_vector() {
1787 let mut engine = make_loaded_engine();
1788 let vec = engine
1789 .embed("a")
1790 .expect("embed() must succeed after loading");
1791 assert_eq!(
1792 vec.len(),
1793 32,
1794 "embed() vector length must equal hidden_size=32, got {}",
1795 vec.len()
1796 );
1797 }
1798
1799 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1801 #[test]
1802 fn test_reload_model_succeeds() {
1803 let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
1804 let tokenizer_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
1805 let mut engine = InferenceEngine::new(EngineConfig::default());
1806
1807 engine
1809 .load_model_from_bytes(&model_bytes, tokenizer_json)
1810 .expect("first load must succeed");
1811 assert!(engine.is_loaded(), "is_loaded() after first load");
1812
1813 engine
1815 .load_model_from_bytes(&model_bytes, tokenizer_json)
1816 .expect("second (re)load must succeed");
1817 assert!(
1818 engine.is_loaded(),
1819 "is_loaded() after reload must still be true"
1820 );
1821 }
1822
1823 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1825 #[test]
1826 fn test_vocab_bytes_some_after_load() {
1827 let engine = make_loaded_engine();
1828 let vb = engine.vocab_bytes();
1829 assert!(
1830 vb.is_some(),
1831 "vocab_bytes() must be Some after model is loaded"
1832 );
1833 let entries = vb.expect("vocab_bytes is Some");
1834 assert!(
1835 !entries.is_empty(),
1836 "vocab_bytes() must contain at least one entry"
1837 );
1838 }
1839
1840 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1842 #[test]
1843 fn test_model_config_some_after_load() {
1844 let engine = make_loaded_engine();
1845 let cfg = engine.model_config();
1846 assert!(cfg.is_some(), "model_config() must be Some after loading");
1847 let mc = cfg.expect("model_config is Some");
1848 assert_eq!(mc.architecture, "llama", "architecture must be 'llama'");
1849 assert_eq!(
1850 mc.num_layers, 1,
1851 "num_layers must be 1 for the synthetic model"
1852 );
1853 assert_eq!(mc.vocab_size, 32, "vocab_size must be 32");
1854 }
1855
1856 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1858 #[test]
1859 fn test_reset_when_loaded_does_not_panic() {
1860 let mut engine = make_loaded_engine();
1861 engine.reset(); assert!(
1864 engine.is_loaded(),
1865 "is_loaded() must still be true after reset()"
1866 );
1867 assert_eq!(engine.hidden_size(), Some(32));
1868 }
1869
1870 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1880 #[test]
1881 fn test_generate_qwen3_arch() {
1882 use oxillama_gguf::test_utils::{build_minimal_qwen3_gguf, minimal_tokenizer_json};
1883
1884 let bytes = build_minimal_qwen3_gguf();
1885 let json = minimal_tokenizer_json();
1886 let mut engine = InferenceEngine::new(EngineConfig::default());
1887 engine
1888 .load_model_from_bytes(&bytes, json)
1889 .expect("test: load qwen3");
1890 assert!(engine.is_loaded(), "qwen3: is_loaded() must be true");
1891 let _out = engine
1892 .generate("abc", 2, |_| {})
1893 .expect("test: generate qwen3");
1894 }
1895
1896 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1898 #[test]
1899 fn test_embed_qwen3_arch() {
1900 use oxillama_gguf::test_utils::{build_minimal_qwen3_gguf, minimal_tokenizer_json};
1901
1902 let bytes = build_minimal_qwen3_gguf();
1903 let json = minimal_tokenizer_json();
1904 let mut engine = InferenceEngine::new(EngineConfig::default());
1905 engine
1906 .load_model_from_bytes(&bytes, json)
1907 .expect("test: load qwen3 for embed");
1908 let vec = engine.embed("abc").expect("test: embed qwen3");
1909 assert_eq!(
1910 vec.len(),
1911 32,
1912 "qwen3 embed must return hidden_size=32 vector"
1913 );
1914 }
1915
1916 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1918 #[test]
1919 fn test_generate_mistral_arch() {
1920 use oxillama_gguf::test_utils::{build_minimal_mistral_gguf, minimal_tokenizer_json};
1921
1922 let bytes = build_minimal_mistral_gguf();
1923 let json = minimal_tokenizer_json();
1924 let mut engine = InferenceEngine::new(EngineConfig::default());
1925 engine
1926 .load_model_from_bytes(&bytes, json)
1927 .expect("test: load mistral");
1928 assert!(engine.is_loaded(), "mistral: is_loaded() must be true");
1929 let _out = engine
1930 .generate("abc", 2, |_| {})
1931 .expect("test: generate mistral");
1932 }
1933
1934 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1936 #[test]
1937 fn test_generate_gemma_arch() {
1938 use oxillama_gguf::test_utils::{build_minimal_gemma_gguf, minimal_tokenizer_json};
1939
1940 let bytes = build_minimal_gemma_gguf();
1941 let json = minimal_tokenizer_json();
1942 let mut engine = InferenceEngine::new(EngineConfig::default());
1943 engine
1944 .load_model_from_bytes(&bytes, json)
1945 .expect("test: load gemma");
1946 assert!(engine.is_loaded(), "gemma: is_loaded() must be true");
1947 let _out = engine
1948 .generate("abc", 2, |_| {})
1949 .expect("test: generate gemma");
1950 }
1951
1952 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1954 #[test]
1955 fn test_embed_gemma_arch() {
1956 use oxillama_gguf::test_utils::{build_minimal_gemma_gguf, minimal_tokenizer_json};
1957
1958 let bytes = build_minimal_gemma_gguf();
1959 let json = minimal_tokenizer_json();
1960 let mut engine = InferenceEngine::new(EngineConfig::default());
1961 engine
1962 .load_model_from_bytes(&bytes, json)
1963 .expect("test: load gemma for embed");
1964 let vec = engine.embed("abc").expect("test: embed gemma");
1965 assert_eq!(
1966 vec.len(),
1967 32,
1968 "gemma embed must return hidden_size=32 vector"
1969 );
1970 }
1971
1972 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1974 #[test]
1975 fn test_generate_phi3_arch() {
1976 use oxillama_gguf::test_utils::{build_minimal_phi3_gguf, minimal_tokenizer_json};
1977
1978 let bytes = build_minimal_phi3_gguf();
1979 let json = minimal_tokenizer_json();
1980 let mut engine = InferenceEngine::new(EngineConfig::default());
1981 engine
1982 .load_model_from_bytes(&bytes, json)
1983 .expect("test: load phi3");
1984 assert!(engine.is_loaded(), "phi3: is_loaded() must be true");
1985 let _out = engine
1986 .generate("abc", 2, |_| {})
1987 .expect("test: generate phi3");
1988 }
1989
1990 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
1992 #[test]
1993 fn test_generate_command_r_arch() {
1994 use oxillama_gguf::test_utils::{build_minimal_command_r_gguf, minimal_tokenizer_json};
1995
1996 let bytes = build_minimal_command_r_gguf();
1997 let json = minimal_tokenizer_json();
1998 let mut engine = InferenceEngine::new(EngineConfig::default());
1999 engine
2000 .load_model_from_bytes(&bytes, json)
2001 .expect("test: load command-r");
2002 assert!(engine.is_loaded(), "command-r: is_loaded() must be true");
2003 let _out = engine
2004 .generate("abc", 2, |_| {})
2005 .expect("test: generate command-r");
2006 }
2007
2008 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
2010 #[test]
2011 fn test_generate_starcoder_arch() {
2012 use oxillama_gguf::test_utils::{build_minimal_starcoder_gguf, minimal_tokenizer_json};
2013
2014 let bytes = build_minimal_starcoder_gguf();
2015 let json = minimal_tokenizer_json();
2016 let mut engine = InferenceEngine::new(EngineConfig::default());
2017 engine
2018 .load_model_from_bytes(&bytes, json)
2019 .expect("test: load starcoder");
2020 assert!(engine.is_loaded(), "starcoder: is_loaded() must be true");
2021 let _out = engine
2022 .generate("abc", 2, |_| {})
2023 .expect("test: generate starcoder");
2024 }
2025
2026 #[test]
2031 fn lora_stack_push_pop() {
2032 use oxillama_arch::lora::LoadedLora;
2033 use oxillama_quant::LoraAdapter;
2034 use std::collections::HashMap;
2035 use std::sync::Arc;
2036
2037 fn make_lora() -> Arc<LoadedLora> {
2038 let adapter = LoraAdapter::new(vec![0.0f32; 4 * 8], vec![0.0f32; 8 * 4], 4, 1.0, 8, 8)
2039 .expect("valid lora adapter");
2040 let mut adapters = HashMap::new();
2041 adapters.insert("test.weight".to_string(), Arc::new(adapter));
2042 Arc::new(LoadedLora {
2043 adapters,
2044 rank: 4,
2045 alpha: 1.0,
2046 })
2047 }
2048
2049 let mut engine = InferenceEngine::new(EngineConfig::default());
2050
2051 assert!(engine.lora_stack().is_empty());
2053 assert_eq!(engine.lora_stack().len(), 0);
2054
2055 engine.push_lora(make_lora(), 1.0);
2057 engine.push_lora(make_lora(), 0.5);
2058 assert_eq!(engine.lora_stack().len(), 2);
2059 assert!(!engine.lora_stack().is_empty());
2060
2061 let popped = engine.pop_lora();
2063 assert!(popped.is_some());
2064 let (_, scale) = popped.expect("pop must return Some");
2065 assert!((scale - 0.5).abs() < 1e-6);
2066 assert_eq!(engine.lora_stack().len(), 1);
2067
2068 engine.clear_loras();
2070 assert!(engine.lora_stack().is_empty());
2071
2072 assert!(engine.pop_lora().is_none());
2074 }
2075
2076 #[test]
2077 fn lora_apply_stack_errors_when_not_loaded() {
2078 use oxillama_arch::lora::LoadedLora;
2079 use oxillama_quant::LoraAdapter;
2080 use std::collections::HashMap;
2081 use std::sync::Arc;
2082
2083 let adapter = LoraAdapter::new(vec![0.0f32; 4 * 8], vec![0.0f32; 8 * 4], 4, 1.0, 8, 8)
2084 .expect("valid lora adapter");
2085 let mut adapters = HashMap::new();
2086 adapters.insert("test.weight".to_string(), Arc::new(adapter));
2087 let lora = Arc::new(LoadedLora {
2088 adapters,
2089 rank: 4,
2090 alpha: 1.0,
2091 });
2092
2093 let mut engine = InferenceEngine::new(EngineConfig::default());
2094 engine.push_lora(lora, 1.0);
2095 let result = engine.apply_lora_stack();
2096 assert!(
2097 matches!(result, Err(RuntimeError::ModelNotLoaded)),
2098 "expected ModelNotLoaded, got {:?}",
2099 result
2100 );
2101 }
2102
2103 #[test]
2105 fn unapply_all_loras_noop_when_unloaded() {
2106 let mut engine = InferenceEngine::new(EngineConfig::default());
2107 engine.unapply_all_loras(); assert!(!engine.is_loaded());
2109 }
2110
2111 #[test]
2113 fn prime_with_prefix_returns_model_not_loaded() {
2114 use crate::kv_cache::prefix::{PrefixCacheConfig, PrefixKvCache};
2115 use crate::kv_cache::KvCache;
2116
2117 let mut engine = InferenceEngine::new(EngineConfig::default());
2118
2119 let mut prefix_cache = PrefixKvCache::new(PrefixCacheConfig {
2121 max_entries: 16,
2122 max_memory_bytes: 1024 * 1024,
2123 min_prefix_len: 1, });
2125 let kv = KvCache::new(1, 4, 32);
2126 let tokens: Vec<u32> = vec![1, 2, 3];
2127 prefix_cache.store(&tokens, &kv, 3, 4, 1);
2128
2129 if let Some((match_len, cached)) = prefix_cache.lookup(&tokens) {
2130 let suffix = &tokens[match_len.min(tokens.len() - 1)..];
2132 let result = engine.prime_with_prefix(cached, match_len.saturating_sub(1), suffix);
2133 assert!(
2134 matches!(result, Err(RuntimeError::ModelNotLoaded)),
2135 "unloaded engine must return ModelNotLoaded, got {:?}",
2136 result
2137 );
2138 }
2139 }
2143}