use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::Validate;
#[derive(ToSchema, Serialize, Deserialize, Builder, Validate, Debug, Clone, Default)]
pub struct CommonExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub ignore_eos: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub min_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub top_k: Option<i32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub min_p: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub repetition_penalty: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub include_stop_str_in_output: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_json: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_regex: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_grammar: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_choice: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[allow(unused)] pub guided_whitespace_pattern: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub skip_special_tokens: Option<bool>,
}
impl CommonExt {
pub fn builder() -> CommonExtBuilder {
CommonExtBuilder::default()
}
}
pub trait CommonExtProvider {
fn common_ext(&self) -> Option<&CommonExt>;
fn get_guided_json(&self) -> Option<serde_json::Value>;
fn get_guided_regex(&self) -> Option<String>;
fn get_guided_grammar(&self) -> Option<String>;
fn get_guided_choice(&self) -> Option<Vec<String>>;
fn get_guided_decoding_backend(&self) -> Option<String>;
#[allow(unused)] fn get_guided_whitespace_pattern(&self) -> Option<String>;
fn get_top_k(&self) -> Option<i32>;
fn get_min_p(&self) -> Option<f32>;
fn get_repetition_penalty(&self) -> Option<f32>;
fn get_include_stop_str_in_output(&self) -> Option<bool>;
fn get_skip_special_tokens(&self) -> Option<bool>;
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_common_ext_builder_default() {
let common_ext = CommonExt::builder().build().unwrap();
assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None);
assert_eq!(common_ext.guided_json, None);
assert_eq!(common_ext.guided_regex, None);
assert_eq!(common_ext.guided_grammar, None);
assert_eq!(common_ext.guided_choice, None);
assert_eq!(common_ext.guided_decoding_backend, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert_eq!(common_ext.skip_special_tokens, None);
}
#[test]
fn test_common_ext_builder_with_values() {
let common_ext = CommonExt::builder()
.ignore_eos(true)
.min_tokens(10)
.top_k(50)
.repetition_penalty(1.2)
.include_stop_str_in_output(true)
.guided_json(serde_json::json!({"key": "value"}))
.guided_regex("regex".to_string())
.guided_grammar("grammar".to_string())
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
.guided_decoding_backend("backend".to_string())
.skip_special_tokens(false)
.build()
.unwrap();
assert_eq!(common_ext.ignore_eos, Some(true));
assert_eq!(common_ext.min_tokens, Some(10));
assert_eq!(common_ext.top_k, Some(50));
assert_eq!(common_ext.repetition_penalty, Some(1.2));
assert_eq!(common_ext.include_stop_str_in_output, Some(true));
assert_eq!(
common_ext.guided_json.as_ref(),
Some(&serde_json::json!({"key": "value"}))
);
assert_eq!(common_ext.guided_regex, Some("regex".to_string()));
assert_eq!(common_ext.guided_grammar, Some("grammar".to_string()));
assert_eq!(
common_ext.guided_choice,
Some(vec!["choice1".to_string(), "choice2".to_string()])
);
assert_eq!(
common_ext.guided_decoding_backend,
Some("backend".to_string())
);
assert_eq!(common_ext.skip_special_tokens, Some(false));
}
#[test]
fn test_common_ext_fields() {
let common_ext = CommonExt::builder()
.ignore_eos(false)
.min_tokens(5)
.include_stop_str_in_output(true)
.build()
.unwrap();
assert_eq!(common_ext.ignore_eos, Some(false));
assert_eq!(common_ext.min_tokens, Some(5));
assert_eq!(common_ext.include_stop_str_in_output, Some(true));
}
#[test]
fn test_validation_min_tokens() {
let common_ext = CommonExt {
ignore_eos: None,
min_tokens: Some(0), top_k: None,
min_p: None,
repetition_penalty: None,
include_stop_str_in_output: None,
guided_json: None,
guided_regex: None,
guided_grammar: None,
guided_choice: None,
guided_decoding_backend: None,
guided_whitespace_pattern: None,
skip_special_tokens: None,
};
assert!(common_ext.validate().is_ok());
}
#[test]
fn test_common_ext_neither_specified() {
let common_ext = CommonExt::builder().build().unwrap();
assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok());
}
#[test]
fn test_common_ext_default() {
let common_ext = CommonExt::default();
assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok());
}
#[test]
fn test_skip_special_tokens_field() {
let common_ext = CommonExt::builder()
.skip_special_tokens(true)
.build()
.unwrap();
assert_eq!(common_ext.skip_special_tokens, Some(true));
let common_ext = CommonExt::builder()
.skip_special_tokens(false)
.build()
.unwrap();
assert_eq!(common_ext.skip_special_tokens, Some(false));
}
#[test]
fn test_skip_special_tokens_serialization() {
let common_ext = CommonExt::builder()
.skip_special_tokens(true)
.build()
.unwrap();
let json = serde_json::to_string(&common_ext).unwrap();
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.skip_special_tokens, Some(true));
let common_ext = CommonExt::builder()
.skip_special_tokens(false)
.build()
.unwrap();
let json = serde_json::to_string(&common_ext).unwrap();
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.skip_special_tokens, Some(false));
let common_ext = CommonExt::builder().build().unwrap();
let json = serde_json::to_string(&common_ext).unwrap();
assert!(!json.contains("skip_special_tokens"));
}
}