infernum_core/
sampling.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct SamplingParams {
8 #[serde(default = "default_temperature")]
11 pub temperature: f32,
12
13 #[serde(default = "default_top_p")]
16 pub top_p: f32,
17
18 #[serde(default)]
21 pub top_k: u32,
22
23 #[serde(default)]
26 pub min_p: f32,
27
28 #[serde(default = "default_repetition_penalty")]
31 pub repetition_penalty: f32,
32
33 #[serde(default)]
36 pub presence_penalty: f32,
37
38 #[serde(default)]
41 pub frequency_penalty: f32,
42
43 #[serde(default)]
45 pub stop_sequences: Vec<String>,
46
47 #[serde(default = "default_max_tokens")]
50 pub max_tokens: u32,
51
52 #[serde(default)]
54 pub seed: Option<u64>,
55}
56
57fn default_temperature() -> f32 {
58 1.0
59}
60
61fn default_top_p() -> f32 {
62 1.0
63}
64
65fn default_repetition_penalty() -> f32 {
66 1.0
67}
68
69fn default_max_tokens() -> u32 {
70 256
71}
72
73impl Default for SamplingParams {
74 fn default() -> Self {
75 Self {
76 temperature: 1.0,
77 top_p: 1.0,
78 top_k: 0,
79 min_p: 0.0,
80 repetition_penalty: 1.0,
81 presence_penalty: 0.0,
82 frequency_penalty: 0.0,
83 stop_sequences: Vec::new(),
84 max_tokens: 256,
85 seed: None,
86 }
87 }
88}
89
90impl SamplingParams {
91 #[must_use]
93 pub fn greedy() -> Self {
94 Self {
95 temperature: 0.0,
96 ..Default::default()
97 }
98 }
99
100 #[must_use]
102 pub fn balanced() -> Self {
103 Self {
104 temperature: 0.7,
105 top_p: 0.9,
106 ..Default::default()
107 }
108 }
109
110 #[must_use]
112 pub fn creative() -> Self {
113 Self {
114 temperature: 1.0,
115 top_p: 0.95,
116 top_k: 50,
117 ..Default::default()
118 }
119 }
120
121 #[must_use]
123 pub fn with_temperature(mut self, temperature: f32) -> Self {
124 self.temperature = temperature;
125 self
126 }
127
128 #[must_use]
130 pub fn with_top_p(mut self, top_p: f32) -> Self {
131 self.top_p = top_p;
132 self
133 }
134
135 #[must_use]
137 pub fn with_top_k(mut self, top_k: u32) -> Self {
138 self.top_k = top_k;
139 self
140 }
141
142 #[must_use]
144 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
145 self.max_tokens = max_tokens;
146 self
147 }
148
149 #[must_use]
151 pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
152 self.stop_sequences.push(stop.into());
153 self
154 }
155
156 #[must_use]
158 pub fn with_seed(mut self, seed: u64) -> Self {
159 self.seed = Some(seed);
160 self
161 }
162
163 pub fn validate(&self) -> Result<(), String> {
169 if self.temperature < 0.0 {
170 return Err("temperature must be non-negative".to_string());
171 }
172 if !(0.0..=1.0).contains(&self.top_p) {
173 return Err("top_p must be between 0.0 and 1.0".to_string());
174 }
175 if !(0.0..=1.0).contains(&self.min_p) {
176 return Err("min_p must be between 0.0 and 1.0".to_string());
177 }
178 if self.repetition_penalty < 0.0 {
179 return Err("repetition_penalty must be non-negative".to_string());
180 }
181 if !(-2.0..=2.0).contains(&self.presence_penalty) {
182 return Err("presence_penalty must be between -2.0 and 2.0".to_string());
183 }
184 if !(-2.0..=2.0).contains(&self.frequency_penalty) {
185 return Err("frequency_penalty must be between -2.0 and 2.0".to_string());
186 }
187 if self.max_tokens == 0 {
188 return Err("max_tokens must be greater than 0".to_string());
189 }
190 Ok(())
191 }
192}