use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Sampling {
Greedy,
#[default]
Random,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GenerationOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<Sampling>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "maximumResponseTokens")]
pub max_response_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
}
impl GenerationOptions {
pub fn builder() -> GenerationOptionsBuilder {
GenerationOptionsBuilder::default()
}
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
}
}
#[derive(Debug, Default)]
pub struct GenerationOptionsBuilder {
temperature: Option<f64>,
sampling: Option<Sampling>,
max_response_tokens: Option<u32>,
seed: Option<u64>,
}
impl GenerationOptionsBuilder {
pub fn temperature(mut self, temp: f64) -> Self {
if (0.0..=2.0).contains(&temp) {
self.temperature = Some(temp);
}
self
}
pub fn try_temperature(mut self, temp: f64) -> Result<Self, crate::Error> {
if (0.0..=2.0).contains(&temp) {
self.temperature = Some(temp);
Ok(self)
} else {
Err(crate::Error::InvalidInput(format!(
"Temperature must be between 0.0 and 2.0, got {temp}"
)))
}
}
pub fn sampling(mut self, sampling: Sampling) -> Self {
self.sampling = Some(sampling);
self
}
pub fn max_response_tokens(mut self, tokens: u32) -> Self {
if tokens > 0 {
self.max_response_tokens = Some(tokens);
}
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn build(self) -> GenerationOptions {
GenerationOptions {
temperature: self.temperature,
sampling: self.sampling,
max_response_tokens: self.max_response_tokens,
seed: self.seed,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_options() {
let options = GenerationOptions::default();
assert!(options.temperature.is_none());
assert!(options.sampling.is_none());
assert!(options.max_response_tokens.is_none());
}
#[test]
fn test_builder() {
let options = GenerationOptions::builder()
.temperature(0.7)
.sampling(Sampling::Random)
.max_response_tokens(500)
.seed(42)
.build();
assert_eq!(options.temperature, Some(0.7));
assert_eq!(options.sampling, Some(Sampling::Random));
assert_eq!(options.max_response_tokens, Some(500));
assert_eq!(options.seed, Some(42));
}
#[test]
fn test_temperature_bounds() {
let options = GenerationOptions::builder().temperature(1.5).build();
assert_eq!(options.temperature, Some(1.5));
let options = GenerationOptions::builder().temperature(-0.5).build();
assert!(options.temperature.is_none());
let options = GenerationOptions::builder().temperature(3.0).build();
assert!(options.temperature.is_none());
}
#[test]
fn test_json_serialization() {
let options = GenerationOptions::builder()
.temperature(0.7)
.max_response_tokens(100)
.build();
let json = options.to_json();
assert!(json.contains("temperature"));
assert!(json.contains("0.7"));
}
}