dynamo_llm/protocols/openai/
nvext.rs1use derive_builder::Builder;
17use serde::{Deserialize, Serialize};
18use validator::{Validate, ValidationError};
19
20pub trait NvExtProvider {
21 fn nvext(&self) -> Option<&NvExt>;
22 fn raw_prompt(&self) -> Option<String>;
23}
24
25#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
27#[validate(schema(function = "validate_nv_ext"))]
28pub struct NvExt {
29 #[serde(default, skip_serializing_if = "Option::is_none")]
31 #[builder(default, setter(strip_option))]
32 pub ignore_eos: Option<bool>,
33
34 #[builder(default, setter(strip_option))] #[validate(custom(function = "validate_top_k"))]
36 #[serde(default, skip_serializing_if = "Option::is_none")]
37 pub top_k: Option<i64>,
38
39 #[builder(default, setter(strip_option))]
42 #[validate(range(exclusive_min = 0.0, max = 2.0))]
43 pub repetition_penalty: Option<f64>,
44
45 #[serde(default, skip_serializing_if = "Option::is_none")]
49 #[builder(default, setter(strip_option))]
50 pub greed_sampling: Option<bool>,
51
52 #[serde(default, skip_serializing_if = "Option::is_none")]
55 #[builder(default, setter(strip_option))]
56 pub use_raw_prompt: Option<bool>,
57
58 #[serde(default, skip_serializing_if = "Option::is_none")]
62 #[builder(default, setter(strip_option))]
63 pub annotations: Option<Vec<String>>,
64}
65
66impl Default for NvExt {
67 fn default() -> Self {
68 NvExt::builder().build().unwrap()
69 }
70}
71
72impl NvExt {
73 pub fn builder() -> NvExtBuilder {
74 NvExtBuilder::default()
75 }
76}
77
78fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
79 Ok(())
80}
81
82fn validate_top_k(top_k: i64) -> Result<(), ValidationError> {
83 if top_k == -1 || (top_k >= 1) {
84 return Ok(());
85 }
86 let mut error = ValidationError::new("top_k");
87 error.message = Some("top_k must be -1 or greater than or equal to 1".into());
88 Err(error)
89}
90
91impl NvExtBuilder {
92 pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
93 self.annotations
94 .get_or_insert_with(|| Some(vec![]))
95 .as_mut()
96 .expect("stop should always be Some(Vec)")
97 .push(annotation.into());
98 self
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use proptest::prelude::*;
105 use validator::Validate;
106
107 use super::*;
108
109 #[test]
111 fn test_nv_ext_builder_default() {
112 let nv_ext = NvExt::builder().build().unwrap();
113 assert_eq!(nv_ext.ignore_eos, None);
114 assert_eq!(nv_ext.top_k, None);
115 assert_eq!(nv_ext.repetition_penalty, None);
116 assert_eq!(nv_ext.greed_sampling, None);
117 }
118
119 #[test]
121 fn test_nv_ext_builder_custom() {
122 let nv_ext = NvExt::builder()
123 .ignore_eos(true)
124 .top_k(10)
125 .repetition_penalty(1.5)
126 .greed_sampling(true)
127 .build()
128 .unwrap();
129
130 assert_eq!(nv_ext.ignore_eos, Some(true));
131 assert_eq!(nv_ext.top_k, Some(10));
132 assert_eq!(nv_ext.repetition_penalty, Some(1.5));
133 assert_eq!(nv_ext.greed_sampling, Some(true));
134
135 assert!(nv_ext.validate().is_ok());
137 }
138
139 proptest! {
141 #[test]
142 fn test_invalid_top_k_value(top_k in any::<i64>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
143 let nv_ext = NvExt::builder()
144 .top_k(top_k)
145 .build()
146 .unwrap();
147
148 let validation_result = nv_ext.validate();
149 assert!(validation_result.is_err(), "top_k should fail validation if less than -1 or in the invalid range 0 < top_k < 1");
150 }
151 }
152
153 #[test]
155 fn test_valid_top_k_values() {
156 let nv_ext = NvExt::builder().top_k(-1).build().unwrap();
157 assert!(nv_ext.validate().is_ok());
158
159 let nv_ext = NvExt::builder().top_k(1).build().unwrap();
160 assert!(nv_ext.validate().is_ok());
161
162 let nv_ext = NvExt::builder().top_k(10).build().unwrap();
163 assert!(nv_ext.validate().is_ok());
164 }
165
166 proptest! {
168 #[test]
169 fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f64..=2.0f64) {
170 let nv_ext = NvExt::builder()
171 .repetition_penalty(repetition_penalty)
172 .build()
173 .unwrap();
174
175 let validation_result = nv_ext.validate();
176 assert!(validation_result.is_ok(), "repetition_penalty should be valid within the range (0, 2]");
177 }
178 }
179
180 proptest! {
182 #[test]
183 fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f64..0.0f64) {
184 let nv_ext = NvExt::builder()
185 .repetition_penalty(repetition_penalty)
186 .build()
187 .unwrap();
188
189 let validation_result = nv_ext.validate();
190 assert!(validation_result.is_err(), "repetition_penalty should fail validation when outside the range (0, 2]");
191 }
192 }
193}