use std::ptr;
use crate::error::TokenSamplingError;
use crate::sampling::LlamaSampler;
use crate::token::data::LlamaTokenData;
use super::LlamaToken;
#[derive(Debug, Clone, PartialEq)]
pub struct LlamaTokenDataArray {
pub data: Vec<LlamaTokenData>,
pub selected: Option<usize>,
pub sorted: bool,
}
impl LlamaTokenDataArray {
#[must_use]
pub const fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
Self {
data,
selected: None,
sorted,
}
}
pub fn from_iter<TIterator>(data: TIterator, sorted: bool) -> Self
where
TIterator: IntoIterator<Item = LlamaTokenData>,
{
Self::new(data.into_iter().collect(), sorted)
}
#[must_use]
pub fn selected_token(&self) -> Option<LlamaToken> {
self.data.get(self.selected?).map(LlamaTokenData::id)
}
}
impl LlamaTokenDataArray {
pub unsafe fn modify_as_c_llama_token_data_array<TResult>(
&mut self,
modify: impl FnOnce(&mut llama_cpp_bindings_sys::llama_token_data_array) -> TResult,
) -> TResult {
let size = self.data.len();
let data = self
.data
.as_mut_ptr()
.cast::<llama_cpp_bindings_sys::llama_token_data>();
let mut c_llama_token_data_array = llama_cpp_bindings_sys::llama_token_data_array {
data,
size,
selected: self
.selected
.and_then(|selected_index| selected_index.try_into().ok())
.unwrap_or(-1),
sorted: self.sorted,
};
let result = modify(&mut c_llama_token_data_array);
assert!(c_llama_token_data_array.size <= self.data.capacity());
unsafe {
if !ptr::eq(c_llama_token_data_array.data, data) {
ptr::copy(
c_llama_token_data_array.data,
data,
c_llama_token_data_array.size,
);
}
self.data.set_len(c_llama_token_data_array.size);
}
self.sorted = c_llama_token_data_array.sorted;
self.selected = c_llama_token_data_array
.selected
.try_into()
.ok()
.filter(|&s| s < self.data.len());
result
}
pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
unsafe {
self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_bindings_sys::llama_sampler_apply(
sampler.sampler,
c_llama_token_data_array,
);
});
}
}
#[must_use]
pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
self.apply_sampler(sampler);
self
}
pub fn sample_token(&mut self, seed: u32) -> Result<LlamaToken, TokenSamplingError> {
self.apply_sampler(&LlamaSampler::dist(seed));
self.selected_token()
.ok_or(TokenSamplingError::NoTokenSelected)
}
pub fn sample_token_greedy(&mut self) -> Result<LlamaToken, TokenSamplingError> {
self.apply_sampler(&LlamaSampler::greedy());
self.selected_token()
.ok_or(TokenSamplingError::NoTokenSelected)
}
}
#[cfg(test)]
mod tests {
use crate::token::LlamaToken;
use crate::token::data::LlamaTokenData;
use super::LlamaTokenDataArray;
#[test]
fn apply_greedy_sampler_selects_highest_logit() {
use crate::sampling::LlamaSampler;
let mut array = LlamaTokenDataArray::new(
vec![
LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
],
false,
);
array.apply_sampler(&LlamaSampler::greedy());
assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
}
#[test]
fn with_sampler_builder_pattern() {
use crate::sampling::LlamaSampler;
let array = LlamaTokenDataArray::new(
vec![
LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
],
false,
)
.with_sampler(&mut LlamaSampler::greedy());
assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
}
#[test]
fn sample_token_greedy_returns_highest() {
let mut array = LlamaTokenDataArray::new(
vec![
LlamaTokenData::new(LlamaToken::new(10), 0.1, 0.0),
LlamaTokenData::new(LlamaToken::new(20), 9.9, 0.0),
],
false,
);
let token = array
.sample_token_greedy()
.expect("test: greedy sampler should select a token");
assert_eq!(token, LlamaToken::new(20));
}
#[test]
fn from_iter_creates_array_from_iterator() {
let array = LlamaTokenDataArray::from_iter(
[
LlamaTokenData::new(LlamaToken::new(0), 0.0, 0.0),
LlamaTokenData::new(LlamaToken::new(1), 1.0, 0.0),
LlamaTokenData::new(LlamaToken::new(2), 2.0, 0.0),
],
false,
);
assert_eq!(array.data.len(), 3);
assert!(!array.sorted);
assert!(array.selected.is_none());
}
#[test]
fn sample_token_with_seed_selects_a_token() {
let mut array = LlamaTokenDataArray::new(
vec![
LlamaTokenData::new(LlamaToken::new(10), 1.0, 0.0),
LlamaTokenData::new(LlamaToken::new(20), 1.0, 0.0),
],
false,
);
let token = array
.sample_token(42)
.expect("test: dist sampler should select a token");
assert!(token == LlamaToken::new(10) || token == LlamaToken::new(20));
}
#[test]
fn selected_token_returns_none_when_no_selection() {
let array = LlamaTokenDataArray::new(
vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
false,
);
assert!(array.selected_token().is_none());
}
#[test]
fn selected_token_returns_none_when_index_out_of_bounds() {
let array = LlamaTokenDataArray {
data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
selected: Some(5),
sorted: false,
};
assert!(array.selected_token().is_none());
}
#[test]
fn modify_as_c_llama_token_data_array_copies_when_data_pointer_changes() {
let mut array = LlamaTokenDataArray::new(
vec![
LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
LlamaTokenData::new(LlamaToken::new(1), 2.0, 0.0),
LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
],
false,
);
let replacement = [
llama_cpp_bindings_sys::llama_token_data {
id: 10,
logit: 5.0,
p: 0.0,
},
llama_cpp_bindings_sys::llama_token_data {
id: 20,
logit: 6.0,
p: 0.0,
},
];
unsafe {
array.modify_as_c_llama_token_data_array(|c_array| {
c_array.data = replacement.as_ptr().cast_mut();
c_array.size = replacement.len();
c_array.selected = 0;
});
}
assert_eq!(array.data.len(), 2);
assert_eq!(array.data[0].id(), LlamaToken::new(10));
assert_eq!(array.data[1].id(), LlamaToken::new(20));
assert_eq!(array.selected, Some(0));
}
#[test]
fn selected_overflow_uses_negative_one() {
let mut array = LlamaTokenDataArray {
data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
selected: Some(usize::MAX),
sorted: false,
};
unsafe {
array.modify_as_c_llama_token_data_array(|c_array| {
assert_eq!(c_array.selected, -1);
});
}
}
}