fn infer_hidden_size(tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>) -> (usize, bool) {
tensors
.iter()
.find(|(name, _)| name.contains("embed_tokens") || name.contains("token_embd"))
.map(|(name, (_, shape))| {
let dim = if shape.len() >= 2 {
let inferred = shape[0].min(shape[1]);
eprintln!(
"[GH-197] Inferred hidden_size={inferred} from tensor '{name}' \
(shape={shape:?}, picked smaller dim)"
);
inferred
} else {
shape.last().copied().unwrap_or(0)
};
(dim, true)
})
.unwrap_or((0, false))
}
fn infer_num_layers(tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>) -> usize {
let max_layer: Option<usize> = tensors
.keys()
.filter_map(|name| {
if name.contains("layers.") || name.contains("blk.") {
let parts: Vec<&str> = name.split(&['.', '_'][..]).collect();
for (i, part) in parts.iter().enumerate() {
if (*part == "layers" || *part == "blk") && i + 1 < parts.len() {
return parts[i + 1].parse::<usize>().ok();
}
}
}
None
})
.max();
if let Some(max) = max_layer {
let count = max + 1;
eprintln!("[GH-197] Inferred num_layers={count} from layer indices 0..{max}");
count
} else {
12
}
}
fn infer_vocab_size(tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>) -> (usize, bool) {
tensors
.iter()
.find(|(name, _)| name.contains("lm_head") || name.contains("output.weight"))
.or_else(|| {
tensors
.iter()
.find(|(name, _)| name.contains("embed_tokens") || name.contains("token_embd"))
})
.map(|(name, (_, shape))| {
let dim = if shape.len() >= 2 {
let inferred = shape[0].max(shape[1]);
eprintln!(
"[GH-197] Inferred vocab_size={inferred} from tensor '{name}' \
(shape={shape:?}, picked larger dim)"
);
inferred
} else {
shape.first().copied().unwrap_or(0)
};
(dim, true)
})
.unwrap_or((0, false))
}
fn infer_model_config(tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>) -> String {
let (hidden_size, hidden_inferred) = infer_hidden_size(tensors);
let num_layers = infer_num_layers(tensors);
let (vocab_size, vocab_inferred) = infer_vocab_size(tensors);
if vocab_inferred && hidden_inferred && vocab_size < hidden_size {
eprintln!(
"[GH-197] WARNING: vocab_size ({vocab_size}) < hidden_size ({hidden_size}). \
This is unusual for LLMs - dimensions may be swapped!"
);
}
let num_attention_heads = tensors
.iter()
.find(|(name, _)| {
name.contains("self_attn.q_proj")
|| name.contains("attn.q_proj")
|| name.contains("attention.wq")
})
.map(|(_, (_, _shape))| {
let head_dim = if hidden_size >= 4096 { 128 } else { 64 };
hidden_size / head_dim
})
.unwrap_or_else(|| {
match hidden_size {
896 => 14, 1536 => 12, 2048 => 16, 4096 => 32, 5120 => 40, 8192 => 64, _ => (hidden_size / 128).max(1), }
});
let intermediate_size = tensors
.iter()
.find(|(name, _)| {
name.contains("mlp.gate_proj")
|| name.contains("mlp.up_proj")
|| name.contains("feed_forward.w1")
})
.map(|(_, (_, shape))| shape.first().copied().unwrap_or(hidden_size * 4))
.unwrap_or(hidden_size * 4);
let head_dim = if num_attention_heads > 0 {
hidden_size / num_attention_heads
} else {
64 };
let num_key_value_heads = tensors
.iter()
.find(|(name, _)| {
name.contains("self_attn.k_proj")
|| name.contains("attn.k_proj")
|| name.contains("attention.wk")
})
.map(|(_, (_, shape))| {
let kv_dim = shape.first().copied().unwrap_or(hidden_size);
if head_dim > 0 {
(kv_dim / head_dim).max(1)
} else {
1
}
})
.unwrap_or(num_attention_heads);
let tokens = crate::demo::SpecialTokens::qwen2();
let bos_id = tokens.bos_id;
let eos_id = tokens.eos_id;
format!(
r#"{{
"architectures": ["Qwen2ForCausalLM"],
"bos_token_id": {bos_id},
"eos_token_id": {eos_id},
"hidden_act": "silu",
"hidden_size": {hidden_size},
"initializer_range": 0.02,
"intermediate_size": {intermediate_size},
"max_position_embeddings": 32768,
"model_type": "qwen2",
"num_attention_heads": {num_attention_heads},
"num_hidden_layers": {num_layers},
"num_key_value_heads": {num_key_value_heads},
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": {vocab_size}
}}"#
)
}
fn infer_tokenizer_json(input_path: &Path) -> String {
if input_path.extension().and_then(|e| e.to_str()) != Some("apr") {
return String::new();
}
extract_apr_tokenizer_hint(input_path).unwrap_or_default()
}
fn extract_apr_tokenizer_hint(input_path: &Path) -> Option<String> {
let data = fs::read(input_path).ok()?;
if data.len() <= 44 {
return None;
}
let metadata_start = 44;
let metadata_end = data[metadata_start..]
.windows(4)
.position(|w| w == b"}\n\n\n" || w == b"}\r\n\r")
.map(|p| metadata_start + p + 1)?;
let metadata_str = std::str::from_utf8(&data[metadata_start..metadata_end]).ok()?;
if metadata_str.contains("\"tokenizer\"") || metadata_str.contains("\"vocabulary\"") {
Some(r#"{"version": "1.0", "model": {"type": "BPE"}}"#.to_string())
} else {
None
}
}
fn read_apr_metadata(apr_path: &Path) -> Option<crate::format::v2::AprV2Metadata> {
if apr_path.extension().and_then(|e| e.to_str()) != Some("apr") {
return None;
}
let data = fs::read(apr_path).ok()?;
let reader = crate::format::v2::AprV2Reader::from_bytes(&data).ok()?;
Some(reader.metadata().clone())
}
fn split_qkv_weight(
name: &str,
data: &[f32],
shape: &[usize],
hidden_size: usize,
kv_dim: usize,
result: &mut BTreeMap<String, (Vec<f32>, Vec<usize>)>,
) -> bool {
let hidden_dim = if shape.len() >= 2 { shape[1] } else { hidden_size };
let q_elements = hidden_size * hidden_dim;
let kv_elements = kv_dim * hidden_dim;
if data.len() < q_elements + 2 * kv_elements {
return false;
}
let prefix = name.strip_suffix("qkv_proj.weight").unwrap_or(name);
result.insert(
format!("{prefix}q_proj.weight"),
(data[..q_elements].to_vec(), vec![hidden_size, hidden_dim]),
);
result.insert(
format!("{prefix}k_proj.weight"),
(
data[q_elements..q_elements + kv_elements].to_vec(),
vec![kv_dim, hidden_dim],
),
);
result.insert(
format!("{prefix}v_proj.weight"),
(
data[q_elements + kv_elements..q_elements + 2 * kv_elements].to_vec(),
vec![kv_dim, hidden_dim],
),
);
true
}
fn split_qkv_bias(
name: &str,
data: &[f32],
hidden_size: usize,
kv_dim: usize,
result: &mut BTreeMap<String, (Vec<f32>, Vec<usize>)>,
) -> bool {
let qkv_dim = hidden_size + 2 * kv_dim;
if data.len() != qkv_dim {
return false;
}
let prefix = name.strip_suffix("qkv_proj.bias").unwrap_or(name);
result.insert(
format!("{prefix}q_proj.bias"),
(data[..hidden_size].to_vec(), vec![hidden_size]),
);
result.insert(
format!("{prefix}k_proj.bias"),
(
data[hidden_size..hidden_size + kv_dim].to_vec(),
vec![kv_dim],
),
);
result.insert(
format!("{prefix}v_proj.bias"),
(data[hidden_size + kv_dim..].to_vec(), vec![kv_dim]),
);
true
}
fn unfuse_qkv_tensors(
tensors: BTreeMap<String, (Vec<f32>, Vec<usize>)>,
apr_path: &Path,
) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
let has_fused = tensors.keys().any(|k| k.contains("qkv_proj."));
if !has_fused {
return tensors;
}
let metadata = read_apr_metadata(apr_path);
let (hidden_size, num_heads, num_kv_heads) = match &metadata {
Some(m) => {
let hs = m.hidden_size.unwrap_or(0);
let nh = m.num_heads.unwrap_or(0);
let nkv = m.num_kv_heads.unwrap_or(nh);
(hs, nh, nkv)
}
None => return tensors,
};
if hidden_size == 0 || num_heads == 0 {
return tensors;
}
let head_dim = hidden_size / num_heads;
let kv_dim = num_kv_heads * head_dim;
let mut result = BTreeMap::new();
for (name, (data, shape)) in tensors {
if name.contains("qkv_proj.weight") {
if !split_qkv_weight(&name, &data, &shape, hidden_size, kv_dim, &mut result) {
result.insert(name, (data, shape));
}
} else if name.contains("qkv_proj.bias") {
if !split_qkv_bias(&name, &data, hidden_size, kv_dim, &mut result) {
result.insert(name, (data, shape));
}
} else {
result.insert(name, (data, shape));
}
}
result
}
fn remove_tied_lm_head(
mut tensors: BTreeMap<String, (Vec<f32>, Vec<usize>)>,
apr_path: &Path,
) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
let metadata = read_apr_metadata(apr_path);
let is_tied = metadata
.as_ref()
.and_then(|m| m.custom.get("tied_embeddings"))
.and_then(|v| v.as_bool())
.unwrap_or(false);
if is_tied {
tensors.remove("lm_head.weight");
}
tensors
}
fn extract_user_metadata(apr_path: &Path) -> UserMetadata {
let data = match fs::read(apr_path) {
Ok(d) => d,
Err(_) => return UserMetadata::new(),
};
if data.len() < 16 {
return UserMetadata::new();
}
let metadata_len = u64::from_le_bytes(data[8..16].try_into().unwrap_or([0u8; 8])) as usize;
if data.len() < 16 + metadata_len {
return UserMetadata::new();
}
let metadata_json = match std::str::from_utf8(&data[16..16 + metadata_len]) {
Ok(s) => s,
Err(_) => return UserMetadata::new(),
};
let parsed: serde_json::Value = match serde_json::from_str(metadata_json) {
Ok(v) => v,
Err(_) => return UserMetadata::new(),
};
if let Some(serde_json::Value::Object(map)) =
parsed.get("custom").and_then(|c| c.get("source_metadata"))
{
let mut result = UserMetadata::new();
for (k, v) in map {
if let serde_json::Value::String(s) = v {
result.insert(k.clone(), s.clone());
}
}
return result;
}
UserMetadata::new()
}
pub(crate) fn detect_apr_quantization(apr_path: &Path) -> Option<QuantizationType> {
use crate::format::v2::{AprV2Reader, TensorDType};
let data = fs::read(apr_path).ok()?;
let reader = AprV2Reader::from_bytes(&data).ok()?;
let mut q4k_count = 0usize;
let mut q6k_count = 0usize;
let mut other_count = 0usize;
for name in reader.tensor_names() {
if let Some(entry) = reader.get_tensor(name) {
if entry.shape.len() >= 2 {
match entry.dtype {
TensorDType::Q4K => q4k_count += 1,
TensorDType::Q6K => q6k_count += 1,
_ => other_count += 1,
}
}
}
}
let total = q4k_count + q6k_count + other_count;
if total == 0 {
return None;
}
if q4k_count > q6k_count && q4k_count > other_count {
return Some(QuantizationType::Q4K);
}
None
}