1use std::{
2 fmt::Debug,
3 ops::{Deref, DerefMut},
4 sync::{Arc, Mutex},
5};
6
7use anyhow::Result;
8use thiserror::Error;
9
10pub use crate::{chain::*, resource::*};
11
12pub type TID = u32;
14
15pub type L = f32;
17
18#[derive(Debug, Error)]
19pub enum SamplerError {
21 #[error("internal error: {0}")]
22 InternalError(String),
24
25 #[error("missing resource error: {0}")]
26 MissingResource(String),
28
29 #[error("logits error: {0}")]
30 LogitsError(LogitsError),
32
33 #[error("rand error: {0}")]
34 RandError(rand::Error),
36
37 #[error("rand weights error: {0}")]
38 RandWeightedError(rand::distributions::WeightedError),
40}
41
42#[derive(Debug, Clone, Error)]
43pub enum LogitsError {
45 #[error("Invalid logit for token id {0}")]
46 InvalidLogit(usize),
49 #[error("internal logits error: {0}")]
50 InternalError(String),
52}
53
54impl From<LogitsError> for SamplerError {
55 fn from(value: LogitsError) -> Self {
56 SamplerError::LogitsError(value)
57 }
58}
59
60#[derive(Debug, Clone, PartialEq)]
61pub struct Logit {
63 pub token_id: TID,
65 pub logit: L,
67 pub prob: L,
69}
70
71#[derive(Debug, Clone, Default)]
72pub struct Logits {
77 sorted: bool,
78 has_softmax: bool,
79 logits: Vec<Logit>,
80}
81
82impl Deref for Logits {
83 type Target = Vec<Logit>;
84
85 fn deref(&self) -> &Self::Target {
86 &self.logits
87 }
88}
89
90impl DerefMut for Logits {
91 fn deref_mut(&mut self) -> &mut Self::Target {
92 &mut self.logits
93 }
94}
95
96impl Logits {
97 pub fn try_from_iter<I: IntoIterator<Item = L>>(it: I) -> Result<Self, LogitsError> {
100 let mut tid = 0;
101 Ok(Self {
102 sorted: false,
103 has_softmax: false,
104 logits: it
105 .into_iter()
106 .enumerate()
107 .map(|(idx, logit)| {
108 if logit.is_nan() {
109 Err(LogitsError::InvalidLogit(idx))?
110 }
111 let result = Logit {
112 token_id: tid,
113 logit,
114 prob: 0f32,
115 };
116 tid += 1;
117 Ok(result)
118 })
119 .collect::<Result<Vec<_>, LogitsError>>()?,
120 })
121 }
122
123 pub fn try_from_iter_top_k<I: IntoIterator<Item = L>>(
131 it: I,
132 k: usize,
133 ) -> Result<Self, LogitsError> {
134 if k == 0 {
135 return Ok(Self::default());
136 }
137
138 Ok(Logits {
139 sorted: true,
140 has_softmax: false,
141 logits: (0u32..)
142 .zip(it)
143 .filter(|(_tid, logit)| logit.is_finite())
144 .fold(Vec::with_capacity(k), |mut logits, (tid, logit)| {
145 if logits.len() == k {
146 if logit > unsafe { logits.last().unwrap_unchecked().logit } {
148 logits.truncate(k - 1);
149 } else {
150 return logits;
151 }
152 }
153 logits.insert(
154 logits.partition_point(|l| logit < l.logit),
155 Logit {
156 token_id: tid,
157 logit,
158 prob: 0f32,
159 },
160 );
161 logits
162 }),
163 })
164 }
165}
166
167impl TryFrom<Vec<L>> for Logits {
168 type Error = LogitsError;
169
170 fn try_from(value: Vec<L>) -> Result<Self, Self::Error> {
171 Self::try_from_iter(value)
172 }
173}
174
175impl Logits {
176 pub fn get_sorted(&self) -> bool {
178 self.sorted
179 }
180
181 pub fn set_sorted(&mut self, is_sorted: bool) -> &mut Self {
183 self.sorted = is_sorted;
184 self
185 }
186
187 pub fn get_softmax(&self) -> bool {
189 self.has_softmax
190 }
191
192 pub fn set_softmax(&mut self, has_softmax: bool) -> &mut Self {
194 self.has_softmax = has_softmax;
195 self
196 }
197
198 pub fn ensure_sorted(&mut self) -> Result<&mut Self> {
200 if self.get_sorted() {
201 return Ok(self);
202 }
203
204 let mut sort_err = Ok(());
205 self.logits.as_mut_slice().sort_by(|a, b| {
206 b.logit.partial_cmp(&a.logit).unwrap_or_else(|| {
207 sort_err = Err(LogitsError::InternalError(String::from(
208 "Impossible: logit comparison failed?",
209 )));
210 std::cmp::Ordering::Less
211 })
212 });
213 sort_err?;
214 self.set_sorted(true);
215 Ok(self)
216 }
217
218 pub fn ensure_softmax(&mut self) -> Result<&mut Self> {
220 if self.is_empty() || self.has_softmax {
221 self.has_softmax = true;
222 self.sorted = true;
223 return Ok(self);
224 }
225 self.ensure_sorted()?;
226 let max_l = self[0].logit;
227 let cum_sum = self.iter_mut().fold(0f32, |cs, l| {
228 l.prob = (l.logit - max_l).exp();
229 cs + l.prob
230 });
231 self.iter_mut().for_each(|l| l.prob /= cum_sum);
232 self.has_softmax = true;
233 Ok(self)
234 }
235
236 pub fn sample<S: Sampler>(
238 &mut self,
239 res: &mut dyn HasSamplerResources,
240 sampler: &mut S,
241 ) -> Result<&mut Self> {
242 sampler.sample(res, self)
243 }
244
245 pub fn sample_token<S: Sampler>(
247 &mut self,
248 res: &mut dyn HasSamplerResources,
249 sampler: &mut S,
250 ) -> Result<Option<TID>> {
251 sampler.sample_token(res, self)
252 }
253}
254
255pub trait Sampler: Debug + Send + Sync {
257 fn sample<'a>(
259 &mut self,
260 res: &mut dyn HasSamplerResources,
261 logits: &'a mut Logits,
262 ) -> Result<&'a mut Logits>;
263
264 fn sampled_token_id(&self) -> Option<TID> {
268 None
269 }
270
271 fn sample_token(
276 &mut self,
277 res: &mut dyn HasSamplerResources,
278 logits: &mut Logits,
279 ) -> Result<Option<TID>> {
280 let _ = self.sample(res, logits)?;
281 Ok(self.sampled_token_id())
282 }
283}
284
285impl Sampler for Box<dyn Sampler> {
286 fn sampled_token_id(&self) -> Option<TID> {
287 (**self).sampled_token_id()
288 }
289
290 fn sample_token(
291 &mut self,
292 res: &mut dyn HasSamplerResources,
293 logits: &mut Logits,
294 ) -> Result<Option<TID>> {
295 (**self).sample_token(res, logits)
296 }
297
298 fn sample<'a>(
299 &mut self,
300 res: &mut dyn HasSamplerResources,
301 logits: &'a mut Logits,
302 ) -> Result<&'a mut Logits> {
303 (**self).sample(res, logits)
304 }
305}
306
307impl Sampler for Arc<Mutex<dyn Sampler>> {
308 fn sampled_token_id(&self) -> Option<TID> {
309 self.lock().ok()?.sampled_token_id()
310 }
311
312 fn sample_token(
313 &mut self,
314 res: &mut dyn HasSamplerResources,
315 logits: &mut Logits,
316 ) -> Result<Option<TID>> {
317 self.lock()
318 .map_err(|e| SamplerError::InternalError(format!("Couldn't acquire lock: {e}")))?
319 .sample_token(res, logits)
320 }
321
322 fn sample<'a>(
323 &mut self,
324 res: &mut dyn HasSamplerResources,
325 logits: &'a mut Logits,
326 ) -> Result<&'a mut Logits> {
327 self.lock()
328 .map_err(|e| SamplerError::InternalError(format!("Couldn't acquire lock: {e}")))?
329 .sample(res, logits)
330 }
331}