burn_dragon_language 0.4.0

Language modeling components for burn_dragon
Documentation
use anyhow::{Result, anyhow};
use burn::tensor::backend::Backend;
use burn::tensor::{Int, Tensor, TensorData};
use rand::distributions::WeightedIndex;
use rand::prelude::*;
use std::cmp::Ordering;

use burn_dragon_core::{BDH, ModelState};

use crate::config::ContextStrategyConfig;
use crate::tokenizer::Tokenizer;
use crate::GenerationConfig;

#[derive(Clone, Copy, Debug)]
pub enum ContextStrategy {
    Infinite,
    Sliding { window: usize },
}

#[derive(Clone, Copy, Debug)]
pub struct GenerationSettings {
    pub max_new_tokens: Option<usize>,
    pub temperature: f32,
    pub top_k: Option<usize>,
    pub strategy: ContextStrategy,
}

fn sample_from_logits_values(mut logits_values: Vec<f32>, top_k: Option<usize>) -> Result<i64> {
    let vocab = logits_values.len();
    if vocab == 0 {
        return Err(anyhow!("logits are empty"));
    }

    if let Some(k) = top_k
        && k > 0
        && k < vocab
    {
        let mut sorted = logits_values.clone();
        sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(Ordering::Equal));
        let threshold = sorted[k - 1];
        for value in logits_values.iter_mut() {
            if *value < threshold {
                *value = f32::NEG_INFINITY;
            }
        }
    }

    let max_logit = logits_values
        .iter()
        .copied()
        .fold(f32::NEG_INFINITY, f32::max);
    let mut probs: Vec<f32> = logits_values
        .iter()
        .map(|value| (value - max_logit).exp())
        .collect();
    let sum: f32 = probs.iter().sum();
    if sum == 0.0 || sum.is_nan() {
        let uniform = 1.0 / vocab as f32;
        for p in probs.iter_mut() {
            *p = uniform;
        }
    } else {
        for p in probs.iter_mut() {
            *p /= sum;
        }
    }

    let dist = WeightedIndex::new(&probs).map_err(|err| anyhow!(err.to_string()))?;
    let mut rng = thread_rng();
    Ok(dist.sample(&mut rng) as i64)
}

pub fn prefill_state<B: Backend>(
    model: &BDH<B>,
    prompt_tokens: &[i64],
    device: &B::Device,
) -> Result<(ModelState<B>, Tensor<B, 1>)> {
    let prompt_len = prompt_tokens.len();
    if prompt_len == 0 {
        return Err(anyhow!("prompt must contain at least one token"));
    }

    let prompt_tensor = Tensor::<B, 2, Int>::from_data(
        TensorData::new(prompt_tokens.to_vec(), [1, prompt_len]),
        device,
    );

    let mut state = model.init_state();
    let logits = model.forward_with_state(prompt_tensor, &mut state);
    let [_, time, vocab] = logits.shape().dims::<3>();
    if time != prompt_len {
        return Err(anyhow!(
            "prefill produced mismatched length: expected {prompt_len}, got {time}"
        ));
    }

    let last_logits = logits.slice_dim(1, (time - 1)..time).reshape([vocab]);

    #[cfg(feature = "viz")]
    state.clear_viz();

    Ok((state, last_logits))
}

pub fn sample_next_token<B: Backend>(
    model: &BDH<B>,
    state: &mut ModelState<B>,
    last_logits: Tensor<B, 1>,
    temperature: f32,
    top_k: Option<usize>,
    device: &B::Device,
) -> Result<(i64, Tensor<B, 1>)> {
    let logits_temp = last_logits.clone().div_scalar(temperature);
    let logits_values = logits_temp
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .map_err(|err| anyhow!("{err:?}"))?;
    let next = sample_from_logits_values(logits_values, top_k)?;

    let next_tensor = Tensor::<B, 2, Int>::from_data(TensorData::new(vec![next], [1, 1]), device);

    let logits = model.forward_with_state(next_tensor, state);
    let [_, time, vocab] = logits.shape().dims::<3>();
    let new_last_logits = logits.slice_dim(1, (time - 1)..time).reshape([vocab]);

    Ok((next, new_last_logits))
}

#[cfg(feature = "web")]
pub async fn sample_next_token_async<B: Backend>(
    model: &BDH<B>,
    state: &mut ModelState<B>,
    last_logits: Tensor<B, 1>,
    temperature: f32,
    top_k: Option<usize>,
    device: &B::Device,
) -> Result<(i64, Tensor<B, 1>)> {
    let logits_temp = last_logits.clone().div_scalar(temperature);
    let logits_values = logits_temp
        .to_data_async()
        .await
        .convert::<f32>()
        .into_vec::<f32>()
        .map_err(|err| anyhow!("{err:?}"))?;
    let next = sample_from_logits_values(logits_values, top_k)?;

    let next_tensor = Tensor::<B, 2, Int>::from_data(TensorData::new(vec![next], [1, 1]), device);

    let logits = model.forward_with_state(next_tensor, state);
    let [_, time, vocab] = logits.shape().dims::<3>();
    let new_last_logits = logits.slice_dim(1, (time - 1)..time).reshape([vocab]);

    Ok((next, new_last_logits))
}

pub fn generate_tokens<B: Backend>(
    model: &BDH<B>,
    prompt_tokens: Vec<i64>,
    device: &B::Device,
    settings: GenerationSettings,
    mut on_token: Option<&mut dyn FnMut(i64)>,
) -> Result<Vec<i64>> {
    let GenerationSettings {
        max_new_tokens,
        temperature,
        top_k,
        strategy,
    } = settings;

    let mut full_tokens = prompt_tokens;
    let (mut state, mut last_logits) = prefill_state(model, &full_tokens, device)?;
    let mut generated = 0usize;

    if let ContextStrategy::Sliding { window } = strategy
        && window > 0
        && state.position > window
    {
        state.trim(window);
    }

    while max_new_tokens.is_none_or(|max| generated < max) {
        let (next, logits) =
            sample_next_token(model, &mut state, last_logits, temperature, top_k, device)?;
        full_tokens.push(next);
        last_logits = logits;
        generated = generated.saturating_add(1);

        if let Some(callback) = &mut on_token {
            callback(next);
        }

        if let ContextStrategy::Sliding { window } = strategy
            && window > 0
            && state.position > window
        {
            state.trim(window);
        }
    }

    Ok(full_tokens)
}

pub fn generate_text<B: Backend>(
    model: &BDH<B>,
    tokenizer: &dyn Tokenizer,
    device: &B::Device,
    block_size: usize,
    generation: &GenerationConfig,
) -> Result<String> {
    let strategy = resolve_context_strategy(&generation.context_strategy, block_size);
    let mut prompt_ids = tokenizer.encode(&generation.prompt, false, false);
    if let ContextStrategy::Sliding { window } = strategy
        && prompt_ids.len() > window
    {
        prompt_ids = prompt_ids[prompt_ids.len() - window..].to_vec();
    }

    let prompt_tokens: Vec<i64> = prompt_ids.iter().map(|&id| id as i64).collect();
    let max_new_tokens = normalize_max_tokens(generation.max_tokens);
    let settings = GenerationSettings {
        max_new_tokens,
        temperature: generation.temperature,
        top_k: generation.top_k,
        strategy,
    };
    let tokens_all = generate_tokens(model, prompt_tokens, device, settings, None)?;

    let decoded_ids: Vec<u32> = tokens_all
        .iter()
        .filter_map(|&tok| (tok >= 0).then_some(tok as u32))
        .collect();

    Ok(tokenizer.decode(&decoded_ids))
}

fn normalize_max_tokens(max_tokens: Option<i64>) -> Option<usize> {
    match max_tokens {
        Some(value) if value >= 0 => Some(value as usize),
        _ => None,
    }
}

pub fn resolve_context_strategy(
    config: &ContextStrategyConfig,
    default_window: usize,
) -> ContextStrategy {
    match config {
        ContextStrategyConfig::Infinite => ContextStrategy::Infinite,
        ContextStrategyConfig::Sliding { window } => {
            let win = if *window == 0 {
                default_window.max(1)
            } else {
                *window
            };
            ContextStrategy::Sliding { window: win }
        }
    }
}