1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//! Sampling functions for the context.

use crate::context::LlamaContext;
use crate::grammar::LlamaGrammar;
use crate::token::data_array::LlamaTokenDataArray;
use crate::token::LlamaToken;

#[cfg(feature = "sampler")]
pub mod sampler;

impl LlamaContext<'_> {
    /// Accept a token into the grammar.
    pub fn grammar_accept_token(&mut self, grammar: &mut LlamaGrammar, token: LlamaToken) {
        unsafe {
            llama_cpp_sys_2::llama_grammar_accept_token(
                grammar.grammar.as_ptr(),
                self.context.as_ptr(),
                token.0,
            );
        }
    }

    /// Perform grammar sampling.
    pub fn sample_grammar(
        &mut self,
        llama_token_data_array: &mut LlamaTokenDataArray,
        llama_grammar: &LlamaGrammar,
    ) {
        unsafe {
            llama_token_data_array.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
                llama_cpp_sys_2::llama_sample_grammar(
                    self.context.as_ptr(),
                    c_llama_token_data_array,
                    llama_grammar.grammar.as_ptr(),
                );
            });
        }
    }

    /// See [`LlamaTokenDataArray::sample_temp`]
    pub fn sample_temp(&mut self, token_data: &mut LlamaTokenDataArray, temperature: f32) {
        token_data.sample_temp(Some(self), temperature);
    }

    /// Sample a token greedily. Note that this *does not* take into account anything that has modified the probabilities - it only looks at logits.
    ///
    /// Most of the time [`LlamaTokenDataArray::sample_softmax`] or [`LlamaTokenDataArray::sample_token`] should be used instead.
    ///
    /// # Panics
    ///
    /// - if `token_data` is empty
    #[must_use]
    pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken {
        assert!(!token_data.data.is_empty(), "no tokens");
        let mut data_arr = llama_cpp_sys_2::llama_token_data_array {
            data: token_data
                .data
                .as_mut_ptr()
                .cast::<llama_cpp_sys_2::llama_token_data>(),
            size: token_data.data.len(),
            sorted: token_data.sorted,
        };
        let token = unsafe {
            llama_cpp_sys_2::llama_sample_token_greedy(
                self.context.as_ptr(),
                std::ptr::addr_of_mut!(data_arr),
            )
        };
        LlamaToken(token)
    }

    /// See [`LlamaTokenDataArray::sample_tail_free`]
    pub fn sample_tail_free(
        &mut self,
        token_data: &mut LlamaTokenDataArray,
        z: f32,
        min_keep: usize,
    ) {
        token_data.sample_tail_free(Some(self), z, min_keep);
    }

    /// See [`LlamaTokenDataArray::sample_typical`]
    pub fn sample_typical(
        &mut self,
        token_data: &mut LlamaTokenDataArray,
        p: f32,
        min_keep: usize,
    ) {
        token_data.sample_typical(Some(self), p, min_keep);
    }

    /// See [`LlamaTokenDataArray::sample_top_p`]
    pub fn sample_top_p(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
        token_data.sample_top_p(Some(self), p, min_keep);
    }

    /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841)
    pub fn sample_min_p(
        &mut self,
        llama_token_data: &mut LlamaTokenDataArray,
        p: f32,
        min_keep: usize,
    ) {
        let ctx = self.context.as_ptr();
        unsafe {
            llama_token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
                llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep);
            });
        }
    }

    /// See [`LlamaTokenDataArray::sample_top_k`]
    pub fn sample_top_k(&mut self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) {
        token_data.sample_top_k(Some(self), k, min_keep);
    }

    /// See [`LlamaTokenDataArray::sample_softmax`]
    pub fn sample_token_softmax(&mut self, token_data: &mut LlamaTokenDataArray) {
        token_data.sample_softmax(Some(self));
    }

    /// See [`LlamaTokenDataArray::sample_repetition_penalty`]
    pub fn sample_repetition_penalty(
        &mut self,
        token_data: &mut LlamaTokenDataArray,
        last_tokens: &[LlamaToken],
        penalty_last_n: usize,
        penalty_repeat: f32,
        penalty_freq: f32,
        penalty_present: f32,
    ) {
        token_data.sample_repetition_penalty(
            Some(self),
            last_tokens,
            penalty_last_n,
            penalty_repeat,
            penalty_freq,
            penalty_present,
        );
    }
}