use crate::error::{CuttleError, Result};
use crate::model::{Model, ModelConfig};
use crate::tensor::{Tensor, Tensor1D, Tensor2D, Tensor3D};
use crate::tokenizer::{Tokenizer, TokenizerConfig};
use log::{debug, info, warn};
use std::path::Path;
use tokio::process;
use vectra::CmpExt;
#[derive(Debug, Clone)]
pub struct InferenceConfig {
pub max_length: usize,
pub temperature: f32,
pub top_p: f32,
pub top_k: usize,
pub do_sample: bool,
pub repetition_penalty: f32,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
max_length: 512,
temperature: 1.0,
top_p: 0.9,
top_k: 50,
do_sample: true,
repetition_penalty: 1.1,
}
}
}
#[derive(Debug)]
pub struct InferenceEngine {
model: Model,
tokenizer: Tokenizer,
config: InferenceConfig,
}
impl InferenceEngine {
pub fn new(model: Model, tokenizer: Tokenizer) -> Self {
Self {
model,
tokenizer,
config: InferenceConfig::default(),
}
}
pub fn with_config(model: Model, tokenizer: Tokenizer, config: InferenceConfig) -> Self {
Self {
model,
tokenizer,
config,
}
}
pub fn from_config_files<P1, P2>(model_config_path: P1, tokenizer_path: P2) -> Result<Self>
where
P1: AsRef<Path>,
P2: AsRef<Path>,
{
info!(
"Loading model from config file: {:?}",
model_config_path.as_ref()
);
let model = Model::from_config_file(model_config_path)?;
info!("Loading tokenizer from file: {:?}", tokenizer_path.as_ref());
let tokenizer = Tokenizer::load(tokenizer_path)?;
Ok(Self::new(model, tokenizer))
}
pub fn generate(&self, prompt: &str) -> Result<String> {
info!("Starting text generation for prompt: {}", prompt);
let input_ids = self.tokenizer.encode(prompt)?;
debug!("Encoded input IDs: {:?}", input_ids);
if input_ids.is_empty() {
return Err(CuttleError::InferenceError(
"Empty input after tokenization".to_string(),
));
}
let generated_ids = self.generate_tokens(&input_ids)?;
let generated_text = self.tokenizer.decode(&generated_ids)?;
info!("Generated text: {}", generated_text);
Ok(generated_text)
}
fn generate_tokens(&self, input_ids: &[usize]) -> Result<Vec<usize>> {
let mut current_ids = input_ids.to_vec();
let max_new_tokens = self.config.max_length.saturating_sub(input_ids.len());
debug!("Generating up to {} new tokens", max_new_tokens);
for step in 0..max_new_tokens {
debug!("Generation step {}/{}", step + 1, max_new_tokens);
let logits = self.model.forward(¤t_ids)?;
let last_logits = self.extract_last_logits(&logits);
let processed_logits = self.process_logits(&last_logits, ¤t_ids)?;
let next_token = self.sample_next_token(&processed_logits)?;
if let Some(eos_id) = self.tokenizer.eos_token_id() {
if next_token == eos_id {
debug!("Generated EOS token, stopping generation");
break;
}
}
current_ids.push(next_token);
debug!("Generated token: {}", next_token);
}
Ok(current_ids[input_ids.len()..].to_vec())
}
fn extract_last_logits(&self, logits: &Tensor3D) -> Tensor3D {
let shape = logits.shape();
let vocab_size = shape[2];
Tensor3D::randn([1, 1, vocab_size])
}
fn process_logits(&self, logits: &Tensor3D, generated_ids: &[usize]) -> Result<Tensor3D> {
let mut processed = logits.clone();
if self.config.temperature != 1.0 {
processed = processed.mul_scalar(1.0 / self.config.temperature);
}
if self.config.repetition_penalty != 1.0 {
processed = self.apply_repetition_penalty(&processed, generated_ids)?;
}
Ok(processed)
}
fn apply_repetition_penalty(
&self,
logits: &Tensor3D,
generated_ids: &[usize],
) -> Result<Tensor3D> {
let penalized = logits.clone();
for &token_id in generated_ids {
if token_id < logits.shape()[0] {
debug!("Applying repetition penalty to token {}", token_id);
}
}
Ok(penalized)
}
fn sample_next_token(&self, logits: &Tensor3D) -> Result<usize> {
if !self.config.do_sample {
return self.greedy_sample(logits);
}
let probs = logits.clone();
let filtered_probs = self.apply_top_k_top_p_filtering(&probs)?;
self.multinomial_sample(&filtered_probs)
}
fn greedy_sample(&self, logits: &Tensor3D) -> Result<usize> {
let max_idx = logits
.data()
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.cmp_ext(b))
.map(|(idx, _)| idx)
.ok_or_else(|| CuttleError::InferenceError("Empty logits tensor".to_string()))?;
Ok(max_idx)
}
fn apply_top_k_top_p_filtering(&self, probs: &Tensor3D) -> Result<Tensor3D> {
let filtered = probs.clone();
debug!(
"Applying top_k={}, top_p={} filtering",
self.config.top_k, self.config.top_p
);
Ok(filtered)
}
fn multinomial_sample(&self, probs: &Tensor3D) -> Result<usize> {
let data = probs.data();
let random_val = (data.len() as f32 * 0.5) as usize % data.len();
Ok(random_val)
}
pub fn generate_batch(&self, prompts: &[String]) -> Result<Vec<String>> {
prompts.iter().map(|prompt| self.generate(prompt)).collect()
}
pub fn set_config(&mut self, config: InferenceConfig) {
self.config = config;
}
pub fn config(&self) -> &InferenceConfig {
&self.config
}
pub fn model(&self) -> &Model {
&self.model
}
pub fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
pub fn perplexity(&self, text: &str) -> Result<f32> {
let input_ids = self.tokenizer.encode(text)?;
if input_ids.len() < 2 {
return Err(CuttleError::InferenceError(
"Text too short for perplexity calculation".to_string(),
));
}
let mut total_log_prob = 0.0;
let mut count = 0;
for i in 1..input_ids.len() {
let context = &input_ids[..i];
let target = input_ids[i];
let logits = self.model.forward(context)?;
let last_logits = self.extract_last_logits(&logits);
let probs = last_logits;
let prob_data = probs.data();
if target < prob_data.len() {
let prob = prob_data[target].max(1e-10); total_log_prob += prob.ln();
count += 1;
}
}
if count == 0 {
return Err(CuttleError::InferenceError(
"No valid tokens for perplexity".to_string(),
));
}
let avg_log_prob = total_log_prob / count as f32;
Ok((-avg_log_prob).exp())
}
pub fn model_info(&self) -> ModelInfo {
let config = self.model.config();
ModelInfo {
vocab_size: config.vocab_size,
hidden_size: config.hidden_size,
num_layers: config.num_layers,
num_attention_heads: config.num_attention_heads,
max_position_embeddings: config.max_position_embeddings,
tokenizer_vocab_size: self.tokenizer.vocab_size(),
}
}
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub num_attention_heads: usize,
pub max_position_embeddings: usize,
pub tokenizer_vocab_size: usize,
}
impl std::fmt::Display for ModelInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Model Info:\n")?;
write!(f, " Vocabulary Size: {}\n", self.vocab_size)?;
write!(f, " Hidden Size: {}\n", self.hidden_size)?;
write!(f, " Number of Layers: {}\n", self.num_layers)?;
write!(f, " Attention Heads: {}\n", self.num_attention_heads)?;
write!(
f,
" Max Position Embeddings: {}\n",
self.max_position_embeddings
)?;
write!(f, " Tokenizer Vocab Size: {}", self.tokenizer_vocab_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::create_default_tokenizer;
#[test]
fn test_inference_engine_creation() {
let model_config = ModelConfig::default();
let model = Model::new(model_config).unwrap();
let tokenizer = create_default_tokenizer();
let engine = InferenceEngine::new(model, tokenizer);
assert_eq!(engine.config().max_length, 512);
}
}