dynamo_llm/protocols/openai/
nvext.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use derive_builder::Builder;
5use serde::{Deserialize, Serialize};
6use validator::{Validate, ValidationError};
7
8pub trait NvExtProvider {
9    fn nvext(&self) -> Option<&NvExt>;
10    fn raw_prompt(&self) -> Option<String>;
11}
12
13/// NVIDIA LLM extensions to the OpenAI API
14#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
15#[validate(schema(function = "validate_nv_ext"))]
16pub struct NvExt {
17    /// If true, the model will ignore the end of string token and generate to max_tokens.
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    #[builder(default, setter(strip_option))]
20    pub ignore_eos: Option<bool>,
21
22    #[builder(default, setter(strip_option))] // NIM LLM might default to -1
23    #[validate(custom(function = "validate_top_k"))]
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub top_k: Option<i32>,
26
27    /// Relative probability floor
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    #[builder(default, setter(strip_option))]
30    #[validate(range(min = 0.0, max = 1.0))]
31    pub min_p: Option<f32>,
32
33    /// How much to penalize tokens based on how frequently they occur in the text.
34    /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
35    #[builder(default, setter(strip_option))]
36    #[validate(range(exclusive_min = 0.0, max = 2.0))]
37    pub repetition_penalty: Option<f32>,
38
39    /// If true, sampling will be forced to be greedy.
40    /// The backend is responsible for selecting the correct backend-specific options to
41    /// implement this.
42    #[serde(default, skip_serializing_if = "Option::is_none")]
43    #[builder(default, setter(strip_option))]
44    pub greed_sampling: Option<bool>,
45
46    /// If true, the preproessor will try to bypass the prompt template and pass the prompt directly to
47    /// to the tokenizer.
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    #[builder(default, setter(strip_option))]
50    pub use_raw_prompt: Option<bool>,
51
52    /// Annotations
53    /// User requests triggers which result in the request issue back out-of-band information in the SSE
54    /// stream using the `event:` field.
55    #[serde(default, skip_serializing_if = "Option::is_none")]
56    #[builder(default, setter(strip_option))]
57    pub annotations: Option<Vec<String>>,
58
59    /// Targeted backend instance ID for the request
60    /// If set, the request will be routed to backend instance with the given ID.
61    /// If not set, the request will be routed to the best matching instance.
62    #[builder(default, setter(strip_option))]
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    pub backend_instance_id: Option<i64>,
65
66    /// Pre-tokenized data to use instead of tokenizing the prompt
67    /// If provided along with backend_instance_id, these tokens will be used directly
68    /// and tokenization will be skipped.
69    #[builder(default, setter(strip_option))]
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub token_data: Option<Vec<u32>>,
72    /// Guided Decoding Options
73    /// If specified, the output will be a JSON object. Can be a string, an object, or null.
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    #[builder(default, setter(strip_option))]
76    pub guided_json: Option<serde_json::Value>,
77
78    /// If specified, the output will follow the regex pattern. Can be a string or null.
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    #[builder(default, setter(strip_option))]
81    pub guided_regex: Option<String>,
82
83    /// If specified, the output will follow the context-free grammar. Can be a string or null.
84    #[serde(default, skip_serializing_if = "Option::is_none")]
85    #[builder(default, setter(strip_option))]
86    pub guided_grammar: Option<String>,
87
88    /// If specified, the output will be exactly one of the choices.
89    #[serde(default, skip_serializing_if = "Option::is_none")]
90    #[builder(default, setter(strip_option))]
91    pub guided_choice: Option<Vec<String>>,
92
93    /// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
94    #[serde(default, skip_serializing_if = "Option::is_none")]
95    #[builder(default, setter(strip_option))]
96    pub guided_decoding_backend: Option<String>,
97
98    /// Maximum number of thinking tokens allowed
99    /// NOTE: Currently passed through to backends as a no-op for future implementation
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    #[builder(default, setter(strip_option))]
102    pub max_thinking_tokens: Option<u32>,
103}
104
105impl Default for NvExt {
106    fn default() -> Self {
107        NvExt::builder().build().unwrap()
108    }
109}
110
111impl NvExt {
112    pub fn builder() -> NvExtBuilder {
113        NvExtBuilder::default()
114    }
115}
116
117fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
118    Ok(())
119}
120
121pub fn validate_top_k(top_k: i32) -> Result<(), ValidationError> {
122    if top_k == -1 || (top_k >= 1) {
123        return Ok(());
124    }
125    let mut error = ValidationError::new("top_k");
126    error.message = Some("top_k must be -1 or greater than or equal to 1".into());
127    Err(error)
128}
129
130impl NvExtBuilder {
131    pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
132        self.annotations
133            .get_or_insert_with(|| Some(vec![]))
134            .as_mut()
135            .expect("stop should always be Some(Vec)")
136            .push(annotation.into());
137        self
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use proptest::prelude::*;
144    use validator::Validate;
145
146    use super::*;
147
148    // Test default builder configuration
149    #[test]
150    fn test_nv_ext_builder_default() {
151        let nv_ext = NvExt::builder().build().unwrap();
152        assert_eq!(nv_ext.ignore_eos, None);
153        assert_eq!(nv_ext.top_k, None);
154        assert_eq!(nv_ext.repetition_penalty, None);
155        assert_eq!(nv_ext.greed_sampling, None);
156        assert_eq!(nv_ext.guided_json, None);
157        assert_eq!(nv_ext.guided_regex, None);
158        assert_eq!(nv_ext.guided_grammar, None);
159        assert_eq!(nv_ext.guided_choice, None);
160        assert_eq!(nv_ext.max_thinking_tokens, None);
161    }
162
163    // Test valid builder configurations
164    #[test]
165    fn test_nv_ext_builder_custom() {
166        let nv_ext = NvExt::builder()
167            .ignore_eos(true)
168            .top_k(10)
169            .repetition_penalty(1.5)
170            .greed_sampling(true)
171            .guided_json(serde_json::json!({"type": "object"}))
172            .guided_regex("^[0-9]+$".to_string())
173            .guided_grammar("S -> 'a' S 'b' | 'c'".to_string())
174            .guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
175            .guided_decoding_backend("xgrammar".to_string())
176            .max_thinking_tokens(1024)
177            .build()
178            .unwrap();
179
180        assert_eq!(nv_ext.ignore_eos, Some(true));
181        assert_eq!(nv_ext.top_k, Some(10));
182        assert_eq!(nv_ext.repetition_penalty, Some(1.5));
183        assert_eq!(nv_ext.greed_sampling, Some(true));
184        assert_eq!(
185            nv_ext.guided_json,
186            Some(serde_json::json!({"type": "object"}))
187        );
188        assert_eq!(nv_ext.guided_regex, Some("^[0-9]+$".to_string()));
189        assert_eq!(
190            nv_ext.guided_grammar,
191            Some("S -> 'a' S 'b' | 'c'".to_string())
192        );
193        assert_eq!(
194            nv_ext.guided_choice,
195            Some(vec!["choice1".to_string(), "choice2".to_string()])
196        );
197        assert_eq!(nv_ext.guided_decoding_backend, Some("xgrammar".to_string()));
198        assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
199        // Validate the built struct
200        assert!(nv_ext.validate().is_ok());
201    }
202
203    // Test invalid `top_k` validation using proptest
204    proptest! {
205        #[test]
206        fn test_invalid_top_k_value(top_k in any::<i32>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
207            let nv_ext = NvExt::builder()
208                .top_k(top_k)
209                .build()
210                .unwrap();
211
212            let validation_result = nv_ext.validate();
213            assert!(validation_result.is_err(), "top_k should fail validation if less than -1 or in the invalid range 0 < top_k < 1");
214        }
215    }
216
217    // Test valid `top_k` values
218    #[test]
219    fn test_valid_top_k_values() {
220        let nv_ext = NvExt::builder().top_k(-1).build().unwrap();
221        assert!(nv_ext.validate().is_ok());
222
223        let nv_ext = NvExt::builder().top_k(1).build().unwrap();
224        assert!(nv_ext.validate().is_ok());
225
226        let nv_ext = NvExt::builder().top_k(10).build().unwrap();
227        assert!(nv_ext.validate().is_ok());
228    }
229
230    // Test valid repetition_penalty values
231    proptest! {
232        #[test]
233        fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f32..=2.0f32) {
234            let nv_ext = NvExt::builder()
235                .repetition_penalty(repetition_penalty)
236                .build()
237                .unwrap();
238
239            let validation_result = nv_ext.validate();
240            assert!(validation_result.is_ok(), "repetition_penalty should be valid within the range (0, 2]");
241        }
242    }
243
244    // Test invalid repetition_penalty values
245    proptest! {
246        #[test]
247        fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f32..0.0f32) {
248            let nv_ext = NvExt::builder()
249                .repetition_penalty(repetition_penalty)
250                .build()
251                .unwrap();
252
253            let validation_result = nv_ext.validate();
254            assert!(validation_result.is_err(), "repetition_penalty should fail validation when outside the range (0, 2]");
255        }
256    }
257}