1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Sampling {
9 Greedy,
11 #[default]
13 Random,
14}
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct GenerationOptions {
30 #[serde(skip_serializing_if = "Option::is_none")]
33 pub temperature: Option<f64>,
34
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub sampling: Option<Sampling>,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
41 #[serde(rename = "maximumResponseTokens")]
42 pub max_response_tokens: Option<u32>,
43
44 #[serde(skip_serializing_if = "Option::is_none")]
49 pub seed: Option<u64>,
50}
51
52impl GenerationOptions {
53 pub fn builder() -> GenerationOptionsBuilder {
55 GenerationOptionsBuilder::default()
56 }
57
58 pub fn to_json(&self) -> String {
60 serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
61 }
62}
63
64#[derive(Debug, Default)]
66pub struct GenerationOptionsBuilder {
67 temperature: Option<f64>,
68 sampling: Option<Sampling>,
69 max_response_tokens: Option<u32>,
70 seed: Option<u64>,
71}
72
73impl GenerationOptionsBuilder {
74 pub fn temperature(mut self, temp: f64) -> Self {
83 if (0.0..=2.0).contains(&temp) {
84 self.temperature = Some(temp);
85 }
86 self
87 }
88
89 pub fn try_temperature(mut self, temp: f64) -> Result<Self, crate::Error> {
98 if (0.0..=2.0).contains(&temp) {
99 self.temperature = Some(temp);
100 Ok(self)
101 } else {
102 Err(crate::Error::InvalidInput(format!(
103 "Temperature must be between 0.0 and 2.0, got {temp}"
104 )))
105 }
106 }
107
108 pub fn sampling(mut self, sampling: Sampling) -> Self {
110 self.sampling = Some(sampling);
111 self
112 }
113
114 pub fn max_response_tokens(mut self, tokens: u32) -> Self {
119 if tokens > 0 {
120 self.max_response_tokens = Some(tokens);
121 }
122 self
123 }
124
125 pub fn seed(mut self, seed: u64) -> Self {
130 self.seed = Some(seed);
131 self
132 }
133
134 pub fn build(self) -> GenerationOptions {
136 GenerationOptions {
137 temperature: self.temperature,
138 sampling: self.sampling,
139 max_response_tokens: self.max_response_tokens,
140 seed: self.seed,
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_default_options() {
151 let options = GenerationOptions::default();
152 assert!(options.temperature.is_none());
153 assert!(options.sampling.is_none());
154 assert!(options.max_response_tokens.is_none());
155 }
156
157 #[test]
158 fn test_builder() {
159 let options = GenerationOptions::builder()
160 .temperature(0.7)
161 .sampling(Sampling::Random)
162 .max_response_tokens(500)
163 .seed(42)
164 .build();
165
166 assert_eq!(options.temperature, Some(0.7));
167 assert_eq!(options.sampling, Some(Sampling::Random));
168 assert_eq!(options.max_response_tokens, Some(500));
169 assert_eq!(options.seed, Some(42));
170 }
171
172 #[test]
173 fn test_temperature_bounds() {
174 let options = GenerationOptions::builder().temperature(1.5).build();
176 assert_eq!(options.temperature, Some(1.5));
177
178 let options = GenerationOptions::builder().temperature(-0.5).build();
180 assert!(options.temperature.is_none());
181
182 let options = GenerationOptions::builder().temperature(3.0).build();
184 assert!(options.temperature.is_none());
185 }
186
187 #[test]
188 fn test_json_serialization() {
189 let options = GenerationOptions::builder()
190 .temperature(0.7)
191 .max_response_tokens(100)
192 .build();
193
194 let json = options.to_json();
195 assert!(json.contains("temperature"));
196 assert!(json.contains("0.7"));
197 }
198}