use crate::ffi;
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
#[derive(Default)]
pub enum SamplingMode {
#[default]
Default,
Greedy,
TopK(u32),
TopP(f64),
}
#[derive(Debug, Clone, Copy, Default)]
pub struct GenerationOptions {
temperature: Option<f64>,
max_tokens: Option<u32>,
sampling: SamplingMode,
}
impl GenerationOptions {
#[must_use]
pub const fn new() -> Self {
Self {
temperature: None,
max_tokens: None,
sampling: SamplingMode::Default,
}
}
#[must_use]
pub const fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
#[must_use]
pub const fn with_maximum_response_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
#[must_use]
pub const fn with_sampling(mut self, mode: SamplingMode) -> Self {
self.sampling = mode;
self
}
pub(crate) fn to_ffi(self) -> ffi::FFIGenerationOptions {
let (mode_code, top_k, top_p) = match self.sampling {
SamplingMode::Default => (0, 0, 0.0),
SamplingMode::Greedy => (1, 0, 0.0),
SamplingMode::TopK(k) => (2, i32::try_from(k).unwrap_or(i32::MAX), 0.0),
SamplingMode::TopP(p) => (3, 0, p),
};
ffi::FFIGenerationOptions {
temperature: self.temperature.unwrap_or(f64::NAN),
maximum_response_tokens: self
.max_tokens
.map_or(0, |t| i32::try_from(t).unwrap_or(i32::MAX)),
sampling_mode: mode_code,
top_k,
top_p,
}
}
}