1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
31pub struct GenerationOptions {
32 pub max_tokens: Option<u32>,
34 pub temperature: Option<f32>,
36 pub top_p: Option<f32>,
38 pub top_k: Option<u32>,
40 pub frequency_penalty: Option<f32>,
42 pub presence_penalty: Option<f32>,
44 pub stop: Option<Vec<String>>,
46 pub seed: Option<u64>,
48}
49
50impl GenerationOptions {
51 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
58 self.max_tokens = Some(max_tokens);
59 self
60 }
61
62 pub fn with_temperature(mut self, temperature: f32) -> Self {
64 self.temperature = Some(temperature);
65 self
66 }
67
68 pub fn with_top_p(mut self, top_p: f32) -> Self {
70 self.top_p = Some(top_p);
71 self
72 }
73
74 pub fn with_top_k(mut self, top_k: u32) -> Self {
76 self.top_k = Some(top_k);
77 self
78 }
79
80 pub fn with_frequency_penalty(mut self, penalty: f32) -> Self {
82 self.frequency_penalty = Some(penalty);
83 self
84 }
85
86 pub fn with_presence_penalty(mut self, penalty: f32) -> Self {
88 self.presence_penalty = Some(penalty);
89 self
90 }
91
92 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
94 self.stop = Some(stop);
95 self
96 }
97
98 pub fn with_seed(mut self, seed: u64) -> Self {
100 self.seed = Some(seed);
101 self
102 }
103
104 pub fn merge(&self, other: &GenerationOptions) -> GenerationOptions {
108 GenerationOptions {
109 max_tokens: other.max_tokens.or(self.max_tokens),
110 temperature: other.temperature.or(self.temperature),
111 top_p: other.top_p.or(self.top_p),
112 top_k: other.top_k.or(self.top_k),
113 frequency_penalty: other.frequency_penalty.or(self.frequency_penalty),
114 presence_penalty: other.presence_penalty.or(self.presence_penalty),
115 stop: other.stop.clone().or_else(|| self.stop.clone()),
116 seed: other.seed.or(self.seed),
117 }
118 }
119}
120
121#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_default_is_all_none() {
131 let opts = GenerationOptions::new();
132 assert_eq!(opts.max_tokens, None);
133 assert_eq!(opts.temperature, None);
134 assert_eq!(opts.top_p, None);
135 assert_eq!(opts.top_k, None);
136 assert_eq!(opts.frequency_penalty, None);
137 assert_eq!(opts.presence_penalty, None);
138 assert_eq!(opts.stop, None);
139 assert_eq!(opts.seed, None);
140 }
141
142 #[test]
143 fn test_builder_methods() {
144 let opts = GenerationOptions::new()
145 .with_max_tokens(4096)
146 .with_temperature(0.7)
147 .with_top_p(0.9)
148 .with_top_k(40)
149 .with_frequency_penalty(0.5)
150 .with_presence_penalty(-0.5)
151 .with_stop(vec!["END".into()])
152 .with_seed(42);
153
154 assert_eq!(opts.max_tokens, Some(4096));
155 assert_eq!(opts.temperature, Some(0.7));
156 assert_eq!(opts.top_p, Some(0.9));
157 assert_eq!(opts.top_k, Some(40));
158 assert_eq!(opts.frequency_penalty, Some(0.5));
159 assert_eq!(opts.presence_penalty, Some(-0.5));
160 assert_eq!(opts.stop, Some(vec!["END".to_string()]));
161 assert_eq!(opts.seed, Some(42));
162 }
163
164 #[test]
165 fn test_merge_other_overrides_self() {
166 let base = GenerationOptions::new()
167 .with_max_tokens(1024)
168 .with_temperature(0.5);
169
170 let override_opts = GenerationOptions::new()
171 .with_max_tokens(4096)
172 .with_top_p(0.9);
173
174 let merged = base.merge(&override_opts);
175 assert_eq!(merged.max_tokens, Some(4096)); assert_eq!(merged.temperature, Some(0.5)); assert_eq!(merged.top_p, Some(0.9)); }
179
180 #[test]
181 fn test_merge_none_does_not_override() {
182 let base = GenerationOptions::new()
183 .with_max_tokens(2048)
184 .with_seed(123);
185
186 let empty = GenerationOptions::new();
187 let merged = base.merge(&empty);
188 assert_eq!(merged.max_tokens, Some(2048));
189 assert_eq!(merged.seed, Some(123));
190 }
191
192 #[test]
193 fn test_serde_roundtrip() {
194 let opts = GenerationOptions::new()
195 .with_max_tokens(4096)
196 .with_temperature(0.8)
197 .with_stop(vec!["<|end|>".into(), "STOP".into()]);
198
199 let json = serde_json::to_string(&opts).unwrap();
200 let restored: GenerationOptions = serde_json::from_str(&json).unwrap();
201 assert_eq!(opts, restored);
202 }
203
204 #[test]
205 fn test_serde_skips_none_fields() {
206 let opts = GenerationOptions::new().with_max_tokens(100);
207 let json = serde_json::to_string(&opts).unwrap();
208 let restored: GenerationOptions = serde_json::from_str(&json).unwrap();
209 assert_eq!(restored.max_tokens, Some(100));
210 assert_eq!(restored.temperature, None);
211 }
212}