use crate::{Model, MullamaError, SamplerChain, SamplerParams};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct SamplerBuilder {
temperature: f32,
top_k: i32,
top_p: f32,
min_p: f32,
penalty_repeat: f32,
penalty_freq: f32,
penalty_present: f32,
penalty_last_n: i32,
seed: u32,
}
impl SamplerBuilder {
pub fn new() -> Self {
Self {
temperature: 0.8,
top_k: 40,
top_p: 0.95,
min_p: 0.05,
penalty_repeat: 1.1,
penalty_freq: 0.0,
penalty_present: 0.0,
penalty_last_n: 64,
seed: 0,
}
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = temp;
self
}
pub fn top_k(mut self, k: i32) -> Self {
self.top_k = k;
self
}
pub fn nucleus(mut self, p: f32) -> Self {
self.top_p = p;
self
}
pub fn min_probability(mut self, min_p: f32) -> Self {
self.min_p = min_p;
self
}
pub fn penalties<F>(mut self, config: F) -> Self
where
F: FnOnce(PenaltyBuilder) -> PenaltyBuilder,
{
let penalty_builder = PenaltyBuilder {
repeat: self.penalty_repeat,
freq: self.penalty_freq,
present: self.penalty_present,
last_n: self.penalty_last_n,
};
let configured = config(penalty_builder);
self.penalty_repeat = configured.repeat;
self.penalty_freq = configured.freq;
self.penalty_present = configured.present;
self.penalty_last_n = configured.last_n;
self
}
pub fn seed(mut self, seed: u32) -> Self {
self.seed = seed;
self
}
pub fn preset<F>(self, preset: F) -> Self
where
F: FnOnce(Self) -> Self,
{
preset(self)
}
pub fn build(self, model: Arc<Model>) -> Result<SamplerChain, MullamaError> {
let params = SamplerParams {
temperature: self.temperature,
top_k: self.top_k,
top_p: self.top_p,
min_p: self.min_p,
penalty_repeat: self.penalty_repeat,
penalty_freq: self.penalty_freq,
penalty_present: self.penalty_present,
penalty_last_n: self.penalty_last_n,
seed: self.seed,
..Default::default()
};
params.build_chain(model)
}
}
impl Default for SamplerBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PenaltyBuilder {
repeat: f32,
freq: f32,
present: f32,
last_n: i32,
}
impl PenaltyBuilder {
pub fn repetition(mut self, penalty: f32) -> Self {
self.repeat = penalty;
self
}
pub fn frequency(mut self, penalty: f32) -> Self {
self.freq = penalty;
self
}
pub fn presence(mut self, penalty: f32) -> Self {
self.present = penalty;
self
}
pub fn lookback(mut self, tokens: i32) -> Self {
self.last_n = tokens;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::presets;
#[test]
fn test_sampler_builder() {
let builder = SamplerBuilder::new()
.temperature(0.8)
.top_k(50)
.nucleus(0.95);
assert_eq!(builder.temperature, 0.8);
assert_eq!(builder.top_k, 50);
assert_eq!(builder.top_p, 0.95);
}
#[test]
fn test_penalty_builder() {
let builder = SamplerBuilder::new()
.penalties(|p| p.repetition(1.2).frequency(0.1).presence(0.1).lookback(128));
assert_eq!(builder.penalty_repeat, 1.2);
assert_eq!(builder.penalty_freq, 0.1);
assert_eq!(builder.penalty_present, 0.1);
assert_eq!(builder.penalty_last_n, 128);
}
#[test]
fn test_presets() {
let creative = SamplerBuilder::new().preset(presets::creative_sampling);
assert!(creative.temperature > 0.8);
assert!(creative.top_k > 50);
let precise = SamplerBuilder::new().preset(presets::precise_sampling);
assert!(precise.temperature < 0.3);
assert!(precise.top_k < 20);
}
}