model-rs 0.1.1

A Rust CLI tool for downloading HuggingFace models and running local LLM inference
use crate::error::{ModelError, Result};
use std::time::{SystemTime, UNIX_EPOCH};

pub fn do_sample(
    logits: &[f32],
    temperature: f32,
    top_p: f32,
    top_k: Option<usize>,
) -> Result<u32> {
    let vocab_size = logits.len();

    if temperature == 0.0 {
        let max_idx = logits
            .iter()
            .enumerate()
            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
            .map(|(idx, _)| idx)
            .unwrap_or(0);
        return Ok(max_idx as u32);
    }

    let scaled: Vec<f32> = logits.iter().map(|&logit| logit / temperature).collect();

    let top_k = top_k.unwrap_or(vocab_size);
    let mut sorted_indices: Vec<usize> = (0..vocab_size).collect();
    sorted_indices.sort_by(|&a, &b| {
        scaled[b]
            .partial_cmp(&scaled[a])
            .unwrap_or(std::cmp::Ordering::Equal)
    });

    sorted_indices.truncate(top_k);

    let mut probs: Vec<f32> = sorted_indices
        .iter()
        .map(|&i| {
            let logit = scaled[i];
            logit.exp()
        })
        .collect();

    let sum: f32 = probs.iter().sum();
    if sum == 0.0 {
        return Ok(sorted_indices[0] as u32);
    }

    for prob in probs.iter_mut() {
        *prob /= sum;
    }

    if top_p < 1.0 {
        let mut cumulative = 0.0;
        let mut cutoff_idx = sorted_indices.len();

        for (idx, &prob) in probs.iter().enumerate() {
            cumulative += prob;
            if cumulative >= top_p {
                cutoff_idx = idx + 1;
                break;
            }
        }

        sorted_indices.truncate(cutoff_idx);
        probs.truncate(cutoff_idx);

        let sum: f32 = probs.iter().sum();
        if sum > 0.0 {
            for prob in probs.iter_mut() {
                *prob /= sum;
            }
        }
    }

    let nanos = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .map_err(|e| ModelError::LocalModelError(format!("Failed to get time: {}", e)))?
        .as_nanos() as f32;

    let random_value = (nanos % 1000000.0) / 1000000.0;
    let mut cumulative = 0.0;

    for (idx, &prob) in probs.iter().enumerate() {
        cumulative += prob;
        if random_value <= cumulative {
            return Ok(sorted_indices[idx] as u32);
        }
    }

    Ok(sorted_indices[sorted_indices.len() - 1] as u32)
}