boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
//! Generic implementation of sampling operations.
//!
//! Same algorithm on all backends. CPU backend has zero-overhead;
//! CUDA backend provides fused kernels that avoid D2H/H2D transfers.

use crate::error::Result;
use numr::dtype::DType;
use numr::ops::RandomOps;
use numr::runtime::Runtime;
use numr::tensor::Tensor;

/// Apply sampling penalties to logits — generic implementation.
///
/// Pulls logits and penalty data to CPU, applies penalties, writes back.
/// CUDA backend overrides this with a fused kernel.
pub fn apply_sampling_penalties_impl<R: Runtime>(
    _client: &R::Client,
    logits: &Tensor<R>,
    token_ids: &Tensor<R>,
    token_counts: &Tensor<R>,
    repeat_penalty: f32,
    frequency_penalty: f32,
    presence_penalty: f32,
) -> Result<()> {
    let mut logits_vec: Vec<f32> = logits.to_vec();
    let ids_vec: Vec<i64> = token_ids.to_vec();
    let counts_vec: Vec<i32> = token_counts.to_vec();

    for (&token_id, &count) in ids_vec.iter().zip(counts_vec.iter()) {
        let i = token_id as usize;
        if i >= logits_vec.len() {
            continue;
        }

        // Repetition penalty (llama.cpp style)
        if repeat_penalty != 1.0 {
            if logits_vec[i] > 0.0 {
                logits_vec[i] /= repeat_penalty;
            } else {
                logits_vec[i] *= repeat_penalty;
            }
        }

        // Frequency penalty: proportional to count
        if frequency_penalty != 0.0 {
            logits_vec[i] -= frequency_penalty * count as f32;
        }

        // Presence penalty
        if presence_penalty != 0.0 {
            logits_vec[i] -= presence_penalty;
        }
    }

    // Write modified logits back to device
    let bytes: &[u8] = unsafe {
        std::slice::from_raw_parts(logits_vec.as_ptr() as *const u8, logits_vec.len() * 4)
    };
    R::copy_to_device(bytes, logits.ptr(), logits.device()).map_err(|e| {
        crate::error::Error::Numr(numr::error::Error::Internal(format!(
            "Failed to write back penalized logits: {}",
            e
        )))
    })?;

    Ok(())
}

/// Sample a token from logits — generic implementation.
///
/// Performs temperature → softmax → top-k → top-p → min-p → multinomial.
/// Randomness generated via `RandomOps::rand` (on-device for GPU backends).
/// CUDA backend provides a fused kernel that keeps everything in a single launch.
pub fn sample_token_impl<R: Runtime>(
    client: &R::Client,
    logits: &Tensor<R>,
    temperature: f32,
    top_k: usize,
    top_p: f32,
    min_p: f32,
) -> Result<u32>
where
    R::Client: RandomOps<R>,
{
    let mut logits_vec: Vec<f32> = logits.to_vec();

    // Temperature scaling
    if temperature != 1.0 {
        let inv_temp = 1.0 / temperature;
        for l in logits_vec.iter_mut() {
            *l *= inv_temp;
        }
    }

    // Softmax
    let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
    let mut probs: Vec<f32> = logits_vec.iter().map(|&l| (l - max_logit).exp()).collect();
    let sum: f32 = probs.iter().sum();
    for p in probs.iter_mut() {
        *p /= sum;
    }

    // Build sorted (index, prob) pairs
    let mut indexed: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

    // Top-k filter
    if top_k > 0 && top_k < indexed.len() {
        indexed.truncate(top_k);
    }

    // Top-p filter
    if top_p < 1.0 {
        let mut cumsum = 0.0f32;
        let mut cutoff = indexed.len();
        for (i, (_, p)) in indexed.iter().enumerate() {
            cumsum += p;
            if cumsum > top_p {
                cutoff = i + 1;
                break;
            }
        }
        indexed.truncate(cutoff);
    }

    // Min-p filter
    if min_p > 0.0 && !indexed.is_empty() {
        let max_prob = indexed[0].1;
        let threshold = min_p * max_prob;
        indexed.retain(|(_, p)| *p >= threshold);
    }

    // Generate random value on-device via RandomOps, read back scalar
    let rand_tensor = client
        .rand(&[1], numr::dtype::DType::F32)
        .map_err(crate::error::Error::Numr)?;
    let random_val: f32 = rand_tensor.to_vec::<f32>()[0];

    // Renormalize and sample
    let total: f32 = indexed.iter().map(|(_, p)| p).sum();
    let mut cumsum = 0.0f32;
    for (i, p) in &indexed {
        cumsum += p / total;
        if cumsum > random_val {
            return Ok(*i as u32);
        }
    }

    Ok(indexed.last().map(|(i, _)| *i as u32).unwrap_or(0))
}

/// Fused logits-to-token — generic implementation.
///
/// Narrows to last seq position, casts to F32, applies penalties, then either
/// argmax (temperature == 0) or full stochastic sampling. Returns `[1]` I64 tensor.
#[allow(clippy::too_many_arguments)]
pub fn logits_to_token_impl<R: Runtime<DType = numr::dtype::DType>>(
    client: &R::Client,
    logits: &Tensor<R>,
    token_ids: &Tensor<R>,
    token_counts: &Tensor<R>,
    num_unique: usize,
    repeat_penalty: f32,
    frequency_penalty: f32,
    presence_penalty: f32,
    temperature: f32,
    top_k: usize,
    top_p: f32,
    min_p: f32,
    seed: Option<u64>,
) -> Result<Tensor<R>>
where
    R::Client: numr::ops::RandomOps<R> + numr::ops::TypeConversionOps<R>,
{
    // 1. Read logits at last seq position (cast to F32 if needed)
    let logits = if logits.dtype() != DType::F32 {
        use numr::ops::TypeConversionOps;
        client
            .cast(logits, DType::F32)
            .map_err(crate::error::Error::Numr)?
    } else {
        logits.clone()
    };
    let shape = logits.shape();
    if shape.len() < 3 {
        return Err(crate::error::Error::InvalidArgument {
            arg: "logits",
            reason: format!("expected rank >= 3, got rank {}", shape.len()),
        });
    }
    let seq_len = shape[1];
    let vocab_size = shape[2];
    if seq_len == 0 || vocab_size == 0 {
        return Err(crate::error::Error::InvalidArgument {
            arg: "logits",
            reason: format!("seq_len and vocab_size must be > 0, got shape {:?}", shape),
        });
    }
    let all_logits: Vec<f32> = logits.to_vec();
    let offset = (seq_len - 1) * vocab_size;
    let mut last_logits: Vec<f32> = all_logits[offset..offset + vocab_size].to_vec();

    // 2. Apply penalties
    if num_unique > 0 {
        let ids_vec: Vec<i64> = token_ids.to_vec();
        let counts_vec: Vec<i32> = token_counts.to_vec();
        let penalty_count = num_unique.min(ids_vec.len()).min(counts_vec.len());

        for idx in 0..penalty_count {
            let token_id = ids_vec[idx] as usize;
            if token_id >= vocab_size {
                continue;
            }
            let count = counts_vec[idx];

            if repeat_penalty != 1.0 {
                if last_logits[token_id] > 0.0 {
                    last_logits[token_id] /= repeat_penalty;
                } else {
                    last_logits[token_id] *= repeat_penalty;
                }
            }
            if frequency_penalty != 0.0 {
                last_logits[token_id] -= frequency_penalty * count as f32;
            }
            if presence_penalty != 0.0 {
                last_logits[token_id] -= presence_penalty;
            }
        }
    }

    // 3. Greedy or stochastic
    let token_id = if temperature == 0.0 {
        // Argmax
        let mut best_idx = 0usize;
        let mut best_val = f32::NEG_INFINITY;
        for (i, &v) in last_logits.iter().enumerate() {
            if v > best_val {
                best_val = v;
                best_idx = i;
            }
        }
        best_idx as i64
    } else {
        // Temperature scaling
        let inv_temp = 1.0 / temperature;
        for l in last_logits.iter_mut() {
            *l *= inv_temp;
        }

        // Softmax
        let max_logit = last_logits
            .iter()
            .cloned()
            .fold(f32::NEG_INFINITY, f32::max);
        let mut probs: Vec<f32> = last_logits.iter().map(|&l| (l - max_logit).exp()).collect();
        let sum: f32 = probs.iter().sum();
        for p in probs.iter_mut() {
            *p /= sum;
        }

        // Build sorted (index, prob)
        let mut indexed: Vec<(usize, f32)> =
            probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        // Top-k
        if top_k > 0 && top_k < indexed.len() {
            indexed.truncate(top_k);
        }

        // Top-p
        if top_p < 1.0 {
            let mut cumsum = 0.0f32;
            let mut cutoff = indexed.len();
            for (i, (_, p)) in indexed.iter().enumerate() {
                cumsum += p;
                if cumsum > top_p {
                    cutoff = i + 1;
                    break;
                }
            }
            indexed.truncate(cutoff);
        }

        // Min-p
        if min_p > 0.0 && !indexed.is_empty() {
            let max_prob = indexed[0].1;
            let threshold = min_p * max_prob;
            indexed.retain(|(_, p)| *p >= threshold);
        }

        // Random value (seeded for reproducibility if seed is provided)
        let rand_tensor = if let Some(s) = seed {
            client
                .rand_seeded(&[1], numr::dtype::DType::F32, s)
                .map_err(crate::error::Error::Numr)?
        } else {
            client
                .rand(&[1], numr::dtype::DType::F32)
                .map_err(crate::error::Error::Numr)?
        };
        let random_val: f32 = rand_tensor.to_vec::<f32>()[0];

        // Renormalize and sample
        let total: f32 = indexed.iter().map(|(_, p)| p).sum();
        let mut cumsum = 0.0f32;
        let mut sampled = indexed.last().map(|(i, _)| *i).unwrap_or(0);
        for (i, p) in &indexed {
            cumsum += p / total;
            if cumsum > random_val {
                sampled = *i;
                break;
            }
        }
        sampled as i64
    };

    Ok(Tensor::from_slice(&[token_id], &[1], logits.device()))
}