use std::ptr;
use llama_cpp_sys_4::{llama_sampler_apply, llama_token_data, llama_token_data_array};
use crate::{sampling::LlamaSampler, token::data::LlamaTokenData};
use super::LlamaToken;
#[derive(Debug, Clone, PartialEq)]
#[allow(clippy::module_name_repetitions)]
pub struct LlamaTokenDataArray {
pub data: Vec<LlamaTokenData>,
pub selected: Option<usize>,
pub sorted: bool,
}
impl LlamaTokenDataArray {
#[must_use]
pub fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
Self {
data,
selected: None,
sorted,
}
}
pub fn from_iter<T>(data: T, sorted: bool) -> LlamaTokenDataArray
where
T: IntoIterator<Item = LlamaTokenData>,
{
Self::new(data.into_iter().collect(), sorted)
}
#[must_use]
pub fn selected_token(&self) -> Option<LlamaToken> {
self.data.get(self.selected?).map(LlamaTokenData::id)
}
}
impl LlamaTokenDataArray {
pub(crate) unsafe fn modify_as_c_llama_token_data_array<T>(
&mut self,
modify: impl FnOnce(&mut llama_token_data_array) -> T,
) -> T {
let size = self.data.len();
let data = self.data.as_mut_ptr().cast::<llama_token_data>();
let mut c_llama_token_data_array = llama_token_data_array {
data,
size,
selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1),
sorted: self.sorted,
};
let result = modify(&mut c_llama_token_data_array);
assert!(
c_llama_token_data_array.size <= self.data.capacity(),
"Size of the returned array exceeds the data buffer's capacity!"
);
if !ptr::eq(c_llama_token_data_array.data, data) {
ptr::copy(
c_llama_token_data_array.data,
data,
c_llama_token_data_array.size,
);
}
self.data.set_len(c_llama_token_data_array.size);
self.sorted = c_llama_token_data_array.sorted;
self.selected = c_llama_token_data_array
.selected
.try_into()
.ok()
.filter(|&s| s < self.data.len());
result
}
pub fn apply_sampler(&mut self, sampler: &mut LlamaSampler) {
unsafe {
self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_sampler_apply(sampler.sampler.as_ptr(), c_llama_token_data_array);
});
}
}
#[must_use]
pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
self.apply_sampler(sampler);
self
}
pub fn sample_token(&mut self, seed: u32) -> LlamaToken {
self.apply_sampler(&mut LlamaSampler::dist(seed));
self.selected_token()
.expect("Dist sampler failed to select a token!")
}
pub fn sample_token_greedy(&mut self) -> LlamaToken {
self.apply_sampler(&mut LlamaSampler::greedy());
self.selected_token()
.expect("Greedy sampler failed to select a token!")
}
}