use crate::error::{AumateError, Result};
use candle_core::Tensor;
use tokenizers::Tokenizer;
#[derive(Debug, Clone)]
pub struct DecodingConfig {
pub max_tokens: usize,
pub temperature: f64,
pub top_p: Option<f64>,
pub eos_token_id: u32,
pub bos_token_id: Option<u32>,
pub pad_token_id: Option<u32>,
pub greedy: bool,
}
impl Default for DecodingConfig {
fn default() -> Self {
Self {
max_tokens: 448, temperature: 1.0,
top_p: None,
greedy: true,
eos_token_id: 50257, bos_token_id: None,
pad_token_id: None,
}
}
}
impl DecodingConfig {
pub fn greedy(max_tokens: usize, eos_token_id: u32) -> Self {
Self { max_tokens, eos_token_id, greedy: true, ..Default::default() }
}
pub fn with_temperature(max_tokens: usize, eos_token_id: u32, temperature: f64) -> Self {
Self { max_tokens, eos_token_id, temperature, greedy: false, ..Default::default() }
}
}
pub struct TextDecoder {
tokenizer: Tokenizer,
}
impl TextDecoder {
pub fn from_file(path: &std::path::Path) -> Result<Self> {
let tokenizer = Tokenizer::from_file(path)
.map_err(|e| AumateError::Other(format!("Failed to load tokenizer: {}", e)))?;
Ok(Self { tokenizer })
}
pub fn new(tokenizer: Tokenizer) -> Self {
Self { tokenizer }
}
pub fn decode_token(&self, token_id: u32) -> Result<String> {
self.tokenizer
.decode(&[token_id], false)
.map_err(|e| AumateError::Other(format!("Failed to decode token: {}", e)))
}
pub fn decode(&self, token_ids: &[u32]) -> Result<String> {
self.tokenizer
.decode(token_ids, true)
.map_err(|e| AumateError::Other(format!("Failed to decode tokens: {}", e)))
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
let encoding = self
.tokenizer
.encode(text, false)
.map_err(|e| AumateError::Other(format!("Failed to encode text: {}", e)))?;
Ok(encoding.get_ids().to_vec())
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(true)
}
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.tokenizer.token_to_id(token)
}
pub fn id_to_token(&self, id: u32) -> Option<String> {
self.tokenizer.id_to_token(id)
}
}
#[allow(dead_code)]
pub fn sample_from_logits(logits: &Tensor, config: &DecodingConfig) -> Result<u32> {
let logits =
logits.squeeze(0).map_err(|e| AumateError::Ml(format!("Squeeze failed: {}", e)))?;
let logits =
logits.squeeze(0).map_err(|e| AumateError::Ml(format!("Squeeze failed: {}", e)))?;
if config.greedy {
let token_id = logits
.argmax(0)
.map_err(|e| AumateError::Ml(format!("Argmax failed: {}", e)))?
.to_scalar::<u32>()
.map_err(|e| AumateError::Ml(format!("To scalar failed: {}", e)))?;
return Ok(token_id);
}
let logits = if config.temperature != 1.0 {
(logits / config.temperature)
.map_err(|e| AumateError::Ml(format!("Temperature scaling failed: {}", e)))?
} else {
logits
};
let probs = candle_nn::ops::softmax(&logits, 0)
.map_err(|e| AumateError::Ml(format!("Softmax failed: {}", e)))?;
let probs = if let Some(top_p) = config.top_p { apply_top_p(&probs, top_p)? } else { probs };
let probs_vec: Vec<f32> =
probs.to_vec1().map_err(|e| AumateError::Ml(format!("To vec failed: {}", e)))?;
let token_id = sample_from_distribution(&probs_vec)?;
Ok(token_id)
}
#[allow(dead_code)]
fn apply_top_p(probs: &Tensor, top_p: f64) -> Result<Tensor> {
let probs_vec: Vec<f32> =
probs.to_vec1().map_err(|e| AumateError::Ml(format!("To vec failed: {}", e)))?;
let mut indexed: Vec<(usize, f32)> = probs_vec.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut cutoff_idx = indexed.len();
for (i, (_, p)) in indexed.iter().enumerate() {
cumsum += *p as f64;
if cumsum >= top_p {
cutoff_idx = i + 1;
break;
}
}
let mut new_probs = vec![0.0f32; probs_vec.len()];
for (idx, prob) in indexed.iter().take(cutoff_idx) {
new_probs[*idx] = *prob;
}
let sum: f32 = new_probs.iter().sum();
if sum > 0.0 {
for p in &mut new_probs {
*p /= sum;
}
}
Tensor::from_vec(new_probs, probs.shape(), probs.device())
.map_err(|e| AumateError::Ml(format!("Failed to create tensor: {}", e)))
}
#[allow(dead_code)]
fn sample_from_distribution(probs: &[f32]) -> Result<u32> {
use rand::Rng;
let mut rng = rand::rng();
let r: f32 = rng.random();
let mut cumsum = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if r < cumsum {
return Ok(i as u32);
}
}
Ok((probs.len() - 1) as u32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoding_config_default() {
let config = DecodingConfig::default();
assert!(config.greedy);
assert_eq!(config.max_tokens, 448);
}
#[test]
fn test_decoding_config_greedy() {
let config = DecodingConfig::greedy(100, 50256);
assert!(config.greedy);
assert_eq!(config.max_tokens, 100);
assert_eq!(config.eos_token_id, 50256);
}
#[test]
fn test_sample_from_distribution() {
let probs = vec![0.0, 0.0, 1.0, 0.0];
let result = sample_from_distribution(&probs).unwrap();
assert_eq!(result, 2);
}
}