1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::{FerrumError, Result, TokenId};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SamplingParams {
11 pub max_tokens: usize,
13 pub temperature: f32,
15 pub top_p: f32,
17 pub top_k: Option<usize>,
19 pub repetition_penalty: f32,
21 pub presence_penalty: f32,
23 pub frequency_penalty: f32,
25 pub stop_sequences: Vec<String>,
27 pub seed: Option<u64>,
29 pub min_p: Option<f32>,
31 pub tfs: Option<f32>,
33 pub typical_p: Option<f32>,
35 pub mirostat: Option<MirostatParams>,
37 #[serde(default)]
39 pub response_format: ResponseFormat,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52#[serde(tag = "type", content = "schema")]
53#[derive(Default)]
54pub enum ResponseFormat {
55 #[default]
57 Text,
58 JsonObject,
60 JsonSchema(String),
62}
63
64impl Default for SamplingParams {
65 fn default() -> Self {
66 Self {
67 max_tokens: 512,
68 temperature: 1.0,
69 top_p: 1.0,
70 top_k: None,
71 repetition_penalty: 1.0,
72 presence_penalty: 0.0,
73 frequency_penalty: 0.0,
74 stop_sequences: vec![],
75 seed: None,
76 min_p: None,
77 tfs: None,
78 typical_p: None,
79 mirostat: None,
80 response_format: ResponseFormat::default(),
81 }
82 }
83}
84
85impl SamplingParams {
86 pub fn greedy() -> Self {
88 Self {
89 temperature: 0.0,
90 top_p: 1.0,
91 top_k: None,
92 ..Default::default()
93 }
94 }
95
96 pub fn with_temperature(temperature: f32) -> Self {
98 Self {
99 temperature,
100 ..Default::default()
101 }
102 }
103
104 pub fn validate(&self) -> Result<()> {
106 if self.temperature < 0.0 {
107 return Err(FerrumError::invalid_request(
108 "Temperature must be non-negative".to_string(),
109 ));
110 }
111 if self.top_p <= 0.0 || self.top_p > 1.0 {
112 return Err(FerrumError::invalid_request(
113 "top_p must be in range (0, 1]".to_string(),
114 ));
115 }
116 if let Some(top_k) = self.top_k {
117 if top_k == 0 {
118 return Err(FerrumError::invalid_request(
119 "top_k must be positive".to_string(),
120 ));
121 }
122 }
123 if self.repetition_penalty <= 0.0 {
124 return Err(FerrumError::invalid_request(
125 "Repetition penalty must be positive".to_string(),
126 ));
127 }
128 if let Some(min_p) = self.min_p {
129 if min_p <= 0.0 || min_p > 1.0 {
130 return Err(FerrumError::invalid_request(
131 "min_p must be in range (0, 1]".to_string(),
132 ));
133 }
134 }
135 if let Some(tfs) = self.tfs {
136 if tfs <= 0.0 || tfs > 1.0 {
137 return Err(FerrumError::invalid_request(
138 "tfs must be in range (0, 1]".to_string(),
139 ));
140 }
141 }
142 if let Some(typical_p) = self.typical_p {
143 if typical_p <= 0.0 || typical_p > 1.0 {
144 return Err(FerrumError::invalid_request(
145 "typical_p must be in range (0, 1]".to_string(),
146 ));
147 }
148 }
149 Ok(())
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct MirostatParams {
156 pub mode: u8,
158 pub tau: f32,
160 pub eta: f32,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct SamplingPresets {
167 pub presets: HashMap<String, SamplingParams>,
168}
169
170impl Default for SamplingPresets {
171 fn default() -> Self {
172 let mut presets = HashMap::new();
173 presets.insert("greedy".to_string(), SamplingParams::greedy());
174 presets.insert(
175 "creative".to_string(),
176 SamplingParams {
177 temperature: 1.2,
178 top_p: 0.9,
179 top_k: Some(50),
180 repetition_penalty: 1.1,
181 ..Default::default()
182 },
183 );
184 presets.insert(
185 "precise".to_string(),
186 SamplingParams {
187 temperature: 0.3,
188 top_p: 0.95,
189 top_k: Some(20),
190 repetition_penalty: 1.05,
191 ..Default::default()
192 },
193 );
194 Self { presets }
195 }
196}
197
198#[derive(
200 Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Default,
201)]
202pub enum Priority {
203 Low = 0,
204 #[default]
205 Normal = 1,
206 High = 2,
207 Critical = 3,
208}
209
210#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
212pub enum FinishReason {
213 Length,
215 Stop,
217 EOS,
219 Cancelled,
221 Error,
223 ContentFilter,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize, Default)]
229pub struct SpecialTokens {
230 pub bos_token: Option<TokenId>,
232 pub eos_token: Option<TokenId>,
234 pub unk_token: Option<TokenId>,
236 pub pad_token: Option<TokenId>,
238 pub sep_token: Option<TokenId>,
240 pub cls_token: Option<TokenId>,
242 pub mask_token: Option<TokenId>,
244}