llama_cpp_bindings/token/
logit_bias.rs1use crate::token::LlamaToken;
2
3#[derive(Clone, Copy, Debug, PartialEq)]
4#[repr(transparent)]
5pub struct LlamaLogitBias {
6 logit_bias: llama_cpp_bindings_sys::llama_logit_bias,
7}
8
9impl LlamaLogitBias {
10 #[must_use]
11 pub const fn new(LlamaToken(token): LlamaToken, bias: f32) -> Self {
12 Self {
13 logit_bias: llama_cpp_bindings_sys::llama_logit_bias { token, bias },
14 }
15 }
16
17 #[must_use]
18 pub const fn token(&self) -> LlamaToken {
19 LlamaToken(self.logit_bias.token)
20 }
21
22 #[must_use]
23 pub const fn bias(&self) -> f32 {
24 self.logit_bias.bias
25 }
26
27 pub const fn set_token(&mut self, token: LlamaToken) {
28 self.logit_bias.token = token.0;
29 }
30
31 pub const fn set_bias(&mut self, bias: f32) {
32 self.logit_bias.bias = bias;
33 }
34}
35
36#[cfg(test)]
37mod tests {
38 use super::LlamaLogitBias;
39 use crate::token::LlamaToken;
40
41 #[test]
42 fn new_stores_token_and_bias() {
43 let token = LlamaToken::new(42);
44 let logit_bias = LlamaLogitBias::new(token, 1.5);
45 assert_eq!(logit_bias.token(), token);
46 assert!((logit_bias.bias() - 1.5_f32).abs() < f32::EPSILON);
47 }
48
49 #[test]
50 fn set_token_updates_token() {
51 let mut logit_bias = LlamaLogitBias::new(LlamaToken::new(1), 0.5);
52 let new_token = LlamaToken::new(99);
53 logit_bias.set_token(new_token);
54 assert_eq!(logit_bias.token(), new_token);
55 }
56
57 #[test]
58 fn set_bias_updates_bias() {
59 let mut logit_bias = LlamaLogitBias::new(LlamaToken::new(1), 0.5);
60 logit_bias.set_bias(-3.0);
61 assert!((logit_bias.bias() - (-3.0_f32)).abs() < f32::EPSILON);
62 }
63}