Skip to main content

llama_crab/
token_data.rs

1//! [`LlamaTokenData`] and [`LlamaTokenDataArray`].
2
3use llama_crab_sys as sys;
4
5use crate::token::LlamaToken;
6
7/// A `(id, logit, p)` triple representing a token in the candidate set.
8#[derive(Clone, Copy, PartialEq)]
9#[repr(transparent)]
10pub struct LlamaTokenData(pub sys::llama_token_data);
11
12impl LlamaTokenData {
13    /// Construct from id, logit and probability.
14    #[must_use]
15    pub fn new(id: LlamaToken, logit: f32, p: f32) -> Self {
16        Self(sys::llama_token_data { id: id.0, logit, p })
17    }
18
19    /// Token id.
20    #[must_use]
21    pub fn id(&self) -> LlamaToken {
22        LlamaToken(self.0.id)
23    }
24
25    /// Set the token id.
26    pub fn set_id(&mut self, id: LlamaToken) {
27        self.0.id = id.0;
28    }
29
30    /// Logit value.
31    #[must_use]
32    pub fn logit(&self) -> f32 {
33        self.0.logit
34    }
35
36    /// Set the logit value.
37    pub fn set_logit(&mut self, logit: f32) {
38        self.0.logit = logit;
39    }
40
41    /// Probability in `[0, 1]` (filled in by `apply_sampler`).
42    #[must_use]
43    pub fn p(&self) -> f32 {
44        self.0.p
45    }
46
47    /// Set the probability.
48    pub fn set_p(&mut self, p: f32) {
49        self.0.p = p;
50    }
51}
52
53impl std::fmt::Debug for LlamaTokenData {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("LlamaTokenData")
56            .field("id", &self.id())
57            .field("logit", &self.logit())
58            .field("p", &self.p())
59            .finish()
60    }
61}
62
63/// A mutable array of [`LlamaTokenData`] with selection/sort metadata.
64#[derive(Debug)]
65pub struct LlamaTokenDataArray {
66    inner: sys::llama_token_data_array,
67    // The boxed slice keeps the backing storage alive for the C struct.
68    _data: Box<[LlamaTokenData]>,
69}
70
71impl LlamaTokenDataArray {
72    /// Construct a new `LlamaTokenDataArray` with `sorted = false` and
73    /// `selected = -1`.
74    #[must_use]
75    pub fn new(data: Vec<LlamaTokenData>) -> Self {
76        let mut data: Box<[LlamaTokenData]> = data.into_boxed_slice();
77        let ptr = data.as_mut_ptr().cast::<sys::llama_token_data>();
78        let inner = sys::llama_token_data_array {
79            data: ptr,
80            size: data.len(),
81            selected: -1,
82            sorted: false,
83        };
84        Self { inner, _data: data }
85    }
86
87    /// Number of elements.
88    #[must_use]
89    pub fn len(&self) -> usize {
90        self.inner.size
91    }
92
93    /// True if empty.
94    #[must_use]
95    pub fn is_empty(&self) -> bool {
96        self.inner.size == 0
97    }
98
99    /// Index of the selected token (-1 if none).
100    #[must_use]
101    pub fn selected(&self) -> i64 {
102        self.inner.selected
103    }
104
105    /// Borrow the underlying `&[sys::llama_token_data]`.
106    #[must_use]
107    pub fn as_raw(&self) -> &[sys::llama_token_data] {
108        // Safety: the boxed slice has the same length as the C array.
109        unsafe { std::slice::from_raw_parts(self.inner.data, self.inner.size) }
110    }
111
112    /// Borrow the `&[LlamaTokenData]` wrapper.
113    #[must_use]
114    pub fn data(&self) -> &[LlamaTokenData] {
115        &self._data
116    }
117
118    /// Mutable borrow of the inner C struct (private — use higher-level
119    /// wrappers when adding new operations).
120    #[allow(dead_code)]
121    pub(crate) fn inner_mut(&mut self) -> &mut sys::llama_token_data_array {
122        &mut self.inner
123    }
124
125    /// Borrow a raw `*mut` for the C API.
126    pub(crate) fn as_mut_ptr(&mut self) -> *mut sys::llama_token_data_array {
127        &mut self.inner
128    }
129}