openai_mock/validators/
optional_fields.rs1use serde::{Serialize, Deserialize};
2
3pub fn validate_temperature(temperature: Option<f32>) -> Result<(), String> {
4 if let Some(temp) = temperature {
5 if temp < 0.0 || temp > 2.0 {
6 return Err(format!("Temperature must be between 0.0 and 2.0, got {}", temp));
7 }
8 }
9 Ok(())
10}
11
12pub fn validate_top_p(top_p: Option<f32>) -> Result<(), String> {
13 if let Some(p) = top_p {
14 if p < 0.0 || p > 1.0 {
15 return Err(format!("Top_p must be between 0.0 and 1.0, got {}", p));
16 }
17 }
18 Ok(())
19}
20
21pub fn validate_n(n: Option<i32>) -> Result<(), String> {
22 if let Some(value) = n {
23 if value <= 0 {
24 return Err(format!("n must be a positive integer, got {}", value));
25 }
26 }
27 Ok(())
28}
29
30pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), String> {
31 if let Some(value) = max_tokens {
32 if value <= 0 {
33 return Err(format!("max_tokens must be a positive integer, got {}", value));
34 }
35 }
36 Ok(())
37}
38
39pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), String> {
40 if let Some(value) = presence_penalty {
41 if value < -2.0 || value > 2.0 {
42 return Err(format!("Presence penalty must be between -2.0 and 2.0, got {}", value));
43 }
44 }
45 Ok(())
46}
47
48pub fn validate_frequency_penalty(frequency_penalty: Option<f32>) -> Result<(), String> {
49 if let Some(value) = frequency_penalty {
50 if value < -2.0 || value > 2.0 {
51 return Err(format!("Frequency penalty must be between -2.0 and 2.0, got {}", value));
52 }
53 }
54 Ok(())
55}
56
57pub fn validate_best_of(best_of: Option<i32>, n: Option<i32>) -> Result<(), String> {
58 if let Some(best_of_value) = best_of {
59 if best_of_value <= 0 {
60 return Err(format!("best_of must be a positive integer, got {}", best_of_value));
61 }
62
63 if let Some(n_value) = n {
64 if best_of_value < n_value {
65 return Err(format!(
66 "best_of must be greater than or equal to n, got best_of={} and n={}",
67 best_of_value, n_value
68 ));
69 }
70 }
71 }
72 Ok(())
73}
74
75pub fn validate_logprobs(logprobs: Option<u32>) -> Result<(), String> {
76 if let Some(value) = logprobs {
77 #[allow(unused_comparisons)]
78 if value < 0 {
79 return Err(format!("logprobs must be a non-negative integer, got {}", value));
80 }
81 }
82 Ok(())
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub enum StopSequence {
87 Single(String),
88 Multiple(Vec<String>),
89}
90
91pub fn validate_stop(stop: Option<StopSequence>) -> Result<(), String> {
92 if let Some(stop_value) = stop {
93 match stop_value {
94 StopSequence::Single(s) => {
95 if s.is_empty() {
96 return Err("Stop sequence cannot be empty".to_string());
97 }
98 }
99 StopSequence::Multiple(sequences) => {
100 if sequences.is_empty() {
101 return Err("Stop sequences array cannot be empty".to_string());
102 }
103 for (i, sequence) in sequences.iter().enumerate() {
104 if sequence.is_empty() {
105 return Err(format!("Stop sequence at index {} cannot be empty", i));
106 }
107 }
108 }
109 }
110 }
111 Ok(())
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_validate_temperature() {
120 assert!(validate_temperature(None).is_ok());
121 assert!(validate_temperature(Some(0.0)).is_ok());
122 assert!(validate_temperature(Some(1.0)).is_ok());
123 assert!(validate_temperature(Some(2.0)).is_ok());
124 assert!(validate_temperature(Some(-0.1)).is_err());
125 assert!(validate_temperature(Some(2.1)).is_err());
126 }
127
128 #[test]
129 fn test_validate_top_p() {
130 assert!(validate_top_p(None).is_ok());
131 assert!(validate_top_p(Some(0.0)).is_ok());
132 assert!(validate_top_p(Some(0.5)).is_ok());
133 assert!(validate_top_p(Some(1.0)).is_ok());
134 assert!(validate_top_p(Some(-0.1)).is_err());
135 assert!(validate_top_p(Some(1.1)).is_err());
136 }
137
138 #[test]
139 fn test_validate_n() {
140 assert!(validate_n(None).is_ok());
141 assert!(validate_n(Some(1)).is_ok());
142 assert!(validate_n(Some(100)).is_ok());
143 assert!(validate_n(Some(0)).is_err());
144 assert!(validate_n(Some(-1)).is_err());
145 }
146
147 #[test]
148 fn test_validate_max_tokens() {
149 assert!(validate_max_tokens(None).is_ok());
150 assert!(validate_max_tokens(Some(1)).is_ok());
151 assert!(validate_max_tokens(Some(100)).is_ok());
152 assert!(validate_max_tokens(Some(0)).is_err());
153 }
154
155 #[test]
156 fn test_validate_presence_penalty() {
157 assert!(validate_presence_penalty(None).is_ok());
158 assert!(validate_presence_penalty(Some(-2.0)).is_ok());
159 assert!(validate_presence_penalty(Some(0.0)).is_ok());
160 assert!(validate_presence_penalty(Some(2.0)).is_ok());
161 assert!(validate_presence_penalty(Some(-2.1)).is_err());
162 assert!(validate_presence_penalty(Some(2.1)).is_err());
163 }
164
165 #[test]
166 fn test_validate_frequency_penalty() {
167 assert!(validate_frequency_penalty(None).is_ok());
168 assert!(validate_frequency_penalty(Some(-2.0)).is_ok());
169 assert!(validate_frequency_penalty(Some(0.0)).is_ok());
170 assert!(validate_frequency_penalty(Some(2.0)).is_ok());
171 assert!(validate_frequency_penalty(Some(-2.1)).is_err());
172 assert!(validate_frequency_penalty(Some(2.1)).is_err());
173 }
174
175 #[test]
176 fn test_validate_best_of() {
177 assert!(validate_best_of(None, None).is_ok());
179 assert!(validate_best_of(Some(1), None).is_ok());
180 assert!(validate_best_of(Some(0), None).is_err());
181
182 assert!(validate_best_of(Some(5), Some(3)).is_ok());
184 assert!(validate_best_of(Some(5), Some(5)).is_ok());
185 assert!(validate_best_of(Some(3), Some(5)).is_err());
186 }
187
188 #[test]
189 fn test_validate_logprobs() {
190 assert!(validate_logprobs(None).is_ok());
191 assert!(validate_logprobs(Some(0)).is_ok());
192 assert!(validate_logprobs(Some(1)).is_ok());
193 assert!(validate_logprobs(Some(100)).is_ok());
194 }
195
196
197
198 #[test]
199 fn test_validate_stop() {
200 assert!(validate_stop(None).is_ok());
202
203 assert!(validate_stop(Some(StopSequence::Single("stop".to_string()))).is_ok());
205 assert!(validate_stop(Some(StopSequence::Single("".to_string()))).is_err());
206
207 assert!(validate_stop(Some(StopSequence::Multiple(vec![
209 "stop1".to_string(),
210 "stop2".to_string()
211 ]))).is_ok());
212 assert!(validate_stop(Some(StopSequence::Multiple(vec![]))).is_err());
213 assert!(validate_stop(Some(StopSequence::Multiple(vec![
214 "valid".to_string(),
215 "".to_string()
216 ]))).is_err());
217 }
218}