Skip to main content

ferrum_types/
sampling.rs

1//! Sampling and generation parameters
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::{FerrumError, Result, TokenId};
7
8/// Sampling parameters for generation
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SamplingParams {
11    /// Maximum number of tokens to generate
12    pub max_tokens: usize,
13    /// Temperature for randomness (0.0 = deterministic, higher = more random)
14    pub temperature: f32,
15    /// Nucleus sampling probability threshold
16    pub top_p: f32,
17    /// Top-k sampling - consider only top k tokens
18    pub top_k: Option<usize>,
19    /// Repetition penalty to reduce repetitive text
20    pub repetition_penalty: f32,
21    /// Presence penalty for token diversity
22    pub presence_penalty: f32,
23    /// Frequency penalty based on token frequency
24    pub frequency_penalty: f32,
25    /// Stop sequences to end generation
26    pub stop_sequences: Vec<String>,
27    /// Random seed for reproducible generation
28    pub seed: Option<u64>,
29    /// Minimum probability threshold for tokens
30    pub min_p: Option<f32>,
31    /// Tail free sampling parameter
32    pub tfs: Option<f32>,
33    /// Typical sampling parameter
34    pub typical_p: Option<f32>,
35    /// Mirostat sampling parameters
36    pub mirostat: Option<MirostatParams>,
37    /// Response format constraint (JSON mode, schema-constrained, etc.)
38    #[serde(default)]
39    pub response_format: ResponseFormat,
40}
41
42/// Response format for structured output. Mirrors OpenAI's
43/// `response_format` API — no proprietary extensions.
44///
45/// - `Text`: no constraint (default)
46/// - `JsonObject`: output must be a valid JSON object (matches OpenAI's
47///   `{"type": "json_object"}`)
48/// - `JsonSchema(schema)`: output must conform to the given JSON Schema
49///   (matches OpenAI's `{"type": "json_schema", "json_schema": {...}}`).
50///   Internally compiled to a regex FSM for per-token hard masking.
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52#[serde(tag = "type", content = "schema")]
53#[derive(Default)]
54pub enum ResponseFormat {
55    /// No constraint — raw text output.
56    #[default]
57    Text,
58    /// Output must be a valid JSON object.
59    JsonObject,
60    /// Output must conform to the given JSON schema (as a JSON string).
61    JsonSchema(String),
62}
63
64impl Default for SamplingParams {
65    fn default() -> Self {
66        Self {
67            max_tokens: 512,
68            temperature: 1.0,
69            top_p: 1.0,
70            top_k: None,
71            repetition_penalty: 1.0,
72            presence_penalty: 0.0,
73            frequency_penalty: 0.0,
74            stop_sequences: vec![],
75            seed: None,
76            min_p: None,
77            tfs: None,
78            typical_p: None,
79            mirostat: None,
80            response_format: ResponseFormat::default(),
81        }
82    }
83}
84
85impl SamplingParams {
86    /// Create greedy sampling parameters (deterministic)
87    pub fn greedy() -> Self {
88        Self {
89            temperature: 0.0,
90            top_p: 1.0,
91            top_k: None,
92            ..Default::default()
93        }
94    }
95
96    /// Create default sampling parameters with temperature
97    pub fn with_temperature(temperature: f32) -> Self {
98        Self {
99            temperature,
100            ..Default::default()
101        }
102    }
103
104    /// Validate sampling parameters
105    pub fn validate(&self) -> Result<()> {
106        if self.temperature < 0.0 {
107            return Err(FerrumError::invalid_request(
108                "Temperature must be non-negative".to_string(),
109            ));
110        }
111        if self.top_p <= 0.0 || self.top_p > 1.0 {
112            return Err(FerrumError::invalid_request(
113                "top_p must be in range (0, 1]".to_string(),
114            ));
115        }
116        if let Some(top_k) = self.top_k {
117            if top_k == 0 {
118                return Err(FerrumError::invalid_request(
119                    "top_k must be positive".to_string(),
120                ));
121            }
122        }
123        if self.repetition_penalty <= 0.0 {
124            return Err(FerrumError::invalid_request(
125                "Repetition penalty must be positive".to_string(),
126            ));
127        }
128        if let Some(min_p) = self.min_p {
129            if min_p <= 0.0 || min_p > 1.0 {
130                return Err(FerrumError::invalid_request(
131                    "min_p must be in range (0, 1]".to_string(),
132                ));
133            }
134        }
135        if let Some(tfs) = self.tfs {
136            if tfs <= 0.0 || tfs > 1.0 {
137                return Err(FerrumError::invalid_request(
138                    "tfs must be in range (0, 1]".to_string(),
139                ));
140            }
141        }
142        if let Some(typical_p) = self.typical_p {
143            if typical_p <= 0.0 || typical_p > 1.0 {
144                return Err(FerrumError::invalid_request(
145                    "typical_p must be in range (0, 1]".to_string(),
146                ));
147            }
148        }
149        Ok(())
150    }
151}
152
153/// Mirostat sampling parameters
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct MirostatParams {
156    /// Mirostat mode (1 or 2)
157    pub mode: u8,
158    /// Target entropy
159    pub tau: f32,
160    /// Learning rate
161    pub eta: f32,
162}
163
164/// Sampling presets
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct SamplingPresets {
167    pub presets: HashMap<String, SamplingParams>,
168}
169
170impl Default for SamplingPresets {
171    fn default() -> Self {
172        let mut presets = HashMap::new();
173        presets.insert("greedy".to_string(), SamplingParams::greedy());
174        presets.insert(
175            "creative".to_string(),
176            SamplingParams {
177                temperature: 1.2,
178                top_p: 0.9,
179                top_k: Some(50),
180                repetition_penalty: 1.1,
181                ..Default::default()
182            },
183        );
184        presets.insert(
185            "precise".to_string(),
186            SamplingParams {
187                temperature: 0.3,
188                top_p: 0.95,
189                top_k: Some(20),
190                repetition_penalty: 1.05,
191                ..Default::default()
192            },
193        );
194        Self { presets }
195    }
196}
197
198/// Request priority levels
199#[derive(
200    Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Default,
201)]
202pub enum Priority {
203    Low = 0,
204    #[default]
205    Normal = 1,
206    High = 2,
207    Critical = 3,
208}
209
210/// Reason for completion
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
212pub enum FinishReason {
213    /// Hit maximum token limit
214    Length,
215    /// Hit stop sequence
216    Stop,
217    /// Hit end-of-sequence token
218    EOS,
219    /// Request was cancelled
220    Cancelled,
221    /// Error occurred during generation
222    Error,
223    /// Content filter triggered
224    ContentFilter,
225}
226
227/// Special tokens configuration
228#[derive(Debug, Clone, Serialize, Deserialize, Default)]
229pub struct SpecialTokens {
230    /// Beginning of sequence token
231    pub bos_token: Option<TokenId>,
232    /// End of sequence token
233    pub eos_token: Option<TokenId>,
234    /// Unknown token
235    pub unk_token: Option<TokenId>,
236    /// Padding token
237    pub pad_token: Option<TokenId>,
238    /// Separator token
239    pub sep_token: Option<TokenId>,
240    /// Classification token
241    pub cls_token: Option<TokenId>,
242    /// Mask token
243    pub mask_token: Option<TokenId>,
244}