Skip to main content

llama_cpp_4/token/
data_array.rs

1//! an rusty equivalent of `llama_token_data_array`.
2use std::ptr;
3
4use llama_cpp_sys_4::{llama_sampler_apply, llama_token_data, llama_token_data_array};
5
6use crate::{sampling::LlamaSampler, token::data::LlamaTokenData};
7
8use super::LlamaToken;
9
10/// a safe wrapper around `llama_token_data_array`.
11#[derive(Debug, Clone, PartialEq)]
12#[allow(clippy::module_name_repetitions)]
13pub struct LlamaTokenDataArray {
14    /// the underlying data
15    pub data: Vec<LlamaTokenData>,
16    /// the index of the selected token in ``data``
17    pub selected: Option<usize>,
18    /// is the data sorted?
19    pub sorted: bool,
20}
21
22impl LlamaTokenDataArray {
23    /// Create a new `LlamaTokenDataArray` from a vector and whether or not the data is sorted.
24    ///
25    /// ```
26    /// # use llama_cpp_2::token::data::LlamaTokenData;
27    /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
28    /// # use llama_cpp_2::token::LlamaToken;
29    /// let array = LlamaTokenDataArray::new(vec![
30    ///         LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
31    ///         LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
32    ///    ], false);
33    /// assert_eq!(array.data.len(), 2);
34    /// assert_eq!(array.sorted, false);
35    /// ```
36    #[must_use]
37    pub fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
38        Self {
39            data,
40            selected: None,
41            sorted,
42        }
43    }
44
45    /// Create a new `LlamaTokenDataArray` from an iterator and whether or not the data is sorted.
46    /// ```
47    /// # use llama_cpp_2::token::data::LlamaTokenData;
48    /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
49    /// # use llama_cpp_2::token::LlamaToken;
50    /// let array = LlamaTokenDataArray::from_iter([
51    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
52    ///     LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
53    /// ], false);
54    /// assert_eq!(array.data.len(), 2);
55    /// assert_eq!(array.sorted, false);
56    pub fn from_iter<T>(data: T, sorted: bool) -> LlamaTokenDataArray
57    where
58        T: IntoIterator<Item = LlamaTokenData>,
59    {
60        Self::new(data.into_iter().collect(), sorted)
61    }
62
63    /// Returns the current selected token, if one exists.
64    #[must_use]
65    pub fn selected_token(&self) -> Option<LlamaToken> {
66        self.data.get(self.selected?).map(LlamaTokenData::id)
67    }
68}
69
70impl LlamaTokenDataArray {
71    /// Modify the underlying data as a `llama_token_data_array`. and reconstruct the `LlamaTokenDataArray`.
72    ///
73    /// # Panics
74    ///
75    /// Panics if some of the safety conditions are not met. (we cannot check all of them at
76    /// runtime so breaking them is UB)
77    ///
78    /// SAFETY:
79    /// The returned array formed by the data pointer and the length must entirely consist of
80    /// initialized token data and the length must be less than the capacity of this array's data
81    /// buffer.
82    /// if the data is not sorted, sorted must be false.
83    pub(crate) unsafe fn modify_as_c_llama_token_data_array<T>(
84        &mut self,
85        modify: impl FnOnce(&mut llama_token_data_array) -> T,
86    ) -> T {
87        let size = self.data.len();
88        let data = self.data.as_mut_ptr().cast::<llama_token_data>();
89
90        let mut c_llama_token_data_array = llama_token_data_array {
91            data,
92            size,
93            selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1),
94            sorted: self.sorted,
95        };
96
97        let result = modify(&mut c_llama_token_data_array);
98
99        assert!(
100            c_llama_token_data_array.size <= self.data.capacity(),
101            "Size of the returned array exceeds the data buffer's capacity!"
102        );
103        if !ptr::eq(c_llama_token_data_array.data, data) {
104            ptr::copy(
105                c_llama_token_data_array.data,
106                data,
107                c_llama_token_data_array.size,
108            );
109        }
110        self.data.set_len(c_llama_token_data_array.size);
111
112        self.sorted = c_llama_token_data_array.sorted;
113        self.selected = c_llama_token_data_array
114            .selected
115            .try_into()
116            .ok()
117            .filter(|&s| s < self.data.len());
118
119        result
120    }
121
122    /// Modifies the data array by applying a sampler to it
123    pub fn apply_sampler(&mut self, sampler: &mut LlamaSampler) {
124        unsafe {
125            self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
126                llama_sampler_apply(sampler.sampler.as_ptr(), c_llama_token_data_array);
127            });
128        }
129    }
130
131    /// Modifies the data array by applying a sampler to it
132    #[must_use]
133    pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
134        self.apply_sampler(sampler);
135        self
136    }
137
138    /// Randomly selects a token from the candidates based on their probabilities.
139    ///
140    /// # Panics
141    /// If the internal llama.cpp sampler fails to select a token.
142    pub fn sample_token(&mut self, seed: u32) -> LlamaToken {
143        self.apply_sampler(&mut LlamaSampler::dist(seed));
144        self.selected_token()
145            .expect("Dist sampler failed to select a token!")
146    }
147
148    /// Selects the token with the highest probability.
149    ///
150    /// # Panics
151    /// If the internal llama.cpp sampler fails to select a token.
152    pub fn sample_token_greedy(&mut self) -> LlamaToken {
153        self.apply_sampler(&mut LlamaSampler::greedy());
154        self.selected_token()
155            .expect("Greedy sampler failed to select a token!")
156    }
157}