use crate::{config::ModelConfig, gpt::GPTModel, weights::ModelWeights, Result};
use dotenv::dotenv;
use hf_hub::api::sync::ApiBuilder;
use serde_json::Value;
use std::env;
use tokenizers::Tokenizer;
pub struct InferenceEngine {
model: GPTModel,
weights: ModelWeights,
tokenizer: Tokenizer,
config: ModelConfig,
}
impl InferenceEngine {
pub fn new(model_name: &str) -> Result<Self> {
Self::new_with_token(model_name, None)
}
pub fn new_with_token(model_name: &str, token: Option<String>) -> Result<Self> {
dotenv().ok();
let token = token.or_else(|| env::var("HF_TOKEN").ok());
let api = ApiBuilder::new()
.with_token(token)
.build()?;
let repo = api.model(model_name.to_string());
let tokenizer_path = repo.get("tokenizer.json")?;
let model_path = repo.get("model.safetensors")?;
let config_path = repo.get("config.json")?;
let config_str = std::fs::read_to_string(config_path)?;
let config_json: Value = serde_json::from_str(&config_str)?;
let config = ModelConfig::from_hf_config(&config_json)
.map_err(|e| format!("Failed to parse config: {}", e))?;
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
let weights = ModelWeights::load_from_safetensors(&model_path)
.map_err(|e| format!("Failed to load weights: {}", e))?;
let model = GPTModel::new(
config.num_layers as usize,
config.vocab_size as usize,
config.max_position_embeddings as usize,
config.hidden_size as usize,
config.num_attention_heads as usize,
config.hidden_size as usize * 4,
);
Ok(Self {
model,
weights,
tokenizer,
config,
})
}
pub fn generate(&self, prompt: &str, max_tokens: usize) -> Result<String> {
let encoded_prompt = self.tokenizer.encode(prompt, false)?;
let token_ids: Vec<u32> = encoded_prompt.get_ids().to_vec();
let generated_tokens = self.model.generate(&token_ids, max_tokens, &self.weights);
let generated_text = self.tokenizer.decode(&generated_tokens, true)?;
Ok(generated_text)
}
pub fn config(&self) -> &ModelConfig {
&self.config
}
pub fn tokenize(&self, text: &str) -> Result<Vec<u32>> {
let encoded = self.tokenizer.encode(text, false)?;
Ok(encoded.get_ids().to_vec())
}
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
Ok(self.tokenizer.decode(tokens, true)?)
}
}