Skip to main content

llama_cpp_bindings/token/
logit_bias.rs

1use crate::token::LlamaToken;
2
3#[derive(Clone, Copy, Debug, PartialEq)]
4#[repr(transparent)]
5pub struct LlamaLogitBias {
6    logit_bias: llama_cpp_bindings_sys::llama_logit_bias,
7}
8
9impl LlamaLogitBias {
10    #[must_use]
11    pub const fn new(LlamaToken(token): LlamaToken, bias: f32) -> Self {
12        Self {
13            logit_bias: llama_cpp_bindings_sys::llama_logit_bias { token, bias },
14        }
15    }
16
17    #[must_use]
18    pub const fn token(&self) -> LlamaToken {
19        LlamaToken(self.logit_bias.token)
20    }
21
22    #[must_use]
23    pub const fn bias(&self) -> f32 {
24        self.logit_bias.bias
25    }
26
27    pub const fn set_token(&mut self, token: LlamaToken) {
28        self.logit_bias.token = token.0;
29    }
30
31    pub const fn set_bias(&mut self, bias: f32) {
32        self.logit_bias.bias = bias;
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use super::LlamaLogitBias;
39    use crate::token::LlamaToken;
40
41    #[test]
42    fn new_stores_token_and_bias() {
43        let token = LlamaToken::new(42);
44        let logit_bias = LlamaLogitBias::new(token, 1.5);
45        assert_eq!(logit_bias.token(), token);
46        assert!((logit_bias.bias() - 1.5_f32).abs() < f32::EPSILON);
47    }
48
49    #[test]
50    fn set_token_updates_token() {
51        let mut logit_bias = LlamaLogitBias::new(LlamaToken::new(1), 0.5);
52        let new_token = LlamaToken::new(99);
53        logit_bias.set_token(new_token);
54        assert_eq!(logit_bias.token(), new_token);
55    }
56
57    #[test]
58    fn set_bias_updates_bias() {
59        let mut logit_bias = LlamaLogitBias::new(LlamaToken::new(1), 0.5);
60        logit_bias.set_bias(-3.0);
61        assert!((logit_bias.bias() - (-3.0_f32)).abs() < f32::EPSILON);
62    }
63}