llm-samplers 0.0.3

Token samplers for large language models
Documentation
use num_traits::Float;
use rand::Rng;

use crate::{
    rand::*,
    samplers::{rand_distrib::*, top_k::*},
    types::*,
};

/// Mirostat v2 sampling
pub struct SampleMirostat2<TID, L, R> {
    tau: L,
    eta: L,
    mu: L,
    token: Option<TID>,
    rd_sampler: RandDistribSampler<TID, R>,
}

impl<TID: CanTokenId, L: Float, R: Rng> SampleMirostat2<TID, L, R> {
    pub fn new<WR: WithRng<Rng = R, Output = usize> + Send + Sync + 'static>(
        tau: L,
        eta: L,
        initial_mu: L,
        rng: Box<WR>,
    ) -> Self {
        Self {
            tau,
            eta,
            mu: initial_mu,
            rd_sampler: RandDistribSampler::<TID, R>::new(rng),
            token: None,
        }
    }
}

// FIXME: Support logit types other than f32?
impl<TID: CanTokenId, R: Rng + Send + Sync> Sampler<TID, f32> for SampleMirostat2<TID, f32, R> {
    fn sample<'a>(
        &mut self,
        logits: &'a mut Logits<TID, f32>,
    ) -> Result<&'a mut Logits<TID, f32>, SamplerError> {
        self.token = None;
        if logits.is_empty() {
            return Ok(logits);
        }

        let Self { tau, eta, mu, .. } = *self;

        logits.softmax()?;
        let new_size = logits
            .iter()
            .enumerate()
            .find_map(|(idx, l)| (-l.prob.log2() > mu).then_some(idx))
            .unwrap_or_default()
            .max(1);
        logits.truncate(new_size);
        logits.softmax()?;
        self.rd_sampler.sample(logits)?;

        if let Some(tid) = self.rd_sampler.sample_token(logits)? {
            let logit = logits.iter().find(|l| l.token_id == tid).ok_or_else(|| {
                SamplerError::InternalError(String::from("Impossible: sample token not in logits?"))
            })?;

            self.mu -= eta * (-logit.prob.log2() - tau);
            self.token = Some(tid);
        }
        Ok(logits)
    }

    fn sampled_token_id(&self) -> Option<TID> {
        self.token
    }
}

/// Mirostat v1 sampling
pub struct SampleMirostat1<TID, L, R> {
    n_vocab: usize,
    tau: L,
    eta: L,
    m: usize,
    mu: L,
    token: Option<TID>,
    rd_sampler: RandDistribSampler<TID, R>,
}

impl<TID: CanTokenId, L: Float, R: Rng> SampleMirostat1<TID, L, R> {
    pub fn new<WR: WithRng<Rng = R, Output = usize> + Send + Sync + 'static>(
        n_vocab: usize,
        tau: L,
        eta: L,
        m: usize,
        initial_mu: L,
        rng: Box<WR>,
    ) -> Self {
        Self {
            n_vocab,
            tau,
            eta,
            m,
            mu: initial_mu,
            rd_sampler: RandDistribSampler::<TID, R>::new(rng),
            token: None,
        }
    }
}

// FIXME: Support logit types other than f32?
impl<TID: CanTokenId, R: Rng + Send + Sync> Sampler<TID, f32> for SampleMirostat1<TID, f32, R> {
    fn sample<'a>(
        &mut self,
        logits: &'a mut Logits<TID, f32>,
    ) -> Result<&'a mut Logits<TID, f32>, SamplerError> {
        let Self {
            n_vocab,
            tau,
            eta,
            m,
            mu,
            ..
        } = *self;
        self.token = None;
        if logits.is_empty() || m < 1 {
            return Ok(logits);
        }
        let n_vocab = n_vocab as f32;

        logits.softmax()?;
        let (sum_ti_bi, sum_ti_sq) = logits
            .iter()
            .zip(logits.iter().skip(1))
            .enumerate()
            .take((m - 1).min(logits.len() - 1))
            .fold((0.0, 0.0), |(sum_ti_bi, sum_ti_sq), (idx, (l, l_next))| {
                let t_i = ((idx + 2) as f32 / (idx + 1) as f32).ln();
                let b_i = l.prob / l_next.prob;
                (sum_ti_bi + t_i * b_i, sum_ti_sq + t_i * t_i)
            });
        let s_hat = sum_ti_bi / sum_ti_sq;
        let epsilon_hat = s_hat - 1.0;
        let k = (epsilon_hat * mu.powf(2.0) / 1.0 - n_vocab.powf(-epsilon_hat)).powf(1.0 / s_hat)
            as usize;
        logits.sample(&mut SampleTopK::new(k, 1))?;

        if let Some(tid) = self.rd_sampler.sample_token(logits)? {
            let logit = logits.iter().find(|l| l.token_id == tid).ok_or_else(|| {
                SamplerError::InternalError(String::from("Impossible: sample token not in logits?"))
            })?;

            self.mu -= eta * (-logit.prob.log2() - tau);
            self.token = Some(tid);
        }
        Ok(logits)
    }

    fn sampled_token_id(&self) -> Option<TID> {
        self.token
    }
}