infernum_core/
sampling.rs

1//! Sampling parameters for text generation.
2
3use serde::{Deserialize, Serialize};
4
5/// Parameters controlling text generation sampling.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct SamplingParams {
8    /// Temperature for sampling (0.0 = greedy, higher = more random).
9    /// Default: 1.0
10    #[serde(default = "default_temperature")]
11    pub temperature: f32,
12
13    /// Top-p (nucleus) sampling threshold.
14    /// Default: 1.0
15    #[serde(default = "default_top_p")]
16    pub top_p: f32,
17
18    /// Top-k sampling (0 = disabled).
19    /// Default: 0
20    #[serde(default)]
21    pub top_k: u32,
22
23    /// Minimum probability for min-p sampling (0.0 = disabled).
24    /// Default: 0.0
25    #[serde(default)]
26    pub min_p: f32,
27
28    /// Repetition penalty (1.0 = no penalty).
29    /// Default: 1.0
30    #[serde(default = "default_repetition_penalty")]
31    pub repetition_penalty: f32,
32
33    /// Presence penalty (-2.0 to 2.0).
34    /// Default: 0.0
35    #[serde(default)]
36    pub presence_penalty: f32,
37
38    /// Frequency penalty (-2.0 to 2.0).
39    /// Default: 0.0
40    #[serde(default)]
41    pub frequency_penalty: f32,
42
43    /// Stop sequences that halt generation.
44    #[serde(default)]
45    pub stop_sequences: Vec<String>,
46
47    /// Maximum number of tokens to generate.
48    /// Default: 256
49    #[serde(default = "default_max_tokens")]
50    pub max_tokens: u32,
51
52    /// Random seed for reproducibility.
53    #[serde(default)]
54    pub seed: Option<u64>,
55}
56
57fn default_temperature() -> f32 {
58    1.0
59}
60
61fn default_top_p() -> f32 {
62    1.0
63}
64
65fn default_repetition_penalty() -> f32 {
66    1.0
67}
68
69fn default_max_tokens() -> u32 {
70    256
71}
72
73impl Default for SamplingParams {
74    fn default() -> Self {
75        Self {
76            temperature: 1.0,
77            top_p: 1.0,
78            top_k: 0,
79            min_p: 0.0,
80            repetition_penalty: 1.0,
81            presence_penalty: 0.0,
82            frequency_penalty: 0.0,
83            stop_sequences: Vec::new(),
84            max_tokens: 256,
85            seed: None,
86        }
87    }
88}
89
90impl SamplingParams {
91    /// Creates greedy sampling parameters (temperature = 0).
92    #[must_use]
93    pub fn greedy() -> Self {
94        Self {
95            temperature: 0.0,
96            ..Default::default()
97        }
98    }
99
100    /// Creates balanced sampling parameters.
101    #[must_use]
102    pub fn balanced() -> Self {
103        Self {
104            temperature: 0.7,
105            top_p: 0.9,
106            ..Default::default()
107        }
108    }
109
110    /// Creates creative sampling parameters.
111    #[must_use]
112    pub fn creative() -> Self {
113        Self {
114            temperature: 1.0,
115            top_p: 0.95,
116            top_k: 50,
117            ..Default::default()
118        }
119    }
120
121    /// Sets the temperature.
122    #[must_use]
123    pub fn with_temperature(mut self, temperature: f32) -> Self {
124        self.temperature = temperature;
125        self
126    }
127
128    /// Sets the top-p value.
129    #[must_use]
130    pub fn with_top_p(mut self, top_p: f32) -> Self {
131        self.top_p = top_p;
132        self
133    }
134
135    /// Sets the top-k value.
136    #[must_use]
137    pub fn with_top_k(mut self, top_k: u32) -> Self {
138        self.top_k = top_k;
139        self
140    }
141
142    /// Sets the maximum tokens.
143    #[must_use]
144    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
145        self.max_tokens = max_tokens;
146        self
147    }
148
149    /// Adds a stop sequence.
150    #[must_use]
151    pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
152        self.stop_sequences.push(stop.into());
153        self
154    }
155
156    /// Sets the random seed.
157    #[must_use]
158    pub fn with_seed(mut self, seed: u64) -> Self {
159        self.seed = Some(seed);
160        self
161    }
162
163    /// Validates the sampling parameters.
164    ///
165    /// # Errors
166    ///
167    /// Returns an error if any parameter is out of valid range.
168    pub fn validate(&self) -> Result<(), String> {
169        if self.temperature < 0.0 {
170            return Err("temperature must be non-negative".to_string());
171        }
172        if !(0.0..=1.0).contains(&self.top_p) {
173            return Err("top_p must be between 0.0 and 1.0".to_string());
174        }
175        if !(0.0..=1.0).contains(&self.min_p) {
176            return Err("min_p must be between 0.0 and 1.0".to_string());
177        }
178        if self.repetition_penalty < 0.0 {
179            return Err("repetition_penalty must be non-negative".to_string());
180        }
181        if !(-2.0..=2.0).contains(&self.presence_penalty) {
182            return Err("presence_penalty must be between -2.0 and 2.0".to_string());
183        }
184        if !(-2.0..=2.0).contains(&self.frequency_penalty) {
185            return Err("frequency_penalty must be between -2.0 and 2.0".to_string());
186        }
187        if self.max_tokens == 0 {
188            return Err("max_tokens must be greater than 0".to_string());
189        }
190        Ok(())
191    }
192}