use bumpalo::Bump;
use std::cell::RefCell;
use crate::token::TokenId;
pub struct GenerationArena {
arena: Bump,
capacity: usize,
}
impl GenerationArena {
pub fn new(capacity: usize) -> Self {
Self {
arena: Bump::with_capacity(capacity),
capacity,
}
}
#[inline]
pub fn alloc_tokens(&self, count: usize) -> &mut [TokenId] {
self.arena.alloc_slice_fill_default(count)
}
#[inline]
pub fn alloc_f32(&self, count: usize) -> &mut [f32] {
self.arena.alloc_slice_fill_default(count)
}
#[inline]
pub fn alloc_slice<T: Default + Copy>(&self, count: usize) -> &mut [T] {
self.arena.alloc_slice_fill_default(count)
}
#[inline]
pub fn alloc_slice_copy<T: Copy>(&self, src: &[T]) -> &mut [T] {
self.arena.alloc_slice_copy(src)
}
#[inline]
pub fn alloc<T>(&self, value: T) -> &mut T {
self.arena.alloc(value)
}
#[inline]
pub fn reset(&mut self) {
self.arena.reset();
}
pub fn allocated_bytes(&self) -> usize {
self.arena.allocated_bytes()
}
pub fn capacity(&self) -> usize {
self.capacity
}
}
impl Default for GenerationArena {
fn default() -> Self {
Self::new(64 * 1024)
}
}
thread_local! {
static GENERATION_ARENA: RefCell<GenerationArena> = RefCell::new(GenerationArena::default());
}
pub fn with_generation_arena<F, R>(f: F) -> R
where
F: FnOnce(&GenerationArena) -> R,
{
GENERATION_ARENA.with(|arena| {
let arena_ref = arena.borrow();
let result = f(&arena_ref);
drop(arena_ref);
arena.borrow_mut().reset();
result
})
}
pub fn with_generation_arena_mut<F, R>(f: F) -> R
where
F: FnOnce(&mut GenerationArena) -> R,
{
GENERATION_ARENA.with(|arena| {
let mut arena_ref = arena.borrow_mut();
let result = f(&mut arena_ref);
arena_ref.reset();
result
})
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ArenaTokenCandidate {
pub id: TokenId,
pub logit: f32,
pub probability: f32,
}
pub struct ArenaCandidates<'a> {
candidates: &'a mut [ArenaTokenCandidate],
len: usize,
sorted: bool,
}
impl<'a> ArenaCandidates<'a> {
pub fn new(arena: &'a GenerationArena, capacity: usize) -> Self {
let candidates = arena.alloc_slice::<ArenaTokenCandidate>(capacity);
Self {
candidates,
len: 0,
sorted: false,
}
}
#[inline]
pub fn push(&mut self, id: TokenId, logit: f32) {
if self.len < self.candidates.len() {
self.candidates[self.len] = ArenaTokenCandidate {
id,
logit,
probability: 0.0,
};
self.len += 1;
self.sorted = false;
}
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_slice(&self) -> &[ArenaTokenCandidate] {
&self.candidates[..self.len]
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [ArenaTokenCandidate] {
&mut self.candidates[..self.len]
}
pub fn sort_by_logit(&mut self) {
if !self.sorted {
self.candidates[..self.len].sort_by(|a, b| {
b.logit
.partial_cmp(&a.logit)
.unwrap_or(std::cmp::Ordering::Equal)
});
self.sorted = true;
}
}
#[inline]
pub fn top(&self) -> Option<&ArenaTokenCandidate> {
if self.len > 0 {
Some(&self.candidates[0])
} else {
None
}
}
pub fn apply_softmax(&mut self) {
if self.len == 0 {
return;
}
let max_logit = self.candidates[..self.len]
.iter()
.map(|c| c.logit)
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for c in &mut self.candidates[..self.len] {
c.probability = (c.logit - max_logit).exp();
sum += c.probability;
}
if sum > 0.0 {
for c in &mut self.candidates[..self.len] {
c.probability /= sum;
}
}
}
pub fn sample(&self, random: f32) -> Option<TokenId> {
if self.len == 0 {
return None;
}
let mut cumulative = 0.0f32;
for c in &self.candidates[..self.len] {
cumulative += c.probability;
if random < cumulative {
return Some(c.id);
}
}
Some(self.candidates[self.len - 1].id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arena_allocation() {
let arena = GenerationArena::new(4096);
let tokens = arena.alloc_tokens(100);
assert_eq!(tokens.len(), 100);
let floats = arena.alloc_f32(50);
assert_eq!(floats.len(), 50);
assert!(arena.allocated_bytes() > 0);
}
#[test]
fn test_arena_reset() {
let mut arena = GenerationArena::new(4096);
let _ = arena.alloc_tokens(100);
let bytes_before = arena.allocated_bytes();
arena.reset();
let _ = arena.alloc_tokens(100);
assert!(arena.allocated_bytes() <= bytes_before + 100);
}
#[test]
fn test_thread_local_arena() {
let result = with_generation_arena(|arena| {
let buffer = arena.alloc_f32(100);
buffer[0] = 42.0;
buffer[0]
});
assert_eq!(result, 42.0);
}
#[test]
fn test_arena_candidates() {
let arena = GenerationArena::new(4096);
let mut candidates = ArenaCandidates::new(&arena, 10);
candidates.push(1, 1.0);
candidates.push(2, 2.0);
candidates.push(3, 0.5);
assert_eq!(candidates.len(), 3);
candidates.sort_by_logit();
assert_eq!(candidates.as_slice()[0].id, 2);
candidates.apply_softmax();
let sum: f32 = candidates.as_slice().iter().map(|c| c.probability).sum();
assert!((sum - 1.0).abs() < 0.001); }
}