dynamo_llm/protocols/openai/
nvext.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use derive_builder::Builder;
17use serde::{Deserialize, Serialize};
18use validator::{Validate, ValidationError};
19
20pub trait NvExtProvider {
21    fn nvext(&self) -> Option<&NvExt>;
22    fn raw_prompt(&self) -> Option<String>;
23}
24
25/// NVIDIA LLM extensions to the OpenAI API
26#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
27#[validate(schema(function = "validate_nv_ext"))]
28pub struct NvExt {
29    /// If true, the model will ignore the end of string token and generate to max_tokens.
30    #[serde(default, skip_serializing_if = "Option::is_none")]
31    #[builder(default, setter(strip_option))]
32    pub ignore_eos: Option<bool>,
33
34    #[builder(default, setter(strip_option))] // NIM LLM might default to -1
35    #[validate(custom(function = "validate_top_k"))]
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    pub top_k: Option<i64>,
38
39    /// How much to penalize tokens based on how frequently they occur in the text.
40    /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
41    #[builder(default, setter(strip_option))]
42    #[validate(range(exclusive_min = 0.0, max = 2.0))]
43    pub repetition_penalty: Option<f64>,
44
45    /// If true, sampling will be forced to be greedy.
46    /// The backend is responsible for selecting the correct backend-specific options to
47    /// implement this.
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    #[builder(default, setter(strip_option))]
50    pub greed_sampling: Option<bool>,
51
52    /// If true, the preproessor will try to bypass the prompt template and pass the prompt directly to
53    /// to the tokenizer.
54    #[serde(default, skip_serializing_if = "Option::is_none")]
55    #[builder(default, setter(strip_option))]
56    pub use_raw_prompt: Option<bool>,
57
58    /// Annotations
59    /// User requests triggers which result in the request issue back out-of-band information in the SSE
60    /// stream using the `event:` field.
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    #[builder(default, setter(strip_option))]
63    pub annotations: Option<Vec<String>>,
64}
65
66impl Default for NvExt {
67    fn default() -> Self {
68        NvExt::builder().build().unwrap()
69    }
70}
71
72impl NvExt {
73    pub fn builder() -> NvExtBuilder {
74        NvExtBuilder::default()
75    }
76}
77
78fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
79    Ok(())
80}
81
82fn validate_top_k(top_k: i64) -> Result<(), ValidationError> {
83    if top_k == -1 || (top_k >= 1) {
84        return Ok(());
85    }
86    let mut error = ValidationError::new("top_k");
87    error.message = Some("top_k must be -1 or greater than or equal to 1".into());
88    Err(error)
89}
90
91impl NvExtBuilder {
92    pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
93        self.annotations
94            .get_or_insert_with(|| Some(vec![]))
95            .as_mut()
96            .expect("stop should always be Some(Vec)")
97            .push(annotation.into());
98        self
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use proptest::prelude::*;
105    use validator::Validate;
106
107    use super::*;
108
109    // Test default builder configuration
110    #[test]
111    fn test_nv_ext_builder_default() {
112        let nv_ext = NvExt::builder().build().unwrap();
113        assert_eq!(nv_ext.ignore_eos, None);
114        assert_eq!(nv_ext.top_k, None);
115        assert_eq!(nv_ext.repetition_penalty, None);
116        assert_eq!(nv_ext.greed_sampling, None);
117    }
118
119    // Test valid builder configurations
120    #[test]
121    fn test_nv_ext_builder_custom() {
122        let nv_ext = NvExt::builder()
123            .ignore_eos(true)
124            .top_k(10)
125            .repetition_penalty(1.5)
126            .greed_sampling(true)
127            .build()
128            .unwrap();
129
130        assert_eq!(nv_ext.ignore_eos, Some(true));
131        assert_eq!(nv_ext.top_k, Some(10));
132        assert_eq!(nv_ext.repetition_penalty, Some(1.5));
133        assert_eq!(nv_ext.greed_sampling, Some(true));
134
135        // Validate the built struct
136        assert!(nv_ext.validate().is_ok());
137    }
138
139    // Test invalid `top_k` validation using proptest
140    proptest! {
141        #[test]
142        fn test_invalid_top_k_value(top_k in any::<i64>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
143            let nv_ext = NvExt::builder()
144                .top_k(top_k)
145                .build()
146                .unwrap();
147
148            let validation_result = nv_ext.validate();
149            assert!(validation_result.is_err(), "top_k should fail validation if less than -1 or in the invalid range 0 < top_k < 1");
150        }
151    }
152
153    // Test valid `top_k` values
154    #[test]
155    fn test_valid_top_k_values() {
156        let nv_ext = NvExt::builder().top_k(-1).build().unwrap();
157        assert!(nv_ext.validate().is_ok());
158
159        let nv_ext = NvExt::builder().top_k(1).build().unwrap();
160        assert!(nv_ext.validate().is_ok());
161
162        let nv_ext = NvExt::builder().top_k(10).build().unwrap();
163        assert!(nv_ext.validate().is_ok());
164    }
165
166    // Test valid repetition_penalty values
167    proptest! {
168        #[test]
169        fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f64..=2.0f64) {
170            let nv_ext = NvExt::builder()
171                .repetition_penalty(repetition_penalty)
172                .build()
173                .unwrap();
174
175            let validation_result = nv_ext.validate();
176            assert!(validation_result.is_ok(), "repetition_penalty should be valid within the range (0, 2]");
177        }
178    }
179
180    // Test invalid repetition_penalty values
181    proptest! {
182        #[test]
183        fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f64..0.0f64) {
184            let nv_ext = NvExt::builder()
185                .repetition_penalty(repetition_penalty)
186                .build()
187                .unwrap();
188
189            let validation_result = nv_ext.validate();
190            assert!(validation_result.is_err(), "repetition_penalty should fail validation when outside the range (0, 2]");
191        }
192    }
193}