llama_cpp_2/token/data_array.rs
1//! an rusty equivalent of `llama_token_data_array`.
2use std::ptr;
3
4use crate::{sampling::LlamaSampler, token::data::LlamaTokenData};
5
6use super::LlamaToken;
7
8/// a safe wrapper around `llama_token_data_array`.
9#[derive(Debug, Clone, PartialEq)]
10#[allow(clippy::module_name_repetitions)]
11pub struct LlamaTokenDataArray {
12 /// the underlying data
13 pub data: Vec<LlamaTokenData>,
14 /// the index of the selected token in ``data``
15 pub selected: Option<usize>,
16 /// is the data sorted?
17 pub sorted: bool,
18}
19
20impl LlamaTokenDataArray {
21 /// Create a new `LlamaTokenDataArray` from a vector and whether or not the data is sorted.
22 ///
23 /// ```
24 /// # use llama_cpp_2::token::data::LlamaTokenData;
25 /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
26 /// # use llama_cpp_2::token::LlamaToken;
27 /// let array = LlamaTokenDataArray::new(vec![
28 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
29 /// LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
30 /// ], false);
31 /// assert_eq!(array.data.len(), 2);
32 /// assert_eq!(array.sorted, false);
33 /// ```
34 #[must_use]
35 pub fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
36 Self {
37 data,
38 selected: None,
39 sorted,
40 }
41 }
42
43 /// Create a new `LlamaTokenDataArray` from an iterator and whether or not the data is sorted.
44 /// ```
45 /// # use llama_cpp_2::token::data::LlamaTokenData;
46 /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
47 /// # use llama_cpp_2::token::LlamaToken;
48 /// let array = LlamaTokenDataArray::from_iter([
49 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
50 /// LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
51 /// ], false);
52 /// assert_eq!(array.data.len(), 2);
53 /// assert_eq!(array.sorted, false);
54 pub fn from_iter<T>(data: T, sorted: bool) -> LlamaTokenDataArray
55 where
56 T: IntoIterator<Item = LlamaTokenData>,
57 {
58 Self::new(data.into_iter().collect(), sorted)
59 }
60
61 /// Returns the current selected token, if one exists.
62 #[must_use]
63 pub fn selected_token(&self) -> Option<LlamaToken> {
64 self.data.get(self.selected?).map(LlamaTokenData::id)
65 }
66}
67
68impl LlamaTokenDataArray {
69 /// Modify the underlying data as a `llama_token_data_array`. and reconstruct the `LlamaTokenDataArray`.
70 ///
71 /// # Panics
72 ///
73 /// Panics if some of the safety conditions are not met. (we cannot check all of them at
74 /// runtime so breaking them is UB)
75 ///
76 /// SAFETY:
77 /// The returned array formed by the data pointer and the length must entirely consist of
78 /// initialized token data and the length must be less than the capacity of this array's data
79 /// buffer.
80 /// if the data is not sorted, sorted must be false.
81 pub(crate) unsafe fn modify_as_c_llama_token_data_array<T>(
82 &mut self,
83 modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T,
84 ) -> T {
85 let size = self.data.len();
86 let data = self
87 .data
88 .as_mut_ptr()
89 .cast::<llama_cpp_sys_2::llama_token_data>();
90
91 let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array {
92 data,
93 size,
94 selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1),
95 sorted: self.sorted,
96 };
97
98 let result = modify(&mut c_llama_token_data_array);
99
100 assert!(
101 c_llama_token_data_array.size <= self.data.capacity(),
102 "Size of the returned array exceeds the data buffer's capacity!"
103 );
104 if !ptr::eq(c_llama_token_data_array.data, data) {
105 ptr::copy(
106 c_llama_token_data_array.data,
107 data,
108 c_llama_token_data_array.size,
109 );
110 }
111 self.data.set_len(c_llama_token_data_array.size);
112
113 self.sorted = c_llama_token_data_array.sorted;
114 self.selected = c_llama_token_data_array
115 .selected
116 .try_into()
117 .ok()
118 .filter(|&s| s < self.data.len());
119
120 result
121 }
122
123 /// Modifies the data array by applying a sampler to it
124 pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
125 unsafe {
126 self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
127 llama_cpp_sys_2::llama_sampler_apply(sampler.sampler, c_llama_token_data_array);
128 });
129 }
130 }
131
132 /// Modifies the data array by applying a sampler to it
133 #[must_use]
134 pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
135 self.apply_sampler(sampler);
136 self
137 }
138
139 /// Randomly selects a token from the candidates based on their probabilities.
140 ///
141 /// # Panics
142 /// If the internal llama.cpp sampler fails to select a token.
143 pub fn sample_token(&mut self, seed: u32) -> LlamaToken {
144 self.apply_sampler(&LlamaSampler::dist(seed));
145 self.selected_token()
146 .expect("Dist sampler failed to select a token!")
147 }
148
149 /// Selects the token with the highest probability.
150 ///
151 /// # Panics
152 /// If the internal llama.cpp sampler fails to select a token.
153 pub fn sample_token_greedy(&mut self) -> LlamaToken {
154 self.apply_sampler(&LlamaSampler::greedy());
155 self.selected_token()
156 .expect("Greedy sampler failed to select a token!")
157 }
158}