Skip to main content

llama_cpp_bindings/token/
data.rs

1use crate::token::LlamaToken;
2
3#[derive(Clone, Copy, Debug, PartialEq)]
4#[repr(transparent)]
5pub struct LlamaTokenData {
6    data: llama_cpp_bindings_sys::llama_token_data,
7}
8
9impl LlamaTokenData {
10    #[must_use]
11    pub const fn new(LlamaToken(id): LlamaToken, logit: f32, p: f32) -> Self {
12        Self {
13            data: llama_cpp_bindings_sys::llama_token_data { id, logit, p },
14        }
15    }
16    #[must_use]
17    pub const fn id(&self) -> LlamaToken {
18        LlamaToken(self.data.id)
19    }
20
21    #[must_use]
22    pub const fn logit(&self) -> f32 {
23        self.data.logit
24    }
25
26    #[must_use]
27    pub const fn p(&self) -> f32 {
28        self.data.p
29    }
30
31    pub const fn set_id(&mut self, id: LlamaToken) {
32        self.data.id = id.0;
33    }
34
35    pub const fn set_logit(&mut self, logit: f32) {
36        self.data.logit = logit;
37    }
38
39    pub const fn set_p(&mut self, p: f32) {
40        self.data.p = p;
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use super::LlamaTokenData;
47    use crate::token::LlamaToken;
48
49    #[test]
50    fn new_stores_all_fields() {
51        let token = LlamaToken::new(7);
52        let data = LlamaTokenData::new(token, 2.5, 0.8);
53        assert_eq!(data.id(), token);
54        assert!((data.logit() - 2.5_f32).abs() < f32::EPSILON);
55        assert!((data.p() - 0.8_f32).abs() < f32::EPSILON);
56    }
57
58    #[test]
59    fn set_id_updates_token() {
60        let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
61        data.set_id(LlamaToken::new(42));
62        assert_eq!(data.id(), LlamaToken::new(42));
63    }
64
65    #[test]
66    fn set_logit_updates_logit() {
67        let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
68        data.set_logit(-1.5);
69        assert!((data.logit() - (-1.5_f32)).abs() < f32::EPSILON);
70    }
71
72    #[test]
73    fn set_p_updates_probability() {
74        let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
75        data.set_p(0.95);
76        assert!((data.p() - 0.95_f32).abs() < f32::EPSILON);
77    }
78}