gllm 0.10.6

Pure Rust library for local embeddings, reranking, and text generation with MoE-optimized inference and aggressive performance tuning
Documentation
use burn::tensor::activation::softmax;
use burn::tensor::backend::Backend;
use burn::tensor::{ElementConversion, Int, Tensor};
use rand::distributions::WeightedIndex;
use rand::prelude::*;

pub struct SamplingConfig {
    pub temperature: f32,
    pub top_p: f32,
    pub top_k: usize,
}

pub fn sample_next_token<B: Backend>(
    logits: Tensor<B, 2>,
    config: &SamplingConfig,
    device: &B::Device,
) -> Vec<i64> {
    let [batch_size, vocab_size] = logits.dims();
    if config.temperature <= 0.0 {
        let indices = match logits.argmax(1).into_data().into_vec::<B::IntElem>() {
            Ok(data) => data,
            Err(_) => return vec![0; batch_size],
        };
        return indices.into_iter().map(|v| v.elem::<i64>()).collect();
    }

    let logits = if (config.temperature - 1.0).abs() > f32::EPSILON {
        logits / config.temperature
    } else {
        logits
    };

    let apply_top_p = config.top_p > 0.0 && config.top_p < 1.0;
    let k = if config.top_k > 0 {
        config.top_k.min(vocab_size)
    } else {
        vocab_size
    };

    let (values, indices) = if config.top_k > 0 || apply_top_p {
        logits.topk_with_indices(k, 1)
    } else {
        let indices = Tensor::<B, 1, Int>::arange(0..vocab_size as i64, device)
            .unsqueeze_dim::<2>(0)
            .repeat(&[batch_size, 1]);
        (logits, indices)
    };
    let probs = softmax(values, 1);

    let prob_data = match probs.into_data().into_vec::<f32>() {
        Ok(data) => data,
        Err(_) => return vec![0; batch_size],
    };
    let index_data = match indices.into_data().into_vec::<B::IntElem>() {
        Ok(data) => data,
        Err(_) => return vec![0; batch_size],
    };

    let mut rng = thread_rng();
    let mut outputs = Vec::with_capacity(batch_size);
    let stride = if config.top_k > 0 || apply_top_p {
        k
    } else {
        vocab_size
    };

    for batch in 0..batch_size {
        let start = batch * stride;
        let end = start + stride;
        let row_probs = &prob_data[start..end];
        let row_indices = &index_data[start..end];
        let mut candidates: Vec<(i64, f32)> = row_indices
            .iter()
            .zip(row_probs.iter())
            .map(|(idx, &prob)| (idx.elem::<i64>(), prob))
            .collect();

        if apply_top_p {
            let mut filtered = Vec::new();
            let mut filtered_weights = Vec::new();
            let mut cumulative = 0.0;

            for (token, weight) in candidates.iter() {
                filtered.push(*token);
                filtered_weights.push(*weight);
                cumulative += *weight;
                if cumulative >= config.top_p {
                    break;
                }
            }

            candidates = filtered.into_iter().map(|id| (id, 0.0)).collect();
            let weights = filtered_weights;

            let sample = WeightedIndex::new(&weights)
                .ok()
                .map(|dist| candidates[dist.sample(&mut rng)].0)
                .unwrap_or_else(|| candidates.first().map(|(id, _)| *id).unwrap_or(0));
            outputs.push(sample);
            continue;
        }

        let weights: Vec<f32> = candidates.iter().map(|(_, prob)| *prob).collect();
        let sample = WeightedIndex::new(&weights)
            .ok()
            .map(|dist| candidates[dist.sample(&mut rng)].0)
            .unwrap_or_else(|| candidates.first().map(|(id, _)| *id).unwrap_or(0));

        outputs.push(sample);
    }

    outputs
}