Skip to main content

openai_protocol/
sampling_params.rs

1use serde::{Deserialize, Serialize};
2use validator::Validate;
3
4use super::common::StringOrArray;
5
6/// Sampling parameters for text generation
7#[serde_with::skip_serializing_none]
8#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate, schemars::JsonSchema)]
9#[validate(schema(function = "validate_sampling_params"))]
10pub struct SamplingParams {
11    /// Temperature for sampling (must be >= 0.0, no upper limit)
12    #[validate(range(min = 0.0))]
13    pub temperature: Option<f32>,
14    /// Maximum number of new tokens to generate (must be >= 0)
15    #[validate(range(min = 0))]
16    pub max_new_tokens: Option<u32>,
17    /// Top-p nucleus sampling (0.0 < top_p <= 1.0)
18    #[validate(custom(function = "validate_top_p_value"))]
19    pub top_p: Option<f32>,
20    /// Top-k sampling (-1 to disable, or >= 1)
21    #[validate(custom(function = "validate_top_k_value"))]
22    pub top_k: Option<i32>,
23    #[validate(range(min = -2.0, max = 2.0))]
24    pub frequency_penalty: Option<f32>,
25    #[validate(range(min = -2.0, max = 2.0))]
26    pub presence_penalty: Option<f32>,
27    #[validate(range(min = 0.0, max = 2.0))]
28    pub repetition_penalty: Option<f32>,
29    pub stop: Option<StringOrArray>,
30    pub ignore_eos: Option<bool>,
31    pub skip_special_tokens: Option<bool>,
32    pub json_schema: Option<String>,
33    pub regex: Option<String>,
34    pub ebnf: Option<String>,
35    #[validate(range(min = 0.0, max = 1.0))]
36    pub min_p: Option<f32>,
37    /// Minimum number of new tokens (validated in schema function for cross-field check with max_new_tokens)
38    pub min_new_tokens: Option<u32>,
39    pub stop_token_ids: Option<Vec<u32>>,
40    pub no_stop_trim: Option<bool>,
41    pub n: Option<u32>,
42    pub sampling_seed: Option<u64>,
43}
44
45// ============================================================================
46// Shared Validation Functions
47// ============================================================================
48
49/// Validates top_p: 0.0 < top_p <= 1.0 (can't use range validator for open interval)
50pub fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> {
51    if !(top_p > 0.0 && top_p <= 1.0) {
52        return Err(validator::ValidationError::new(
53            "top_p must be in (0, 1] - greater than 0.0 and at most 1.0",
54        ));
55    }
56    Ok(())
57}
58
59/// Validates top_k: -1 (disabled) or >= 1 (special -1 case - can't use range validator)
60pub fn validate_top_k_value(top_k: i32) -> Result<(), validator::ValidationError> {
61    if top_k != -1 && top_k < 1 {
62        return Err(validator::ValidationError::new(
63            "top_k must be -1 (disabled) or at least 1",
64        ));
65    }
66    Ok(())
67}
68
69// ============================================================================
70// SamplingParams-Specific Validation
71// ============================================================================
72
73/// Validation function for SamplingParams - cross-field validation only
74fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::ValidationError> {
75    // 1. Cross-field validation: min_new_tokens <= max_new_tokens
76    if let (Some(min), Some(max)) = (params.min_new_tokens, params.max_new_tokens) {
77        if min > max {
78            return Err(validator::ValidationError::new(
79                "min_new_tokens cannot exceed max_new_tokens",
80            ));
81        }
82    }
83
84    // 2. Validate mutually exclusive structured output constraints
85    let constraint_count = [
86        params.regex.is_some(),
87        params.ebnf.is_some(),
88        params.json_schema.is_some(),
89    ]
90    .iter()
91    .filter(|&&x| x)
92    .count();
93
94    if constraint_count > 1 {
95        return Err(validator::ValidationError::new(
96            "only one of regex, ebnf, or json_schema can be set",
97        ));
98    }
99
100    Ok(())
101}