dynamo_llm/protocols/openai/
nvext.rs1use derive_builder::Builder;
5use serde::{Deserialize, Serialize};
6use validator::{Validate, ValidationError};
7
8pub trait NvExtProvider {
9 fn nvext(&self) -> Option<&NvExt>;
10 fn raw_prompt(&self) -> Option<String>;
11}
12
13#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
15#[validate(schema(function = "validate_nv_ext"))]
16pub struct NvExt {
17 #[serde(default, skip_serializing_if = "Option::is_none")]
19 #[builder(default, setter(strip_option))]
20 pub ignore_eos: Option<bool>,
21
22 #[builder(default, setter(strip_option))] #[validate(custom(function = "validate_top_k"))]
24 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub top_k: Option<i32>,
26
27 #[serde(default, skip_serializing_if = "Option::is_none")]
29 #[builder(default, setter(strip_option))]
30 #[validate(range(min = 0.0, max = 1.0))]
31 pub min_p: Option<f32>,
32
33 #[builder(default, setter(strip_option))]
36 #[validate(range(exclusive_min = 0.0, max = 2.0))]
37 pub repetition_penalty: Option<f32>,
38
39 #[serde(default, skip_serializing_if = "Option::is_none")]
43 #[builder(default, setter(strip_option))]
44 pub greed_sampling: Option<bool>,
45
46 #[serde(default, skip_serializing_if = "Option::is_none")]
49 #[builder(default, setter(strip_option))]
50 pub use_raw_prompt: Option<bool>,
51
52 #[serde(default, skip_serializing_if = "Option::is_none")]
56 #[builder(default, setter(strip_option))]
57 pub annotations: Option<Vec<String>>,
58
59 #[builder(default, setter(strip_option))]
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub backend_instance_id: Option<i64>,
65
66 #[builder(default, setter(strip_option))]
70 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub token_data: Option<Vec<u32>>,
72 #[serde(default, skip_serializing_if = "Option::is_none")]
75 #[builder(default, setter(strip_option))]
76 pub guided_json: Option<serde_json::Value>,
77
78 #[serde(default, skip_serializing_if = "Option::is_none")]
80 #[builder(default, setter(strip_option))]
81 pub guided_regex: Option<String>,
82
83 #[serde(default, skip_serializing_if = "Option::is_none")]
85 #[builder(default, setter(strip_option))]
86 pub guided_grammar: Option<String>,
87
88 #[serde(default, skip_serializing_if = "Option::is_none")]
90 #[builder(default, setter(strip_option))]
91 pub guided_choice: Option<Vec<String>>,
92
93 #[serde(default, skip_serializing_if = "Option::is_none")]
95 #[builder(default, setter(strip_option))]
96 pub guided_decoding_backend: Option<String>,
97
98 #[serde(default, skip_serializing_if = "Option::is_none")]
101 #[builder(default, setter(strip_option))]
102 pub max_thinking_tokens: Option<u32>,
103}
104
105impl Default for NvExt {
106 fn default() -> Self {
107 NvExt::builder().build().unwrap()
108 }
109}
110
111impl NvExt {
112 pub fn builder() -> NvExtBuilder {
113 NvExtBuilder::default()
114 }
115}
116
117fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
118 Ok(())
119}
120
121pub fn validate_top_k(top_k: i32) -> Result<(), ValidationError> {
122 if top_k == -1 || (top_k >= 1) {
123 return Ok(());
124 }
125 let mut error = ValidationError::new("top_k");
126 error.message = Some("top_k must be -1 or greater than or equal to 1".into());
127 Err(error)
128}
129
130impl NvExtBuilder {
131 pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
132 self.annotations
133 .get_or_insert_with(|| Some(vec![]))
134 .as_mut()
135 .expect("stop should always be Some(Vec)")
136 .push(annotation.into());
137 self
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use proptest::prelude::*;
144 use validator::Validate;
145
146 use super::*;
147
148 #[test]
150 fn test_nv_ext_builder_default() {
151 let nv_ext = NvExt::builder().build().unwrap();
152 assert_eq!(nv_ext.ignore_eos, None);
153 assert_eq!(nv_ext.top_k, None);
154 assert_eq!(nv_ext.repetition_penalty, None);
155 assert_eq!(nv_ext.greed_sampling, None);
156 assert_eq!(nv_ext.guided_json, None);
157 assert_eq!(nv_ext.guided_regex, None);
158 assert_eq!(nv_ext.guided_grammar, None);
159 assert_eq!(nv_ext.guided_choice, None);
160 assert_eq!(nv_ext.max_thinking_tokens, None);
161 }
162
163 #[test]
165 fn test_nv_ext_builder_custom() {
166 let nv_ext = NvExt::builder()
167 .ignore_eos(true)
168 .top_k(10)
169 .repetition_penalty(1.5)
170 .greed_sampling(true)
171 .guided_json(serde_json::json!({"type": "object"}))
172 .guided_regex("^[0-9]+$".to_string())
173 .guided_grammar("S -> 'a' S 'b' | 'c'".to_string())
174 .guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
175 .guided_decoding_backend("xgrammar".to_string())
176 .max_thinking_tokens(1024)
177 .build()
178 .unwrap();
179
180 assert_eq!(nv_ext.ignore_eos, Some(true));
181 assert_eq!(nv_ext.top_k, Some(10));
182 assert_eq!(nv_ext.repetition_penalty, Some(1.5));
183 assert_eq!(nv_ext.greed_sampling, Some(true));
184 assert_eq!(
185 nv_ext.guided_json,
186 Some(serde_json::json!({"type": "object"}))
187 );
188 assert_eq!(nv_ext.guided_regex, Some("^[0-9]+$".to_string()));
189 assert_eq!(
190 nv_ext.guided_grammar,
191 Some("S -> 'a' S 'b' | 'c'".to_string())
192 );
193 assert_eq!(
194 nv_ext.guided_choice,
195 Some(vec!["choice1".to_string(), "choice2".to_string()])
196 );
197 assert_eq!(nv_ext.guided_decoding_backend, Some("xgrammar".to_string()));
198 assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
199 assert!(nv_ext.validate().is_ok());
201 }
202
203 proptest! {
205 #[test]
206 fn test_invalid_top_k_value(top_k in any::<i32>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
207 let nv_ext = NvExt::builder()
208 .top_k(top_k)
209 .build()
210 .unwrap();
211
212 let validation_result = nv_ext.validate();
213 assert!(validation_result.is_err(), "top_k should fail validation if less than -1 or in the invalid range 0 < top_k < 1");
214 }
215 }
216
217 #[test]
219 fn test_valid_top_k_values() {
220 let nv_ext = NvExt::builder().top_k(-1).build().unwrap();
221 assert!(nv_ext.validate().is_ok());
222
223 let nv_ext = NvExt::builder().top_k(1).build().unwrap();
224 assert!(nv_ext.validate().is_ok());
225
226 let nv_ext = NvExt::builder().top_k(10).build().unwrap();
227 assert!(nv_ext.validate().is_ok());
228 }
229
230 proptest! {
232 #[test]
233 fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f32..=2.0f32) {
234 let nv_ext = NvExt::builder()
235 .repetition_penalty(repetition_penalty)
236 .build()
237 .unwrap();
238
239 let validation_result = nv_ext.validate();
240 assert!(validation_result.is_ok(), "repetition_penalty should be valid within the range (0, 2]");
241 }
242 }
243
244 proptest! {
246 #[test]
247 fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f32..0.0f32) {
248 let nv_ext = NvExt::builder()
249 .repetition_penalty(repetition_penalty)
250 .build()
251 .unwrap();
252
253 let validation_result = nv_ext.validate();
254 assert!(validation_result.is_err(), "repetition_penalty should fail validation when outside the range (0, 2]");
255 }
256 }
257}