1use crate::{EmbeddingRequest, LlmError, LlmRequest, Result};
7
8#[derive(Debug, Clone)]
10pub struct ValidationRules {
11 pub max_prompt_length: Option<usize>,
13 pub min_prompt_length: usize,
15 pub max_tokens_limit: Option<u32>,
17 pub require_prompt: bool,
19 pub max_temperature: f64,
21 pub min_temperature: f64,
23 pub max_images: usize,
25 pub max_tools: usize,
27}
28
29impl Default for ValidationRules {
30 fn default() -> Self {
31 Self {
32 max_prompt_length: Some(1_000_000), min_prompt_length: 1,
34 max_tokens_limit: Some(200_000), require_prompt: true,
36 max_temperature: 2.0,
37 min_temperature: 0.0,
38 max_images: 20,
39 max_tools: 100,
40 }
41 }
42}
43
44impl ValidationRules {
45 pub fn strict() -> Self {
47 Self {
48 max_prompt_length: Some(100_000), min_prompt_length: 1,
50 max_tokens_limit: Some(100_000), require_prompt: true,
52 max_temperature: 1.5,
53 min_temperature: 0.0,
54 max_images: 10,
55 max_tools: 50,
56 }
57 }
58
59 pub fn lenient() -> Self {
61 Self {
62 max_prompt_length: None,
63 min_prompt_length: 0,
64 max_tokens_limit: None,
65 require_prompt: false,
66 max_temperature: 2.0,
67 min_temperature: 0.0,
68 max_images: 100,
69 max_tools: 200,
70 }
71 }
72
73 pub fn validate_llm_request(&self, request: &LlmRequest) -> Result<()> {
75 if self.require_prompt && request.prompt.trim().is_empty() {
77 return Err(LlmError::InvalidRequest(
78 "Prompt cannot be empty".to_string(),
79 ));
80 }
81
82 if request.prompt.len() < self.min_prompt_length {
83 return Err(LlmError::InvalidRequest(format!(
84 "Prompt too short: {} chars (minimum: {})",
85 request.prompt.len(),
86 self.min_prompt_length
87 )));
88 }
89
90 if let Some(max_len) = self.max_prompt_length {
91 if request.prompt.len() > max_len {
92 return Err(LlmError::InvalidRequest(format!(
93 "Prompt too long: {} chars (maximum: {})",
94 request.prompt.len(),
95 max_len
96 )));
97 }
98 }
99
100 if let Some(temp) = request.temperature {
102 if temp < self.min_temperature || temp > self.max_temperature {
103 return Err(LlmError::InvalidRequest(format!(
104 "Temperature out of range: {} (must be between {} and {})",
105 temp, self.min_temperature, self.max_temperature
106 )));
107 }
108 }
109
110 if let Some(max_tokens) = request.max_tokens {
112 if max_tokens == 0 {
113 return Err(LlmError::InvalidRequest(
114 "max_tokens must be greater than 0".to_string(),
115 ));
116 }
117
118 if let Some(limit) = self.max_tokens_limit {
119 if max_tokens > limit {
120 return Err(LlmError::InvalidRequest(format!(
121 "max_tokens too large: {} (maximum: {})",
122 max_tokens, limit
123 )));
124 }
125 }
126 }
127
128 if request.images.len() > self.max_images {
130 return Err(LlmError::InvalidRequest(format!(
131 "Too many images: {} (maximum: {})",
132 request.images.len(),
133 self.max_images
134 )));
135 }
136
137 if request.tools.len() > self.max_tools {
139 return Err(LlmError::InvalidRequest(format!(
140 "Too many tools: {} (maximum: {})",
141 request.tools.len(),
142 self.max_tools
143 )));
144 }
145
146 for tool in &request.tools {
148 if tool.name.trim().is_empty() {
149 return Err(LlmError::InvalidRequest(
150 "Tool name cannot be empty".to_string(),
151 ));
152 }
153 if tool.description.trim().is_empty() {
154 return Err(LlmError::InvalidRequest(format!(
155 "Tool '{}' must have a description",
156 tool.name
157 )));
158 }
159 }
160
161 Ok(())
162 }
163
164 pub fn validate_embedding_request(&self, request: &EmbeddingRequest) -> Result<()> {
166 if request.texts.is_empty() {
167 return Err(LlmError::InvalidRequest(
168 "Embedding request must contain at least one text".to_string(),
169 ));
170 }
171
172 for (i, text) in request.texts.iter().enumerate() {
173 if text.trim().is_empty() {
174 return Err(LlmError::InvalidRequest(format!(
175 "Text at index {} cannot be empty",
176 i
177 )));
178 }
179
180 if let Some(max_len) = self.max_prompt_length {
181 if text.len() > max_len {
182 return Err(LlmError::InvalidRequest(format!(
183 "Text at index {} too long: {} chars (maximum: {})",
184 i,
185 text.len(),
186 max_len
187 )));
188 }
189 }
190 }
191
192 Ok(())
193 }
194}
195
196pub struct RequestValidator {
198 rules: ValidationRules,
199}
200
201impl Default for RequestValidator {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl RequestValidator {
208 pub fn new() -> Self {
210 Self {
211 rules: ValidationRules::default(),
212 }
213 }
214
215 pub fn with_rules(rules: ValidationRules) -> Self {
217 Self { rules }
218 }
219
220 pub fn validate(&self, request: &LlmRequest) -> Result<()> {
222 self.rules.validate_llm_request(request)
223 }
224
225 pub fn validate_embedding(&self, request: &EmbeddingRequest) -> Result<()> {
227 self.rules.validate_embedding_request(request)
228 }
229
230 pub fn rules(&self) -> &ValidationRules {
232 &self.rules
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::Tool;
240
241 #[test]
242 fn test_validate_valid_request() {
243 let validator = RequestValidator::new();
244 let request = LlmRequest {
245 prompt: "Hello, world!".to_string(),
246 system_prompt: None,
247 temperature: Some(0.7),
248 max_tokens: Some(100),
249 tools: vec![],
250 images: vec![],
251 };
252
253 assert!(validator.validate(&request).is_ok());
254 }
255
256 #[test]
257 fn test_validate_empty_prompt() {
258 let validator = RequestValidator::new();
259 let request = LlmRequest {
260 prompt: "".to_string(),
261 system_prompt: None,
262 temperature: None,
263 max_tokens: None,
264 tools: vec![],
265 images: vec![],
266 };
267
268 let result = validator.validate(&request);
269 assert!(result.is_err());
270 assert!(matches!(result.unwrap_err(), LlmError::InvalidRequest(_)));
271 }
272
273 #[test]
274 fn test_validate_temperature_out_of_range() {
275 let validator = RequestValidator::new();
276 let request = LlmRequest {
277 prompt: "Test".to_string(),
278 system_prompt: None,
279 temperature: Some(3.0),
280 max_tokens: None,
281 tools: vec![],
282 images: vec![],
283 };
284
285 let result = validator.validate(&request);
286 assert!(result.is_err());
287 }
288
289 #[test]
290 fn test_validate_zero_max_tokens() {
291 let validator = RequestValidator::new();
292 let request = LlmRequest {
293 prompt: "Test".to_string(),
294 system_prompt: None,
295 temperature: None,
296 max_tokens: Some(0),
297 tools: vec![],
298 images: vec![],
299 };
300
301 let result = validator.validate(&request);
302 assert!(result.is_err());
303 }
304
305 #[test]
306 fn test_validate_too_many_tools() {
307 let validator = RequestValidator::with_rules(ValidationRules {
308 max_tools: 2,
309 ..ValidationRules::default()
310 });
311
312 let request = LlmRequest {
313 prompt: "Test".to_string(),
314 system_prompt: None,
315 temperature: None,
316 max_tokens: None,
317 tools: vec![
318 Tool {
319 name: "tool1".to_string(),
320 description: "desc1".to_string(),
321 parameters: serde_json::json!({}),
322 },
323 Tool {
324 name: "tool2".to_string(),
325 description: "desc2".to_string(),
326 parameters: serde_json::json!({}),
327 },
328 Tool {
329 name: "tool3".to_string(),
330 description: "desc3".to_string(),
331 parameters: serde_json::json!({}),
332 },
333 ],
334 images: vec![],
335 };
336
337 let result = validator.validate(&request);
338 assert!(result.is_err());
339 }
340
341 #[test]
342 fn test_validate_tool_without_name() {
343 let validator = RequestValidator::new();
344 let request = LlmRequest {
345 prompt: "Test".to_string(),
346 system_prompt: None,
347 temperature: None,
348 max_tokens: None,
349 tools: vec![Tool {
350 name: "".to_string(),
351 description: "description".to_string(),
352 parameters: serde_json::json!({}),
353 }],
354 images: vec![],
355 };
356
357 let result = validator.validate(&request);
358 assert!(result.is_err());
359 }
360
361 #[test]
362 fn test_validate_embedding_request() {
363 let validator = RequestValidator::new();
364 let request = EmbeddingRequest {
365 texts: vec!["Hello".to_string(), "World".to_string()],
366 model: None,
367 };
368
369 assert!(validator.validate_embedding(&request).is_ok());
370 }
371
372 #[test]
373 fn test_validate_empty_embedding_request() {
374 let validator = RequestValidator::new();
375 let request = EmbeddingRequest {
376 texts: vec![],
377 model: None,
378 };
379
380 let result = validator.validate_embedding(&request);
381 assert!(result.is_err());
382 }
383
384 #[test]
385 fn test_validation_rules_strict() {
386 let rules = ValidationRules::strict();
387 assert!(rules.max_prompt_length.is_some());
388 assert_eq!(rules.max_prompt_length.unwrap(), 100_000);
389 }
390
391 #[test]
392 fn test_validation_rules_lenient() {
393 let rules = ValidationRules::lenient();
394 assert!(rules.max_prompt_length.is_none());
395 assert_eq!(rules.min_prompt_length, 0);
396 }
397}