use crate::sampling::SamplingParams;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SamplingPreset {
Balanced,
Creative,
Precise,
Greedy,
Conversational,
}
static ALL_PRESETS: [SamplingPreset; 5] = [
SamplingPreset::Balanced,
SamplingPreset::Creative,
SamplingPreset::Precise,
SamplingPreset::Greedy,
SamplingPreset::Conversational,
];
impl SamplingPreset {
pub fn params(&self) -> SamplingParams {
match self {
Self::Balanced => SamplingParams {
temperature: 0.7,
top_k: 40,
top_p: 0.9,
repetition_penalty: 1.1,
..SamplingParams::default()
},
Self::Creative => SamplingParams {
temperature: 1.0,
top_k: 0,
top_p: 0.95,
repetition_penalty: 1.0,
..SamplingParams::default()
},
Self::Precise => SamplingParams {
temperature: 0.1,
top_k: 10,
top_p: 0.5,
repetition_penalty: 1.2,
..SamplingParams::default()
},
Self::Greedy => SamplingParams {
temperature: 0.0,
top_k: 0,
top_p: 1.0,
repetition_penalty: 1.0,
..SamplingParams::default()
},
Self::Conversational => SamplingParams {
temperature: 0.8,
top_k: 50,
top_p: 0.9,
repetition_penalty: 1.1,
..SamplingParams::default()
},
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Balanced => "Balanced",
Self::Creative => "Creative",
Self::Precise => "Precise",
Self::Greedy => "Greedy",
Self::Conversational => "Conversational",
}
}
pub fn description(&self) -> &'static str {
match self {
Self::Balanced => "General-purpose: moderate creativity with good coherence",
Self::Creative => "Creative writing: high diversity and novel outputs",
Self::Precise => "Factual/code: low randomness for accurate outputs",
Self::Greedy => "Deterministic: always picks the most likely token",
Self::Conversational => "Chat: natural-sounding conversation with personality",
}
}
pub fn all() -> &'static [SamplingPreset] {
&ALL_PRESETS
}
}
impl From<SamplingPreset> for SamplingParams {
fn from(preset: SamplingPreset) -> Self {
preset.params()
}
}
impl std::fmt::Display for SamplingPreset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn balanced_preset() {
let params = SamplingPreset::Balanced.params();
assert!((params.temperature - 0.7).abs() < f32::EPSILON);
assert_eq!(params.top_k, 40);
assert!((params.top_p - 0.9).abs() < f32::EPSILON);
assert!((params.repetition_penalty - 1.1).abs() < f32::EPSILON);
}
#[test]
fn creative_preset() {
let params = SamplingPreset::Creative.params();
assert!((params.temperature - 1.0).abs() < f32::EPSILON);
assert_eq!(params.top_k, 0); assert!((params.top_p - 0.95).abs() < f32::EPSILON);
}
#[test]
fn precise_preset() {
let params = SamplingPreset::Precise.params();
assert!((params.temperature - 0.1).abs() < f32::EPSILON);
assert!((params.repetition_penalty - 1.2).abs() < f32::EPSILON);
}
#[test]
fn greedy_preset() {
let params = SamplingPreset::Greedy.params();
assert!(params.temperature < f32::EPSILON);
}
#[test]
fn conversational_preset() {
let params = SamplingPreset::Conversational.params();
assert!((params.temperature - 0.8).abs() < f32::EPSILON);
assert_eq!(params.top_k, 50);
}
#[test]
fn all_presets_covers_all_variants() {
let all = SamplingPreset::all();
assert_eq!(all.len(), 5);
assert!(all.contains(&SamplingPreset::Balanced));
assert!(all.contains(&SamplingPreset::Creative));
assert!(all.contains(&SamplingPreset::Precise));
assert!(all.contains(&SamplingPreset::Greedy));
assert!(all.contains(&SamplingPreset::Conversational));
}
#[test]
fn preset_names_non_empty() {
for preset in SamplingPreset::all() {
assert!(!preset.name().is_empty());
assert!(!preset.description().is_empty());
}
}
#[test]
fn preset_into_sampling_params() {
let params: SamplingParams = SamplingPreset::Balanced.into();
assert!((params.temperature - 0.7).abs() < f32::EPSILON);
}
#[test]
fn preset_display() {
assert_eq!(format!("{}", SamplingPreset::Balanced), "Balanced");
assert_eq!(format!("{}", SamplingPreset::Greedy), "Greedy");
}
#[test]
fn all_presets_produce_valid_params() {
for preset in SamplingPreset::all() {
let params = preset.params();
assert!(params.temperature >= 0.0);
assert!(params.top_p >= 0.0 && params.top_p <= 1.0);
assert!(params.repetition_penalty >= 1.0);
}
}
#[test]
fn preset_clone_and_copy() {
let p = SamplingPreset::Creative;
let p2 = p;
let p3 = p;
assert_eq!(p, p2);
assert_eq!(p, p3);
}
}