use std::path::Path;
use oxillama_arch::config::ModelConfig;
use oxillama_arch::traits::{ForwardPass, KvCacheAccess};
use oxillama_gguf::GgufModel;
use crate::error::{RuntimeError, RuntimeResult};
use crate::kv_cache::KvCache;
use crate::sampling::{Sampler, SamplerConfig};
use crate::tokenizer_bridge::TokenizerBridge;
#[derive(Debug, Clone)]
pub struct EngineConfig {
pub model_path: String,
pub tokenizer_path: Option<String>,
pub context_size: Option<usize>,
pub num_threads: usize,
pub sampler: SamplerConfig,
pub prefill_chunk_size: usize,
}
impl Default for EngineConfig {
fn default() -> Self {
Self {
model_path: String::new(),
tokenizer_path: None,
context_size: None,
num_threads: 4,
sampler: SamplerConfig::default(),
prefill_chunk_size: 512,
}
}
}
pub struct InferenceEngine {
config: EngineConfig,
gguf_model: Option<GgufModel>,
model_config: Option<ModelConfig>,
forward_pass: Option<Box<dyn ForwardPass>>,
kv_cache: Option<KvCache>,
tokenizer: Option<TokenizerBridge>,
eos_token_id: Option<u32>,
}
impl InferenceEngine {
pub fn new(config: EngineConfig) -> Self {
Self {
config,
gguf_model: None,
model_config: None,
forward_pass: None,
kv_cache: None,
tokenizer: None,
eos_token_id: None,
}
}
pub fn load_model_from_bytes(
&mut self,
model_bytes: &[u8],
tokenizer_json: &str,
) -> RuntimeResult<()> {
let gguf = GgufModel::from_bytes(model_bytes.to_vec())?;
tracing::info!(
arch = gguf.architecture().unwrap_or("unknown"),
tensors = gguf.file.header.tensor_count,
"GGUF file parsed from bytes"
);
let mut model_config = ModelConfig::from_metadata(&gguf.file.metadata)?;
if let Some(ctx) = self.config.context_size {
model_config.max_context_length = ctx;
}
tracing::info!(
arch = %model_config.architecture,
layers = model_config.num_layers,
hidden = model_config.hidden_size,
heads = model_config.num_attention_heads,
kv_heads = model_config.num_kv_heads,
vocab = model_config.vocab_size,
ctx = model_config.max_context_length,
"model config loaded from bytes"
);
let forward_pass = build_forward_pass(&gguf, &model_config)?;
let kv_dim = model_config.num_kv_heads * model_config.head_dim;
let kv_cache = KvCache::new(
model_config.num_layers,
model_config.max_context_length,
kv_dim,
);
tracing::info!(
layers = model_config.num_layers,
max_ctx = model_config.max_context_length,
kv_dim = kv_dim,
"KV cache initialized (from-bytes path)"
);
let tokenizer = TokenizerBridge::from_bytes(tokenizer_json.as_bytes())?;
let eos_token_id = tokenizer.eos_token_id();
tracing::info!(
vocab_size = tokenizer.vocab_size(),
eos = ?eos_token_id,
"tokenizer loaded from JSON string"
);
self.model_config = Some(model_config);
self.forward_pass = Some(forward_pass);
self.kv_cache = Some(kv_cache);
self.tokenizer = Some(tokenizer);
self.eos_token_id = eos_token_id;
self.gguf_model = Some(gguf);
Ok(())
}
pub fn load_model(&mut self) -> RuntimeResult<()> {
let path = Path::new(&self.config.model_path);
if !path.exists() {
return Err(RuntimeError::ModelLoadError {
message: format!("model file not found: {}", self.config.model_path),
});
}
tracing::info!(path = %self.config.model_path, "loading GGUF model");
let gguf = GgufModel::load(&self.config.model_path)?;
tracing::info!(
arch = gguf.architecture().unwrap_or("unknown"),
tensors = gguf.file.header.tensor_count,
"GGUF file parsed"
);
let mut model_config = ModelConfig::from_metadata(&gguf.file.metadata)?;
if let Some(ctx) = self.config.context_size {
model_config.max_context_length = ctx;
}
tracing::info!(
arch = %model_config.architecture,
layers = model_config.num_layers,
hidden = model_config.hidden_size,
heads = model_config.num_attention_heads,
kv_heads = model_config.num_kv_heads,
vocab = model_config.vocab_size,
ctx = model_config.max_context_length,
"model config loaded"
);
let forward_pass = build_forward_pass(&gguf, &model_config)?;
let kv_dim = model_config.num_kv_heads * model_config.head_dim;
let kv_cache = KvCache::new(
model_config.num_layers,
model_config.max_context_length,
kv_dim,
);
tracing::info!(
layers = model_config.num_layers,
max_ctx = model_config.max_context_length,
kv_dim = kv_dim,
"KV cache initialized"
);
let tokenizer = load_tokenizer(&self.config, &gguf)?;
let eos_token_id = tokenizer.eos_token_id();
tracing::info!(
vocab_size = tokenizer.vocab_size(),
eos = ?eos_token_id,
"tokenizer loaded"
);
self.model_config = Some(model_config);
self.forward_pass = Some(forward_pass);
self.kv_cache = Some(kv_cache);
self.tokenizer = Some(tokenizer);
self.eos_token_id = eos_token_id;
self.gguf_model = Some(gguf);
Ok(())
}
pub fn generate(
&mut self,
prompt: &str,
max_tokens: usize,
mut callback: impl FnMut(&str),
) -> RuntimeResult<String> {
let tokenizer = self
.tokenizer
.as_ref()
.ok_or(RuntimeError::ModelNotLoaded)?;
let forward_pass = self
.forward_pass
.as_mut()
.ok_or(RuntimeError::ModelNotLoaded)?;
let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
let prompt_tokens = tokenizer.encode(prompt)?;
if prompt_tokens.is_empty() {
return Ok(String::new());
}
tracing::debug!(n_tokens = prompt_tokens.len(), "prompt tokenized");
let mut recent_tokens = prompt_tokens.clone();
let mut generated_tokens: Vec<u32> = Vec::new();
let mut output_text = String::new();
let chunk_size = if self.config.prefill_chunk_size == 0 {
prompt_tokens.len()
} else {
self.config.prefill_chunk_size
};
let mut logits = if prompt_tokens.len() <= chunk_size {
tracing::debug!(
chunk = 1,
tokens = prompt_tokens.len(),
"prefill: single batch"
);
forward_pass.forward(&prompt_tokens, kv_cache)?
} else {
let n_chunks = prompt_tokens.len().div_ceil(chunk_size);
tracing::debug!(
n_chunks = n_chunks,
chunk_size = chunk_size,
total = prompt_tokens.len(),
"prefill: chunked"
);
let mut last_logits = Vec::new();
for (i, chunk) in prompt_tokens.chunks(chunk_size).enumerate() {
tracing::trace!(
chunk_idx = i,
chunk_len = chunk.len(),
kv_pos = kv_cache.seq_len(),
"prefill chunk"
);
last_logits = forward_pass.forward(chunk, kv_cache)?;
}
last_logits
};
let mut sampler = Sampler::new(self.config.sampler.clone());
for _step in 0..max_tokens {
let next_token = sampler.sample(&logits, &recent_tokens);
if Some(next_token) == self.eos_token_id {
tracing::debug!("EOS token generated, stopping");
break;
}
if kv_cache.seq_len() >= forward_pass.max_context_length() {
tracing::warn!("context length reached, stopping generation");
break;
}
let token_text = tokenizer.decode(&[next_token])?;
callback(&token_text);
output_text.push_str(&token_text);
recent_tokens.push(next_token);
generated_tokens.push(next_token);
logits = forward_pass.forward(&[next_token], kv_cache)?;
}
tracing::info!(
prompt_tokens = prompt_tokens.len(),
generated_tokens = generated_tokens.len(),
"generation complete"
);
Ok(output_text)
}
pub fn generate_with_config(
&mut self,
prompt: &str,
max_tokens: usize,
sampler_config: SamplerConfig,
mut callback: impl FnMut(&str),
) -> RuntimeResult<String> {
let tokenizer = self
.tokenizer
.as_ref()
.ok_or(RuntimeError::ModelNotLoaded)?;
let forward_pass = self
.forward_pass
.as_mut()
.ok_or(RuntimeError::ModelNotLoaded)?;
let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
let prompt_tokens = tokenizer.encode(prompt)?;
if prompt_tokens.is_empty() {
return Ok(String::new());
}
let mut recent_tokens = prompt_tokens.clone();
let mut generated_tokens: Vec<u32> = Vec::new();
let mut output_text = String::new();
for &token in &prompt_tokens[..prompt_tokens.len() - 1] {
forward_pass.forward(&[token], kv_cache)?;
}
let last = *prompt_tokens.last().ok_or(RuntimeError::ModelNotLoaded)?;
let mut logits = forward_pass.forward(&[last], kv_cache)?;
let mut sampler = Sampler::new(sampler_config);
for _step in 0..max_tokens {
let next_token = sampler.sample(&logits, &recent_tokens);
if Some(next_token) == self.eos_token_id {
tracing::debug!("EOS token generated, stopping");
break;
}
if kv_cache.seq_len() >= forward_pass.max_context_length() {
tracing::warn!("context length reached, stopping generation");
break;
}
let token_text = tokenizer.decode(&[next_token])?;
callback(&token_text);
output_text.push_str(&token_text);
recent_tokens.push(next_token);
generated_tokens.push(next_token);
logits = forward_pass.forward(&[next_token], kv_cache)?;
}
tracing::info!(
prompt_tokens = prompt_tokens.len(),
generated_tokens = generated_tokens.len(),
"generation (with custom config) complete"
);
Ok(output_text)
}
pub fn vocab_bytes(&self) -> Option<Vec<(u32, Vec<u8>)>> {
self.tokenizer.as_ref().map(|t| t.vocab_bytes())
}
pub fn apply_lora_adapters(
&mut self,
lora: &oxillama_arch::lora::LoadedLora,
) -> RuntimeResult<()> {
let fp = self
.forward_pass
.as_mut()
.ok_or(RuntimeError::ModelNotLoaded)?;
fp.apply_lora(lora).map_err(RuntimeError::Arch)?;
Ok(())
}
pub fn is_loaded(&self) -> bool {
self.forward_pass.is_some()
}
pub fn config(&self) -> &EngineConfig {
&self.config
}
pub fn model_config(&self) -> Option<&ModelConfig> {
self.model_config.as_ref()
}
pub fn reset(&mut self) {
if let Some(ref mut cache) = self.kv_cache {
cache.clear();
}
}
pub fn tokenize(&self, text: &str) -> RuntimeResult<Vec<u32>> {
let tokenizer = self
.tokenizer
.as_ref()
.ok_or(RuntimeError::ModelNotLoaded)?;
tokenizer.encode(text)
}
pub fn prefill(&mut self, tokens: &[u32]) -> RuntimeResult<()> {
if tokens.is_empty() {
return Ok(());
}
let forward_pass = self
.forward_pass
.as_mut()
.ok_or(RuntimeError::ModelNotLoaded)?;
let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
for &token in tokens {
forward_pass.forward(&[token], kv_cache)?;
}
Ok(())
}
pub fn forward_one(&mut self, token: u32) -> RuntimeResult<Vec<f32>> {
let forward_pass = self
.forward_pass
.as_mut()
.ok_or(RuntimeError::ModelNotLoaded)?;
let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
let logits = forward_pass.forward(&[token], kv_cache)?;
Ok(logits)
}
pub fn is_eos(&self, token: u32) -> bool {
self.eos_token_id == Some(token)
}
pub fn decode_token(&self, token: u32) -> RuntimeResult<String> {
let tokenizer = self
.tokenizer
.as_ref()
.ok_or(RuntimeError::ModelNotLoaded)?;
tokenizer.decode(&[token])
}
pub fn hidden_size(&self) -> Option<usize> {
self.model_config.as_ref().map(|c| c.hidden_size)
}
pub fn embed(&mut self, text: &str) -> RuntimeResult<Vec<f32>> {
self.reset();
let forward_pass = self
.forward_pass
.as_mut()
.ok_or(RuntimeError::ModelNotLoaded)?;
let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
let tokens = {
let tok = self
.tokenizer
.as_ref()
.ok_or(RuntimeError::ModelNotLoaded)?;
tok.encode(text)?
};
if tokens.is_empty() {
let dim = self
.model_config
.as_ref()
.map(|c| c.hidden_size)
.unwrap_or(0);
return Ok(vec![0.0f32; dim]);
}
let hidden = forward_pass.embed(&tokens, kv_cache)?;
let norm: f32 = hidden.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-9 {
Ok(hidden.into_iter().map(|x| x / norm).collect())
} else {
Ok(hidden)
}
}
pub fn embed_batch(&mut self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
let tokenizer = self
.tokenizer
.as_ref()
.ok_or(RuntimeError::ModelNotLoaded)?;
let forward_pass = self
.forward_pass
.as_mut()
.ok_or(RuntimeError::ModelNotLoaded)?;
let kv_cache = self.kv_cache.as_mut().ok_or(RuntimeError::ModelNotLoaded)?;
let hidden_size = forward_pass.hidden_size();
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
kv_cache.clear();
let tokens = tokenizer.encode(text)?;
if tokens.is_empty() {
embeddings.push(vec![0.0f32; hidden_size]);
continue;
}
let hidden_state = forward_pass.embed(&tokens, kv_cache)?;
let norm: f32 = hidden_state.iter().map(|x| x * x).sum::<f32>().sqrt();
let embedding = if norm > 1e-12 {
hidden_state.iter().map(|x| x / norm).collect()
} else {
hidden_state
};
embeddings.push(embedding);
}
Ok(embeddings)
}
}
fn build_forward_pass(
gguf: &GgufModel,
config: &ModelConfig,
) -> RuntimeResult<Box<dyn ForwardPass>> {
match config.architecture.as_str() {
#[cfg(feature = "llama")]
"llama" => {
let model = oxillama_arch::llama::load_llama_from_gguf(gguf, config)?;
Ok(Box::new(model))
}
#[cfg(feature = "qwen3")]
"qwen3" => {
let model = oxillama_arch::qwen3::load_qwen3_from_gguf(gguf, config)?;
Ok(Box::new(model))
}
#[cfg(feature = "mistral")]
"mistral" => {
let model = oxillama_arch::mistral::load_mistral_from_gguf(gguf, config)?;
Ok(Box::new(model))
}
#[cfg(feature = "gemma")]
"gemma" | "gemma2" | "gemma3" => {
let model = oxillama_arch::gemma::load_gemma_from_gguf(gguf, config)?;
Ok(Box::new(model))
}
#[cfg(feature = "phi")]
"phi3" | "phi" => {
let model = oxillama_arch::phi::load_phi_from_gguf(gguf, config)?;
Ok(Box::new(model))
}
#[cfg(feature = "command-r")]
"command-r" => {
let model = oxillama_arch::command_r::load_command_r_from_gguf(gguf, config)?;
Ok(Box::new(model))
}
#[cfg(feature = "starcoder")]
"starcoder" => {
let model = oxillama_arch::starcoder::load_starcoder_from_gguf(gguf, config)?;
Ok(Box::new(model))
}
arch => Err(RuntimeError::ModelLoadError {
message: format!("unsupported architecture: '{arch}'"),
}),
}
}
fn load_tokenizer(config: &EngineConfig, gguf: &GgufModel) -> RuntimeResult<TokenizerBridge> {
if let Some(ref path) = config.tokenizer_path {
return TokenizerBridge::from_file(path);
}
if let Some(tokenizer_json) = gguf
.file
.metadata
.get("tokenizer.ggml.tokens")
.and_then(|_| {
gguf.file
.metadata
.get("tokenizer.huggingface.json")
.and_then(|v| v.as_str())
})
{
return TokenizerBridge::from_bytes(tokenizer_json.as_bytes());
}
let model_dir = Path::new(&config.model_path)
.parent()
.unwrap_or(Path::new("."));
let tokenizer_path = model_dir.join("tokenizer.json");
if tokenizer_path.exists() {
return TokenizerBridge::from_file(tokenizer_path.to_str().unwrap_or("tokenizer.json"));
}
Err(RuntimeError::TokenizerError {
message: "no tokenizer found: provide --tokenizer path or place tokenizer.json next to the model file".to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embed_returns_err_when_not_loaded() {
let mut engine = InferenceEngine::new(EngineConfig::default());
let result = engine.embed("hello world");
assert!(
result.is_err(),
"embed() should return Err when no model is loaded"
);
}
#[test]
fn test_hidden_size_none_when_not_loaded() {
let engine = InferenceEngine::new(EngineConfig::default());
assert!(
engine.hidden_size().is_none(),
"hidden_size() should be None before load_model()"
);
}
#[test]
fn test_is_loaded_false_initially() {
let engine = InferenceEngine::new(EngineConfig::default());
assert!(!engine.is_loaded());
}
#[test]
fn test_model_config_none_when_not_loaded() {
let engine = InferenceEngine::new(EngineConfig::default());
assert!(engine.model_config().is_none());
}
#[test]
fn test_config_roundtrip() {
let cfg = EngineConfig {
model_path: "test.gguf".to_string(),
num_threads: 8,
..EngineConfig::default()
};
let engine = InferenceEngine::new(cfg);
assert_eq!(engine.config().model_path, "test.gguf");
assert_eq!(engine.config().num_threads, 8);
}
#[test]
fn test_generate_errors_when_not_loaded() {
let mut engine = InferenceEngine::new(EngineConfig::default());
let result = engine.generate("hello", 10, |_| {});
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded, got {result:?}"
);
}
#[test]
fn test_generate_with_config_errors_when_not_loaded() {
let mut engine = InferenceEngine::new(EngineConfig::default());
let result = engine.generate_with_config("hello", 5, SamplerConfig::greedy(), |_| {});
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded, got {result:?}"
);
}
#[test]
fn test_tokenize_errors_when_not_loaded() {
let engine = InferenceEngine::new(EngineConfig::default());
let result = engine.tokenize("hello world");
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded, got {result:?}"
);
}
#[test]
fn test_prefill_errors_when_not_loaded() {
let mut engine = InferenceEngine::new(EngineConfig::default());
let result = engine.prefill(&[1, 2, 3]);
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded, got {result:?}"
);
}
#[test]
fn test_prefill_empty_slice_ok_when_no_model() {
let mut engine = InferenceEngine::new(EngineConfig::default());
let result = engine.prefill(&[]);
assert!(result.is_ok(), "empty prefill should be Ok, got {result:?}");
}
#[test]
fn test_forward_one_errors_when_not_loaded() {
let mut engine = InferenceEngine::new(EngineConfig::default());
let result = engine.forward_one(42);
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded, got {result:?}"
);
}
#[test]
fn test_decode_token_errors_when_not_loaded() {
let engine = InferenceEngine::new(EngineConfig::default());
let result = engine.decode_token(1);
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded, got {result:?}"
);
}
#[test]
fn test_is_eos_false_when_not_loaded() {
let engine = InferenceEngine::new(EngineConfig::default());
assert!(!engine.is_eos(0));
assert!(!engine.is_eos(u32::MAX));
}
#[test]
fn test_vocab_bytes_none_when_not_loaded() {
let engine = InferenceEngine::new(EngineConfig::default());
assert!(engine.vocab_bytes().is_none());
}
#[test]
fn test_reset_does_not_panic_when_no_kv_cache() {
let mut engine = InferenceEngine::new(EngineConfig::default());
engine.reset(); }
#[test]
fn test_apply_lora_adapters_errors_when_not_loaded() {
use oxillama_arch::lora::LoadedLora;
let mut engine = InferenceEngine::new(EngineConfig::default());
let lora = LoadedLora {
rank: 8,
alpha: 1.0,
adapters: std::collections::HashMap::new(),
};
let result = engine.apply_lora_adapters(&lora);
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded, got {result:?}"
);
}
#[test]
fn test_load_model_missing_file_errors() {
let cfg = EngineConfig {
model_path: "/nonexistent/path/model_abc_xyz.gguf".to_string(),
..EngineConfig::default()
};
let mut engine = InferenceEngine::new(cfg);
let result = engine.load_model();
assert!(
matches!(result, Err(RuntimeError::ModelLoadError { .. })),
"expected ModelLoadError for missing file, got {result:?}"
);
}
#[test]
fn test_load_model_from_bytes_bad_magic_errors() {
let cfg = EngineConfig::default();
let mut engine = InferenceEngine::new(cfg);
let bad_bytes = b"THIS IS NOT A GGUF FILE AT ALL";
let result = engine.load_model_from_bytes(bad_bytes, "{}");
assert!(
result.is_err(),
"load_model_from_bytes with garbage bytes should error, got Ok(())"
);
}
#[test]
fn test_load_model_from_bytes_empty_errors() {
let cfg = EngineConfig::default();
let mut engine = InferenceEngine::new(cfg);
let result = engine.load_model_from_bytes(&[], "{}");
assert!(
result.is_err(),
"load_model_from_bytes with empty bytes should error"
);
}
#[test]
fn test_engine_config_default_fields() {
let cfg = EngineConfig::default();
assert!(
cfg.model_path.is_empty(),
"default model_path should be empty"
);
assert!(
cfg.tokenizer_path.is_none(),
"default tokenizer_path should be None"
);
assert!(
cfg.context_size.is_none(),
"default context_size should be None"
);
assert_eq!(cfg.num_threads, 4, "default num_threads should be 4");
}
#[test]
fn test_engine_config_context_override() {
let cfg = EngineConfig {
context_size: Some(2048),
..EngineConfig::default()
};
assert_eq!(cfg.context_size, Some(2048));
}
#[test]
fn test_generate_with_config_errors_when_not_loaded_variant() {
let mut engine = InferenceEngine::new(EngineConfig::default());
let sc = SamplerConfig {
temperature: 0.7,
top_k: 40,
..SamplerConfig::default()
};
let result = engine.generate_with_config("test prompt", 5, sc, |_| {});
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded, got {result:?}"
);
}
#[test]
fn test_load_model_existing_invalid_file_errors() {
let mut tmp = std::env::temp_dir();
tmp.push("oxillama_engine_bad_magic_test.gguf");
std::fs::write(&tmp, b"NOT A GGUF FILE AT ALL - GARBAGE BYTES 0123456789")
.expect("write temp file");
let cfg = EngineConfig {
model_path: tmp
.to_str()
.expect("temp path must be valid UTF-8")
.to_string(),
..EngineConfig::default()
};
let mut engine = InferenceEngine::new(cfg);
let result = engine.load_model();
let _ = std::fs::remove_file(&tmp);
assert!(
result.is_err(),
"load_model with invalid GGUF content should return Err"
);
}
#[test]
fn test_is_loaded_remains_false_after_failed_load() {
let cfg = EngineConfig {
model_path: "/nonexistent/guaranteed_missing_model.gguf".to_string(),
..EngineConfig::default()
};
let mut engine = InferenceEngine::new(cfg);
let _ = engine.load_model();
assert!(
!engine.is_loaded(),
"is_loaded() must be false after a failed load_model()"
);
}
#[test]
fn test_engine_config_clone_is_independent() {
let original = EngineConfig {
model_path: "original.gguf".to_string(),
num_threads: 16,
context_size: Some(4096),
..EngineConfig::default()
};
let mut cloned = original.clone();
cloned.model_path = "cloned.gguf".to_string();
cloned.num_threads = 1;
assert_eq!(original.model_path, "original.gguf");
assert_eq!(original.num_threads, 16);
assert_eq!(original.context_size, Some(4096));
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
fn make_loaded_engine() -> InferenceEngine {
let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
let tokenizer_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&model_bytes, tokenizer_json)
.expect("synthetic GGUF must load successfully");
engine
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_load_model_from_bytes_succeeds() {
let engine = make_loaded_engine();
assert!(
engine.is_loaded(),
"is_loaded() must be true after a successful load_model_from_bytes()"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_hidden_size_after_load() {
let engine = make_loaded_engine();
let hs = engine.hidden_size();
assert_eq!(
hs,
Some(32),
"hidden_size() must be Some(32) after loading the synthetic model, got {hs:?}"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_tokenize_after_load() {
let engine = make_loaded_engine();
let result = engine.tokenize("abc");
assert!(
result.is_ok(),
"tokenize() must return Ok after model is loaded, got {result:?}"
);
let tokens = result.expect("tokenize succeeded");
assert!(
!tokens.is_empty(),
"tokenize('abc') must produce at least one token"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_is_eos_after_load() {
let engine = make_loaded_engine();
assert!(
engine.is_eos(2),
"is_eos(2) must be true — </s> is the EOS token in the synthetic tokenizer"
);
assert!(
!engine.is_eos(3),
"is_eos(3) must be false — token 3 ('a') is not EOS"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_decode_token_after_load() {
let engine = make_loaded_engine();
let result = engine.decode_token(3);
assert!(
result.is_ok(),
"decode_token(3) must return Ok, got {result:?}"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_after_load() {
let mut engine = make_loaded_engine();
let result = engine.generate("a", 3, |_| {});
assert!(
result.is_ok(),
"generate() must return Ok after model is loaded, got {result:?}"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_respects_max_tokens() {
let mut engine = make_loaded_engine();
let max = 5usize;
let mut count = 0usize;
let result = engine.generate("a", max, |_tok| {
count += 1;
});
assert!(result.is_ok(), "generate() must return Ok, got {result:?}");
assert!(
count <= max,
"callback was invoked {count} times but max_tokens={max}"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_streaming_calls_callback() {
let mut engine = make_loaded_engine();
let mut invocations = 0usize;
let max_tokens = 4;
let result = engine.generate("a", max_tokens, |_piece| {
invocations += 1;
});
assert!(
result.is_ok(),
"generate() streaming path must return Ok, got {result:?}"
);
assert!(
invocations <= max_tokens,
"streaming callback fired {invocations} > max_tokens={max_tokens}"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_embed_after_load() {
let mut engine = make_loaded_engine();
let result = engine.embed("a");
assert!(
result.is_ok(),
"embed() must return Ok after model is loaded, got {result:?}"
);
let vec = result.expect("embed succeeded");
assert!(!vec.is_empty(), "embed() must return a non-empty vector");
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_embed_returns_hidden_size_vector() {
let mut engine = make_loaded_engine();
let vec = engine
.embed("a")
.expect("embed() must succeed after loading");
assert_eq!(
vec.len(),
32,
"embed() vector length must equal hidden_size=32, got {}",
vec.len()
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_reload_model_succeeds() {
let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
let tokenizer_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&model_bytes, tokenizer_json)
.expect("first load must succeed");
assert!(engine.is_loaded(), "is_loaded() after first load");
engine
.load_model_from_bytes(&model_bytes, tokenizer_json)
.expect("second (re)load must succeed");
assert!(
engine.is_loaded(),
"is_loaded() after reload must still be true"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_vocab_bytes_some_after_load() {
let engine = make_loaded_engine();
let vb = engine.vocab_bytes();
assert!(
vb.is_some(),
"vocab_bytes() must be Some after model is loaded"
);
let entries = vb.expect("vocab_bytes is Some");
assert!(
!entries.is_empty(),
"vocab_bytes() must contain at least one entry"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_model_config_some_after_load() {
let engine = make_loaded_engine();
let cfg = engine.model_config();
assert!(cfg.is_some(), "model_config() must be Some after loading");
let mc = cfg.expect("model_config is Some");
assert_eq!(mc.architecture, "llama", "architecture must be 'llama'");
assert_eq!(
mc.num_layers, 1,
"num_layers must be 1 for the synthetic model"
);
assert_eq!(mc.vocab_size, 32, "vocab_size must be 32");
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_reset_when_loaded_does_not_panic() {
let mut engine = make_loaded_engine();
engine.reset(); assert!(
engine.is_loaded(),
"is_loaded() must still be true after reset()"
);
assert_eq!(engine.hidden_size(), Some(32));
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_qwen3_arch() {
use oxillama_gguf::test_utils::{build_minimal_qwen3_gguf, minimal_tokenizer_json};
let bytes = build_minimal_qwen3_gguf();
let json = minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&bytes, json)
.expect("test: load qwen3");
assert!(engine.is_loaded(), "qwen3: is_loaded() must be true");
let _out = engine
.generate("abc", 2, |_| {})
.expect("test: generate qwen3");
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_embed_qwen3_arch() {
use oxillama_gguf::test_utils::{build_minimal_qwen3_gguf, minimal_tokenizer_json};
let bytes = build_minimal_qwen3_gguf();
let json = minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&bytes, json)
.expect("test: load qwen3 for embed");
let vec = engine.embed("abc").expect("test: embed qwen3");
assert_eq!(
vec.len(),
32,
"qwen3 embed must return hidden_size=32 vector"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_mistral_arch() {
use oxillama_gguf::test_utils::{build_minimal_mistral_gguf, minimal_tokenizer_json};
let bytes = build_minimal_mistral_gguf();
let json = minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&bytes, json)
.expect("test: load mistral");
assert!(engine.is_loaded(), "mistral: is_loaded() must be true");
let _out = engine
.generate("abc", 2, |_| {})
.expect("test: generate mistral");
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_gemma_arch() {
use oxillama_gguf::test_utils::{build_minimal_gemma_gguf, minimal_tokenizer_json};
let bytes = build_minimal_gemma_gguf();
let json = minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&bytes, json)
.expect("test: load gemma");
assert!(engine.is_loaded(), "gemma: is_loaded() must be true");
let _out = engine
.generate("abc", 2, |_| {})
.expect("test: generate gemma");
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_embed_gemma_arch() {
use oxillama_gguf::test_utils::{build_minimal_gemma_gguf, minimal_tokenizer_json};
let bytes = build_minimal_gemma_gguf();
let json = minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&bytes, json)
.expect("test: load gemma for embed");
let vec = engine.embed("abc").expect("test: embed gemma");
assert_eq!(
vec.len(),
32,
"gemma embed must return hidden_size=32 vector"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_phi3_arch() {
use oxillama_gguf::test_utils::{build_minimal_phi3_gguf, minimal_tokenizer_json};
let bytes = build_minimal_phi3_gguf();
let json = minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&bytes, json)
.expect("test: load phi3");
assert!(engine.is_loaded(), "phi3: is_loaded() must be true");
let _out = engine
.generate("abc", 2, |_| {})
.expect("test: generate phi3");
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_command_r_arch() {
use oxillama_gguf::test_utils::{build_minimal_command_r_gguf, minimal_tokenizer_json};
let bytes = build_minimal_command_r_gguf();
let json = minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&bytes, json)
.expect("test: load command-r");
assert!(engine.is_loaded(), "command-r: is_loaded() must be true");
let _out = engine
.generate("abc", 2, |_| {})
.expect("test: generate command-r");
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_generate_starcoder_arch() {
use oxillama_gguf::test_utils::{build_minimal_starcoder_gguf, minimal_tokenizer_json};
let bytes = build_minimal_starcoder_gguf();
let json = minimal_tokenizer_json();
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(&bytes, json)
.expect("test: load starcoder");
assert!(engine.is_loaded(), "starcoder: is_loaded() must be true");
let _out = engine
.generate("abc", 2, |_| {})
.expect("test: generate starcoder");
}
}