openai_protocol/
sampling_params.rs1use serde::{Deserialize, Serialize};
2use validator::Validate;
3
4use super::common::StringOrArray;
5
6#[serde_with::skip_serializing_none]
8#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)]
9#[validate(schema(function = "validate_sampling_params"))]
10pub struct SamplingParams {
11 #[validate(range(min = 0.0))]
13 pub temperature: Option<f32>,
14 #[validate(range(min = 0))]
16 pub max_new_tokens: Option<u32>,
17 #[validate(custom(function = "validate_top_p_value"))]
19 pub top_p: Option<f32>,
20 #[validate(custom(function = "validate_top_k_value"))]
22 pub top_k: Option<i32>,
23 #[validate(range(min = -2.0, max = 2.0))]
24 pub frequency_penalty: Option<f32>,
25 #[validate(range(min = -2.0, max = 2.0))]
26 pub presence_penalty: Option<f32>,
27 #[validate(range(min = 0.0, max = 2.0))]
28 pub repetition_penalty: Option<f32>,
29 pub stop: Option<StringOrArray>,
30 pub ignore_eos: Option<bool>,
31 pub skip_special_tokens: Option<bool>,
32 pub json_schema: Option<String>,
33 pub regex: Option<String>,
34 pub ebnf: Option<String>,
35 #[validate(range(min = 0.0, max = 1.0))]
36 pub min_p: Option<f32>,
37 pub min_new_tokens: Option<u32>,
39 pub stop_token_ids: Option<Vec<u32>>,
40 pub no_stop_trim: Option<bool>,
41 pub n: Option<u32>,
42 pub sampling_seed: Option<u64>,
43}
44
45pub fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> {
51 if !(top_p > 0.0 && top_p <= 1.0) {
52 return Err(validator::ValidationError::new(
53 "top_p must be in (0, 1] - greater than 0.0 and at most 1.0",
54 ));
55 }
56 Ok(())
57}
58
59pub fn validate_top_k_value(top_k: i32) -> Result<(), validator::ValidationError> {
61 if top_k != -1 && top_k < 1 {
62 return Err(validator::ValidationError::new(
63 "top_k must be -1 (disabled) or at least 1",
64 ));
65 }
66 Ok(())
67}
68
69fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::ValidationError> {
75 if let (Some(min), Some(max)) = (params.min_new_tokens, params.max_new_tokens) {
77 if min > max {
78 return Err(validator::ValidationError::new(
79 "min_new_tokens cannot exceed max_new_tokens",
80 ));
81 }
82 }
83
84 let constraint_count = [
86 params.regex.is_some(),
87 params.ebnf.is_some(),
88 params.json_schema.is_some(),
89 ]
90 .iter()
91 .filter(|&&x| x)
92 .count();
93
94 if constraint_count > 1 {
95 return Err(validator::ValidationError::new(
96 "only one of regex, ebnf, or json_schema can be set",
97 ));
98 }
99
100 Ok(())
101}