openai_mock/validators/
optional_fields.rs

1use serde::{Serialize, Deserialize};
2
3pub fn validate_temperature(temperature: Option<f32>) -> Result<(), String> {
4    if let Some(temp) = temperature {
5        if temp < 0.0 || temp > 2.0 {
6            return Err(format!("Temperature must be between 0.0 and 2.0, got {}", temp));
7        }
8    }
9    Ok(())
10}
11
12pub fn validate_top_p(top_p: Option<f32>) -> Result<(), String> {
13    if let Some(p) = top_p {
14        if p < 0.0 || p > 1.0 {
15            return Err(format!("Top_p must be between 0.0 and 1.0, got {}", p));
16        }
17    }
18    Ok(())
19}
20
21pub fn validate_n(n: Option<i32>) -> Result<(), String> {
22    if let Some(value) = n {
23        if value <= 0 {
24            return Err(format!("n must be a positive integer, got {}", value));
25        }
26    }
27    Ok(())
28}
29
30pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), String> {
31    if let Some(value) = max_tokens {
32        if value <= 0 {
33            return Err(format!("max_tokens must be a positive integer, got {}", value));
34        }
35    }
36    Ok(())
37}
38
39pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), String> {
40    if let Some(value) = presence_penalty {
41        if value < -2.0 || value > 2.0 {
42            return Err(format!("Presence penalty must be between -2.0 and 2.0, got {}", value));
43        }
44    }
45    Ok(())
46}
47
48pub fn validate_frequency_penalty(frequency_penalty: Option<f32>) -> Result<(), String> {
49    if let Some(value) = frequency_penalty {
50        if value < -2.0 || value > 2.0 {
51            return Err(format!("Frequency penalty must be between -2.0 and 2.0, got {}", value));
52        }
53    }
54    Ok(())
55}
56
57pub fn validate_best_of(best_of: Option<i32>, n: Option<i32>) -> Result<(), String> {
58    if let Some(best_of_value) = best_of {
59        if best_of_value <= 0 {
60            return Err(format!("best_of must be a positive integer, got {}", best_of_value));
61        }
62
63        if let Some(n_value) = n {
64            if best_of_value < n_value {
65                return Err(format!(
66                    "best_of must be greater than or equal to n, got best_of={} and n={}",
67                    best_of_value, n_value
68                ));
69            }
70        }
71    }
72    Ok(())
73}
74
75pub fn validate_logprobs(logprobs: Option<u32>) -> Result<(), String> {
76    if let Some(value) = logprobs {
77        #[allow(unused_comparisons)]
78        if value < 0 {
79            return Err(format!("logprobs must be a non-negative integer, got {}", value));
80        }
81    }
82    Ok(())
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub enum StopSequence {
87    Single(String),
88    Multiple(Vec<String>),
89}
90
91pub fn validate_stop(stop: Option<StopSequence>) -> Result<(), String> {
92    if let Some(stop_value) = stop {
93        match stop_value {
94            StopSequence::Single(s) => {
95                if s.is_empty() {
96                    return Err("Stop sequence cannot be empty".to_string());
97                }
98            }
99            StopSequence::Multiple(sequences) => {
100                if sequences.is_empty() {
101                    return Err("Stop sequences array cannot be empty".to_string());
102                }
103                for (i, sequence) in sequences.iter().enumerate() {
104                    if sequence.is_empty() {
105                        return Err(format!("Stop sequence at index {} cannot be empty", i));
106                    }
107                }
108            }
109        }
110    }
111    Ok(())
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_validate_temperature() {
120        assert!(validate_temperature(None).is_ok());
121        assert!(validate_temperature(Some(0.0)).is_ok());
122        assert!(validate_temperature(Some(1.0)).is_ok());
123        assert!(validate_temperature(Some(2.0)).is_ok());
124        assert!(validate_temperature(Some(-0.1)).is_err());
125        assert!(validate_temperature(Some(2.1)).is_err());
126    }
127
128    #[test]
129    fn test_validate_top_p() {
130        assert!(validate_top_p(None).is_ok());
131        assert!(validate_top_p(Some(0.0)).is_ok());
132        assert!(validate_top_p(Some(0.5)).is_ok());
133        assert!(validate_top_p(Some(1.0)).is_ok());
134        assert!(validate_top_p(Some(-0.1)).is_err());
135        assert!(validate_top_p(Some(1.1)).is_err());
136    }
137
138    #[test]
139    fn test_validate_n() {
140        assert!(validate_n(None).is_ok());
141        assert!(validate_n(Some(1)).is_ok());
142        assert!(validate_n(Some(100)).is_ok());
143        assert!(validate_n(Some(0)).is_err());
144        assert!(validate_n(Some(-1)).is_err());
145    }
146
147    #[test]
148    fn test_validate_max_tokens() {
149        assert!(validate_max_tokens(None).is_ok());
150        assert!(validate_max_tokens(Some(1)).is_ok());
151        assert!(validate_max_tokens(Some(100)).is_ok());
152        assert!(validate_max_tokens(Some(0)).is_err());
153    }
154
155    #[test]
156    fn test_validate_presence_penalty() {
157        assert!(validate_presence_penalty(None).is_ok());
158        assert!(validate_presence_penalty(Some(-2.0)).is_ok());
159        assert!(validate_presence_penalty(Some(0.0)).is_ok());
160        assert!(validate_presence_penalty(Some(2.0)).is_ok());
161        assert!(validate_presence_penalty(Some(-2.1)).is_err());
162        assert!(validate_presence_penalty(Some(2.1)).is_err());
163    }
164
165    #[test]
166    fn test_validate_frequency_penalty() {
167        assert!(validate_frequency_penalty(None).is_ok());
168        assert!(validate_frequency_penalty(Some(-2.0)).is_ok());
169        assert!(validate_frequency_penalty(Some(0.0)).is_ok());
170        assert!(validate_frequency_penalty(Some(2.0)).is_ok());
171        assert!(validate_frequency_penalty(Some(-2.1)).is_err());
172        assert!(validate_frequency_penalty(Some(2.1)).is_err());
173    }
174
175    #[test]
176    fn test_validate_best_of() {
177        // Test basic positive integer validation
178        assert!(validate_best_of(None, None).is_ok());
179        assert!(validate_best_of(Some(1), None).is_ok());
180        assert!(validate_best_of(Some(0), None).is_err());
181
182        // Test relationship with n
183        assert!(validate_best_of(Some(5), Some(3)).is_ok());
184        assert!(validate_best_of(Some(5), Some(5)).is_ok());
185        assert!(validate_best_of(Some(3), Some(5)).is_err());
186    }
187
188    #[test]
189    fn test_validate_logprobs() {
190        assert!(validate_logprobs(None).is_ok());
191        assert!(validate_logprobs(Some(0)).is_ok());
192        assert!(validate_logprobs(Some(1)).is_ok());
193        assert!(validate_logprobs(Some(100)).is_ok());
194    }
195
196
197
198    #[test]
199    fn test_validate_stop() {
200        // Test None case
201        assert!(validate_stop(None).is_ok());
202
203        // Test single string cases
204        assert!(validate_stop(Some(StopSequence::Single("stop".to_string()))).is_ok());
205        assert!(validate_stop(Some(StopSequence::Single("".to_string()))).is_err());
206
207        // Test array cases
208        assert!(validate_stop(Some(StopSequence::Multiple(vec![
209            "stop1".to_string(),
210            "stop2".to_string()
211        ]))).is_ok());
212        assert!(validate_stop(Some(StopSequence::Multiple(vec![]))).is_err());
213        assert!(validate_stop(Some(StopSequence::Multiple(vec![
214            "valid".to_string(),
215            "".to_string()
216        ]))).is_err());
217    }
218}