use std::sync::Arc;
use crate::backend::Backend;
use crate::gguf::GgufFile;
use crate::model::{
EmbeddingConfig, EmbeddingExtractor, InferenceContext, Model, ModelConfig, ModelLoader,
};
use crate::sampling::{Sampler, SamplerConfig};
use crate::tokenizer::Tokenizer;
#[derive(thiserror::Error, Debug)]
pub enum EngineError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("GGUF error: {0}")]
Gguf(#[from] crate::gguf::GgufError),
#[error("Model error: {0}")]
Model(#[from] crate::model::ModelError),
#[error("Tokenizer error: {0}")]
Tokenizer(#[from] crate::tokenizer::TokenizerError),
#[error("Embedding error: {0}")]
Embedding(#[from] crate::model::EmbeddingError),
#[error("Engine error: {0}")]
Other(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct EngineConfig {
pub model_path: String,
pub tokenizer_path: Option<String>,
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
pub repeat_penalty: f32,
pub max_tokens: usize,
pub seed: Option<u64>,
pub use_gpu: bool,
pub max_context_len: Option<usize>,
#[cfg(feature = "hailo")]
pub hailo_config: Option<crate::backend::hailo::HailoConfig>,
pub kv_cache_type: crate::model::KVCacheType,
}
impl Default for EngineConfig {
fn default() -> Self {
Self {
model_path: String::new(),
tokenizer_path: None,
temperature: 0.7,
top_k: 40,
top_p: 0.95,
repeat_penalty: 1.1,
max_tokens: 512,
seed: None,
use_gpu: false,
max_context_len: None,
#[cfg(feature = "hailo")]
hailo_config: None,
kv_cache_type: crate::model::KVCacheType::F32,
}
}
}
impl EngineConfig {
pub fn from_config_file(
path: impl AsRef<std::path::Path>,
) -> Result<Self, crate::config::ConfigError> {
let config = crate::config::Config::from_file(path)?;
Ok(config.to_engine_config(None))
}
pub fn from_config(
config_path: Option<impl AsRef<std::path::Path>>,
) -> Result<Self, crate::config::ConfigError> {
let config = crate::config::Config::load(config_path)?;
Ok(config.to_engine_config(None))
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ChatTemplate {
UserAssistant,
ChatML,
Llama2,
None,
}
impl ChatTemplate {
pub fn detect_from_model_type(model_type: Option<&str>) -> Self {
match model_type {
Some("qwen2" | "qwen") => ChatTemplate::ChatML,
Some("llama" | "codellama") => ChatTemplate::Llama2,
Some("mistral" | "mixtral") => ChatTemplate::Llama2,
_ => ChatTemplate::None,
}
}
pub fn detect(gguf: &GgufFile) -> Self {
if let Some(template) = gguf.data.get_string("tokenizer.chat_template") {
if template.contains("<|user|>") {
ChatTemplate::UserAssistant
} else if template.contains("<|im_start|>") {
ChatTemplate::ChatML
} else if template.contains("[INST]") {
ChatTemplate::Llama2
} else {
ChatTemplate::None
}
} else if let Some(arch) = gguf.data.get_string("general.architecture") {
match arch.to_lowercase().as_str() {
"qwen2" | "qwen" | "qwen3" | "qwen35" | "qwen3moe" | "qwen3next" => {
ChatTemplate::ChatML
}
_ => ChatTemplate::None,
}
} else {
ChatTemplate::None
}
}
pub fn wrap_prompt(&self, prompt: &str) -> String {
if prompt.contains("<|user|>")
|| prompt.contains("<|im_start|>")
|| prompt.contains("[INST]")
{
return prompt.to_string();
}
match self {
ChatTemplate::UserAssistant => {
format!("<|user|>\n{}<|assistant|>\n", prompt)
}
ChatTemplate::ChatML => {
format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
prompt
)
}
ChatTemplate::Llama2 => {
format!("[INST] {} [/INST]", prompt)
}
ChatTemplate::None => prompt.to_string(),
}
}
pub fn format_first_turn(&self, system_prompt: &str, user_message: &str) -> String {
match self {
ChatTemplate::UserAssistant => {
format!(
"<|system|>\n{}<|user|>\n{}<|assistant|>\n",
system_prompt, user_message
)
}
ChatTemplate::ChatML => {
format!(
"<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
system_prompt, user_message
)
}
ChatTemplate::Llama2 => {
format!(
"[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]",
system_prompt, user_message
)
}
ChatTemplate::None => {
format!(
"System: {}\n\nUser: {}\n\nAssistant:",
system_prompt, user_message
)
}
}
}
pub fn format_continuation(&self, user_message: &str) -> String {
match self {
ChatTemplate::UserAssistant => {
format!("<|user|>\n{}<|assistant|>\n", user_message)
}
ChatTemplate::ChatML => {
format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
user_message
)
}
ChatTemplate::Llama2 => {
format!(" [INST] {} [/INST]", user_message)
}
ChatTemplate::None => {
format!("\n\nUser: {}\n\nAssistant:", user_message)
}
}
}
pub fn stop_patterns(&self) -> &[&str] {
match self {
ChatTemplate::UserAssistant => &["<|user|>", "<|end|>"],
ChatTemplate::ChatML => &["<|im_end|>", "<|im_start|>"],
ChatTemplate::Llama2 => &["[INST]", "</s>"],
ChatTemplate::None => &["User:", "\nUser:"],
}
}
}
pub struct Engine {
gguf: Option<GgufFile>,
model: Box<dyn Model>,
tokenizer: Tokenizer,
config: ModelConfig,
backend: Arc<dyn Backend>,
sampler_config: SamplerConfig,
chat_template: ChatTemplate,
add_bos: bool,
engine_config: EngineConfig,
}
impl Engine {
pub fn load(config: EngineConfig) -> Result<Self, EngineError> {
if config.model_path.is_empty() {
return Err(EngineError::Other("model_path is required".into()));
}
let path = std::path::Path::new(&config.model_path);
match path.extension().and_then(|e| e.to_str()) {
#[cfg(feature = "onnx")]
Some("onnx") => Self::load_onnx(config),
#[cfg(not(feature = "onnx"))]
Some("onnx") => Err(EngineError::Other(
"ONNX support requires the `onnx` feature. Build with: cargo build --features onnx"
.into(),
)),
_ => Self::load_gguf(config),
}
}
fn load_gguf(config: EngineConfig) -> Result<Self, EngineError> {
tracing::info!("Loading GGUF model from: {}", config.model_path);
let gguf = GgufFile::open(&config.model_path)?;
tracing::info!("Loading tokenizer...");
let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
if tok_path.ends_with(".json") {
Tokenizer::from_hf_json(tok_path)?
} else {
let tok_gguf = GgufFile::open(tok_path)?;
Tokenizer::from_gguf(&tok_gguf)?
}
} else {
Tokenizer::from_gguf(&gguf)?
};
tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
tracing::info!("Loading model weights...");
let loader = ModelLoader::load(&config.model_path)?;
let model_config = loader.config().clone();
tracing::info!(
"Model: {} layers, {} heads, {} hidden dim, {} ctx",
model_config.num_layers,
model_config.num_heads,
model_config.hidden_size,
model_config.max_seq_len,
);
let arch = loader.architecture();
let (backend, model): (Arc<dyn Backend>, Box<dyn Model>) = if arch.is_encoder_only() {
tracing::info!("Detected encoder-only architecture: {:?}", arch);
let bert_model = loader.build_bert_model()?;
(
Arc::new(crate::backend::cpu::CpuBackend::new()),
Box::new(bert_model),
)
} else {
let concrete_model = loader.build_model()?;
if config.use_gpu {
Self::select_gpu_model(concrete_model, &model_config, &config)
} else {
(
Arc::new(crate::backend::cpu::CpuBackend::new()),
Box::new(concrete_model),
)
}
};
let chat_template = ChatTemplate::detect(&gguf);
tracing::info!("Chat template: {:?}", chat_template);
let add_bos = gguf
.data
.get_bool("tokenizer.ggml.add_bos_token")
.unwrap_or(tokenizer.has_explicit_bos);
let sampler_config = SamplerConfig {
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repeat_penalty: config.repeat_penalty,
seed: config.seed,
..Default::default()
};
tracing::info!("Engine ready");
Ok(Self {
gguf: Some(gguf),
model,
tokenizer,
config: model_config,
backend,
sampler_config,
chat_template,
add_bos,
engine_config: config,
})
}
#[cfg(feature = "onnx")]
fn load_onnx(config: EngineConfig) -> Result<Self, EngineError> {
use crate::onnx::OnnxModelLoader;
tracing::info!("Loading ONNX model from: {}", config.model_path);
let model_dir = std::path::Path::new(&config.model_path)
.parent()
.unwrap_or(std::path::Path::new("."));
let loader = OnnxModelLoader::load(&config.model_path)
.map_err(|e| EngineError::Other(format!("ONNX load error: {}", e)))?;
let model_config = loader.config().clone();
let hf_config = loader.hf_config().clone();
tracing::info!(
"Model: {} layers, {} heads, {} hidden dim, {} ctx",
model_config.num_layers,
model_config.num_heads,
model_config.hidden_size,
model_config.max_seq_len,
);
let concrete_model = loader
.build_model()
.map_err(|e| EngineError::Other(format!("ONNX model build error: {}", e)))?;
tracing::info!("Loading tokenizer...");
let tokenizer = if let Some(ref tok_path) = config.tokenizer_path {
if tok_path.ends_with(".json") {
Tokenizer::from_hf_json(tok_path)?
} else {
let tok_gguf = GgufFile::open(tok_path)?;
Tokenizer::from_gguf(&tok_gguf)?
}
} else {
let tokenizer_path = model_dir.join("tokenizer.json");
if tokenizer_path.exists() {
tracing::info!("Using tokenizer.json from: {}", tokenizer_path.display());
Tokenizer::from_hf_json(&tokenizer_path)?
} else {
return Err(EngineError::Other(format!(
"No tokenizer found. ONNX models require a tokenizer.json file \
in the same directory as the model, or specify --tokenizer <path>. \
Looked for: {}",
tokenizer_path.display()
)));
}
};
tracing::info!("Vocabulary size: {}", tokenizer.vocab_size);
let backend: Arc<dyn Backend> = if config.use_gpu {
Self::select_gpu_backend(&concrete_model)
} else {
Arc::new(crate::backend::cpu::CpuBackend::new())
};
let model: Box<dyn Model> = Box::new(concrete_model);
let chat_template = ChatTemplate::detect_from_model_type(hf_config.model_type.as_deref());
tracing::info!("Chat template: {:?}", chat_template);
let add_bos = true;
let sampler_config = SamplerConfig {
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repeat_penalty: config.repeat_penalty,
seed: config.seed,
..Default::default()
};
tracing::info!("Engine ready (ONNX)");
Ok(Self {
gguf: None,
model,
tokenizer,
config: model_config,
backend,
sampler_config,
chat_template,
add_bos,
engine_config: config,
})
}
#[allow(unused_variables)]
fn select_gpu_model(
model: crate::model::LlamaModel,
config: &ModelConfig,
engine_config: &EngineConfig,
) -> (Arc<dyn Backend>, Box<dyn Model>) {
let gpu_seq_len = match engine_config.max_context_len {
Some(cap) if cap > 0 && cap < config.max_seq_len => {
tracing::info!(
"Capping GPU context length from {} to {} (max_context_len)",
config.max_seq_len,
cap
);
cap
}
_ => config.max_seq_len,
};
#[cfg(feature = "cuda")]
{
if cudarc::driver::CudaDevice::new(0).is_ok() {
let architecture = model.architecture();
match crate::backend::cuda::gpu_only::GpuOnlyInference::from_model(
model,
gpu_seq_len,
) {
Ok(gpu) => {
tracing::info!(
"Using full GPU inference (attention + DeltaNet + MoE all on CUDA)"
);
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(crate::backend::cpu::CpuBackend::new()),
Box::new(wrapper),
);
}
Err(e) => {
eprintln!("Error: CUDA GPU inference init failed: {}", e);
eprintln!("The model was consumed during init. Please restart without --gpu.");
std::process::exit(1);
}
}
} else {
tracing::info!("No CUDA device available, trying other GPU backends...");
}
}
#[cfg(feature = "vulkan")]
{
if crate::backend::vulkan::VulkanBackend::new().is_ok() {
let architecture = model.architecture();
match crate::backend::vulkan::gpu_only::VulkanGpuInference::from_model(
model,
gpu_seq_len,
) {
Ok(gpu) => {
tracing::info!("Using full GPU inference on Vulkan");
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(crate::backend::cpu::CpuBackend::new()),
Box::new(wrapper),
);
}
Err(e) => {
eprintln!("Error: Vulkan GPU inference init failed: {}", e);
eprintln!("The model was consumed during init. Please restart without --gpu.");
std::process::exit(1);
}
}
} else {
tracing::info!("No Vulkan device available, trying other GPU backends...");
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
{
if crate::backend::metal::MetalBackend::new().is_ok() {
let architecture = model.architecture();
match crate::backend::metal::gpu_only::MetalGpuInference::from_model(
model,
gpu_seq_len,
) {
Ok(gpu) => {
tracing::info!("Using full GPU inference on Metal");
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(crate::backend::cpu::CpuBackend::new()),
Box::new(wrapper),
);
}
Err(e) => {
eprintln!("Error: Metal GPU inference init failed: {}", e);
eprintln!("The model was consumed during init. Please restart without --gpu.");
std::process::exit(1);
}
}
} else {
tracing::info!("No Metal device available, trying other GPU backends...");
}
}
#[cfg(all(feature = "dx12", target_os = "windows"))]
{
if crate::backend::dx12::Dx12Backend::new().is_ok() {
let architecture = model.architecture();
match crate::backend::dx12::gpu_only::Dx12GpuInference::from_model(
model,
gpu_seq_len,
) {
Ok(gpu) => {
tracing::info!("Using full GPU inference on DX12");
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(crate::backend::cpu::CpuBackend::new()),
Box::new(wrapper),
);
}
Err(e) => {
eprintln!("Error: DX12 GPU inference init failed: {}", e);
eprintln!("The model was consumed during init. Please restart without --gpu.");
std::process::exit(1);
}
}
} else {
tracing::info!("No DX12 device available");
}
}
#[cfg(feature = "hailo")]
{
if let Some(ref hailo_config) = engine_config.hailo_config {
if crate::backend::hailo::context::check_device_available().is_ok() {
let architecture = model.architecture();
match crate::backend::hailo::gpu_only::HailoGpuInference::from_model(
model,
gpu_seq_len,
hailo_config.clone(),
) {
Ok(gpu) => {
tracing::info!("Using hybrid CPU+Hailo inference");
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(crate::backend::cpu::CpuBackend::new()),
Box::new(wrapper),
);
}
Err(e) => {
eprintln!("Error: Hailo inference init failed: {}", e);
eprintln!("The model was consumed during init. Please restart without --hailo.");
std::process::exit(1);
}
}
} else {
tracing::info!("No Hailo device available, falling back to CPU...");
}
}
}
let backend = Self::select_gpu_backend(&model);
(backend, Box::new(model))
}
#[allow(unused_variables)]
pub fn select_gpu_backend(model: &crate::model::LlamaModel) -> Arc<dyn Backend> {
#[cfg(feature = "cuda")]
{
match crate::backend::cuda::CudaBackend::new() {
Ok(mut cuda) => {
tracing::info!("Using CUDA backend: {}", cuda.device_name());
if let Err(e) = cuda.load_model_weights(model) {
tracing::warn!("Failed to load GPU weights ({}), using quantized ops", e);
}
return Arc::new(cuda);
}
Err(e) => {
tracing::info!("CUDA not available ({}), trying Metal...", e);
}
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
{
match crate::backend::metal::MetalBackend::new() {
Ok(metal) => {
tracing::info!("Using Metal backend: {}", metal.device_name());
return Arc::new(metal);
}
Err(e) => {
tracing::info!("Metal not available ({}), trying DX12...", e);
}
}
}
#[cfg(all(feature = "dx12", target_os = "windows"))]
{
match crate::backend::dx12::Dx12Backend::new() {
Ok(dx12) => {
tracing::info!("Using DX12 backend: {}", dx12.device_name());
return Arc::new(dx12);
}
Err(e) => {
tracing::info!("DX12 not available ({}), trying Vulkan...", e);
}
}
}
#[cfg(feature = "vulkan")]
{
match crate::backend::vulkan::VulkanBackend::new() {
Ok(vk) => {
tracing::info!("Using Vulkan backend: {}", vk.device_name());
return Arc::new(vk);
}
Err(e) => {
tracing::warn!("Vulkan not available ({}), falling back to CPU", e);
}
}
}
#[cfg(not(any(
feature = "cuda",
feature = "vulkan",
all(feature = "metal", target_os = "macos"),
all(feature = "dx12", target_os = "windows")
)))]
{
tracing::warn!(
"No GPU backend compiled. Build with --features cuda, --features metal, --features dx12, or --features vulkan"
);
}
Arc::new(crate::backend::cpu::CpuBackend::new())
}
pub fn model_config(&self) -> &ModelConfig {
&self.config
}
pub fn chat_template(&self) -> &ChatTemplate {
&self.chat_template
}
pub fn gguf(&self) -> Option<&GgufFile> {
self.gguf.as_ref()
}
pub fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
pub fn engine_config(&self) -> &EngineConfig {
&self.engine_config
}
pub fn model(&self) -> &dyn Model {
&*self.model
}
pub fn backend(&self) -> &Arc<dyn Backend> {
&self.backend
}
pub fn add_bos(&self) -> bool {
self.add_bos
}
pub fn create_inference_context(&self) -> InferenceContext {
if self.engine_config.kv_cache_type.is_turboquant() {
InferenceContext::new_with_cache_type(
&self.config,
self.backend.clone(),
self.engine_config.kv_cache_type,
)
} else {
self.model.create_context(self.backend.clone())
}
}
pub fn generate(&self, prompt: &str, max_tokens: usize) -> Result<String, EngineError> {
let mut ctx = self.create_inference_context();
let mut sampler = Sampler::new(self.sampler_config.clone(), self.config.vocab_size);
let formatted = self.chat_template.wrap_prompt(prompt);
let mut tokens = self.tokenizer.encode(&formatted, self.add_bos)?;
let mut output = String::new();
for _ in 0..max_tokens {
if let Some(&last) = tokens.last()
&& last == self.tokenizer.special_tokens.eos_token_id
{
break;
}
let input_tokens = if ctx.position == 0 {
&tokens[..]
} else {
&tokens[tokens.len() - 1..]
};
let logits = self.model.forward(input_tokens, &mut ctx)?;
let next_token = sampler.sample(&logits, &tokens);
if next_token == self.tokenizer.special_tokens.eos_token_id {
break;
}
if let Ok(text) = self.tokenizer.decode(&[next_token]) {
let combined = format!("{}{}", output, text);
let stop = self
.chat_template
.stop_patterns()
.iter()
.any(|p| combined.contains(p));
if stop {
for pattern in self.chat_template.stop_patterns() {
if let Some(idx) = combined.find(pattern) {
output = combined[..idx].to_string();
return Ok(output.trim().to_string());
}
}
break;
}
output.push_str(&text);
}
tokens.push(next_token);
}
Ok(output.trim().to_string())
}
pub fn generate_streaming(&self, prompt: &str, max_tokens: usize) -> GenerationStream<'_> {
GenerationStream::new(self, prompt, max_tokens)
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>, EngineError> {
let mut ctx = self.create_inference_context();
let embed_config = EmbeddingConfig::default();
let extractor = EmbeddingExtractor::new(embed_config, &self.config);
let embedding =
extractor.embed_text(self.model.as_ref(), &self.tokenizer, &mut ctx, text)?;
Ok(embedding)
}
}
pub struct GenerationStream<'a> {
engine: &'a Engine,
ctx: InferenceContext,
sampler: Sampler,
tokens: Vec<u32>,
remaining: usize,
done: bool,
accumulated: String,
pending_bytes: Vec<u8>,
}
impl<'a> GenerationStream<'a> {
fn new(engine: &'a Engine, prompt: &str, max_tokens: usize) -> Self {
let ctx = engine.create_inference_context();
let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
let formatted = engine.chat_template.wrap_prompt(prompt);
if std::env::var("LLAMA_DEBUG").is_ok() {
eprintln!("[DEBUG] formatted prompt: {:?}", formatted);
eprintln!("[DEBUG] add_bos: {}", engine.add_bos);
}
let tokens = engine
.tokenizer
.encode(&formatted, engine.add_bos)
.unwrap_or_default();
if std::env::var("LLAMA_DEBUG").is_ok() {
eprintln!("[DEBUG] encoded {} tokens: {:?}", tokens.len(), &tokens[..tokens.len().min(50)]);
for (i, &tid) in tokens.iter().enumerate() {
if let Some(s) = engine.tokenizer.get_token(tid) {
eprintln!("[DEBUG] token[{}] = {} -> {:?}", i, tid, s);
}
}
}
Self {
engine,
ctx,
sampler,
tokens,
remaining: max_tokens,
done: false,
accumulated: String::new(),
pending_bytes: Vec::new(),
}
}
}
impl<'a> Iterator for GenerationStream<'a> {
type Item = Result<String, EngineError>;
fn next(&mut self) -> Option<Self::Item> {
if self.done || self.remaining == 0 {
return None;
}
if let Some(&last) = self.tokens.last()
&& last == self.engine.tokenizer.special_tokens.eos_token_id
{
self.done = true;
return None;
}
let input_tokens = if self.ctx.position == 0 {
&self.tokens[..]
} else {
&self.tokens[self.tokens.len() - 1..]
};
let logits = match self.engine.model.forward(input_tokens, &mut self.ctx) {
Ok(l) => l,
Err(e) => {
self.done = true;
return Some(Err(EngineError::Model(e)));
}
};
let next_token = self.sampler.sample(&logits, &self.tokens);
if std::env::var("LLAMA_DEBUG_LOGITS").is_ok() {
let logit_data = logits.as_f32().unwrap();
let mut indexed: Vec<(usize, f32)> = logit_data.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let step = self.tokens.len();
eprint!("[LOGIT] step={} top5:", step);
for (id, score) in indexed.iter().take(5) {
let tok_str = self.engine.tokenizer.get_token(*id as u32).unwrap_or_default();
eprint!(" {}({:.2})={:?}", id, score, tok_str);
}
let chosen_str = self.engine.tokenizer.get_token(next_token).unwrap_or_default();
eprintln!(" → chosen={}({:?})", next_token, chosen_str);
}
if next_token == self.engine.tokenizer.special_tokens.eos_token_id {
self.done = true;
return None;
}
match self
.engine
.tokenizer
.decode_token_streaming(next_token, &mut self.pending_bytes)
{
Ok(text) => {
self.tokens.push(next_token);
self.remaining -= 1;
if text.is_empty() {
return self.next();
}
let combined = format!("{}{}", self.accumulated, text);
for pattern in self.engine.chat_template.stop_patterns() {
if combined.contains(pattern) {
self.done = true;
if let Some(idx) = combined.find(pattern) {
if idx > self.accumulated.len() {
let before = &combined[self.accumulated.len()..idx];
return Some(Ok(before.to_string()));
}
}
return None;
}
}
self.accumulated.push_str(&text);
Some(Ok(text))
}
Err(e) => {
self.tokens.push(next_token);
self.remaining -= 1;
Some(Err(EngineError::Tokenizer(e)))
}
}
}
}
pub struct ChatEngine {
engine: Engine,
system_prompt: String,
conversation_tokens: Vec<u32>,
ctx: InferenceContext,
sampler: Sampler,
is_first_turn: bool,
}
impl ChatEngine {
pub fn new(engine: Engine, system_prompt: Option<String>) -> Self {
let ctx = engine.create_inference_context();
let sampler = Sampler::new(engine.sampler_config.clone(), engine.config.vocab_size);
Self {
system_prompt: system_prompt
.unwrap_or_else(|| "You are a helpful AI assistant.".to_string()),
conversation_tokens: Vec::new(),
ctx,
sampler,
is_first_turn: true,
engine,
}
}
pub fn engine(&self) -> &Engine {
&self.engine
}
pub fn system_prompt(&self) -> &str {
&self.system_prompt
}
pub fn context_len(&self) -> usize {
self.conversation_tokens.len()
}
pub fn chat(&mut self, message: &str) -> Result<String, EngineError> {
let max_tokens = self.engine.engine_config.max_tokens;
let formatted = if self.is_first_turn {
self.engine
.chat_template
.format_first_turn(&self.system_prompt, message)
} else {
self.engine.chat_template.format_continuation(message)
};
let new_tokens = self
.engine
.tokenizer
.encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
self.ensure_context_space(new_tokens.len(), max_tokens);
self.conversation_tokens.extend(&new_tokens);
let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
let mut response_text = String::new();
if new_tokens.is_empty() {
self.is_first_turn = false;
return Ok(response_text);
}
let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
if first_token == eos_id {
self.is_first_turn = false;
return Ok(response_text);
}
if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
response_text.push_str(&text);
}
self.conversation_tokens.push(first_token);
for _ in 1..max_tokens {
let should_stop = self
.engine
.chat_template
.stop_patterns()
.iter()
.any(|p| response_text.contains(p));
if should_stop {
for pattern in self.engine.chat_template.stop_patterns() {
if let Some(idx) = response_text.find(pattern) {
response_text.truncate(idx);
break;
}
}
break;
}
let last_token = *self
.conversation_tokens
.last()
.unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
if next_token == eos_id {
break;
}
if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
response_text.push_str(&text);
}
self.conversation_tokens.push(next_token);
}
self.is_first_turn = false;
Ok(response_text.trim().to_string())
}
pub fn chat_with_prefix(
&mut self,
message: &str,
prefix: &str,
) -> Result<String, EngineError> {
let max_tokens = self.engine.engine_config.max_tokens;
let formatted = if self.is_first_turn {
self.engine
.chat_template
.format_first_turn(&self.system_prompt, message)
} else {
self.engine.chat_template.format_continuation(message)
};
let formatted_with_prefix = format!("{}{}", formatted, prefix);
let new_tokens = self
.engine
.tokenizer
.encode(&formatted_with_prefix, self.is_first_turn && self.engine.add_bos)?;
self.ensure_context_space(new_tokens.len(), max_tokens);
self.conversation_tokens.extend(&new_tokens);
let eos_id = self.engine.tokenizer.special_tokens.eos_token_id;
let mut response_text = prefix.to_string();
if new_tokens.is_empty() {
self.is_first_turn = false;
return Ok(response_text);
}
let prefill_logits = self.engine.model.forward(&new_tokens, &mut self.ctx)?;
let first_token = self.sampler.sample(&prefill_logits, &self.conversation_tokens);
if first_token == eos_id {
self.is_first_turn = false;
return Ok(response_text);
}
if let Ok(text) = self.engine.tokenizer.decode(&[first_token]) {
response_text.push_str(&text);
}
self.conversation_tokens.push(first_token);
for _ in 1..max_tokens {
let should_stop = self
.engine
.chat_template
.stop_patterns()
.iter()
.any(|p| response_text.contains(p));
if should_stop {
for pattern in self.engine.chat_template.stop_patterns() {
if let Some(idx) = response_text.find(pattern) {
response_text.truncate(idx);
break;
}
}
break;
}
let last_token = *self
.conversation_tokens
.last()
.unwrap_or(&self.engine.tokenizer.special_tokens.bos_token_id);
let logits = self.engine.model.forward(&[last_token], &mut self.ctx)?;
let next_token = self.sampler.sample(&logits, &self.conversation_tokens);
if next_token == eos_id {
break;
}
if let Ok(text) = self.engine.tokenizer.decode(&[next_token]) {
response_text.push_str(&text);
}
self.conversation_tokens.push(next_token);
}
self.is_first_turn = false;
Ok(response_text.trim().to_string())
}
pub fn chat_streaming(&mut self, message: &str) -> Result<ChatStream<'_>, EngineError> {
let max_tokens = self.engine.engine_config.max_tokens;
let formatted = if self.is_first_turn {
self.engine
.chat_template
.format_first_turn(&self.system_prompt, message)
} else {
self.engine.chat_template.format_continuation(message)
};
let new_tokens = self
.engine
.tokenizer
.encode(&formatted, self.is_first_turn && self.engine.add_bos)?;
self.ensure_context_space(new_tokens.len(), max_tokens);
self.conversation_tokens.extend(&new_tokens);
let prefill_logits = if !new_tokens.is_empty() {
Some(self.engine.model.forward(&new_tokens, &mut self.ctx)?)
} else {
None
};
self.is_first_turn = false;
Ok(ChatStream {
chat_engine: self,
remaining: max_tokens,
done: false,
accumulated: String::new(),
prefill_logits,
})
}
pub fn clear_history(&mut self) {
self.conversation_tokens.clear();
self.ctx.reset();
self.sampler.reset();
self.is_first_turn = true;
}
fn ensure_context_space(&mut self, new_token_count: usize, max_gen_tokens: usize) {
let total_len = self.conversation_tokens.len() + new_token_count + max_gen_tokens;
if total_len > self.engine.config.max_seq_len {
let excess = total_len - self.engine.config.max_seq_len + 100;
if excess >= self.conversation_tokens.len() {
tracing::warn!("Context full, resetting conversation");
self.conversation_tokens.clear();
self.ctx.reset();
} else {
tracing::info!("Trimming {} tokens from context", excess);
self.conversation_tokens = self.conversation_tokens[excess..].to_vec();
self.ctx.kv_cache.shift_left(excess);
self.ctx.position = self.ctx.position.saturating_sub(excess);
}
}
}
}
pub struct ChatStream<'a> {
chat_engine: &'a mut ChatEngine,
remaining: usize,
done: bool,
accumulated: String,
prefill_logits: Option<crate::tensor::Tensor>,
}
impl<'a> Iterator for ChatStream<'a> {
type Item = Result<String, EngineError>;
fn next(&mut self) -> Option<Self::Item> {
if self.done || self.remaining == 0 {
return None;
}
for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
if self.accumulated.contains(pattern) {
self.done = true;
return None;
}
}
let logits = if let Some(prefill) = self.prefill_logits.take() {
prefill
} else {
let last_token = *self.chat_engine.conversation_tokens.last().unwrap_or(
&self
.chat_engine
.engine
.tokenizer
.special_tokens
.bos_token_id,
);
match self
.chat_engine
.engine
.model
.forward(&[last_token], &mut self.chat_engine.ctx)
{
Ok(l) => l,
Err(e) => {
self.done = true;
return Some(Err(EngineError::Model(e)));
}
}
};
let next_token = self
.chat_engine
.sampler
.sample(&logits, &self.chat_engine.conversation_tokens);
if next_token
== self
.chat_engine
.engine
.tokenizer
.special_tokens
.eos_token_id
{
self.done = true;
return None;
}
match self.chat_engine.engine.tokenizer.decode(&[next_token]) {
Ok(text) => {
let combined = format!("{}{}", self.accumulated, text);
for pattern in self.chat_engine.engine.chat_template.stop_patterns() {
if combined.contains(pattern) {
self.done = true;
if let Some(idx) = combined.find(pattern) {
let before = &combined[self.accumulated.len()..idx];
self.chat_engine.conversation_tokens.push(next_token);
if !before.is_empty() {
return Some(Ok(before.to_string()));
}
}
return None;
}
}
self.accumulated.push_str(&text);
self.chat_engine.conversation_tokens.push(next_token);
self.remaining -= 1;
Some(Ok(text))
}
Err(e) => {
self.chat_engine.conversation_tokens.push(next_token);
self.remaining -= 1;
Some(Err(EngineError::Tokenizer(e)))
}
}
}
}