Skip to main content

llama_runner/
sample.rs

1use llama_cpp_2::{GrammarError, model::LlamaModel, sampling::LlamaSampler};
2use serde::{Deserialize, Serialize};
3use strum::Display;
4
5#[derive(Debug, Clone, Default, Serialize, Deserialize)]
6pub struct SimpleSamplingParams {
7    pub top_p: Option<f32>,
8    pub top_k: Option<i32>,
9    pub temperature: Option<f32>,
10    pub seed: Option<u32>,
11    pub presence_penalty: Option<f32>,
12    pub repetition_penalty: Option<f32>,
13}
14
15impl SimpleSamplingParams {
16    pub fn to_llama(&self) -> LlamaSampler {
17        let mut samplers = Vec::new();
18        if let Some(k) = self.top_k {
19            samplers.push(LlamaSampler::top_k(k));
20        }
21        if let Some(p) = self.top_p {
22            samplers.push(LlamaSampler::top_p(p, 8));
23        }
24        if let Some(p) = self.presence_penalty
25            && let Some(r) = self.repetition_penalty
26        {
27            samplers.push(LlamaSampler::penalties(-1, p, 0.0, r));
28        }
29        samplers.push(LlamaSampler::dist(
30            self.seed.unwrap_or_else(|| rand::random()),
31        ));
32        LlamaSampler::chain_simple(samplers)
33    }
34}
35
36#[derive(Debug, Clone)]
37pub struct LlguidanceSamplingParams {
38    pub schema: LlguidanceSchema,
39    pub data: String,
40}
41
42impl LlguidanceSamplingParams {
43    pub fn to_llama(&self, model: &LlamaModel) -> Result<LlamaSampler, GrammarError> {
44        LlamaSampler::llguidance(model, self.schema.to_string().as_str(), &self.data)
45    }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
49pub enum LlguidanceSchema {
50    #[strum(to_string = "json")]
51    Json,
52    #[strum(to_string = "lark")]
53    Lark,
54}