Skip to main content

openrouter_api/utils/validation/
completion.rs

1//! Validation utilities for text completion requests
2
3use super::common::*;
4use crate::error::{Error, Result};
5use crate::types::completion::CompletionRequest;
6
7/// Maximum allowed prompt length for completions
8const MAX_PROMPT_LENGTH: usize = 1_000_000;
9
10/// Validates a completion request for common errors
11pub fn validate_completion_request(request: &CompletionRequest) -> Result<()> {
12    // Validate model
13    validate_model_id(&request.model)?;
14
15    // Validate prompt
16    validate_non_empty_string(&request.prompt, "prompt")?;
17    validate_string_length(&request.prompt, "prompt", 1, MAX_PROMPT_LENGTH)?;
18
19    // Validate extra parameters if present
20    if let serde_json::Value::Object(params) = &request.extra_params {
21        validate_extra_params(params)?;
22    }
23
24    Ok(())
25}
26
27/// Validates extra parameters in completion requests
28fn validate_extra_params(params: &serde_json::Map<String, serde_json::Value>) -> Result<()> {
29    // Temperature: [0.0, 2.0]
30    validate_optional_numeric_param(params, "temperature", 0.0, 2.0)?;
31
32    // Top P: (0.0, 1.0]
33    if let Some(value) = params.get("top_p") {
34        if let Some(top_p) = value.as_f64() {
35            if top_p <= 0.0 || top_p > 1.0 {
36                return Err(Error::ConfigError(format!(
37                    "Top P must be between 0.0 (exclusive) and 1.0 (inclusive), got {}",
38                    top_p
39                )));
40            }
41        } else {
42            return Err(Error::ConfigError(
43                "Parameter 'top_p' must be a number".to_string(),
44            ));
45        }
46    }
47
48    // Max tokens: [1, 8192] or 0 for unlimited
49    if let Some(value) = params.get("max_tokens") {
50        if let Some(tokens) = value.as_u64() {
51            if tokens != 0 && !(1..=8192).contains(&tokens) {
52                return Err(Error::ConfigError(format!(
53                    "Max tokens must be 0 (unlimited) or between 1 and 8192, got {}",
54                    tokens
55                )));
56            }
57        } else {
58            return Err(Error::ConfigError(
59                "Parameter 'max_tokens' must be an integer".to_string(),
60            ));
61        }
62    }
63
64    // Frequency Penalty: [-2.0, 2.0]
65    validate_optional_numeric_param(params, "frequency_penalty", -2.0, 2.0)?;
66
67    // Presence Penalty: [-2.0, 2.0]
68    validate_optional_numeric_param(params, "presence_penalty", -2.0, 2.0)?;
69
70    // Validate stop sequences if present
71    if let Some(value) = params.get("stop") {
72        validate_stop_sequence(value)?;
73    }
74
75    // Validate logit bias if present
76    if let Some(value) = params.get("logit_bias") {
77        validate_logit_bias(value)?;
78    }
79
80    // Validate echo parameter if present
81    if let Some(value) = params.get("echo") {
82        if !value.is_boolean() {
83            return Err(Error::ConfigError(
84                "Parameter 'echo' must be a boolean".to_string(),
85            ));
86        }
87    }
88
89    // Validate suffix parameter if present
90    if let Some(value) = params.get("suffix") {
91        if let Some(suffix) = value.as_str() {
92            validate_string_length(suffix, "suffix", 0, 1000)?;
93        } else if !value.is_null() {
94            return Err(Error::ConfigError(
95                "Parameter 'suffix' must be a string or null".to_string(),
96            ));
97        }
98    }
99
100    // Validate best_of parameter if present
101    if let Some(value) = params.get("best_of") {
102        if let Some(best_of) = value.as_u64() {
103            validate_numeric_range(best_of, "best_of", 1, 20)?;
104        } else {
105            return Err(Error::ConfigError(
106                "Parameter 'best_of' must be an integer".to_string(),
107            ));
108        }
109    }
110
111    // Validate logprobs parameter if present
112    if let Some(value) = params.get("logprobs") {
113        if let Some(logprobs) = value.as_u64() {
114            validate_numeric_range(logprobs, "logprobs", 0, 5)?;
115        } else {
116            return Err(Error::ConfigError(
117                "Parameter 'logprobs' must be an integer".to_string(),
118            ));
119        }
120    }
121
122    Ok(())
123}
124
125/// Validates stop sequence parameter
126fn validate_stop_sequence(value: &serde_json::Value) -> Result<()> {
127    match value {
128        serde_json::Value::String(stop) => {
129            // Single stop sequence
130            validate_string_length(stop, "stop", 1, 100)?;
131        }
132        serde_json::Value::Array(stops) => {
133            // Multiple stop sequences
134            validate_collection_size(stops, "stop", 1, 4)?;
135
136            for (index, stop_val) in stops.iter().enumerate() {
137                if let Some(stop_str) = stop_val.as_str() {
138                    validate_string_length(stop_str, &format!("stop[{}]", index), 1, 100)?;
139                } else {
140                    return Err(Error::ConfigError(format!(
141                        "Stop sequence at index {} must be a string",
142                        index
143                    )));
144                }
145            }
146        }
147        _ => {
148            return Err(Error::ConfigError(
149                "Parameter 'stop' must be a string or array of strings".to_string(),
150            ));
151        }
152    }
153    Ok(())
154}
155
156/// Validates logit bias parameter
157fn validate_logit_bias(value: &serde_json::Value) -> Result<()> {
158    if let serde_json::Value::Object(bias_map) = value {
159        // Validate each token-bias pair
160        for (token_str, bias_val) in bias_map {
161            // Validate token is a valid integer string
162            if token_str.parse::<i32>().is_err() {
163                return Err(Error::ConfigError(format!(
164                    "Logit bias token '{}' must be a valid integer",
165                    token_str
166                )));
167            }
168
169            // Validate bias value is a number
170            if !bias_val.is_number() {
171                return Err(Error::ConfigError(format!(
172                    "Logit bias for token '{}' must be a number",
173                    token_str
174                )));
175            }
176
177            // Validate bias range: [-100, 100]
178            if let Some(bias) = bias_val.as_f64() {
179                if !(-100.0..=100.0).contains(&bias) {
180                    return Err(Error::ConfigError(format!(
181                        "Logit bias for token '{}' must be between -100 and 100, got {}",
182                        token_str, bias
183                    )));
184                }
185            }
186        }
187    } else {
188        return Err(Error::ConfigError(
189            "Parameter 'logit_bias' must be a JSON object".to_string(),
190        ));
191    }
192    Ok(())
193}
194
195/// Estimates token count for a completion prompt (rough approximation)
196pub fn estimate_prompt_tokens(prompt: &str) -> u32 {
197    // Very rough approximation: 1 token per 4 characters
198    // This is less accurate than for chat since completion prompts can be any format
199    (prompt.len() as f32 / 4.0).ceil() as u32
200}
201
202/// Checks if a completion prompt might exceed reasonable token limits
203pub fn check_prompt_token_limits(prompt: &str, model: &str) -> Result<()> {
204    let estimated_tokens = estimate_prompt_tokens(prompt);
205
206    // Use a more conservative limit for completions since context windows vary
207    const MAX_COMPLETION_TOKENS: u32 = 200_000;
208
209    if estimated_tokens > MAX_COMPLETION_TOKENS {
210        return Err(Error::ContextLengthExceeded {
211            model: model.to_string(),
212            message: format!(
213                "Estimated prompt token count ({}) exceeds maximum recommended limit ({})",
214                estimated_tokens, MAX_COMPLETION_TOKENS
215            ),
216        });
217    }
218
219    Ok(())
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use serde_json::json;
226
227    fn create_valid_completion_request() -> CompletionRequest {
228        CompletionRequest {
229            model: "openai/gpt-4".to_string(),
230            prompt: "Once upon a time,".to_string(),
231            extra_params: serde_json::json!({}),
232        }
233    }
234
235    #[test]
236    fn test_validate_completion_request_valid() {
237        let request = create_valid_completion_request();
238        assert!(validate_completion_request(&request).is_ok());
239    }
240
241    #[test]
242    fn test_validate_completion_request_empty_model() {
243        let mut request = create_valid_completion_request();
244        request.model = "".to_string();
245        assert!(validate_completion_request(&request).is_err());
246    }
247
248    #[test]
249    fn test_validate_completion_request_invalid_model_format() {
250        let mut request = create_valid_completion_request();
251        request.model = "invalid-model-name".to_string();
252        assert!(validate_completion_request(&request).is_err());
253    }
254
255    #[test]
256    fn test_validate_completion_request_empty_prompt() {
257        let mut request = create_valid_completion_request();
258        request.prompt = "".to_string();
259        assert!(validate_completion_request(&request).is_err());
260    }
261
262    #[test]
263    fn test_validate_completion_request_whitespace_prompt() {
264        let mut request = create_valid_completion_request();
265        request.prompt = "   ".to_string();
266        assert!(validate_completion_request(&request).is_err());
267    }
268
269    #[test]
270    fn test_validate_completion_request_prompt_too_long() {
271        let mut request = create_valid_completion_request();
272        request.prompt = "a".repeat(1_000_001);
273        assert!(validate_completion_request(&request).is_err());
274    }
275
276    #[test]
277    fn test_validate_completion_request_valid_extra_params() {
278        let mut request = create_valid_completion_request();
279        request.extra_params = json!({
280            "temperature": 0.7,
281            "max_tokens": 100,
282            "top_p": 0.9,
283            "frequency_penalty": 0.5,
284            "presence_penalty": 0.3
285        });
286        assert!(validate_completion_request(&request).is_ok());
287    }
288
289    #[test]
290    fn test_validate_completion_request_temperature_bounds() {
291        let test_cases = [
292            (-0.1, false), // Too low
293            (0.0, true),   // Minimum valid
294            (1.0, true),   // Valid
295            (2.0, true),   // Maximum valid
296            (2.1, false),  // Too high
297        ];
298
299        for (temp, should_pass) in test_cases {
300            let mut request = create_valid_completion_request();
301            request.extra_params = json!({"temperature": temp});
302
303            let result = validate_completion_request(&request);
304            if should_pass {
305                assert!(result.is_ok(), "Temperature {} should be valid", temp);
306            } else {
307                assert!(result.is_err(), "Temperature {} should be invalid", temp);
308            }
309        }
310    }
311
312    #[test]
313    fn test_validate_completion_request_top_p_bounds() {
314        let test_cases = [
315            (0.0, false), // Too low (exclusive)
316            (0.1, true),  // Valid
317            (1.0, true),  // Maximum valid (inclusive)
318            (1.1, false), // Too high
319        ];
320
321        for (top_p, should_pass) in test_cases {
322            let mut request = create_valid_completion_request();
323            request.extra_params = json!({"top_p": top_p});
324
325            let result = validate_completion_request(&request);
326            if should_pass {
327                assert!(result.is_ok(), "Top P {} should be valid", top_p);
328            } else {
329                assert!(result.is_err(), "Top P {} should be invalid", top_p);
330            }
331        }
332    }
333
334    #[test]
335    fn test_validate_completion_request_max_tokens_bounds() {
336        let test_cases = [
337            (0, true),     // 0 means unlimited
338            (1, true),     // Minimum valid
339            (8192, true),  // Maximum valid
340            (8193, false), // Too high
341        ];
342
343        for (max_tokens, should_pass) in test_cases {
344            let mut request = create_valid_completion_request();
345            request.extra_params = json!({"max_tokens": max_tokens});
346
347            let result = validate_completion_request(&request);
348            if should_pass {
349                assert!(result.is_ok(), "Max tokens {} should be valid", max_tokens);
350            } else {
351                assert!(
352                    result.is_err(),
353                    "Max tokens {} should be invalid",
354                    max_tokens
355                );
356            }
357        }
358    }
359
360    #[test]
361    fn test_validate_completion_request_penalty_bounds() {
362        let test_cases = [
363            (-2.0, true), // Minimum valid
364            (-1.0, true), // Valid
365            (0.0, true),  // Valid
366            (1.0, true),  // Valid
367            (2.0, true),  // Maximum valid
368            (2.1, false), // Too high
369        ];
370
371        for (penalty, should_pass) in test_cases {
372            let mut request = create_valid_completion_request();
373            request.extra_params = json!({
374                "frequency_penalty": penalty,
375                "presence_penalty": penalty
376            });
377
378            let result = validate_completion_request(&request);
379            if should_pass {
380                assert!(result.is_ok(), "Penalty {} should be valid", penalty);
381            } else {
382                assert!(result.is_err(), "Penalty {} should be invalid", penalty);
383            }
384        }
385    }
386
387    #[test]
388    fn test_validate_stop_sequence_string() {
389        let mut request = create_valid_completion_request();
390        request.extra_params = json!({"stop": "END"});
391        assert!(validate_completion_request(&request).is_ok());
392    }
393
394    #[test]
395    fn test_validate_stop_sequence_array() {
396        let mut request = create_valid_completion_request();
397        request.extra_params = json!({"stop": ["END", "STOP", "FINISHED"]});
398        assert!(validate_completion_request(&request).is_ok());
399    }
400
401    #[test]
402    fn test_validate_stop_sequence_too_many() {
403        let mut request = create_valid_completion_request();
404        request.extra_params = json!({"stop": ["A", "B", "C", "D", "E"]}); // 5 items, max is 4
405        assert!(validate_completion_request(&request).is_err());
406    }
407
408    #[test]
409    fn test_validate_stop_sequence_empty() {
410        let mut request = create_valid_completion_request();
411        request.extra_params = json!({"stop": ""});
412        assert!(validate_completion_request(&request).is_err());
413    }
414
415    #[test]
416    fn test_validate_logit_bias_valid() {
417        let mut request = create_valid_completion_request();
418        request.extra_params = json!({
419            "logit_bias": {
420                "1000": -10.0,
421                "2000": 5.0,
422                "3000": 0.0
423            }
424        });
425        assert!(validate_completion_request(&request).is_ok());
426    }
427
428    #[test]
429    fn test_validate_logit_bias_invalid_range() {
430        let test_cases = [
431            (-100.1, false), // Too low
432            (-100.0, true),  // Minimum valid
433            (0.0, true),     // Valid
434            (100.0, true),   // Maximum valid
435            (100.1, false),  // Too high
436        ];
437
438        for (bias, should_pass) in test_cases {
439            let mut request = create_valid_completion_request();
440            request.extra_params = json!({
441                "logit_bias": {
442                    "1000": bias
443                }
444            });
445
446            let result = validate_completion_request(&request);
447            if should_pass {
448                assert!(result.is_ok(), "Bias {} should be valid", bias);
449            } else {
450                assert!(result.is_err(), "Bias {} should be invalid", bias);
451            }
452        }
453    }
454
455    #[test]
456    fn test_validate_logit_bias_invalid_token() {
457        let mut request = create_valid_completion_request();
458        request.extra_params = json!({
459            "logit_bias": {
460                "invalid_token": 5.0
461            }
462        });
463        assert!(validate_completion_request(&request).is_err());
464    }
465
466    #[test]
467    fn test_validate_echo_parameter() {
468        let mut request = create_valid_completion_request();
469        request.extra_params = json!({"echo": true});
470        assert!(validate_completion_request(&request).is_ok());
471
472        request.extra_params = json!({"echo": false});
473        assert!(validate_completion_request(&request).is_ok());
474
475        request.extra_params = json!({"echo": "invalid"});
476        assert!(validate_completion_request(&request).is_err());
477    }
478
479    #[test]
480    fn test_validate_suffix_parameter() {
481        let mut request = create_valid_completion_request();
482        request.extra_params = json!({"suffix": "completed"});
483        assert!(validate_completion_request(&request).is_ok());
484
485        request.extra_params = json!({"suffix": ""});
486        assert!(validate_completion_request(&request).is_ok());
487
488        request.extra_params = json!({"suffix": null});
489        assert!(validate_completion_request(&request).is_ok());
490
491        request.extra_params = json!({"suffix": 123});
492        assert!(validate_completion_request(&request).is_err());
493    }
494
495    #[test]
496    fn test_validate_best_of_parameter() {
497        let test_cases = [
498            (0, false),  // Too low
499            (1, true),   // Minimum valid
500            (10, true),  // Valid
501            (20, true),  // Maximum valid
502            (21, false), // Too high
503        ];
504
505        for (best_of, should_pass) in test_cases {
506            let mut request = create_valid_completion_request();
507            request.extra_params = json!({"best_of": best_of});
508
509            let result = validate_completion_request(&request);
510            if should_pass {
511                assert!(result.is_ok(), "Best of {} should be valid", best_of);
512            } else {
513                assert!(result.is_err(), "Best of {} should be invalid", best_of);
514            }
515        }
516    }
517
518    #[test]
519    fn test_validate_logprobs_parameter() {
520        let test_cases = [
521            (0, true),  // Valid
522            (1, true),  // Valid
523            (5, true),  // Maximum valid
524            (6, false), // Too high
525        ];
526
527        for (logprobs, should_pass) in test_cases {
528            let mut request = create_valid_completion_request();
529            request.extra_params = json!({"logprobs": logprobs});
530
531            let result = validate_completion_request(&request);
532            if should_pass {
533                assert!(result.is_ok(), "Logprobs {} should be valid", logprobs);
534            } else {
535                assert!(result.is_err(), "Logprobs {} should be invalid", logprobs);
536            }
537        }
538    }
539
540    #[test]
541    fn test_estimate_prompt_tokens() {
542        let test_cases = [
543            ("Hello", 2),
544            ("Hello, world!", 4),
545            ("This is a longer sentence with more words.", 9),
546            ("", 0),
547        ];
548
549        for (prompt, _expected_approx) in test_cases {
550            let tokens = estimate_prompt_tokens(prompt);
551            if !prompt.is_empty() {
552                assert!(
553                    tokens > 0,
554                    "Should estimate some tokens for non-empty prompt"
555                );
556            }
557            assert!(
558                tokens <= prompt.len() as u32,
559                "Should be less than or equal to character count"
560            );
561
562            // Rough approximation check
563            let expected = (prompt.len() as f32 / 4.0).ceil() as u32;
564            assert_eq!(tokens, expected, "Should match expected calculation");
565        }
566    }
567
568    #[test]
569    fn test_check_prompt_token_limits() {
570        let short_prompt = "Hello, world!";
571        assert!(check_prompt_token_limits(short_prompt, "openai/gpt-4").is_ok());
572
573        let medium_prompt = "word ".repeat(1000);
574        assert!(check_prompt_token_limits(&medium_prompt, "openai/gpt-4").is_ok());
575
576        let long_prompt = "word ".repeat(200_000); // ~250,000 tokens - exceeds limit
577        assert!(check_prompt_token_limits(&long_prompt, "openai/gpt-4").is_err());
578    }
579
580    #[test]
581    fn test_validate_completion_request_complex_params() {
582        let mut request = create_valid_completion_request();
583        request.extra_params = json!({
584            "temperature": 0.8,
585            "max_tokens": 150,
586            "top_p": 0.95,
587            "frequency_penalty": 0.1,
588            "presence_penalty": 0.1,
589            "stop": ["END", "STOP"],
590            "logit_bias": {
591                "100": -5.0,
592                "200": 3.0
593            },
594            "echo": false,
595            "suffix": null,
596            "best_of": 1,
597            "logprobs": 2
598        });
599
600        assert!(validate_completion_request(&request).is_ok());
601    }
602
603    #[test]
604    fn test_validate_completion_request_mixed_valid_invalid() {
605        let mut request = create_valid_completion_request();
606        request.extra_params = json!({
607            "temperature": 0.8,     // valid
608            "max_tokens": 25000,    // invalid - too high
609            "top_p": 0.95,          // valid
610            "frequency_penalty": 0.1 // valid
611        });
612
613        assert!(validate_completion_request(&request).is_err());
614    }
615}