openai_protocol/
sampling_params.rs1use serde::{Deserialize, Serialize};
2use validator::Validate;
3
4use super::common::StringOrArray;
5
6#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)]
8#[validate(schema(function = "validate_sampling_params"))]
9pub struct SamplingParams {
10 #[serde(skip_serializing_if = "Option::is_none")]
12 #[validate(range(min = 0.0))]
13 pub temperature: Option<f32>,
14 #[serde(skip_serializing_if = "Option::is_none")]
16 #[validate(range(min = 0))]
17 pub max_new_tokens: Option<u32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
20 #[validate(custom(function = "validate_top_p_value"))]
21 pub top_p: Option<f32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 #[validate(custom(function = "validate_top_k_value"))]
25 pub top_k: Option<i32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 #[validate(range(min = -2.0, max = 2.0))]
28 pub frequency_penalty: Option<f32>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 #[validate(range(min = -2.0, max = 2.0))]
31 pub presence_penalty: Option<f32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 #[validate(range(min = 0.0, max = 2.0))]
34 pub repetition_penalty: Option<f32>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub stop: Option<StringOrArray>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub ignore_eos: Option<bool>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub skip_special_tokens: Option<bool>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub json_schema: Option<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub regex: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub ebnf: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 #[validate(range(min = 0.0, max = 1.0))]
49 pub min_p: Option<f32>,
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub min_new_tokens: Option<u32>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub stop_token_ids: Option<Vec<u32>>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub no_stop_trim: Option<bool>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub n: Option<u32>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub sampling_seed: Option<u64>,
61}
62
63pub fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> {
69 if !(top_p > 0.0 && top_p <= 1.0) {
70 return Err(validator::ValidationError::new(
71 "top_p must be in (0, 1] - greater than 0.0 and at most 1.0",
72 ));
73 }
74 Ok(())
75}
76
77pub fn validate_top_k_value(top_k: i32) -> Result<(), validator::ValidationError> {
79 if top_k != -1 && top_k < 1 {
80 return Err(validator::ValidationError::new(
81 "top_k must be -1 (disabled) or at least 1",
82 ));
83 }
84 Ok(())
85}
86
87fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::ValidationError> {
93 if let (Some(min), Some(max)) = (params.min_new_tokens, params.max_new_tokens) {
95 if min > max {
96 return Err(validator::ValidationError::new(
97 "min_new_tokens cannot exceed max_new_tokens",
98 ));
99 }
100 }
101
102 let constraint_count = [
104 params.regex.is_some(),
105 params.ebnf.is_some(),
106 params.json_schema.is_some(),
107 ]
108 .iter()
109 .filter(|&&x| x)
110 .count();
111
112 if constraint_count > 1 {
113 return Err(validator::ValidationError::new(
114 "only one of regex, ebnf, or json_schema can be set",
115 ));
116 }
117
118 Ok(())
119}