1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::{FerrumError, Result, TokenId};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SamplingParams {
11 pub max_tokens: usize,
13 pub temperature: f32,
15 pub top_p: f32,
17 pub top_k: Option<usize>,
19 pub repetition_penalty: f32,
21 pub presence_penalty: f32,
23 pub frequency_penalty: f32,
25 pub stop_sequences: Vec<String>,
27 pub seed: Option<u64>,
29 pub min_p: Option<f32>,
31 pub tfs: Option<f32>,
33 pub typical_p: Option<f32>,
35 pub mirostat: Option<MirostatParams>,
37 #[serde(default)]
39 pub response_format: ResponseFormat,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
49#[serde(tag = "type", content = "schema")]
50pub enum ResponseFormat {
51 Text,
53 JsonObject,
55 JsonSchema(String),
57}
58
59impl Default for ResponseFormat {
60 fn default() -> Self {
61 Self::Text
62 }
63}
64
65impl Default for SamplingParams {
66 fn default() -> Self {
67 Self {
68 max_tokens: 512,
69 temperature: 1.0,
70 top_p: 1.0,
71 top_k: None,
72 repetition_penalty: 1.0,
73 presence_penalty: 0.0,
74 frequency_penalty: 0.0,
75 stop_sequences: vec![],
76 seed: None,
77 min_p: None,
78 tfs: None,
79 typical_p: None,
80 mirostat: None,
81 response_format: ResponseFormat::default(),
82 }
83 }
84}
85
86impl SamplingParams {
87 pub fn greedy() -> Self {
89 Self {
90 temperature: 0.0,
91 top_p: 1.0,
92 top_k: None,
93 ..Default::default()
94 }
95 }
96
97 pub fn with_temperature(temperature: f32) -> Self {
99 Self {
100 temperature,
101 ..Default::default()
102 }
103 }
104
105 pub fn validate(&self) -> Result<()> {
107 if self.temperature < 0.0 {
108 return Err(FerrumError::invalid_request(
109 "Temperature must be non-negative".to_string(),
110 ));
111 }
112 if self.top_p <= 0.0 || self.top_p > 1.0 {
113 return Err(FerrumError::invalid_request(
114 "top_p must be in range (0, 1]".to_string(),
115 ));
116 }
117 if let Some(top_k) = self.top_k {
118 if top_k == 0 {
119 return Err(FerrumError::invalid_request(
120 "top_k must be positive".to_string(),
121 ));
122 }
123 }
124 if self.repetition_penalty <= 0.0 {
125 return Err(FerrumError::invalid_request(
126 "Repetition penalty must be positive".to_string(),
127 ));
128 }
129 if let Some(min_p) = self.min_p {
130 if min_p <= 0.0 || min_p > 1.0 {
131 return Err(FerrumError::invalid_request(
132 "min_p must be in range (0, 1]".to_string(),
133 ));
134 }
135 }
136 if let Some(tfs) = self.tfs {
137 if tfs <= 0.0 || tfs > 1.0 {
138 return Err(FerrumError::invalid_request(
139 "tfs must be in range (0, 1]".to_string(),
140 ));
141 }
142 }
143 if let Some(typical_p) = self.typical_p {
144 if typical_p <= 0.0 || typical_p > 1.0 {
145 return Err(FerrumError::invalid_request(
146 "typical_p must be in range (0, 1]".to_string(),
147 ));
148 }
149 }
150 Ok(())
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct MirostatParams {
157 pub mode: u8,
159 pub tau: f32,
161 pub eta: f32,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct SamplingPresets {
168 pub presets: HashMap<String, SamplingParams>,
169}
170
171impl Default for SamplingPresets {
172 fn default() -> Self {
173 let mut presets = HashMap::new();
174 presets.insert("greedy".to_string(), SamplingParams::greedy());
175 presets.insert(
176 "creative".to_string(),
177 SamplingParams {
178 temperature: 1.2,
179 top_p: 0.9,
180 top_k: Some(50),
181 repetition_penalty: 1.1,
182 ..Default::default()
183 },
184 );
185 presets.insert(
186 "precise".to_string(),
187 SamplingParams {
188 temperature: 0.3,
189 top_p: 0.95,
190 top_k: Some(20),
191 repetition_penalty: 1.05,
192 ..Default::default()
193 },
194 );
195 Self { presets }
196 }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
201pub enum Priority {
202 Low = 0,
203 Normal = 1,
204 High = 2,
205 Critical = 3,
206}
207
208impl Default for Priority {
209 fn default() -> Self {
210 Priority::Normal
211 }
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
216pub enum FinishReason {
217 Length,
219 Stop,
221 EOS,
223 Cancelled,
225 Error,
227 ContentFilter,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct SpecialTokens {
234 pub bos_token: Option<TokenId>,
236 pub eos_token: Option<TokenId>,
238 pub unk_token: Option<TokenId>,
240 pub pad_token: Option<TokenId>,
242 pub sep_token: Option<TokenId>,
244 pub cls_token: Option<TokenId>,
246 pub mask_token: Option<TokenId>,
248}
249
250impl Default for SpecialTokens {
251 fn default() -> Self {
252 Self {
253 bos_token: None,
254 eos_token: None,
255 unk_token: None,
256 pad_token: None,
257 sep_token: None,
258 cls_token: None,
259 mask_token: None,
260 }
261 }
262}