1use super::common::*;
4use crate::error::{Error, Result};
5use crate::types::completion::CompletionRequest;
6
7const MAX_PROMPT_LENGTH: usize = 1_000_000;
9
10pub fn validate_completion_request(request: &CompletionRequest) -> Result<()> {
12 validate_model_id(&request.model)?;
14
15 validate_non_empty_string(&request.prompt, "prompt")?;
17 validate_string_length(&request.prompt, "prompt", 1, MAX_PROMPT_LENGTH)?;
18
19 if let serde_json::Value::Object(params) = &request.extra_params {
21 validate_extra_params(params)?;
22 }
23
24 Ok(())
25}
26
27fn validate_extra_params(params: &serde_json::Map<String, serde_json::Value>) -> Result<()> {
29 validate_optional_numeric_param(params, "temperature", 0.0, 2.0)?;
31
32 if let Some(value) = params.get("top_p") {
34 if let Some(top_p) = value.as_f64() {
35 if top_p <= 0.0 || top_p > 1.0 {
36 return Err(Error::ConfigError(format!(
37 "Top P must be between 0.0 (exclusive) and 1.0 (inclusive), got {}",
38 top_p
39 )));
40 }
41 } else {
42 return Err(Error::ConfigError(
43 "Parameter 'top_p' must be a number".to_string(),
44 ));
45 }
46 }
47
48 if let Some(value) = params.get("max_tokens") {
50 if let Some(tokens) = value.as_u64() {
51 if tokens != 0 && !(1..=8192).contains(&tokens) {
52 return Err(Error::ConfigError(format!(
53 "Max tokens must be 0 (unlimited) or between 1 and 8192, got {}",
54 tokens
55 )));
56 }
57 } else {
58 return Err(Error::ConfigError(
59 "Parameter 'max_tokens' must be an integer".to_string(),
60 ));
61 }
62 }
63
64 validate_optional_numeric_param(params, "frequency_penalty", -2.0, 2.0)?;
66
67 validate_optional_numeric_param(params, "presence_penalty", -2.0, 2.0)?;
69
70 if let Some(value) = params.get("stop") {
72 validate_stop_sequence(value)?;
73 }
74
75 if let Some(value) = params.get("logit_bias") {
77 validate_logit_bias(value)?;
78 }
79
80 if let Some(value) = params.get("echo") {
82 if !value.is_boolean() {
83 return Err(Error::ConfigError(
84 "Parameter 'echo' must be a boolean".to_string(),
85 ));
86 }
87 }
88
89 if let Some(value) = params.get("suffix") {
91 if let Some(suffix) = value.as_str() {
92 validate_string_length(suffix, "suffix", 0, 1000)?;
93 } else if !value.is_null() {
94 return Err(Error::ConfigError(
95 "Parameter 'suffix' must be a string or null".to_string(),
96 ));
97 }
98 }
99
100 if let Some(value) = params.get("best_of") {
102 if let Some(best_of) = value.as_u64() {
103 validate_numeric_range(best_of, "best_of", 1, 20)?;
104 } else {
105 return Err(Error::ConfigError(
106 "Parameter 'best_of' must be an integer".to_string(),
107 ));
108 }
109 }
110
111 if let Some(value) = params.get("logprobs") {
113 if let Some(logprobs) = value.as_u64() {
114 validate_numeric_range(logprobs, "logprobs", 0, 5)?;
115 } else {
116 return Err(Error::ConfigError(
117 "Parameter 'logprobs' must be an integer".to_string(),
118 ));
119 }
120 }
121
122 Ok(())
123}
124
125fn validate_stop_sequence(value: &serde_json::Value) -> Result<()> {
127 match value {
128 serde_json::Value::String(stop) => {
129 validate_string_length(stop, "stop", 1, 100)?;
131 }
132 serde_json::Value::Array(stops) => {
133 validate_collection_size(stops, "stop", 1, 4)?;
135
136 for (index, stop_val) in stops.iter().enumerate() {
137 if let Some(stop_str) = stop_val.as_str() {
138 validate_string_length(stop_str, &format!("stop[{}]", index), 1, 100)?;
139 } else {
140 return Err(Error::ConfigError(format!(
141 "Stop sequence at index {} must be a string",
142 index
143 )));
144 }
145 }
146 }
147 _ => {
148 return Err(Error::ConfigError(
149 "Parameter 'stop' must be a string or array of strings".to_string(),
150 ));
151 }
152 }
153 Ok(())
154}
155
156fn validate_logit_bias(value: &serde_json::Value) -> Result<()> {
158 if let serde_json::Value::Object(bias_map) = value {
159 for (token_str, bias_val) in bias_map {
161 if token_str.parse::<i32>().is_err() {
163 return Err(Error::ConfigError(format!(
164 "Logit bias token '{}' must be a valid integer",
165 token_str
166 )));
167 }
168
169 if !bias_val.is_number() {
171 return Err(Error::ConfigError(format!(
172 "Logit bias for token '{}' must be a number",
173 token_str
174 )));
175 }
176
177 if let Some(bias) = bias_val.as_f64() {
179 if !(-100.0..=100.0).contains(&bias) {
180 return Err(Error::ConfigError(format!(
181 "Logit bias for token '{}' must be between -100 and 100, got {}",
182 token_str, bias
183 )));
184 }
185 }
186 }
187 } else {
188 return Err(Error::ConfigError(
189 "Parameter 'logit_bias' must be a JSON object".to_string(),
190 ));
191 }
192 Ok(())
193}
194
195pub fn estimate_prompt_tokens(prompt: &str) -> u32 {
197 (prompt.len() as f32 / 4.0).ceil() as u32
200}
201
202pub fn check_prompt_token_limits(prompt: &str, model: &str) -> Result<()> {
204 let estimated_tokens = estimate_prompt_tokens(prompt);
205
206 const MAX_COMPLETION_TOKENS: u32 = 200_000;
208
209 if estimated_tokens > MAX_COMPLETION_TOKENS {
210 return Err(Error::ContextLengthExceeded {
211 model: model.to_string(),
212 message: format!(
213 "Estimated prompt token count ({}) exceeds maximum recommended limit ({})",
214 estimated_tokens, MAX_COMPLETION_TOKENS
215 ),
216 });
217 }
218
219 Ok(())
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use serde_json::json;
226
227 fn create_valid_completion_request() -> CompletionRequest {
228 CompletionRequest {
229 model: "openai/gpt-4".to_string(),
230 prompt: "Once upon a time,".to_string(),
231 extra_params: serde_json::json!({}),
232 }
233 }
234
235 #[test]
236 fn test_validate_completion_request_valid() {
237 let request = create_valid_completion_request();
238 assert!(validate_completion_request(&request).is_ok());
239 }
240
241 #[test]
242 fn test_validate_completion_request_empty_model() {
243 let mut request = create_valid_completion_request();
244 request.model = "".to_string();
245 assert!(validate_completion_request(&request).is_err());
246 }
247
248 #[test]
249 fn test_validate_completion_request_invalid_model_format() {
250 let mut request = create_valid_completion_request();
251 request.model = "invalid-model-name".to_string();
252 assert!(validate_completion_request(&request).is_err());
253 }
254
255 #[test]
256 fn test_validate_completion_request_empty_prompt() {
257 let mut request = create_valid_completion_request();
258 request.prompt = "".to_string();
259 assert!(validate_completion_request(&request).is_err());
260 }
261
262 #[test]
263 fn test_validate_completion_request_whitespace_prompt() {
264 let mut request = create_valid_completion_request();
265 request.prompt = " ".to_string();
266 assert!(validate_completion_request(&request).is_err());
267 }
268
269 #[test]
270 fn test_validate_completion_request_prompt_too_long() {
271 let mut request = create_valid_completion_request();
272 request.prompt = "a".repeat(1_000_001);
273 assert!(validate_completion_request(&request).is_err());
274 }
275
276 #[test]
277 fn test_validate_completion_request_valid_extra_params() {
278 let mut request = create_valid_completion_request();
279 request.extra_params = json!({
280 "temperature": 0.7,
281 "max_tokens": 100,
282 "top_p": 0.9,
283 "frequency_penalty": 0.5,
284 "presence_penalty": 0.3
285 });
286 assert!(validate_completion_request(&request).is_ok());
287 }
288
289 #[test]
290 fn test_validate_completion_request_temperature_bounds() {
291 let test_cases = [
292 (-0.1, false), (0.0, true), (1.0, true), (2.0, true), (2.1, false), ];
298
299 for (temp, should_pass) in test_cases {
300 let mut request = create_valid_completion_request();
301 request.extra_params = json!({"temperature": temp});
302
303 let result = validate_completion_request(&request);
304 if should_pass {
305 assert!(result.is_ok(), "Temperature {} should be valid", temp);
306 } else {
307 assert!(result.is_err(), "Temperature {} should be invalid", temp);
308 }
309 }
310 }
311
312 #[test]
313 fn test_validate_completion_request_top_p_bounds() {
314 let test_cases = [
315 (0.0, false), (0.1, true), (1.0, true), (1.1, false), ];
320
321 for (top_p, should_pass) in test_cases {
322 let mut request = create_valid_completion_request();
323 request.extra_params = json!({"top_p": top_p});
324
325 let result = validate_completion_request(&request);
326 if should_pass {
327 assert!(result.is_ok(), "Top P {} should be valid", top_p);
328 } else {
329 assert!(result.is_err(), "Top P {} should be invalid", top_p);
330 }
331 }
332 }
333
334 #[test]
335 fn test_validate_completion_request_max_tokens_bounds() {
336 let test_cases = [
337 (0, true), (1, true), (8192, true), (8193, false), ];
342
343 for (max_tokens, should_pass) in test_cases {
344 let mut request = create_valid_completion_request();
345 request.extra_params = json!({"max_tokens": max_tokens});
346
347 let result = validate_completion_request(&request);
348 if should_pass {
349 assert!(result.is_ok(), "Max tokens {} should be valid", max_tokens);
350 } else {
351 assert!(
352 result.is_err(),
353 "Max tokens {} should be invalid",
354 max_tokens
355 );
356 }
357 }
358 }
359
360 #[test]
361 fn test_validate_completion_request_penalty_bounds() {
362 let test_cases = [
363 (-2.0, true), (-1.0, true), (0.0, true), (1.0, true), (2.0, true), (2.1, false), ];
370
371 for (penalty, should_pass) in test_cases {
372 let mut request = create_valid_completion_request();
373 request.extra_params = json!({
374 "frequency_penalty": penalty,
375 "presence_penalty": penalty
376 });
377
378 let result = validate_completion_request(&request);
379 if should_pass {
380 assert!(result.is_ok(), "Penalty {} should be valid", penalty);
381 } else {
382 assert!(result.is_err(), "Penalty {} should be invalid", penalty);
383 }
384 }
385 }
386
387 #[test]
388 fn test_validate_stop_sequence_string() {
389 let mut request = create_valid_completion_request();
390 request.extra_params = json!({"stop": "END"});
391 assert!(validate_completion_request(&request).is_ok());
392 }
393
394 #[test]
395 fn test_validate_stop_sequence_array() {
396 let mut request = create_valid_completion_request();
397 request.extra_params = json!({"stop": ["END", "STOP", "FINISHED"]});
398 assert!(validate_completion_request(&request).is_ok());
399 }
400
401 #[test]
402 fn test_validate_stop_sequence_too_many() {
403 let mut request = create_valid_completion_request();
404 request.extra_params = json!({"stop": ["A", "B", "C", "D", "E"]}); assert!(validate_completion_request(&request).is_err());
406 }
407
408 #[test]
409 fn test_validate_stop_sequence_empty() {
410 let mut request = create_valid_completion_request();
411 request.extra_params = json!({"stop": ""});
412 assert!(validate_completion_request(&request).is_err());
413 }
414
415 #[test]
416 fn test_validate_logit_bias_valid() {
417 let mut request = create_valid_completion_request();
418 request.extra_params = json!({
419 "logit_bias": {
420 "1000": -10.0,
421 "2000": 5.0,
422 "3000": 0.0
423 }
424 });
425 assert!(validate_completion_request(&request).is_ok());
426 }
427
428 #[test]
429 fn test_validate_logit_bias_invalid_range() {
430 let test_cases = [
431 (-100.1, false), (-100.0, true), (0.0, true), (100.0, true), (100.1, false), ];
437
438 for (bias, should_pass) in test_cases {
439 let mut request = create_valid_completion_request();
440 request.extra_params = json!({
441 "logit_bias": {
442 "1000": bias
443 }
444 });
445
446 let result = validate_completion_request(&request);
447 if should_pass {
448 assert!(result.is_ok(), "Bias {} should be valid", bias);
449 } else {
450 assert!(result.is_err(), "Bias {} should be invalid", bias);
451 }
452 }
453 }
454
455 #[test]
456 fn test_validate_logit_bias_invalid_token() {
457 let mut request = create_valid_completion_request();
458 request.extra_params = json!({
459 "logit_bias": {
460 "invalid_token": 5.0
461 }
462 });
463 assert!(validate_completion_request(&request).is_err());
464 }
465
466 #[test]
467 fn test_validate_echo_parameter() {
468 let mut request = create_valid_completion_request();
469 request.extra_params = json!({"echo": true});
470 assert!(validate_completion_request(&request).is_ok());
471
472 request.extra_params = json!({"echo": false});
473 assert!(validate_completion_request(&request).is_ok());
474
475 request.extra_params = json!({"echo": "invalid"});
476 assert!(validate_completion_request(&request).is_err());
477 }
478
479 #[test]
480 fn test_validate_suffix_parameter() {
481 let mut request = create_valid_completion_request();
482 request.extra_params = json!({"suffix": "completed"});
483 assert!(validate_completion_request(&request).is_ok());
484
485 request.extra_params = json!({"suffix": ""});
486 assert!(validate_completion_request(&request).is_ok());
487
488 request.extra_params = json!({"suffix": null});
489 assert!(validate_completion_request(&request).is_ok());
490
491 request.extra_params = json!({"suffix": 123});
492 assert!(validate_completion_request(&request).is_err());
493 }
494
495 #[test]
496 fn test_validate_best_of_parameter() {
497 let test_cases = [
498 (0, false), (1, true), (10, true), (20, true), (21, false), ];
504
505 for (best_of, should_pass) in test_cases {
506 let mut request = create_valid_completion_request();
507 request.extra_params = json!({"best_of": best_of});
508
509 let result = validate_completion_request(&request);
510 if should_pass {
511 assert!(result.is_ok(), "Best of {} should be valid", best_of);
512 } else {
513 assert!(result.is_err(), "Best of {} should be invalid", best_of);
514 }
515 }
516 }
517
518 #[test]
519 fn test_validate_logprobs_parameter() {
520 let test_cases = [
521 (0, true), (1, true), (5, true), (6, false), ];
526
527 for (logprobs, should_pass) in test_cases {
528 let mut request = create_valid_completion_request();
529 request.extra_params = json!({"logprobs": logprobs});
530
531 let result = validate_completion_request(&request);
532 if should_pass {
533 assert!(result.is_ok(), "Logprobs {} should be valid", logprobs);
534 } else {
535 assert!(result.is_err(), "Logprobs {} should be invalid", logprobs);
536 }
537 }
538 }
539
540 #[test]
541 fn test_estimate_prompt_tokens() {
542 let test_cases = [
543 ("Hello", 2),
544 ("Hello, world!", 4),
545 ("This is a longer sentence with more words.", 9),
546 ("", 0),
547 ];
548
549 for (prompt, _expected_approx) in test_cases {
550 let tokens = estimate_prompt_tokens(prompt);
551 if !prompt.is_empty() {
552 assert!(
553 tokens > 0,
554 "Should estimate some tokens for non-empty prompt"
555 );
556 }
557 assert!(
558 tokens <= prompt.len() as u32,
559 "Should be less than or equal to character count"
560 );
561
562 let expected = (prompt.len() as f32 / 4.0).ceil() as u32;
564 assert_eq!(tokens, expected, "Should match expected calculation");
565 }
566 }
567
568 #[test]
569 fn test_check_prompt_token_limits() {
570 let short_prompt = "Hello, world!";
571 assert!(check_prompt_token_limits(short_prompt, "openai/gpt-4").is_ok());
572
573 let medium_prompt = "word ".repeat(1000);
574 assert!(check_prompt_token_limits(&medium_prompt, "openai/gpt-4").is_ok());
575
576 let long_prompt = "word ".repeat(200_000); assert!(check_prompt_token_limits(&long_prompt, "openai/gpt-4").is_err());
578 }
579
580 #[test]
581 fn test_validate_completion_request_complex_params() {
582 let mut request = create_valid_completion_request();
583 request.extra_params = json!({
584 "temperature": 0.8,
585 "max_tokens": 150,
586 "top_p": 0.95,
587 "frequency_penalty": 0.1,
588 "presence_penalty": 0.1,
589 "stop": ["END", "STOP"],
590 "logit_bias": {
591 "100": -5.0,
592 "200": 3.0
593 },
594 "echo": false,
595 "suffix": null,
596 "best_of": 1,
597 "logprobs": 2
598 });
599
600 assert!(validate_completion_request(&request).is_ok());
601 }
602
603 #[test]
604 fn test_validate_completion_request_mixed_valid_invalid() {
605 let mut request = create_valid_completion_request();
606 request.extra_params = json!({
607 "temperature": 0.8, "max_tokens": 25000, "top_p": 0.95, "frequency_penalty": 0.1 });
612
613 assert!(validate_completion_request(&request).is_err());
614 }
615}