use svod_tensor::Tensor;
use crate::state::StateDict;
use super::GigaAmConfig;
use super::error::{Error, Result};
use super::ConvNormType;
pub fn remap_pytorch(sd: StateDict, config: &GigaAmConfig) -> Result<StateDict> {
let mut out = StateDict::new();
let mut bn_var_keys: Vec<(String, Tensor)> = Vec::new();
for (key, tensor) in sd {
let Some(mapped) = remap_key(&key, config) else {
continue;
};
if mapped.starts_with("__bn_var__.") {
let layer_idx = mapped.strip_prefix("__bn_var__.").unwrap().to_string();
bn_var_keys.push((layer_idx, tensor));
continue;
}
out.insert(mapped, tensor);
}
if matches!(config.conv_norm_type, ConvNormType::BatchNorm) {
for (layer_idx, var_tensor) in bn_var_keys {
let data = var_tensor.as_vec::<f32>().map_err(|e| Error::Tensor { source: Box::new(e) })?;
let invstd: Vec<f32> = data.iter().map(|&v| 1.0 / (v + 1e-5).sqrt()).collect();
let invstd_tensor = Tensor::from_slice(&invstd);
out.insert(format!("layers.{layer_idx}.conv.bn_invstd"), invstd_tensor);
}
}
Ok(out)
}
fn remap_key(key: &str, config: &GigaAmConfig) -> Option<String> {
let key = key.strip_prefix("model.").unwrap_or(key);
let parts: Vec<&str> = key.split('.').collect();
if parts.len() == 5 && parts[..3] == ["encoder", "pre_encode", "conv"] {
let idx = parts[3];
let param = parts[4];
let conv_map = match idx {
"0" => "conv1",
"2" => "conv2",
_ => return None,
};
return Some(format!("subsampling.{conv_map}_{param}"));
}
if parts.len() == 4 && parts[..3] == ["encoder", "pre_encode", "out"] {
return Some(format!("subsampling.linear_{}", parts[3]));
}
if parts.len() >= 4 && parts[..2] == ["encoder", "layers"] {
let i = parts[2];
let rest = &parts[3..];
return remap_encoder_layer(i, rest, config);
}
if parts.len() >= 4 && parts[..2] == ["head", "decoder_layers"] {
return Some(format!("head.{}", &parts[3..].join(".")));
}
if parts.len() >= 3 && parts[..2] == ["head", "decoder"] {
if parts[2] == "embed" && parts.len() == 4 && parts[3] == "weight" {
return Some("head.predictor.embed".to_string());
}
if parts[2] == "lstm" && parts.len() == 4 {
return remap_rnnt_lstm_param(parts[3]);
}
return None;
}
if parts.len() >= 4 && parts[..2] == ["head", "joint"] {
let sub = parts[2];
let last = parts.last().unwrap();
if (sub == "enc" || sub == "pred") && parts.len() == 4 {
let suffix = match *last {
"weight" => "w",
"bias" => "b",
_ => return None,
};
return Some(format!("head.joint.{sub}_{suffix}"));
}
if sub == "joint_net" && parts.len() == 5 && parts[3] == "1" {
let suffix = match *last {
"weight" => "out_w",
"bias" => "out_b",
_ => return None,
};
return Some(format!("head.joint.{suffix}"));
}
return None;
}
None
}
fn remap_encoder_layer(i: &str, rest: &[&str], config: &GigaAmConfig) -> Option<String> {
const PASSTHROUGH: &[(&str, &str)] = &[
("norm_feed_forward1", "ffn1.norm"),
("feed_forward1", "ffn1"),
("norm_self_att", "mhsa.norm"),
("norm_conv", "conv.norm"),
("norm_feed_forward2", "ffn2.norm"),
("feed_forward2", "ffn2"),
("norm_out", "final_norm"),
];
if let Some(&(_, target)) = PASSTHROUGH.iter().find(|(src, _)| *src == rest[0]) {
return Some(format!("layers.{i}.{target}.{}", rest[1..].join(".")));
}
if rest[0] == "self_attn" && rest.len() == 3 {
let base = match rest[1] {
"linear_q" => "q",
"linear_k" => "k",
"linear_v" => "v",
"linear_out" => "out",
_ => return None,
};
let suffix = match rest[2] {
"weight" => "proj",
"bias" => "bias",
_ => return None,
};
return Some(format!("layers.{i}.mhsa.{base}_{suffix}"));
}
if rest[0] == "conv" && rest.len() == 3 {
let param = rest[2];
return match rest[1] {
"pointwise_conv1" => Some(format!("layers.{i}.conv.pw1_{param}")),
"depthwise_conv" => Some(format!("layers.{i}.conv.dw_{param}")),
"pointwise_conv2" => Some(format!("layers.{i}.conv.pw2_{param}")),
"batch_norm" => remap_bn_key(i, param, config),
_ => None,
};
}
None
}
fn remap_rnnt_lstm_param(token: &str) -> Option<String> {
let (base, layer) = token.rsplit_once("_l")?;
if !layer.chars().all(|c| c.is_ascii_digit()) {
return None;
}
let mapped = match base {
"weight_ih" => "w_ih",
"weight_hh" => "w_hh",
"bias_ih" => "b_ih",
"bias_hh" => "b_hh",
_ => return None,
};
Some(format!("head.predictor.lstm.{layer}.{mapped}"))
}
fn remap_bn_key(layer: &str, param: &str, config: &GigaAmConfig) -> Option<String> {
match &config.conv_norm_type {
ConvNormType::LayerNorm => match param {
"weight" => Some(format!("layers.{layer}.conv.conv_norm.weight")),
"bias" => Some(format!("layers.{layer}.conv.conv_norm.bias")),
_ => None,
},
ConvNormType::BatchNorm => match param {
"weight" => Some(format!("layers.{layer}.conv.bn_scale")),
"bias" => Some(format!("layers.{layer}.conv.bn_bias")),
"running_mean" => Some(format!("layers.{layer}.conv.bn_mean")),
"running_var" => Some(format!("__bn_var__.{layer}")),
"num_batches_tracked" => None,
_ => None,
},
}
}