llama_cpp_2/token/
data_array.rs

1//! an rusty equivalent of `llama_token_data_array`.
2use std::ptr;
3
4use crate::{sampling::LlamaSampler, token::data::LlamaTokenData};
5
6use super::LlamaToken;
7
8/// a safe wrapper around `llama_token_data_array`.
9#[derive(Debug, Clone, PartialEq)]
10#[allow(clippy::module_name_repetitions)]
11pub struct LlamaTokenDataArray {
12    /// the underlying data
13    pub data: Vec<LlamaTokenData>,
14    /// the index of the selected token in ``data``
15    pub selected: Option<usize>,
16    /// is the data sorted?
17    pub sorted: bool,
18}
19
20impl LlamaTokenDataArray {
21    /// Create a new `LlamaTokenDataArray` from a vector and whether or not the data is sorted.
22    ///
23    /// ```
24    /// # use llama_cpp_2::token::data::LlamaTokenData;
25    /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
26    /// # use llama_cpp_2::token::LlamaToken;
27    /// let array = LlamaTokenDataArray::new(vec![
28    ///         LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
29    ///         LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
30    ///    ], false);
31    /// assert_eq!(array.data.len(), 2);
32    /// assert_eq!(array.sorted, false);
33    /// ```
34    #[must_use]
35    pub fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
36        Self {
37            data,
38            selected: None,
39            sorted,
40        }
41    }
42
43    /// Create a new `LlamaTokenDataArray` from an iterator and whether or not the data is sorted.
44    /// ```
45    /// # use llama_cpp_2::token::data::LlamaTokenData;
46    /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
47    /// # use llama_cpp_2::token::LlamaToken;
48    /// let array = LlamaTokenDataArray::from_iter([
49    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
50    ///     LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
51    /// ], false);
52    /// assert_eq!(array.data.len(), 2);
53    /// assert_eq!(array.sorted, false);
54    pub fn from_iter<T>(data: T, sorted: bool) -> LlamaTokenDataArray
55    where
56        T: IntoIterator<Item = LlamaTokenData>,
57    {
58        Self::new(data.into_iter().collect(), sorted)
59    }
60
61    /// Returns the current selected token, if one exists.
62    #[must_use]
63    pub fn selected_token(&self) -> Option<LlamaToken> {
64        self.data.get(self.selected?).map(LlamaTokenData::id)
65    }
66}
67
68impl LlamaTokenDataArray {
69    /// Modify the underlying data as a `llama_token_data_array`. and reconstruct the `LlamaTokenDataArray`.
70    ///
71    /// # Panics
72    ///
73    /// Panics if some of the safety conditions are not met. (we cannot check all of them at
74    /// runtime so breaking them is UB)
75    ///
76    /// SAFETY:
77    /// The returned array formed by the data pointer and the length must entirely consist of
78    /// initialized token data and the length must be less than the capacity of this array's data
79    /// buffer.
80    /// if the data is not sorted, sorted must be false.
81    pub(crate) unsafe fn modify_as_c_llama_token_data_array<T>(
82        &mut self,
83        modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T,
84    ) -> T {
85        let size = self.data.len();
86        let data = self
87            .data
88            .as_mut_ptr()
89            .cast::<llama_cpp_sys_2::llama_token_data>();
90
91        let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array {
92            data,
93            size,
94            selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1),
95            sorted: self.sorted,
96        };
97
98        let result = modify(&mut c_llama_token_data_array);
99
100        assert!(
101            c_llama_token_data_array.size <= self.data.capacity(),
102            "Size of the returned array exceeds the data buffer's capacity!"
103        );
104        if !ptr::eq(c_llama_token_data_array.data, data) {
105            ptr::copy(
106                c_llama_token_data_array.data,
107                data,
108                c_llama_token_data_array.size,
109            );
110        }
111        self.data.set_len(c_llama_token_data_array.size);
112
113        self.sorted = c_llama_token_data_array.sorted;
114        self.selected = c_llama_token_data_array
115            .selected
116            .try_into()
117            .ok()
118            .filter(|&s| s < self.data.len());
119
120        result
121    }
122
123    /// Modifies the data array by applying a sampler to it
124    pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
125        unsafe {
126            self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
127                llama_cpp_sys_2::llama_sampler_apply(sampler.sampler, c_llama_token_data_array);
128            });
129        }
130    }
131
132    /// Modifies the data array by applying a sampler to it
133    #[must_use]
134    pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
135        self.apply_sampler(sampler);
136        self
137    }
138
139    /// Randomly selects a token from the candidates based on their probabilities.
140    ///
141    /// # Panics
142    /// If the internal llama.cpp sampler fails to select a token.
143    pub fn sample_token(&mut self, seed: u32) -> LlamaToken {
144        self.apply_sampler(&LlamaSampler::dist(seed));
145        self.selected_token()
146            .expect("Dist sampler failed to select a token!")
147    }
148
149    /// Selects the token with the highest probability.
150    ///
151    /// # Panics
152    /// If the internal llama.cpp sampler fails to select a token.
153    pub fn sample_token_greedy(&mut self) -> LlamaToken {
154        self.apply_sampler(&LlamaSampler::greedy());
155        self.selected_token()
156            .expect("Greedy sampler failed to select a token!")
157    }
158}