Skip to main content

ferrum_types/
sampling.rs

1//! Sampling and generation parameters
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::{FerrumError, Result, TokenId};
7
8/// Sampling parameters for generation
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SamplingParams {
11    /// Maximum number of tokens to generate
12    pub max_tokens: usize,
13    /// Temperature for randomness (0.0 = deterministic, higher = more random)
14    pub temperature: f32,
15    /// Nucleus sampling probability threshold
16    pub top_p: f32,
17    /// Top-k sampling - consider only top k tokens
18    pub top_k: Option<usize>,
19    /// Repetition penalty to reduce repetitive text
20    pub repetition_penalty: f32,
21    /// Presence penalty for token diversity
22    pub presence_penalty: f32,
23    /// Frequency penalty based on token frequency
24    pub frequency_penalty: f32,
25    /// Stop sequences to end generation
26    pub stop_sequences: Vec<String>,
27    /// Random seed for reproducible generation
28    pub seed: Option<u64>,
29    /// Minimum probability threshold for tokens
30    pub min_p: Option<f32>,
31    /// Tail free sampling parameter
32    pub tfs: Option<f32>,
33    /// Typical sampling parameter
34    pub typical_p: Option<f32>,
35    /// Mirostat sampling parameters
36    pub mirostat: Option<MirostatParams>,
37    /// Response format constraint (JSON mode, schema-constrained, etc.)
38    #[serde(default)]
39    pub response_format: ResponseFormat,
40}
41
42/// Response format for structured output.
43///
44/// Controls how the model output is constrained:
45/// - `Text`: no constraint (default)
46/// - `JsonObject`: output must be valid JSON
47/// - `JsonSchema`: output must conform to a JSON schema
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
49#[serde(tag = "type", content = "schema")]
50pub enum ResponseFormat {
51    /// No constraint — raw text output.
52    Text,
53    /// Output must be a valid JSON object.
54    JsonObject,
55    /// Output must conform to the given JSON schema (as a JSON string).
56    JsonSchema(String),
57}
58
59impl Default for ResponseFormat {
60    fn default() -> Self {
61        Self::Text
62    }
63}
64
65impl Default for SamplingParams {
66    fn default() -> Self {
67        Self {
68            max_tokens: 512,
69            temperature: 1.0,
70            top_p: 1.0,
71            top_k: None,
72            repetition_penalty: 1.0,
73            presence_penalty: 0.0,
74            frequency_penalty: 0.0,
75            stop_sequences: vec![],
76            seed: None,
77            min_p: None,
78            tfs: None,
79            typical_p: None,
80            mirostat: None,
81            response_format: ResponseFormat::default(),
82        }
83    }
84}
85
86impl SamplingParams {
87    /// Create greedy sampling parameters (deterministic)
88    pub fn greedy() -> Self {
89        Self {
90            temperature: 0.0,
91            top_p: 1.0,
92            top_k: None,
93            ..Default::default()
94        }
95    }
96
97    /// Create default sampling parameters with temperature
98    pub fn with_temperature(temperature: f32) -> Self {
99        Self {
100            temperature,
101            ..Default::default()
102        }
103    }
104
105    /// Validate sampling parameters
106    pub fn validate(&self) -> Result<()> {
107        if self.temperature < 0.0 {
108            return Err(FerrumError::invalid_request(
109                "Temperature must be non-negative".to_string(),
110            ));
111        }
112        if self.top_p <= 0.0 || self.top_p > 1.0 {
113            return Err(FerrumError::invalid_request(
114                "top_p must be in range (0, 1]".to_string(),
115            ));
116        }
117        if let Some(top_k) = self.top_k {
118            if top_k == 0 {
119                return Err(FerrumError::invalid_request(
120                    "top_k must be positive".to_string(),
121                ));
122            }
123        }
124        if self.repetition_penalty <= 0.0 {
125            return Err(FerrumError::invalid_request(
126                "Repetition penalty must be positive".to_string(),
127            ));
128        }
129        if let Some(min_p) = self.min_p {
130            if min_p <= 0.0 || min_p > 1.0 {
131                return Err(FerrumError::invalid_request(
132                    "min_p must be in range (0, 1]".to_string(),
133                ));
134            }
135        }
136        if let Some(tfs) = self.tfs {
137            if tfs <= 0.0 || tfs > 1.0 {
138                return Err(FerrumError::invalid_request(
139                    "tfs must be in range (0, 1]".to_string(),
140                ));
141            }
142        }
143        if let Some(typical_p) = self.typical_p {
144            if typical_p <= 0.0 || typical_p > 1.0 {
145                return Err(FerrumError::invalid_request(
146                    "typical_p must be in range (0, 1]".to_string(),
147                ));
148            }
149        }
150        Ok(())
151    }
152}
153
154/// Mirostat sampling parameters
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct MirostatParams {
157    /// Mirostat mode (1 or 2)
158    pub mode: u8,
159    /// Target entropy
160    pub tau: f32,
161    /// Learning rate
162    pub eta: f32,
163}
164
165/// Sampling presets
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct SamplingPresets {
168    pub presets: HashMap<String, SamplingParams>,
169}
170
171impl Default for SamplingPresets {
172    fn default() -> Self {
173        let mut presets = HashMap::new();
174        presets.insert("greedy".to_string(), SamplingParams::greedy());
175        presets.insert(
176            "creative".to_string(),
177            SamplingParams {
178                temperature: 1.2,
179                top_p: 0.9,
180                top_k: Some(50),
181                repetition_penalty: 1.1,
182                ..Default::default()
183            },
184        );
185        presets.insert(
186            "precise".to_string(),
187            SamplingParams {
188                temperature: 0.3,
189                top_p: 0.95,
190                top_k: Some(20),
191                repetition_penalty: 1.05,
192                ..Default::default()
193            },
194        );
195        Self { presets }
196    }
197}
198
199/// Request priority levels
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
201pub enum Priority {
202    Low = 0,
203    Normal = 1,
204    High = 2,
205    Critical = 3,
206}
207
208impl Default for Priority {
209    fn default() -> Self {
210        Priority::Normal
211    }
212}
213
214/// Reason for completion
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
216pub enum FinishReason {
217    /// Hit maximum token limit
218    Length,
219    /// Hit stop sequence
220    Stop,
221    /// Hit end-of-sequence token
222    EOS,
223    /// Request was cancelled
224    Cancelled,
225    /// Error occurred during generation
226    Error,
227    /// Content filter triggered
228    ContentFilter,
229}
230
231/// Special tokens configuration
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct SpecialTokens {
234    /// Beginning of sequence token
235    pub bos_token: Option<TokenId>,
236    /// End of sequence token
237    pub eos_token: Option<TokenId>,
238    /// Unknown token
239    pub unk_token: Option<TokenId>,
240    /// Padding token
241    pub pad_token: Option<TokenId>,
242    /// Separator token
243    pub sep_token: Option<TokenId>,
244    /// Classification token
245    pub cls_token: Option<TokenId>,
246    /// Mask token
247    pub mask_token: Option<TokenId>,
248}
249
250impl Default for SpecialTokens {
251    fn default() -> Self {
252        Self {
253            bos_token: None,
254            eos_token: None,
255            unk_token: None,
256            pad_token: None,
257            sep_token: None,
258            cls_token: None,
259            mask_token: None,
260        }
261    }
262}