oxibonsai_runtime/
presets.rs1use crate::sampling::SamplingParams;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SamplingPreset {
26 Balanced,
28 Creative,
30 Precise,
32 Greedy,
34 Conversational,
36}
37
38static ALL_PRESETS: [SamplingPreset; 5] = [
40 SamplingPreset::Balanced,
41 SamplingPreset::Creative,
42 SamplingPreset::Precise,
43 SamplingPreset::Greedy,
44 SamplingPreset::Conversational,
45];
46
47impl SamplingPreset {
48 pub fn params(&self) -> SamplingParams {
50 match self {
51 Self::Balanced => SamplingParams {
52 temperature: 0.7,
53 top_k: 40,
54 top_p: 0.9,
55 repetition_penalty: 1.1,
56 ..SamplingParams::default()
57 },
58 Self::Creative => SamplingParams {
59 temperature: 1.0,
60 top_k: 0,
61 top_p: 0.95,
62 repetition_penalty: 1.0,
63 ..SamplingParams::default()
64 },
65 Self::Precise => SamplingParams {
66 temperature: 0.1,
67 top_k: 10,
68 top_p: 0.5,
69 repetition_penalty: 1.2,
70 ..SamplingParams::default()
71 },
72 Self::Greedy => SamplingParams {
73 temperature: 0.0,
74 top_k: 0,
75 top_p: 1.0,
76 repetition_penalty: 1.0,
77 ..SamplingParams::default()
78 },
79 Self::Conversational => SamplingParams {
80 temperature: 0.8,
81 top_k: 50,
82 top_p: 0.9,
83 repetition_penalty: 1.1,
84 ..SamplingParams::default()
85 },
86 }
87 }
88
89 pub fn name(&self) -> &'static str {
91 match self {
92 Self::Balanced => "Balanced",
93 Self::Creative => "Creative",
94 Self::Precise => "Precise",
95 Self::Greedy => "Greedy",
96 Self::Conversational => "Conversational",
97 }
98 }
99
100 pub fn description(&self) -> &'static str {
102 match self {
103 Self::Balanced => "General-purpose: moderate creativity with good coherence",
104 Self::Creative => "Creative writing: high diversity and novel outputs",
105 Self::Precise => "Factual/code: low randomness for accurate outputs",
106 Self::Greedy => "Deterministic: always picks the most likely token",
107 Self::Conversational => "Chat: natural-sounding conversation with personality",
108 }
109 }
110
111 pub fn all() -> &'static [SamplingPreset] {
113 &ALL_PRESETS
114 }
115}
116
117impl From<SamplingPreset> for SamplingParams {
118 fn from(preset: SamplingPreset) -> Self {
119 preset.params()
120 }
121}
122
123impl std::fmt::Display for SamplingPreset {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 write!(f, "{}", self.name())
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn balanced_preset() {
135 let params = SamplingPreset::Balanced.params();
136 assert!((params.temperature - 0.7).abs() < f32::EPSILON);
137 assert_eq!(params.top_k, 40);
138 assert!((params.top_p - 0.9).abs() < f32::EPSILON);
139 assert!((params.repetition_penalty - 1.1).abs() < f32::EPSILON);
140 }
141
142 #[test]
143 fn creative_preset() {
144 let params = SamplingPreset::Creative.params();
145 assert!((params.temperature - 1.0).abs() < f32::EPSILON);
146 assert_eq!(params.top_k, 0); assert!((params.top_p - 0.95).abs() < f32::EPSILON);
148 }
149
150 #[test]
151 fn precise_preset() {
152 let params = SamplingPreset::Precise.params();
153 assert!((params.temperature - 0.1).abs() < f32::EPSILON);
154 assert!((params.repetition_penalty - 1.2).abs() < f32::EPSILON);
155 }
156
157 #[test]
158 fn greedy_preset() {
159 let params = SamplingPreset::Greedy.params();
160 assert!(params.temperature < f32::EPSILON);
161 }
162
163 #[test]
164 fn conversational_preset() {
165 let params = SamplingPreset::Conversational.params();
166 assert!((params.temperature - 0.8).abs() < f32::EPSILON);
167 assert_eq!(params.top_k, 50);
168 }
169
170 #[test]
171 fn all_presets_covers_all_variants() {
172 let all = SamplingPreset::all();
173 assert_eq!(all.len(), 5);
174 assert!(all.contains(&SamplingPreset::Balanced));
175 assert!(all.contains(&SamplingPreset::Creative));
176 assert!(all.contains(&SamplingPreset::Precise));
177 assert!(all.contains(&SamplingPreset::Greedy));
178 assert!(all.contains(&SamplingPreset::Conversational));
179 }
180
181 #[test]
182 fn preset_names_non_empty() {
183 for preset in SamplingPreset::all() {
184 assert!(!preset.name().is_empty());
185 assert!(!preset.description().is_empty());
186 }
187 }
188
189 #[test]
190 fn preset_into_sampling_params() {
191 let params: SamplingParams = SamplingPreset::Balanced.into();
192 assert!((params.temperature - 0.7).abs() < f32::EPSILON);
193 }
194
195 #[test]
196 fn preset_display() {
197 assert_eq!(format!("{}", SamplingPreset::Balanced), "Balanced");
198 assert_eq!(format!("{}", SamplingPreset::Greedy), "Greedy");
199 }
200
201 #[test]
202 fn all_presets_produce_valid_params() {
203 for preset in SamplingPreset::all() {
204 let params = preset.params();
205 assert!(params.temperature >= 0.0);
206 assert!(params.top_p >= 0.0 && params.top_p <= 1.0);
207 assert!(params.repetition_penalty >= 1.0);
208 }
209 }
210
211 #[test]
212 fn preset_clone_and_copy() {
213 let p = SamplingPreset::Creative;
214 let p2 = p;
215 let p3 = p;
216 assert_eq!(p, p2);
217 assert_eq!(p, p3);
218 }
219}