llama_cpp_4/token/data_array.rs
1//! an rusty equivalent of `llama_token_data_array`.
2use std::ptr;
3
4use llama_cpp_sys_4::{llama_sampler_apply, llama_token_data, llama_token_data_array};
5
6use crate::{sampling::LlamaSampler, token::data::LlamaTokenData};
7
8use super::LlamaToken;
9
10/// a safe wrapper around `llama_token_data_array`.
11#[derive(Debug, Clone, PartialEq)]
12#[allow(clippy::module_name_repetitions)]
13pub struct LlamaTokenDataArray {
14 /// the underlying data
15 pub data: Vec<LlamaTokenData>,
16 /// the index of the selected token in ``data``
17 pub selected: Option<usize>,
18 /// is the data sorted?
19 pub sorted: bool,
20}
21
22impl LlamaTokenDataArray {
23 /// Create a new `LlamaTokenDataArray` from a vector and whether or not the data is sorted.
24 ///
25 /// ```
26 /// # use llama_cpp_2::token::data::LlamaTokenData;
27 /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
28 /// # use llama_cpp_2::token::LlamaToken;
29 /// let array = LlamaTokenDataArray::new(vec![
30 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
31 /// LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
32 /// ], false);
33 /// assert_eq!(array.data.len(), 2);
34 /// assert_eq!(array.sorted, false);
35 /// ```
36 #[must_use]
37 pub fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
38 Self {
39 data,
40 selected: None,
41 sorted,
42 }
43 }
44
45 /// Create a new `LlamaTokenDataArray` from an iterator and whether or not the data is sorted.
46 /// ```
47 /// # use llama_cpp_2::token::data::LlamaTokenData;
48 /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
49 /// # use llama_cpp_2::token::LlamaToken;
50 /// let array = LlamaTokenDataArray::from_iter([
51 /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
52 /// LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
53 /// ], false);
54 /// assert_eq!(array.data.len(), 2);
55 /// assert_eq!(array.sorted, false);
56 pub fn from_iter<T>(data: T, sorted: bool) -> LlamaTokenDataArray
57 where
58 T: IntoIterator<Item = LlamaTokenData>,
59 {
60 Self::new(data.into_iter().collect(), sorted)
61 }
62
63 /// Returns the current selected token, if one exists.
64 #[must_use]
65 pub fn selected_token(&self) -> Option<LlamaToken> {
66 self.data.get(self.selected?).map(LlamaTokenData::id)
67 }
68}
69
70impl LlamaTokenDataArray {
71 /// Modify the underlying data as a `llama_token_data_array`. and reconstruct the `LlamaTokenDataArray`.
72 ///
73 /// # Panics
74 ///
75 /// Panics if some of the safety conditions are not met. (we cannot check all of them at
76 /// runtime so breaking them is UB)
77 ///
78 /// SAFETY:
79 /// The returned array formed by the data pointer and the length must entirely consist of
80 /// initialized token data and the length must be less than the capacity of this array's data
81 /// buffer.
82 /// if the data is not sorted, sorted must be false.
83 pub(crate) unsafe fn modify_as_c_llama_token_data_array<T>(
84 &mut self,
85 modify: impl FnOnce(&mut llama_token_data_array) -> T,
86 ) -> T {
87 let size = self.data.len();
88 let data = self.data.as_mut_ptr().cast::<llama_token_data>();
89
90 let mut c_llama_token_data_array = llama_token_data_array {
91 data,
92 size,
93 selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1),
94 sorted: self.sorted,
95 };
96
97 let result = modify(&mut c_llama_token_data_array);
98
99 assert!(
100 c_llama_token_data_array.size <= self.data.capacity(),
101 "Size of the returned array exceeds the data buffer's capacity!"
102 );
103 if !ptr::eq(c_llama_token_data_array.data, data) {
104 ptr::copy(
105 c_llama_token_data_array.data,
106 data,
107 c_llama_token_data_array.size,
108 );
109 }
110 self.data.set_len(c_llama_token_data_array.size);
111
112 self.sorted = c_llama_token_data_array.sorted;
113 self.selected = c_llama_token_data_array
114 .selected
115 .try_into()
116 .ok()
117 .filter(|&s| s < self.data.len());
118
119 result
120 }
121
122 /// Modifies the data array by applying a sampler to it
123 pub fn apply_sampler(&mut self, sampler: &mut LlamaSampler) {
124 unsafe {
125 self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
126 llama_sampler_apply(sampler.sampler.as_ptr(), c_llama_token_data_array);
127 });
128 }
129 }
130
131 /// Modifies the data array by applying a sampler to it
132 #[must_use]
133 pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
134 self.apply_sampler(sampler);
135 self
136 }
137
138 /// Randomly selects a token from the candidates based on their probabilities.
139 ///
140 /// # Panics
141 /// If the internal llama.cpp sampler fails to select a token.
142 pub fn sample_token(&mut self, seed: u32) -> LlamaToken {
143 self.apply_sampler(&mut LlamaSampler::dist(seed));
144 self.selected_token()
145 .expect("Dist sampler failed to select a token!")
146 }
147
148 /// Selects the token with the highest probability.
149 ///
150 /// # Panics
151 /// If the internal llama.cpp sampler fails to select a token.
152 pub fn sample_token_greedy(&mut self) -> LlamaToken {
153 self.apply_sampler(&mut LlamaSampler::greedy());
154 self.selected_token()
155 .expect("Greedy sampler failed to select a token!")
156 }
157}