1use llama_cpp_2::{GrammarError, model::LlamaModel, sampling::LlamaSampler};
2use serde::{Deserialize, Serialize};
3use strum::Display;
4
5#[derive(Debug, Clone, Default, Serialize, Deserialize)]
6pub struct SimpleSamplingParams {
7 pub top_p: Option<f32>,
8 pub top_k: Option<i32>,
9 pub temperature: Option<f32>,
10 pub seed: Option<u32>,
11 pub presence_penalty: Option<f32>,
12 pub repetition_penalty: Option<f32>,
13}
14
15impl SimpleSamplingParams {
16 pub fn to_llama(&self) -> LlamaSampler {
17 let mut samplers = Vec::new();
18 if let Some(k) = self.top_k {
19 samplers.push(LlamaSampler::top_k(k));
20 }
21 if let Some(p) = self.top_p {
22 samplers.push(LlamaSampler::top_p(p, 8));
23 }
24 if let Some(p) = self.presence_penalty
25 && let Some(r) = self.repetition_penalty
26 {
27 samplers.push(LlamaSampler::penalties(-1, p, 0.0, r));
28 }
29 samplers.push(LlamaSampler::dist(
30 self.seed.unwrap_or_else(|| rand::random()),
31 ));
32 LlamaSampler::chain_simple(samplers)
33 }
34}
35
36#[derive(Debug, Clone)]
37pub struct LlguidanceSamplingParams {
38 pub schema: LlguidanceSchema,
39 pub data: String,
40}
41
42impl LlguidanceSamplingParams {
43 pub fn to_llama(&self, model: &LlamaModel) -> Result<LlamaSampler, GrammarError> {
44 LlamaSampler::llguidance(model, self.schema.to_string().as_str(), &self.data)
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
49pub enum LlguidanceSchema {
50 #[strum(to_string = "json")]
51 Json,
52 #[strum(to_string = "lark")]
53 Lark,
54}