Skip to main content

oxibonsai_runtime/
presets.rs

1//! Predefined sampling parameter presets for common use cases.
2
3use crate::sampling::SamplingParams;
4
5/// Named preset for sampling behavior.
6///
7/// # Example
8///
9/// ```
10/// use oxibonsai_runtime::presets::SamplingPreset;
11/// use oxibonsai_runtime::sampling::SamplingParams;
12///
13/// // Use a preset directly
14/// let params: SamplingParams = SamplingPreset::Balanced.into();
15/// assert!((params.temperature - 0.7).abs() < f32::EPSILON);
16///
17/// // Iterate over all presets
18/// for preset in SamplingPreset::all() {
19///     let p = preset.params();
20///     assert!(p.temperature >= 0.0);
21///     assert!(p.top_p >= 0.0 && p.top_p <= 1.0);
22/// }
23/// ```
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SamplingPreset {
26    /// General use: temp=0.7, top_p=0.9, top_k=40, rep_pen=1.1
27    Balanced,
28    /// Creative writing: temp=1.0, top_p=0.95, top_k=0 (disabled), rep_pen=1.0
29    Creative,
30    /// Factual/code generation: temp=0.1, top_p=0.5, top_k=10, rep_pen=1.2
31    Precise,
32    /// Deterministic output: temp=0.0, greedy decoding
33    Greedy,
34    /// Chat/conversation: temp=0.8, top_p=0.9, top_k=50, rep_pen=1.1
35    Conversational,
36}
37
38/// All available presets in a static array.
39static ALL_PRESETS: [SamplingPreset; 5] = [
40    SamplingPreset::Balanced,
41    SamplingPreset::Creative,
42    SamplingPreset::Precise,
43    SamplingPreset::Greedy,
44    SamplingPreset::Conversational,
45];
46
47impl SamplingPreset {
48    /// Get the sampling parameters for this preset.
49    pub fn params(&self) -> SamplingParams {
50        match self {
51            Self::Balanced => SamplingParams {
52                temperature: 0.7,
53                top_k: 40,
54                top_p: 0.9,
55                repetition_penalty: 1.1,
56                ..SamplingParams::default()
57            },
58            Self::Creative => SamplingParams {
59                temperature: 1.0,
60                top_k: 0,
61                top_p: 0.95,
62                repetition_penalty: 1.0,
63                ..SamplingParams::default()
64            },
65            Self::Precise => SamplingParams {
66                temperature: 0.1,
67                top_k: 10,
68                top_p: 0.5,
69                repetition_penalty: 1.2,
70                ..SamplingParams::default()
71            },
72            Self::Greedy => SamplingParams {
73                temperature: 0.0,
74                top_k: 0,
75                top_p: 1.0,
76                repetition_penalty: 1.0,
77                ..SamplingParams::default()
78            },
79            Self::Conversational => SamplingParams {
80                temperature: 0.8,
81                top_k: 50,
82                top_p: 0.9,
83                repetition_penalty: 1.1,
84                ..SamplingParams::default()
85            },
86        }
87    }
88
89    /// Human-readable name of this preset.
90    pub fn name(&self) -> &'static str {
91        match self {
92            Self::Balanced => "Balanced",
93            Self::Creative => "Creative",
94            Self::Precise => "Precise",
95            Self::Greedy => "Greedy",
96            Self::Conversational => "Conversational",
97        }
98    }
99
100    /// Description of this preset's intended use case.
101    pub fn description(&self) -> &'static str {
102        match self {
103            Self::Balanced => "General-purpose: moderate creativity with good coherence",
104            Self::Creative => "Creative writing: high diversity and novel outputs",
105            Self::Precise => "Factual/code: low randomness for accurate outputs",
106            Self::Greedy => "Deterministic: always picks the most likely token",
107            Self::Conversational => "Chat: natural-sounding conversation with personality",
108        }
109    }
110
111    /// Get all available presets.
112    pub fn all() -> &'static [SamplingPreset] {
113        &ALL_PRESETS
114    }
115}
116
117impl From<SamplingPreset> for SamplingParams {
118    fn from(preset: SamplingPreset) -> Self {
119        preset.params()
120    }
121}
122
123impl std::fmt::Display for SamplingPreset {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        write!(f, "{}", self.name())
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn balanced_preset() {
135        let params = SamplingPreset::Balanced.params();
136        assert!((params.temperature - 0.7).abs() < f32::EPSILON);
137        assert_eq!(params.top_k, 40);
138        assert!((params.top_p - 0.9).abs() < f32::EPSILON);
139        assert!((params.repetition_penalty - 1.1).abs() < f32::EPSILON);
140    }
141
142    #[test]
143    fn creative_preset() {
144        let params = SamplingPreset::Creative.params();
145        assert!((params.temperature - 1.0).abs() < f32::EPSILON);
146        assert_eq!(params.top_k, 0); // disabled
147        assert!((params.top_p - 0.95).abs() < f32::EPSILON);
148    }
149
150    #[test]
151    fn precise_preset() {
152        let params = SamplingPreset::Precise.params();
153        assert!((params.temperature - 0.1).abs() < f32::EPSILON);
154        assert!((params.repetition_penalty - 1.2).abs() < f32::EPSILON);
155    }
156
157    #[test]
158    fn greedy_preset() {
159        let params = SamplingPreset::Greedy.params();
160        assert!(params.temperature < f32::EPSILON);
161    }
162
163    #[test]
164    fn conversational_preset() {
165        let params = SamplingPreset::Conversational.params();
166        assert!((params.temperature - 0.8).abs() < f32::EPSILON);
167        assert_eq!(params.top_k, 50);
168    }
169
170    #[test]
171    fn all_presets_covers_all_variants() {
172        let all = SamplingPreset::all();
173        assert_eq!(all.len(), 5);
174        assert!(all.contains(&SamplingPreset::Balanced));
175        assert!(all.contains(&SamplingPreset::Creative));
176        assert!(all.contains(&SamplingPreset::Precise));
177        assert!(all.contains(&SamplingPreset::Greedy));
178        assert!(all.contains(&SamplingPreset::Conversational));
179    }
180
181    #[test]
182    fn preset_names_non_empty() {
183        for preset in SamplingPreset::all() {
184            assert!(!preset.name().is_empty());
185            assert!(!preset.description().is_empty());
186        }
187    }
188
189    #[test]
190    fn preset_into_sampling_params() {
191        let params: SamplingParams = SamplingPreset::Balanced.into();
192        assert!((params.temperature - 0.7).abs() < f32::EPSILON);
193    }
194
195    #[test]
196    fn preset_display() {
197        assert_eq!(format!("{}", SamplingPreset::Balanced), "Balanced");
198        assert_eq!(format!("{}", SamplingPreset::Greedy), "Greedy");
199    }
200
201    #[test]
202    fn all_presets_produce_valid_params() {
203        for preset in SamplingPreset::all() {
204            let params = preset.params();
205            assert!(params.temperature >= 0.0);
206            assert!(params.top_p >= 0.0 && params.top_p <= 1.0);
207            assert!(params.repetition_penalty >= 1.0);
208        }
209    }
210
211    #[test]
212    fn preset_clone_and_copy() {
213        let p = SamplingPreset::Creative;
214        let p2 = p;
215        let p3 = p;
216        assert_eq!(p, p2);
217        assert_eq!(p, p3);
218    }
219}