oxify_connect_llm/
validation.rs

1//! Request validation utilities
2//!
3//! This module provides validation utilities to catch errors before making API calls,
4//! helping to save costs and improve user experience.
5
6use crate::{EmbeddingRequest, LlmError, LlmRequest, Result};
7
8/// Validation rules for LLM requests
9#[derive(Debug, Clone)]
10pub struct ValidationRules {
11    /// Maximum prompt length in characters
12    pub max_prompt_length: Option<usize>,
13    /// Minimum prompt length in characters
14    pub min_prompt_length: usize,
15    /// Maximum number of tokens
16    pub max_tokens_limit: Option<u32>,
17    /// Require non-empty prompt
18    pub require_prompt: bool,
19    /// Maximum temperature value
20    pub max_temperature: f64,
21    /// Minimum temperature value
22    pub min_temperature: f64,
23    /// Maximum number of images
24    pub max_images: usize,
25    /// Maximum number of tools
26    pub max_tools: usize,
27}
28
29impl Default for ValidationRules {
30    fn default() -> Self {
31        Self {
32            max_prompt_length: Some(1_000_000), // 1M chars
33            min_prompt_length: 1,
34            max_tokens_limit: Some(200_000), // 200K tokens
35            require_prompt: true,
36            max_temperature: 2.0,
37            min_temperature: 0.0,
38            max_images: 20,
39            max_tools: 100,
40        }
41    }
42}
43
44impl ValidationRules {
45    /// Create strict validation rules
46    pub fn strict() -> Self {
47        Self {
48            max_prompt_length: Some(100_000), // 100K chars
49            min_prompt_length: 1,
50            max_tokens_limit: Some(100_000), // 100K tokens
51            require_prompt: true,
52            max_temperature: 1.5,
53            min_temperature: 0.0,
54            max_images: 10,
55            max_tools: 50,
56        }
57    }
58
59    /// Create lenient validation rules
60    pub fn lenient() -> Self {
61        Self {
62            max_prompt_length: None,
63            min_prompt_length: 0,
64            max_tokens_limit: None,
65            require_prompt: false,
66            max_temperature: 2.0,
67            min_temperature: 0.0,
68            max_images: 100,
69            max_tools: 200,
70        }
71    }
72
73    /// Validate an LLM request
74    pub fn validate_llm_request(&self, request: &LlmRequest) -> Result<()> {
75        // Validate prompt
76        if self.require_prompt && request.prompt.trim().is_empty() {
77            return Err(LlmError::InvalidRequest(
78                "Prompt cannot be empty".to_string(),
79            ));
80        }
81
82        if request.prompt.len() < self.min_prompt_length {
83            return Err(LlmError::InvalidRequest(format!(
84                "Prompt too short: {} chars (minimum: {})",
85                request.prompt.len(),
86                self.min_prompt_length
87            )));
88        }
89
90        if let Some(max_len) = self.max_prompt_length {
91            if request.prompt.len() > max_len {
92                return Err(LlmError::InvalidRequest(format!(
93                    "Prompt too long: {} chars (maximum: {})",
94                    request.prompt.len(),
95                    max_len
96                )));
97            }
98        }
99
100        // Validate temperature
101        if let Some(temp) = request.temperature {
102            if temp < self.min_temperature || temp > self.max_temperature {
103                return Err(LlmError::InvalidRequest(format!(
104                    "Temperature out of range: {} (must be between {} and {})",
105                    temp, self.min_temperature, self.max_temperature
106                )));
107            }
108        }
109
110        // Validate max_tokens
111        if let Some(max_tokens) = request.max_tokens {
112            if max_tokens == 0 {
113                return Err(LlmError::InvalidRequest(
114                    "max_tokens must be greater than 0".to_string(),
115                ));
116            }
117
118            if let Some(limit) = self.max_tokens_limit {
119                if max_tokens > limit {
120                    return Err(LlmError::InvalidRequest(format!(
121                        "max_tokens too large: {} (maximum: {})",
122                        max_tokens, limit
123                    )));
124                }
125            }
126        }
127
128        // Validate images
129        if request.images.len() > self.max_images {
130            return Err(LlmError::InvalidRequest(format!(
131                "Too many images: {} (maximum: {})",
132                request.images.len(),
133                self.max_images
134            )));
135        }
136
137        // Validate tools
138        if request.tools.len() > self.max_tools {
139            return Err(LlmError::InvalidRequest(format!(
140                "Too many tools: {} (maximum: {})",
141                request.tools.len(),
142                self.max_tools
143            )));
144        }
145
146        // Validate tool definitions
147        for tool in &request.tools {
148            if tool.name.trim().is_empty() {
149                return Err(LlmError::InvalidRequest(
150                    "Tool name cannot be empty".to_string(),
151                ));
152            }
153            if tool.description.trim().is_empty() {
154                return Err(LlmError::InvalidRequest(format!(
155                    "Tool '{}' must have a description",
156                    tool.name
157                )));
158            }
159        }
160
161        Ok(())
162    }
163
164    /// Validate an embedding request
165    pub fn validate_embedding_request(&self, request: &EmbeddingRequest) -> Result<()> {
166        if request.texts.is_empty() {
167            return Err(LlmError::InvalidRequest(
168                "Embedding request must contain at least one text".to_string(),
169            ));
170        }
171
172        for (i, text) in request.texts.iter().enumerate() {
173            if text.trim().is_empty() {
174                return Err(LlmError::InvalidRequest(format!(
175                    "Text at index {} cannot be empty",
176                    i
177                )));
178            }
179
180            if let Some(max_len) = self.max_prompt_length {
181                if text.len() > max_len {
182                    return Err(LlmError::InvalidRequest(format!(
183                        "Text at index {} too long: {} chars (maximum: {})",
184                        i,
185                        text.len(),
186                        max_len
187                    )));
188                }
189            }
190        }
191
192        Ok(())
193    }
194}
195
196/// Validates LLM requests before sending them to providers
197pub struct RequestValidator {
198    rules: ValidationRules,
199}
200
201impl Default for RequestValidator {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207impl RequestValidator {
208    /// Create a new validator with default rules
209    pub fn new() -> Self {
210        Self {
211            rules: ValidationRules::default(),
212        }
213    }
214
215    /// Create a validator with custom rules
216    pub fn with_rules(rules: ValidationRules) -> Self {
217        Self { rules }
218    }
219
220    /// Validate an LLM request
221    pub fn validate(&self, request: &LlmRequest) -> Result<()> {
222        self.rules.validate_llm_request(request)
223    }
224
225    /// Validate an embedding request
226    pub fn validate_embedding(&self, request: &EmbeddingRequest) -> Result<()> {
227        self.rules.validate_embedding_request(request)
228    }
229
230    /// Get a reference to the validation rules
231    pub fn rules(&self) -> &ValidationRules {
232        &self.rules
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::Tool;
240
241    #[test]
242    fn test_validate_valid_request() {
243        let validator = RequestValidator::new();
244        let request = LlmRequest {
245            prompt: "Hello, world!".to_string(),
246            system_prompt: None,
247            temperature: Some(0.7),
248            max_tokens: Some(100),
249            tools: vec![],
250            images: vec![],
251        };
252
253        assert!(validator.validate(&request).is_ok());
254    }
255
256    #[test]
257    fn test_validate_empty_prompt() {
258        let validator = RequestValidator::new();
259        let request = LlmRequest {
260            prompt: "".to_string(),
261            system_prompt: None,
262            temperature: None,
263            max_tokens: None,
264            tools: vec![],
265            images: vec![],
266        };
267
268        let result = validator.validate(&request);
269        assert!(result.is_err());
270        assert!(matches!(result.unwrap_err(), LlmError::InvalidRequest(_)));
271    }
272
273    #[test]
274    fn test_validate_temperature_out_of_range() {
275        let validator = RequestValidator::new();
276        let request = LlmRequest {
277            prompt: "Test".to_string(),
278            system_prompt: None,
279            temperature: Some(3.0),
280            max_tokens: None,
281            tools: vec![],
282            images: vec![],
283        };
284
285        let result = validator.validate(&request);
286        assert!(result.is_err());
287    }
288
289    #[test]
290    fn test_validate_zero_max_tokens() {
291        let validator = RequestValidator::new();
292        let request = LlmRequest {
293            prompt: "Test".to_string(),
294            system_prompt: None,
295            temperature: None,
296            max_tokens: Some(0),
297            tools: vec![],
298            images: vec![],
299        };
300
301        let result = validator.validate(&request);
302        assert!(result.is_err());
303    }
304
305    #[test]
306    fn test_validate_too_many_tools() {
307        let validator = RequestValidator::with_rules(ValidationRules {
308            max_tools: 2,
309            ..ValidationRules::default()
310        });
311
312        let request = LlmRequest {
313            prompt: "Test".to_string(),
314            system_prompt: None,
315            temperature: None,
316            max_tokens: None,
317            tools: vec![
318                Tool {
319                    name: "tool1".to_string(),
320                    description: "desc1".to_string(),
321                    parameters: serde_json::json!({}),
322                },
323                Tool {
324                    name: "tool2".to_string(),
325                    description: "desc2".to_string(),
326                    parameters: serde_json::json!({}),
327                },
328                Tool {
329                    name: "tool3".to_string(),
330                    description: "desc3".to_string(),
331                    parameters: serde_json::json!({}),
332                },
333            ],
334            images: vec![],
335        };
336
337        let result = validator.validate(&request);
338        assert!(result.is_err());
339    }
340
341    #[test]
342    fn test_validate_tool_without_name() {
343        let validator = RequestValidator::new();
344        let request = LlmRequest {
345            prompt: "Test".to_string(),
346            system_prompt: None,
347            temperature: None,
348            max_tokens: None,
349            tools: vec![Tool {
350                name: "".to_string(),
351                description: "description".to_string(),
352                parameters: serde_json::json!({}),
353            }],
354            images: vec![],
355        };
356
357        let result = validator.validate(&request);
358        assert!(result.is_err());
359    }
360
361    #[test]
362    fn test_validate_embedding_request() {
363        let validator = RequestValidator::new();
364        let request = EmbeddingRequest {
365            texts: vec!["Hello".to_string(), "World".to_string()],
366            model: None,
367        };
368
369        assert!(validator.validate_embedding(&request).is_ok());
370    }
371
372    #[test]
373    fn test_validate_empty_embedding_request() {
374        let validator = RequestValidator::new();
375        let request = EmbeddingRequest {
376            texts: vec![],
377            model: None,
378        };
379
380        let result = validator.validate_embedding(&request);
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn test_validation_rules_strict() {
386        let rules = ValidationRules::strict();
387        assert!(rules.max_prompt_length.is_some());
388        assert_eq!(rules.max_prompt_length.unwrap(), 100_000);
389    }
390
391    #[test]
392    fn test_validation_rules_lenient() {
393        let rules = ValidationRules::lenient();
394        assert!(rules.max_prompt_length.is_none());
395        assert_eq!(rules.min_prompt_length, 0);
396    }
397}