llm_samplers/
types.rs

1use std::{
2    fmt::Debug,
3    ops::{Deref, DerefMut},
4    sync::{Arc, Mutex},
5};
6
7use anyhow::Result;
8use thiserror::Error;
9
10pub use crate::{chain::*, resource::*};
11
12/// Type for token IDs.
13pub type TID = u32;
14
15/// Type for logits.
16pub type L = f32;
17
18#[derive(Debug, Error)]
19/// Sampler errors
20pub enum SamplerError {
21    #[error("internal error: {0}")]
22    /// General internal error type.
23    InternalError(String),
24
25    #[error("missing resource error: {0}")]
26    /// Missing resource error type.
27    MissingResource(String),
28
29    #[error("logits error: {0}")]
30    /// Container for errors that occured while processing logits.
31    LogitsError(LogitsError),
32
33    #[error("rand error: {0}")]
34    /// RNG-related errors
35    RandError(rand::Error),
36
37    #[error("rand weights error: {0}")]
38    /// RNG weights-related errors
39    RandWeightedError(rand::distributions::WeightedError),
40}
41
42#[derive(Debug, Clone, Error)]
43/// Logit errors
44pub enum LogitsError {
45    #[error("Invalid logit for token id {0}")]
46    /// Contains the position (AKA token id) of the offending logit.
47    /// Logits cannot be NaN.
48    InvalidLogit(usize),
49    #[error("internal logits error: {0}")]
50    /// General internal error type.
51    InternalError(String),
52}
53
54impl From<LogitsError> for SamplerError {
55    fn from(value: LogitsError) -> Self {
56        SamplerError::LogitsError(value)
57    }
58}
59
60#[derive(Debug, Clone, PartialEq)]
61/// An individual logit with some additional metadata for use by the samplers.
62pub struct Logit {
63    /// The token id.
64    pub token_id: TID,
65    /// The logit value.
66    pub logit: L,
67    /// Computed probability.
68    pub prob: L,
69}
70
71#[derive(Debug, Clone, Default)]
72/// A collection of [Logit]s. You normally will need to build this from the result of
73/// evaluating the LLM.
74///
75/// For convenience, this can [Deref] to the internal [Vec].
76pub struct Logits {
77    sorted: bool,
78    has_softmax: bool,
79    logits: Vec<Logit>,
80}
81
82impl Deref for Logits {
83    type Target = Vec<Logit>;
84
85    fn deref(&self) -> &Self::Target {
86        &self.logits
87    }
88}
89
90impl DerefMut for Logits {
91    fn deref_mut(&mut self) -> &mut Self::Target {
92        &mut self.logits
93    }
94}
95
96impl Logits {
97    /// Make a new [Logits] from an iterator of `L`. We'd like to
98    /// write this as [TryFrom] but unfortunately the types make this impossible.
99    pub fn try_from_iter<I: IntoIterator<Item = L>>(it: I) -> Result<Self, LogitsError> {
100        let mut tid = 0;
101        Ok(Self {
102            sorted: false,
103            has_softmax: false,
104            logits: it
105                .into_iter()
106                .enumerate()
107                .map(|(idx, logit)| {
108                    if logit.is_nan() {
109                        Err(LogitsError::InvalidLogit(idx))?
110                    }
111                    let result = Logit {
112                        token_id: tid,
113                        logit,
114                        prob: 0f32,
115                    };
116                    tid += 1;
117                    Ok(result)
118                })
119                .collect::<Result<Vec<_>, LogitsError>>()?,
120        })
121    }
122
123    /// Make a new [Logits] from an iterator of `L` while only keeping the top `k`
124    /// values and maintaining sorted order. This may be faster than building the
125    /// full logits and then later sorting/pruning them. Set `k` high enough that
126    /// the logits it prunes aren't ones that would be considered with normal
127    /// sampling. Something like 500 to 2,000 is probably reasonable.
128    ///
129    /// Note: Infinite and NaN values will also be filtered.
130    pub fn try_from_iter_top_k<I: IntoIterator<Item = L>>(
131        it: I,
132        k: usize,
133    ) -> Result<Self, LogitsError> {
134        if k == 0 {
135            return Ok(Self::default());
136        }
137
138        Ok(Logits {
139            sorted: true,
140            has_softmax: false,
141            logits: (0u32..)
142                .zip(it)
143                .filter(|(_tid, logit)| logit.is_finite())
144                .fold(Vec::with_capacity(k), |mut logits, (tid, logit)| {
145                    if logits.len() == k {
146                        // The Vec is guaranteed not to be empty at this point.
147                        if logit > unsafe { logits.last().unwrap_unchecked().logit } {
148                            logits.truncate(k - 1);
149                        } else {
150                            return logits;
151                        }
152                    }
153                    logits.insert(
154                        logits.partition_point(|l| logit < l.logit),
155                        Logit {
156                            token_id: tid,
157                            logit,
158                            prob: 0f32,
159                        },
160                    );
161                    logits
162                }),
163        })
164    }
165}
166
167impl TryFrom<Vec<L>> for Logits {
168    type Error = LogitsError;
169
170    fn try_from(value: Vec<L>) -> Result<Self, Self::Error> {
171        Self::try_from_iter(value)
172    }
173}
174
175impl Logits {
176    /// Get the sorted flag.
177    pub fn get_sorted(&self) -> bool {
178        self.sorted
179    }
180
181    /// Set the sorted flag.
182    pub fn set_sorted(&mut self, is_sorted: bool) -> &mut Self {
183        self.sorted = is_sorted;
184        self
185    }
186
187    /// Get the softmax flag.
188    pub fn get_softmax(&self) -> bool {
189        self.has_softmax
190    }
191
192    /// Set the softmax flag.
193    pub fn set_softmax(&mut self, has_softmax: bool) -> &mut Self {
194        self.has_softmax = has_softmax;
195        self
196    }
197
198    /// Ensure the [Logits] are sorted. Generally not necessary to call this directly.
199    pub fn ensure_sorted(&mut self) -> Result<&mut Self> {
200        if self.get_sorted() {
201            return Ok(self);
202        }
203
204        let mut sort_err = Ok(());
205        self.logits.as_mut_slice().sort_by(|a, b| {
206            b.logit.partial_cmp(&a.logit).unwrap_or_else(|| {
207                sort_err = Err(LogitsError::InternalError(String::from(
208                    "Impossible: logit comparison failed?",
209                )));
210                std::cmp::Ordering::Less
211            })
212        });
213        sort_err?;
214        self.set_sorted(true);
215        Ok(self)
216    }
217
218    /// Ensure the softmax function has been applied to the [Logits].
219    pub fn ensure_softmax(&mut self) -> Result<&mut Self> {
220        if self.is_empty() || self.has_softmax {
221            self.has_softmax = true;
222            self.sorted = true;
223            return Ok(self);
224        }
225        self.ensure_sorted()?;
226        let max_l = self[0].logit;
227        let cum_sum = self.iter_mut().fold(0f32, |cs, l| {
228            l.prob = (l.logit - max_l).exp();
229            cs + l.prob
230        });
231        self.iter_mut().for_each(|l| l.prob /= cum_sum);
232        self.has_softmax = true;
233        Ok(self)
234    }
235
236    /// Convenience method
237    pub fn sample<S: Sampler>(
238        &mut self,
239        res: &mut dyn HasSamplerResources,
240        sampler: &mut S,
241    ) -> Result<&mut Self> {
242        sampler.sample(res, self)
243    }
244
245    /// Convenience method
246    pub fn sample_token<S: Sampler>(
247        &mut self,
248        res: &mut dyn HasSamplerResources,
249        sampler: &mut S,
250    ) -> Result<Option<TID>> {
251        sampler.sample_token(res, self)
252    }
253}
254
255/// The main sampler trait.
256pub trait Sampler: Debug + Send + Sync {
257    /// Runs the [Sampler]. Depending on the type of [Sampler], this may produce a token id.
258    fn sample<'a>(
259        &mut self,
260        res: &mut dyn HasSamplerResources,
261        logits: &'a mut Logits,
262    ) -> Result<&'a mut Logits>;
263
264    /// Returns the last sampled token id if available.
265    ///
266    /// A default implemenation is provided which simply returns [None].
267    fn sampled_token_id(&self) -> Option<TID> {
268        None
269    }
270
271    /// Run the sampler and return the last sampled token id if available.
272    ///
273    /// A default implementation is provided which just calls [Sampler::sample] followed by
274    /// [Sampler::sampled_token_id()].
275    fn sample_token(
276        &mut self,
277        res: &mut dyn HasSamplerResources,
278        logits: &mut Logits,
279    ) -> Result<Option<TID>> {
280        let _ = self.sample(res, logits)?;
281        Ok(self.sampled_token_id())
282    }
283}
284
285impl Sampler for Box<dyn Sampler> {
286    fn sampled_token_id(&self) -> Option<TID> {
287        (**self).sampled_token_id()
288    }
289
290    fn sample_token(
291        &mut self,
292        res: &mut dyn HasSamplerResources,
293        logits: &mut Logits,
294    ) -> Result<Option<TID>> {
295        (**self).sample_token(res, logits)
296    }
297
298    fn sample<'a>(
299        &mut self,
300        res: &mut dyn HasSamplerResources,
301        logits: &'a mut Logits,
302    ) -> Result<&'a mut Logits> {
303        (**self).sample(res, logits)
304    }
305}
306
307impl Sampler for Arc<Mutex<dyn Sampler>> {
308    fn sampled_token_id(&self) -> Option<TID> {
309        self.lock().ok()?.sampled_token_id()
310    }
311
312    fn sample_token(
313        &mut self,
314        res: &mut dyn HasSamplerResources,
315        logits: &mut Logits,
316    ) -> Result<Option<TID>> {
317        self.lock()
318            .map_err(|e| SamplerError::InternalError(format!("Couldn't acquire lock: {e}")))?
319            .sample_token(res, logits)
320    }
321
322    fn sample<'a>(
323        &mut self,
324        res: &mut dyn HasSamplerResources,
325        logits: &'a mut Logits,
326    ) -> Result<&'a mut Logits> {
327        self.lock()
328            .map_err(|e| SamplerError::InternalError(format!("Couldn't acquire lock: {e}")))?
329            .sample(res, logits)
330    }
331}