use std::fmt;
use serde_json::Value;
use crate::config::{get_bool_or, get_f64_or, get_usize, get_usize_or};
use crate::error::{MIError, Result};
pub const SUPPORTED_RWKV_MODEL_TYPES: &[&str] = &["rwkv6", "rwkv7"];
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RwkvVersion {
V6,
V7,
}
impl fmt::Display for RwkvVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::V6 => write!(f, "RWKV-6 (Finch)"),
Self::V7 => write!(f, "RWKV-7 (Goose)"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RwkvLoraDims {
pub time_mix_extra_dim: usize,
pub time_decay_extra_dim: usize,
pub decay_low_rank_dim: usize,
pub a_low_rank_dim: usize,
pub v_low_rank_dim: usize,
pub gate_low_rank_dim: usize,
}
#[derive(Debug, Clone)]
pub struct RwkvConfig {
pub version: RwkvVersion,
pub hidden_size: usize,
pub num_layers: usize,
pub head_dim: usize,
pub num_heads: usize,
pub vocab_size: usize,
pub norm_eps: f64,
pub intermediate_size: usize,
pub rescale_every: Option<usize>,
pub head_size_divisor: Option<usize>,
pub lora_dims: RwkvLoraDims,
pub hidden_ratio: Option<f64>,
pub tie_word_embeddings: bool,
}
impl RwkvConfig {
pub fn from_hf_config(config: &Value) -> Result<Self> {
let model_type = config
.get("model_type")
.and_then(Value::as_str)
.ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
match model_type {
"rwkv6" => Self::parse_rwkv6(config),
"rwkv7" => Self::parse_rwkv7(config),
other => Err(MIError::Config(format!(
"unsupported RWKV model_type: '{other}'"
))),
}
}
#[must_use]
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
pub fn group_norm_eps(&self) -> f64 {
self.head_size_divisor
.map_or(self.norm_eps, |d| self.norm_eps * (d as f64).powi(2))
}
}
impl RwkvConfig {
fn parse_rwkv6(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let head_dim = get_usize(config, "num_attention_heads")?;
if head_dim == 0 {
return Err(MIError::Config(
"head_dim (num_attention_heads) is 0".into(),
));
}
let num_heads = hidden_size / head_dim;
let intermediate_size =
get_usize_or(config, "intermediate_size", (hidden_size * 7 / 2) / 32 * 32);
Ok(Self {
version: RwkvVersion::V6,
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
head_dim,
num_heads,
vocab_size: get_usize(config, "vocab_size")?,
norm_eps: get_f64_or(config, "layer_norm_epsilon", 1e-5),
intermediate_size,
rescale_every: Some(get_usize_or(config, "rescale_every", 6)),
head_size_divisor: Some(get_usize_or(config, "head_size_divisor", 8)),
lora_dims: RwkvLoraDims {
time_mix_extra_dim: 32,
time_decay_extra_dim: 64,
decay_low_rank_dim: 0,
a_low_rank_dim: 0,
v_low_rank_dim: 0,
gate_low_rank_dim: 0,
},
hidden_ratio: None,
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
})
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss,
clippy::as_conversions
)]
fn parse_rwkv7(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let head_dim = get_usize_or(config, "head_dim", 64);
if head_dim == 0 {
return Err(MIError::Config("head_dim is 0".into()));
}
let num_heads = hidden_size / head_dim;
let hidden_ratio = config
.get("hidden_ratio")
.and_then(Value::as_f64)
.unwrap_or(4.0);
let intermediate_size = get_usize_or(
config,
"intermediate_size",
Self::round_to_32((hidden_size as f64 * hidden_ratio) as usize),
);
let factor = head_dim as f64 / 64.0;
let sqrt_h = (hidden_size as f64).sqrt();
let decay_low_rank_dim = get_usize_or(
config,
"decay_low_rank_dim",
Self::fla_lora_default(2.5, sqrt_h, factor),
);
let a_low_rank_dim = get_usize_or(
config,
"a_low_rank_dim",
Self::fla_lora_default(2.5, sqrt_h, factor),
);
let v_low_rank_dim = get_usize_or(
config,
"v_low_rank_dim",
Self::fla_lora_default(1.7, sqrt_h, factor),
);
let gate_low_rank_dim = get_usize_or(
config,
"gate_low_rank_dim",
Self::fla_lora_default(5.0, sqrt_h, 1.0),
);
Ok(Self {
version: RwkvVersion::V7,
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
head_dim,
num_heads,
vocab_size: get_usize(config, "vocab_size")?,
norm_eps: get_f64_or(config, "norm_eps", 1e-5),
intermediate_size,
rescale_every: None,
head_size_divisor: None,
lora_dims: RwkvLoraDims {
time_mix_extra_dim: 0,
time_decay_extra_dim: 0,
decay_low_rank_dim,
a_low_rank_dim,
v_low_rank_dim,
gate_low_rank_dim,
},
hidden_ratio: Some(hidden_ratio),
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
})
}
const fn round_to_32(n: usize) -> usize {
n.div_ceil(32) * 32
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::as_conversions
)]
fn fla_lora_default(scale: f64, sqrt_h: f64, factor: f64) -> usize {
let raw = (scale * sqrt_h * factor / 32.0).round() as usize * 32;
raw.max(32)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn rwkv6_config_json() -> Value {
serde_json::json!({
"model_type": "rwkv6",
"hidden_size": 2048,
"num_hidden_layers": 24,
"num_attention_heads": 64,
"vocab_size": 65536,
"layer_norm_epsilon": 1e-5,
"head_size_divisor": 8,
"rescale_every": 6,
"tie_word_embeddings": false
})
}
#[test]
fn parse_rwkv6_basic() {
let config = RwkvConfig::from_hf_config(&rwkv6_config_json()).unwrap();
assert_eq!(config.version, RwkvVersion::V6);
assert_eq!(config.hidden_size, 2048);
assert_eq!(config.num_layers, 24);
assert_eq!(config.head_dim, 64);
assert_eq!(config.num_heads, 32);
assert_eq!(config.vocab_size, 65536);
assert!((config.norm_eps - 1e-5).abs() < f64::EPSILON);
assert_eq!(config.intermediate_size, 7168);
assert_eq!(config.rescale_every, Some(6));
assert_eq!(config.head_size_divisor, Some(8));
assert_eq!(config.lora_dims.time_mix_extra_dim, 32);
assert_eq!(config.lora_dims.time_decay_extra_dim, 64);
assert!(config.hidden_ratio.is_none());
assert!(!config.tie_word_embeddings);
}
#[test]
fn rwkv6_group_norm_eps() {
let config = RwkvConfig::from_hf_config(&rwkv6_config_json()).unwrap();
let expected = 1e-5 * 64.0;
assert!((config.group_norm_eps() - expected).abs() < f64::EPSILON);
}
#[test]
fn rwkv6_explicit_intermediate_size() {
let json = serde_json::json!({
"model_type": "rwkv6",
"hidden_size": 2048,
"num_hidden_layers": 24,
"num_attention_heads": 64,
"vocab_size": 65536,
"intermediate_size": 8192
});
let config = RwkvConfig::from_hf_config(&json).unwrap();
assert_eq!(config.intermediate_size, 8192);
}
fn rwkv7_config_json() -> Value {
serde_json::json!({
"model_type": "rwkv7",
"hidden_size": 2048,
"num_hidden_layers": 24,
"head_dim": 64,
"vocab_size": 65536,
"norm_eps": 1e-5,
"intermediate_size": 8192,
"hidden_ratio": 4.0,
"decay_low_rank_dim": 96,
"a_low_rank_dim": 96,
"v_low_rank_dim": 64,
"gate_low_rank_dim": 256,
"tie_word_embeddings": false
})
}
#[test]
fn parse_rwkv7_basic() {
let config = RwkvConfig::from_hf_config(&rwkv7_config_json()).unwrap();
assert_eq!(config.version, RwkvVersion::V7);
assert_eq!(config.hidden_size, 2048);
assert_eq!(config.num_layers, 24);
assert_eq!(config.head_dim, 64);
assert_eq!(config.num_heads, 32);
assert_eq!(config.vocab_size, 65536);
assert!((config.norm_eps - 1e-5).abs() < f64::EPSILON);
assert_eq!(config.intermediate_size, 8192);
assert!(config.rescale_every.is_none());
assert!(config.head_size_divisor.is_none());
assert_eq!(config.lora_dims.decay_low_rank_dim, 96);
assert_eq!(config.lora_dims.a_low_rank_dim, 96);
assert_eq!(config.lora_dims.v_low_rank_dim, 64);
assert_eq!(config.lora_dims.gate_low_rank_dim, 256);
assert_eq!(config.lora_dims.time_mix_extra_dim, 0);
assert_eq!(config.lora_dims.time_decay_extra_dim, 0);
assert_eq!(config.hidden_ratio, Some(4.0));
assert!(!config.tie_word_embeddings);
}
#[test]
fn rwkv7_group_norm_eps() {
let config = RwkvConfig::from_hf_config(&rwkv7_config_json()).unwrap();
assert!((config.group_norm_eps() - 1e-5).abs() < f64::EPSILON);
}
#[test]
fn rwkv7_default_lora_dims() {
let json = serde_json::json!({
"model_type": "rwkv7",
"hidden_size": 2048,
"num_hidden_layers": 24,
"vocab_size": 65536
});
let config = RwkvConfig::from_hf_config(&json).unwrap();
assert!(config.lora_dims.decay_low_rank_dim >= 32);
assert!(config.lora_dims.gate_low_rank_dim >= 32);
}
#[test]
fn unsupported_model_type_errors() {
let json = serde_json::json!({ "model_type": "gpt2" });
let result = RwkvConfig::from_hf_config(&json);
assert!(result.is_err());
}
#[test]
fn missing_model_type_errors() {
let json = serde_json::json!({ "hidden_size": 2048 });
let result = RwkvConfig::from_hf_config(&json);
assert!(result.is_err());
}
#[test]
fn version_display() {
assert_eq!(RwkvVersion::V6.to_string(), "RWKV-6 (Finch)");
assert_eq!(RwkvVersion::V7.to_string(), "RWKV-7 (Goose)");
}
}