use serde::{Deserialize, Deserializer};
#[derive(Debug, Clone)]
pub struct RopeScaling {
pub rope_type: Option<String>,
pub factor: Option<f64>,
pub original_max_position_embeddings: Option<usize>,
pub attention_factor: Option<f64>,
pub beta_fast: Option<f64>,
pub beta_slow: Option<f64>,
pub short_factor: Option<Vec<f64>>,
pub long_factor: Option<Vec<f64>>,
pub low_freq_factor: Option<f64>,
pub high_freq_factor: Option<f64>,
pub mrope_section: Vec<usize>,
pub interleaved: bool,
}
#[derive(Deserialize)]
struct RopeScalingHelper {
#[serde(rename = "type")]
type_field: Option<String>,
rope_type: Option<String>,
factor: Option<f64>,
original_max_position_embeddings: Option<usize>,
attention_factor: Option<f64>,
beta_fast: Option<f64>,
beta_slow: Option<f64>,
short_factor: Option<Vec<f64>>,
long_factor: Option<Vec<f64>>,
low_freq_factor: Option<f64>,
high_freq_factor: Option<f64>,
#[serde(default)]
mrope_section: Vec<usize>,
#[serde(default)]
interleaved: bool,
}
impl<'de> Deserialize<'de> for RopeScaling {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let helper = RopeScalingHelper::deserialize(deserializer)?;
let rope_type = helper.rope_type.or(helper.type_field);
Ok(RopeScaling {
rope_type,
factor: helper.factor,
original_max_position_embeddings: helper.original_max_position_embeddings,
attention_factor: helper.attention_factor,
beta_fast: helper.beta_fast,
beta_slow: helper.beta_slow,
short_factor: helper.short_factor,
long_factor: helper.long_factor,
low_freq_factor: helper.low_freq_factor,
high_freq_factor: helper.high_freq_factor,
mrope_section: helper.mrope_section,
interleaved: helper.interleaved,
})
}
}
impl Default for RopeScaling {
fn default() -> Self {
Self {
rope_type: Some("default".to_string()),
factor: None,
original_max_position_embeddings: None,
attention_factor: None,
beta_fast: None,
beta_slow: None,
short_factor: None,
long_factor: None,
low_freq_factor: None,
high_freq_factor: None,
mrope_section: vec![16, 24, 24],
interleaved: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deserialize_with_rope_type_only() {
let json = r#"{"rope_type": "linear", "factor": 2.0}"#;
let config: RopeScaling = serde_json::from_str(json).unwrap();
assert_eq!(config.rope_type, Some("linear".to_string()));
assert_eq!(config.factor, Some(2.0));
}
#[test]
fn test_deserialize_with_type_only() {
let json = r#"{"type": "dynamic", "factor": 1.5}"#;
let config: RopeScaling = serde_json::from_str(json).unwrap();
assert_eq!(config.rope_type, Some("dynamic".to_string()));
assert_eq!(config.factor, Some(1.5));
}
#[test]
fn test_deserialize_with_both_type_and_rope_type() {
let json = r#"{"type": "old_value", "rope_type": "linear", "factor": 2.0}"#;
let config: RopeScaling = serde_json::from_str(json).unwrap();
assert_eq!(config.rope_type, Some("linear".to_string()));
assert_eq!(config.factor, Some(2.0));
}
#[test]
fn test_deserialize_with_both_type_and_rope_type_reversed_order() {
let json = r#"{"rope_type": "yarn", "type": "ignored", "factor": 3.0}"#;
let config: RopeScaling = serde_json::from_str(json).unwrap();
assert_eq!(config.rope_type, Some("yarn".to_string()));
assert_eq!(config.factor, Some(3.0));
}
#[test]
fn test_deserialize_with_mrope_section() {
let json = r#"{
"rope_type": "mrope",
"mrope_section": [16, 24, 24],
"interleaved": true
}"#;
let config: RopeScaling = serde_json::from_str(json).unwrap();
assert_eq!(config.rope_type, Some("mrope".to_string()));
assert_eq!(config.mrope_section, vec![16, 24, 24]);
assert!(config.interleaved);
}
#[test]
fn test_deserialize_empty() {
let json = r#"{}"#;
let config: RopeScaling = serde_json::from_str(json).unwrap();
assert_eq!(config.rope_type, None);
assert_eq!(config.mrope_section, Vec::<usize>::new());
assert!(!config.interleaved);
}
}