use crate::nn::sampling::SamplingConfig;
#[derive(Debug, Clone)]
pub struct GenerateConfig {
pub do_sample: bool,
pub top_k: usize,
pub top_p: f64,
pub temperature: f64,
pub repetition_penalty: f64,
pub subtalker_do_sample: bool,
pub subtalker_top_k: usize,
pub subtalker_top_p: f64,
pub subtalker_temperature: f64,
pub max_new_tokens: usize,
}
impl Default for GenerateConfig {
fn default() -> Self {
Self {
do_sample: true,
top_k: 50,
top_p: 1.0,
temperature: 0.9,
repetition_penalty: 1.05,
subtalker_do_sample: true,
subtalker_top_k: 50,
subtalker_top_p: 1.0,
subtalker_temperature: 0.9,
max_new_tokens: 2048,
}
}
}
impl GenerateConfig {
pub fn from_json(json: &str) -> std::result::Result<Self, serde_json::Error> {
#[derive(serde::Deserialize)]
struct RawConfig {
do_sample: Option<bool>,
top_k: Option<usize>,
top_p: Option<f64>,
temperature: Option<f64>,
repetition_penalty: Option<f64>,
subtalker_dosample: Option<bool>,
subtalker_top_k: Option<usize>,
subtalker_top_p: Option<f64>,
subtalker_temperature: Option<f64>,
max_new_tokens: Option<usize>,
}
let raw: RawConfig = serde_json::from_str(json)?;
let defaults = Self::default();
Ok(Self {
do_sample: raw.do_sample.unwrap_or(defaults.do_sample),
top_k: raw.top_k.unwrap_or(defaults.top_k),
top_p: raw.top_p.unwrap_or(defaults.top_p),
temperature: raw.temperature.unwrap_or(defaults.temperature),
repetition_penalty: raw
.repetition_penalty
.unwrap_or(defaults.repetition_penalty),
subtalker_do_sample: raw
.subtalker_dosample
.unwrap_or(defaults.subtalker_do_sample),
subtalker_top_k: raw.subtalker_top_k.unwrap_or(defaults.subtalker_top_k),
subtalker_top_p: raw.subtalker_top_p.unwrap_or(defaults.subtalker_top_p),
subtalker_temperature: raw
.subtalker_temperature
.unwrap_or(defaults.subtalker_temperature),
max_new_tokens: raw.max_new_tokens.unwrap_or(defaults.max_new_tokens),
})
}
pub fn from_file(path: impl AsRef<std::path::Path>) -> Self {
let path = path.as_ref();
match std::fs::read_to_string(path) {
Ok(content) => Self::from_json(&content).unwrap_or_else(|e| {
tracing::warn!(
"Failed to parse generate_config.json at {}: {}, using defaults",
path.display(),
e
);
Self::default()
}),
Err(e) => {
tracing::debug!(
"No generate_config.json at {}: {}, using defaults",
path.display(),
e
);
Self::default()
}
}
}
pub fn from_model_dir(model_dir: impl AsRef<std::path::Path>) -> Self {
let config_path = model_dir.as_ref().join("generate_config.json");
Self::from_file(config_path)
}
#[allow(clippy::too_many_arguments)]
pub fn merge(
&self,
do_sample: Option<bool>,
top_k: Option<usize>,
top_p: Option<f64>,
temperature: Option<f64>,
repetition_penalty: Option<f64>,
subtalker_do_sample: Option<bool>,
subtalker_top_k: Option<usize>,
subtalker_top_p: Option<f64>,
subtalker_temperature: Option<f64>,
) -> SamplingConfig {
SamplingConfig {
do_sample: do_sample.unwrap_or(self.do_sample),
top_k: top_k.unwrap_or(self.top_k),
top_p: top_p.unwrap_or(self.top_p),
temperature: temperature.unwrap_or(self.temperature),
repetition_penalty: repetition_penalty.unwrap_or(self.repetition_penalty),
subtalker_do_sample: subtalker_do_sample.unwrap_or(self.subtalker_do_sample),
subtalker_top_k: subtalker_top_k.unwrap_or(self.subtalker_top_k),
subtalker_top_p: subtalker_top_p.unwrap_or(self.subtalker_top_p),
subtalker_temperature: subtalker_temperature.unwrap_or(self.subtalker_temperature),
..Default::default()
}
}
pub fn effective_max_tokens(&self, user_max: Option<usize>) -> usize {
user_max.unwrap_or(self.max_new_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_config_default() {
let config = GenerateConfig::default();
assert!(config.do_sample);
assert_eq!(config.top_k, 50);
assert_eq!(config.top_p, 1.0);
assert_eq!(config.temperature, 0.9);
assert_eq!(config.repetition_penalty, 1.05);
assert!(config.subtalker_do_sample);
assert_eq!(config.max_new_tokens, 2048);
}
#[test]
fn test_generate_config_from_json_full() {
let json = r#"{
"do_sample": false,
"top_k": 100,
"top_p": 0.95,
"temperature": 0.7,
"repetition_penalty": 1.2,
"subtalker_dosample": false,
"subtalker_top_k": 30,
"subtalker_top_p": 0.8,
"subtalker_temperature": 0.5,
"max_new_tokens": 4096
}"#;
let config = GenerateConfig::from_json(json).unwrap();
assert!(!config.do_sample);
assert_eq!(config.top_k, 100);
assert_eq!(config.top_p, 0.95);
assert_eq!(config.temperature, 0.7);
assert_eq!(config.repetition_penalty, 1.2);
assert!(!config.subtalker_do_sample);
assert_eq!(config.subtalker_top_k, 30);
assert_eq!(config.subtalker_top_p, 0.8);
assert_eq!(config.subtalker_temperature, 0.5);
assert_eq!(config.max_new_tokens, 4096);
}
#[test]
fn test_generate_config_from_json_partial() {
let json = r#"{
"temperature": 0.5,
"top_k": 25
}"#;
let config = GenerateConfig::from_json(json).unwrap();
assert!(config.do_sample); assert_eq!(config.top_k, 25); assert_eq!(config.top_p, 1.0); assert_eq!(config.temperature, 0.5); assert_eq!(config.max_new_tokens, 2048); }
#[test]
fn test_generate_config_from_json_empty() {
let json = "{}";
let config = GenerateConfig::from_json(json).unwrap();
let defaults = GenerateConfig::default();
assert_eq!(config.do_sample, defaults.do_sample);
assert_eq!(config.top_k, defaults.top_k);
assert_eq!(config.temperature, defaults.temperature);
}
#[test]
fn test_generate_config_merge_all_none() {
let config = GenerateConfig::default();
let merged = config.merge(None, None, None, None, None, None, None, None, None);
assert_eq!(merged.do_sample, config.do_sample);
assert_eq!(merged.top_k, config.top_k);
assert_eq!(merged.temperature, config.temperature);
}
#[test]
fn test_generate_config_merge_with_overrides() {
let config = GenerateConfig::default();
let merged = config.merge(
Some(false), Some(100), Some(0.8), Some(0.5), Some(1.5), Some(false), Some(25), Some(0.9), Some(0.3), );
assert!(!merged.do_sample);
assert_eq!(merged.top_k, 100);
assert_eq!(merged.top_p, 0.8);
assert_eq!(merged.temperature, 0.5);
assert_eq!(merged.repetition_penalty, 1.5);
assert!(!merged.subtalker_do_sample);
assert_eq!(merged.subtalker_top_k, 25);
assert_eq!(merged.subtalker_top_p, 0.9);
assert_eq!(merged.subtalker_temperature, 0.3);
}
#[test]
fn test_effective_max_tokens() {
let config = GenerateConfig::default();
assert_eq!(config.effective_max_tokens(Some(1000)), 1000);
assert_eq!(config.effective_max_tokens(None), 2048);
}
}