openai_protocol/
sampling_params.rs

1use serde::{Deserialize, Serialize};
2use validator::Validate;
3
4use super::common::StringOrArray;
5
6/// Sampling parameters for text generation
7#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)]
8#[validate(schema(function = "validate_sampling_params"))]
9pub struct SamplingParams {
10    /// Temperature for sampling (must be >= 0.0, no upper limit)
11    #[serde(skip_serializing_if = "Option::is_none")]
12    #[validate(range(min = 0.0))]
13    pub temperature: Option<f32>,
14    /// Maximum number of new tokens to generate (must be >= 0)
15    #[serde(skip_serializing_if = "Option::is_none")]
16    #[validate(range(min = 0))]
17    pub max_new_tokens: Option<u32>,
18    /// Top-p nucleus sampling (0.0 < top_p <= 1.0)
19    #[serde(skip_serializing_if = "Option::is_none")]
20    #[validate(custom(function = "validate_top_p_value"))]
21    pub top_p: Option<f32>,
22    /// Top-k sampling (-1 to disable, or >= 1)
23    #[serde(skip_serializing_if = "Option::is_none")]
24    #[validate(custom(function = "validate_top_k_value"))]
25    pub top_k: Option<i32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    #[validate(range(min = -2.0, max = 2.0))]
28    pub frequency_penalty: Option<f32>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    #[validate(range(min = -2.0, max = 2.0))]
31    pub presence_penalty: Option<f32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    #[validate(range(min = 0.0, max = 2.0))]
34    pub repetition_penalty: Option<f32>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub stop: Option<StringOrArray>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub ignore_eos: Option<bool>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub skip_special_tokens: Option<bool>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub json_schema: Option<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub regex: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub ebnf: Option<String>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    #[validate(range(min = 0.0, max = 1.0))]
49    pub min_p: Option<f32>,
50    /// Minimum number of new tokens (validated in schema function for cross-field check with max_new_tokens)
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub min_new_tokens: Option<u32>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub stop_token_ids: Option<Vec<u32>>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub no_stop_trim: Option<bool>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub n: Option<u32>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub sampling_seed: Option<u64>,
61}
62
63// ============================================================================
64// Shared Validation Functions
65// ============================================================================
66
67/// Validates top_p: 0.0 < top_p <= 1.0 (can't use range validator for open interval)
68pub fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> {
69    if !(top_p > 0.0 && top_p <= 1.0) {
70        return Err(validator::ValidationError::new(
71            "top_p must be in (0, 1] - greater than 0.0 and at most 1.0",
72        ));
73    }
74    Ok(())
75}
76
77/// Validates top_k: -1 (disabled) or >= 1 (special -1 case - can't use range validator)
78pub fn validate_top_k_value(top_k: i32) -> Result<(), validator::ValidationError> {
79    if top_k != -1 && top_k < 1 {
80        return Err(validator::ValidationError::new(
81            "top_k must be -1 (disabled) or at least 1",
82        ));
83    }
84    Ok(())
85}
86
87// ============================================================================
88// SamplingParams-Specific Validation
89// ============================================================================
90
91/// Validation function for SamplingParams - cross-field validation only
92fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::ValidationError> {
93    // 1. Cross-field validation: min_new_tokens <= max_new_tokens
94    if let (Some(min), Some(max)) = (params.min_new_tokens, params.max_new_tokens) {
95        if min > max {
96            return Err(validator::ValidationError::new(
97                "min_new_tokens cannot exceed max_new_tokens",
98            ));
99        }
100    }
101
102    // 2. Validate mutually exclusive structured output constraints
103    let constraint_count = [
104        params.regex.is_some(),
105        params.ebnf.is_some(),
106        params.json_schema.is_some(),
107    ]
108    .iter()
109    .filter(|&&x| x)
110    .count();
111
112    if constraint_count > 1 {
113        return Err(validator::ValidationError::new(
114            "only one of regex, ebnf, or json_schema can be set",
115        ));
116    }
117
118    Ok(())
119}