use crate::error::{RealizarError, Result};
use crate::format::{detect_format, ModelFormat};
use std::path::PathBuf;
use std::time::Instant;
pub(crate) fn qtype_to_dtype_str(qtype: u32) -> &'static str {
crate::gguf::GgmlQuantType::from_id(qtype).map_or("Unknown", crate::gguf::GgmlQuantType::as_str)
}
#[derive(Debug, Clone)]
pub struct InferenceConfig {
pub model_path: PathBuf,
pub prompt: Option<String>,
pub input_tokens: Option<Vec<u32>>,
pub max_tokens: usize,
pub temperature: f32,
pub top_k: usize,
pub no_gpu: bool,
pub trace: bool,
pub trace_verbose: bool,
pub trace_output: Option<PathBuf>,
pub trace_steps: Option<Vec<String>>,
pub verbose: bool,
pub stop_tokens: Vec<u32>,
#[doc(hidden)]
pub use_mock_backend: bool,
}
impl InferenceConfig {
#[must_use]
pub fn new(model_path: impl Into<PathBuf>) -> Self {
Self {
model_path: model_path.into(),
prompt: None,
input_tokens: None,
max_tokens: 32,
temperature: 0.0, top_k: 1,
no_gpu: false,
trace: false,
trace_verbose: false,
trace_output: None,
trace_steps: None,
verbose: false,
stop_tokens: Vec::new(),
use_mock_backend: false,
}
}
#[must_use]
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = Some(prompt.into());
self
}
#[must_use]
pub fn with_input_tokens(mut self, tokens: Vec<u32>) -> Self {
self.input_tokens = Some(tokens);
self
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
#[must_use]
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = top_k;
self
}
#[must_use]
pub fn without_gpu(mut self) -> Self {
self.no_gpu = true;
self
}
#[must_use]
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
#[must_use]
pub fn with_trace(mut self, trace: bool) -> Self {
self.trace = trace;
self
}
#[must_use]
pub fn with_trace_output(mut self, path: impl Into<PathBuf>) -> Self {
self.trace_output = Some(path.into());
self
}
#[must_use]
pub fn with_stop_tokens(mut self, stop_tokens: Vec<u32>) -> Self {
self.stop_tokens = stop_tokens;
self
}
}
#[derive(Debug, Clone)]
pub struct PreparedTokens {
tokens: Vec<u32>,
input_count: usize,
}
impl PreparedTokens {
#[must_use]
pub fn tokens(&self) -> &[u32] {
&self.tokens
}
#[must_use]
pub fn input_count(&self) -> usize {
self.input_count
}
}
pub fn prepare_tokens(config: &InferenceConfig, format: &ModelFormat) -> Result<PreparedTokens> {
if let Some(ref tokens) = config.input_tokens {
return Ok(PreparedTokens {
input_count: tokens.len(),
tokens: tokens.clone(),
});
}
let prompt = match config.prompt {
Some(ref p) => p.clone(),
None => {
return Ok(PreparedTokens {
tokens: vec![1u32],
input_count: 1,
})
},
};
match format {
ModelFormat::Gguf => prepare_tokens_gguf(config, &prompt),
ModelFormat::SafeTensors => prepare_tokens_safetensors(config, &prompt),
ModelFormat::Apr => prepare_tokens_apr(config, &prompt),
}
}
fn prepare_tokens_gguf(config: &InferenceConfig, prompt: &str) -> Result<PreparedTokens> {
use crate::chat_template::{format_messages, ChatMessage};
use crate::gguf::{GGUFValue, MappedGGUFModel};
let mapped = MappedGGUFModel::from_path(&config.model_path)?;
let gguf_arch = mapped.model.architecture().unwrap_or("transformer");
let has_chat_template = mapped
.model
.metadata
.get("tokenizer.chat_template")
.is_some_and(|v| matches!(v, GGUFValue::String(s) if !s.is_empty()));
let model_name = config
.model_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("");
let filename_instruct = model_name.to_lowercase().contains("instruct")
|| model_name.to_lowercase().contains("-chat");
let formatted_prompt = if has_chat_template || filename_instruct {
let template_hint = apr_arch_to_template_hint(gguf_arch, model_name);
let messages = vec![ChatMessage::user(prompt)];
format_messages(&messages, Some(template_hint)).unwrap_or_else(|_| prompt.to_string())
} else {
prompt.to_string()
};
if config.verbose {
eprintln!(
"[DEBUG] has_chat_template={}, filename_instruct={}",
has_chat_template, filename_instruct
);
eprintln!(
"[DEBUG] formatted_prompt={:?}",
&formatted_prompt[..formatted_prompt.len().min(200)]
);
}
let mut tokens = mapped.model.encode(&formatted_prompt).ok_or_else(|| {
RealizarError::InferenceError(format!(
"Tokenizer encode failed for GGUF model (no tokenizer data in GGUF file?). \
Prompt length: {} chars",
formatted_prompt.len()
))
})?;
let add_bos = match mapped
.model
.metadata
.get(crate::gguf::keys::TOKENIZER_ADD_BOS)
{
Some(GGUFValue::Bool(b)) => *b,
_ => {
let arch = mapped
.model
.metadata
.get(crate::gguf::keys::GENERAL_ARCHITECTURE)
.and_then(|v| {
if let GGUFValue::String(s) = v {
Some(s.as_str())
} else {
None
}
})
.unwrap_or("unknown");
let constraints = crate::gguf::ArchConstraints::from_architecture(arch);
constraints.positional_encoding != crate::gguf::PositionalEncoding::Absolute
},
};
if add_bos {
if let Some(bos_id) = mapped.model.bos_token_id() {
if tokens.first() != Some(&bos_id) {
tokens.insert(0, bos_id);
}
}
}
if config.verbose {
eprintln!(
"[DEBUG] add_bos={}, encoded {} tokens: {:?}",
add_bos,
tokens.len(),
&tokens[..tokens.len().min(30)]
);
}
Ok(PreparedTokens {
input_count: tokens.len(),
tokens,
})
}
fn prepare_tokens_safetensors(config: &InferenceConfig, prompt: &str) -> Result<PreparedTokens> {
use crate::apr::AprV2Model;
use crate::chat_template::{format_messages, ChatMessage};
use crate::safetensors::SafetensorsConfig;
let st_config = SafetensorsConfig::load_from_sibling(&config.model_path);
let architecture = st_config
.as_ref()
.map(SafetensorsConfig::architecture)
.unwrap_or_default();
let model_name = config
.model_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("");
let arch_lower = architecture.to_lowercase();
let is_instruct = arch_lower.contains("instruct")
|| model_name.to_lowercase().contains("instruct")
|| matches!(
arch_lower.as_str(),
"qwen2forcausallm" | "llamaforcausallm" | "mistralforcausallm" | "phiforcausallm"
);
let formatted_prompt = if is_instruct {
let template_hint = safetensors_arch_to_template_hint(&architecture, model_name);
let messages = vec![ChatMessage::user(prompt)];
format_messages(&messages, Some(template_hint)).unwrap_or_else(|_| prompt.to_string())
} else {
prompt.to_string()
};
let tokens =
AprV2Model::encode_text(&config.model_path, &formatted_prompt).ok_or_else(|| {
RealizarError::InferenceError(format!(
"Tokenizer encode failed for SafeTensors model (no tokenizer.json sibling?). \
Prompt length: {} chars",
formatted_prompt.len()
))
})?;
Ok(PreparedTokens {
input_count: tokens.len(),
tokens,
})
}
fn prepare_tokens_apr(config: &InferenceConfig, prompt: &str) -> Result<PreparedTokens> {
use crate::apr::AprV2Model;
use crate::chat_template::{format_messages, ChatMessage};
let model_name = config
.model_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("");
let (apr_arch, has_chatml_tokens) =
if config.model_path.extension().is_some_and(|e| e == "apr") {
match AprV2Model::load(&config.model_path) {
Ok(model) => {
let arch = model.metadata().architecture.clone().unwrap_or_default();
let has_chatml = model.metadata().get_embedded_vocabulary().is_some_and(
|vocab: Vec<String>| vocab.iter().any(|t| t == "<|im_start|>"),
);
(arch, has_chatml)
},
Err(_) => (String::new(), false),
}
} else {
(String::new(), false)
};
let is_instruct_arch = matches!(
apr_arch.to_lowercase().as_str(),
"qwen2" | "qwen" | "llama" | "mistral" | "phi" | "phi3"
);
let filename_instruct = model_name.to_lowercase().contains("instruct");
let is_instruct = is_instruct_arch || has_chatml_tokens || filename_instruct;
let formatted_prompt = if is_instruct {
let template_hint = apr_arch_to_template_hint(&apr_arch, model_name);
let messages = vec![ChatMessage::user(prompt)];
format_messages(&messages, Some(template_hint)).unwrap_or_else(|_| prompt.to_string())
} else {
prompt.to_string()
};
let tokens =
AprV2Model::encode_text(&config.model_path, &formatted_prompt).ok_or_else(|| {
RealizarError::InferenceError(format!(
"Tokenizer encode failed for APR model (no tokenizer in APR metadata?). \
Prompt length: {} chars",
formatted_prompt.len()
))
})?;
Ok(PreparedTokens {
input_count: tokens.len(),
tokens,
})
}
fn safetensors_arch_to_template_hint(architecture: &str, _model_name: &str) -> &'static str {
crate::tensor_names::normalize_architecture(architecture)
}
include!("inference_result.rs");
include!("gguf_gpu_generate.rs");
include!("mod_log_transformer_eos.rs");
include!("mod_05.rs");
include!("batch.rs");