dynamo_llm/protocols/openai/
common_ext.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::nvext::validate_top_k;
5use derive_builder::Builder;
6use serde::{Deserialize, Serialize};
7use validator::Validate;
8
9/// Common extensions for OpenAI API requests that are not part of the standard OpenAI spec
10/// but are commonly needed across different request types.
11#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone, Default)]
12pub struct CommonExt {
13    /// If true, the model will ignore the end of string token and generate to max_tokens.
14    /// This field can also be specified in nvext, but the root-level value takes precedence.
15    #[serde(default, skip_serializing_if = "Option::is_none")]
16    #[builder(default, setter(strip_option))]
17    pub ignore_eos: Option<bool>,
18
19    /// The minimum number of tokens to generate.
20    /// This is a common parameter needed across different request types.
21    #[serde(default, skip_serializing_if = "Option::is_none")]
22    #[builder(default, setter(strip_option))]
23    pub min_tokens: Option<u32>,
24
25    /// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    #[builder(default, setter(strip_option))]
28    #[validate(custom(function = "validate_top_k"))]
29    pub top_k: Option<i32>,
30
31    /// Relative probability floor
32    #[serde(default, skip_serializing_if = "Option::is_none")]
33    #[builder(default, setter(strip_option))]
34    #[validate(range(min = 0.0, max = 1.0))]
35    pub min_p: Option<f32>,
36
37    /// How much to penalize tokens based on how frequently they occur in the text.
38    /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    #[builder(default, setter(strip_option))]
41    #[validate(range(exclusive_min = 0.0, max = 2.0))]
42    pub repetition_penalty: Option<f32>,
43
44    /// include_stop_str_in_output
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    #[builder(default, setter(strip_option))]
47    pub include_stop_str_in_output: Option<bool>,
48
49    /// Guided Decoding Options
50    /// If specified, the output will be a JSON object. Can be a string, an object, or null.
51    #[serde(default, skip_serializing_if = "Option::is_none")]
52    #[builder(default, setter(strip_option))]
53    pub guided_json: Option<serde_json::Value>,
54
55    /// If specified, the output will follow the regex pattern. Can be a string or null.
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    #[builder(default, setter(strip_option))]
58    pub guided_regex: Option<String>,
59
60    /// If specified, the output will follow the context-free grammar. Can be a string or null.
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    #[builder(default, setter(strip_option))]
63    pub guided_grammar: Option<String>,
64
65    /// If specified, the output will be exactly one of the choices.
66    #[serde(default, skip_serializing_if = "Option::is_none")]
67    #[builder(default, setter(strip_option))]
68    pub guided_choice: Option<Vec<String>>,
69
70    /// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
71    #[serde(default, skip_serializing_if = "Option::is_none")]
72    #[builder(default, setter(strip_option))]
73    pub guided_decoding_backend: Option<String>,
74}
75
76impl CommonExt {
77    pub fn builder() -> CommonExtBuilder {
78        CommonExtBuilder::default()
79    }
80}
81
82/// Trait for types that provide CommonExt fields
83pub trait CommonExtProvider {
84    /// Get a reference to the CommonExt struct if available
85    fn common_ext(&self) -> Option<&CommonExt>;
86
87    /// Guided Decoding Options
88    fn get_guided_json(&self) -> Option<&serde_json::Value>;
89    fn get_guided_regex(&self) -> Option<String>;
90    fn get_guided_grammar(&self) -> Option<String>;
91    fn get_guided_choice(&self) -> Option<Vec<String>>;
92    fn get_guided_decoding_backend(&self) -> Option<String>;
93
94    /// Other sampling Options
95    fn get_top_k(&self) -> Option<i32>;
96    fn get_min_p(&self) -> Option<f32>;
97    fn get_repetition_penalty(&self) -> Option<f32>;
98    fn get_include_stop_str_in_output(&self) -> Option<bool>;
99}
100
101/// Helper function to emit deprecation warnings for nvext parameters
102pub fn emit_nvext_deprecation_warning(
103    field_name: &str,
104    nvext_has_value: bool,
105    common_has_value: bool,
106) {
107    if nvext_has_value && !common_has_value {
108        tracing::warn!(
109            "DEPRECATION WARNING: 'nvext.{field_name}' is deprecated and will be removed in a future release. Use '{field_name}' at the top level or in 'extra_body' instead."
110        );
111    } else if nvext_has_value && common_has_value {
112        tracing::warn!(
113            "DEPRECATION WARNING: 'nvext.{field_name}' is deprecated and will be removed in a future release. Top-level '{field_name}' takes precedence. Use '{field_name}' at the top level or in 'extra_body' instead."
114        );
115    }
116}
117
118/// Helper function to choose between common and nvext values with deprecation warnings
119pub fn choose_with_deprecation<T: Clone>(
120    field: &'static str,
121    common: Option<&T>,
122    nv: Option<&T>,
123) -> Option<T> {
124    if nv.is_some() {
125        emit_nvext_deprecation_warning(field, true, common.is_some());
126    }
127    common.cloned().or_else(|| nv.cloned())
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    use serde_json;
135
136    #[test]
137    fn test_common_ext_builder_default() {
138        let common_ext = CommonExt::builder().build().unwrap();
139        assert_eq!(common_ext.ignore_eos, None);
140        assert_eq!(common_ext.min_tokens, None);
141        assert_eq!(common_ext.top_k, None);
142        assert_eq!(common_ext.repetition_penalty, None);
143        assert_eq!(common_ext.guided_json, None);
144        assert_eq!(common_ext.guided_regex, None);
145        assert_eq!(common_ext.guided_grammar, None);
146        assert_eq!(common_ext.guided_choice, None);
147        assert_eq!(common_ext.guided_decoding_backend, None);
148        assert_eq!(common_ext.include_stop_str_in_output, None);
149    }
150
151    #[test]
152    fn test_common_ext_builder_with_values() {
153        let common_ext = CommonExt::builder()
154            .ignore_eos(true)
155            .min_tokens(10)
156            .top_k(50)
157            .repetition_penalty(1.2)
158            .include_stop_str_in_output(true)
159            .guided_json(serde_json::json!({"key": "value"}))
160            .guided_regex("regex".to_string())
161            .guided_grammar("grammar".to_string())
162            .guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
163            .guided_decoding_backend("backend".to_string())
164            .build()
165            .unwrap();
166
167        assert_eq!(common_ext.ignore_eos, Some(true));
168        assert_eq!(common_ext.min_tokens, Some(10));
169        assert_eq!(common_ext.top_k, Some(50));
170        assert_eq!(common_ext.repetition_penalty, Some(1.2));
171        assert_eq!(common_ext.include_stop_str_in_output, Some(true));
172        assert_eq!(
173            common_ext.guided_json.as_ref(),
174            Some(&serde_json::json!({"key": "value"}))
175        );
176        assert_eq!(common_ext.guided_regex, Some("regex".to_string()));
177        assert_eq!(common_ext.guided_grammar, Some("grammar".to_string()));
178        assert_eq!(
179            common_ext.guided_choice,
180            Some(vec!["choice1".to_string(), "choice2".to_string()])
181        );
182        assert_eq!(
183            common_ext.guided_decoding_backend,
184            Some("backend".to_string())
185        );
186    }
187
188    #[test]
189    fn test_common_ext_fields() {
190        // Test that CommonExt fields can be set and retrieved correctly
191        let common_ext = CommonExt::builder()
192            .ignore_eos(false)
193            .min_tokens(5)
194            .include_stop_str_in_output(true)
195            .build()
196            .unwrap();
197
198        assert_eq!(common_ext.ignore_eos, Some(false));
199        assert_eq!(common_ext.min_tokens, Some(5));
200        assert_eq!(common_ext.include_stop_str_in_output, Some(true));
201    }
202
203    #[test]
204    fn test_validation_min_tokens() {
205        // Test that min_tokens with 0 is valid
206        let common_ext = CommonExt {
207            ignore_eos: None,
208            min_tokens: Some(0), // Should be valid (min = 0)
209            top_k: None,
210            min_p: None,
211            repetition_penalty: None,
212            include_stop_str_in_output: None,
213            guided_json: None,
214            guided_regex: None,
215            guided_grammar: None,
216            guided_choice: None,
217            guided_decoding_backend: None,
218        };
219        assert!(common_ext.validate().is_ok());
220    }
221
222    #[test]
223    fn test_common_ext_neither_specified() {
224        // Test that neither ignore_eos nor min_tokens specified works
225        let common_ext = CommonExt::builder().build().unwrap();
226
227        assert_eq!(common_ext.ignore_eos, None);
228        assert_eq!(common_ext.min_tokens, None);
229        assert_eq!(common_ext.top_k, None);
230        assert_eq!(common_ext.repetition_penalty, None);
231        assert_eq!(common_ext.include_stop_str_in_output, None);
232        assert!(common_ext.validate().is_ok());
233    }
234
235    #[test]
236    fn test_common_ext_default() {
237        // Test that Default trait implementation works correctly
238        let common_ext = CommonExt::default();
239
240        assert_eq!(common_ext.ignore_eos, None);
241        assert_eq!(common_ext.min_tokens, None);
242        assert_eq!(common_ext.top_k, None);
243        assert_eq!(common_ext.repetition_penalty, None);
244        assert_eq!(common_ext.include_stop_str_in_output, None);
245        assert!(common_ext.validate().is_ok());
246    }
247
248    #[test]
249    fn test_choose_with_deprecation() {
250        // Common takes precedence
251        let result = choose_with_deprecation(
252            "test_field",
253            Some(&"common_value".to_string()),
254            Some(&"nvext_value".to_string()),
255        );
256        assert_eq!(result, Some("common_value".to_string()));
257
258        // Fallback to nvext
259        let result = choose_with_deprecation("test_field", None, Some(&"nvext_value".to_string()));
260        assert_eq!(result, Some("nvext_value".to_string()));
261
262        // Both None
263        let result: Option<String> = choose_with_deprecation("test_field", None, None);
264        assert_eq!(result, None);
265    }
266}