use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};
pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>;
fn raw_prompt(&self) -> Option<String>;
}
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub ignore_eos: Option<bool>,
#[builder(default, setter(strip_option))] #[validate(custom(function = "validate_top_k"))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<i64>,
#[builder(default, setter(strip_option))]
#[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub greed_sampling: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub use_raw_prompt: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub annotations: Option<Vec<String>>,
}
impl Default for NvExt {
fn default() -> Self {
NvExt::builder().build().unwrap()
}
}
impl NvExt {
pub fn builder() -> NvExtBuilder {
NvExtBuilder::default()
}
}
fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(())
}
fn validate_top_k(top_k: i64) -> Result<(), ValidationError> {
if top_k == -1 || (top_k >= 1) {
return Ok(());
}
let mut error = ValidationError::new("top_k");
error.message = Some("top_k must be -1 or greater than or equal to 1".into());
Err(error)
}
impl NvExtBuilder {
pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
self.annotations
.get_or_insert_with(|| Some(vec![]))
.as_mut()
.expect("stop should always be Some(Vec)")
.push(annotation.into());
self
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use validator::Validate;
use super::*;
#[test]
fn test_nv_ext_builder_default() {
let nv_ext = NvExt::builder().build().unwrap();
assert_eq!(nv_ext.ignore_eos, None);
assert_eq!(nv_ext.top_k, None);
assert_eq!(nv_ext.repetition_penalty, None);
assert_eq!(nv_ext.greed_sampling, None);
}
#[test]
fn test_nv_ext_builder_custom() {
let nv_ext = NvExt::builder()
.ignore_eos(true)
.top_k(10)
.repetition_penalty(1.5)
.greed_sampling(true)
.build()
.unwrap();
assert_eq!(nv_ext.ignore_eos, Some(true));
assert_eq!(nv_ext.top_k, Some(10));
assert_eq!(nv_ext.repetition_penalty, Some(1.5));
assert_eq!(nv_ext.greed_sampling, Some(true));
assert!(nv_ext.validate().is_ok());
}
proptest! {
#[test]
fn test_invalid_top_k_value(top_k in any::<i64>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
let nv_ext = NvExt::builder()
.top_k(top_k)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_err(), "top_k should fail validation if less than -1 or in the invalid range 0 < top_k < 1");
}
}
#[test]
fn test_valid_top_k_values() {
let nv_ext = NvExt::builder().top_k(-1).build().unwrap();
assert!(nv_ext.validate().is_ok());
let nv_ext = NvExt::builder().top_k(1).build().unwrap();
assert!(nv_ext.validate().is_ok());
let nv_ext = NvExt::builder().top_k(10).build().unwrap();
assert!(nv_ext.validate().is_ok());
}
proptest! {
#[test]
fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f64..=2.0f64) {
let nv_ext = NvExt::builder()
.repetition_penalty(repetition_penalty)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_ok(), "repetition_penalty should be valid within the range (0, 2]");
}
}
proptest! {
#[test]
fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f64..0.0f64) {
let nv_ext = NvExt::builder()
.repetition_penalty(repetition_penalty)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_err(), "repetition_penalty should fail validation when outside the range (0, 2]");
}
}
}