#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/tensor_names_generated.rs"));
use crate::error::{RealizarError, Result};
use crate::safetensors_infer::TensorSource;
pub(crate) fn resolve_global<S: TensorSource>(
source: &S,
arch: &str,
role: GlobalTensorRole,
) -> Result<Vec<f32>> {
let arch_key = normalize_architecture(arch);
let mut tried = Vec::new();
for name in global_names(arch_key, role) {
if let Ok(t) = source.get_tensor_auto(name) {
return Ok(t);
}
tried.push(*name);
}
for name in global_fallback_names(role) {
if let Ok(t) = source.get_tensor_auto(name) {
return Ok(t);
}
tried.push(*name);
}
for name in global_names(arch_key, role) {
if let Some(bare) = name.strip_prefix("model.") {
if let Ok(t) = source.get_tensor_auto(bare) {
return Ok(t);
}
tried.push(bare);
}
}
let available = source.tensor_names();
let sample: Vec<&str> = available.iter().take(5).copied().collect();
Err(RealizarError::UnsupportedOperation {
operation: "tensor_names::resolve_global".to_string(),
reason: format!(
"Tensor not found for {:?} (arch='{}'). Tried: {:?}. \
Available tensors ({} total): {:?}{}",
role,
arch,
tried,
available.len(),
sample,
if available.len() > 5 { ", ..." } else { "" }
),
})
}
pub(crate) fn resolve_global_optional<S: TensorSource>(
source: &S,
arch: &str,
role: GlobalTensorRole,
) -> Option<Vec<f32>> {
resolve_global(source, arch, role).ok()
}
pub(crate) fn has_global<S: TensorSource>(source: &S, arch: &str, role: GlobalTensorRole) -> bool {
let arch_key = normalize_architecture(arch);
for name in global_names(arch_key, role) {
if source.has_tensor(name) {
return true;
}
}
for name in global_fallback_names(role) {
if source.has_tensor(name) {
return true;
}
}
for name in global_names(arch_key, role) {
if let Some(bare) = name.strip_prefix("model.") {
if source.has_tensor(bare) {
return true;
}
}
}
false
}
pub(crate) fn resolve_layer<S: TensorSource>(
source: &S,
arch: &str,
layer_idx: usize,
role: LayerTensorRole,
) -> Result<Vec<f32>> {
let arch_key = normalize_architecture(arch);
let mut tried = Vec::new();
for template in layer_templates(arch_key, role) {
let name = template.replace("{n}", &layer_idx.to_string());
if let Ok(t) = source.get_tensor_auto(&name) {
return Ok(t);
}
tried.push(name);
}
for template in layer_fallback_templates(role) {
let name = template.replace("{n}", &layer_idx.to_string());
if let Ok(t) = source.get_tensor_auto(&name) {
return Ok(t);
}
tried.push(name);
}
for template in layer_templates(arch_key, role) {
let name = template.replace("{n}", &layer_idx.to_string());
if let Some(bare) = name.strip_prefix("model.") {
if let Ok(t) = source.get_tensor_auto(bare) {
return Ok(t);
}
tried.push(bare.to_string());
}
}
let available = source.tensor_names();
let sample: Vec<&str> = available.iter().take(5).copied().collect();
Err(RealizarError::UnsupportedOperation {
operation: "tensor_names::resolve_layer".to_string(),
reason: format!(
"Tensor not found for {:?} layer {} (arch='{}'). Tried: {:?}. \
Available tensors ({} total): {:?}{}",
role,
layer_idx,
arch,
tried,
available.len(),
sample,
if available.len() > 5 { ", ..." } else { "" }
),
})
}
pub(crate) fn resolve_layer_optional<S: TensorSource>(
source: &S,
arch: &str,
layer_idx: usize,
role: LayerTensorRole,
) -> Option<Vec<f32>> {
resolve_layer(source, arch, layer_idx, role).ok()
}
pub(crate) fn has_layer<S: TensorSource>(
source: &S,
arch: &str,
layer_idx: usize,
role: LayerTensorRole,
) -> bool {
let arch_key = normalize_architecture(arch);
for template in layer_templates(arch_key, role) {
let name = template.replace("{n}", &layer_idx.to_string());
if source.has_tensor(&name) {
return true;
}
}
for template in layer_fallback_templates(role) {
let name = template.replace("{n}", &layer_idx.to_string());
if source.has_tensor(&name) {
return true;
}
}
for template in layer_templates(arch_key, role) {
let name = template.replace("{n}", &layer_idx.to_string());
if let Some(bare) = name.strip_prefix("model.") {
if source.has_tensor(bare) {
return true;
}
}
}
false
}
pub(crate) fn has_fused<S: TensorSource>(
source: &S,
arch: &str,
layer_idx: usize,
role: FusedTensorRole,
) -> bool {
let arch_key = normalize_architecture(arch);
for template in fused_templates(arch_key, role) {
let name = template.replace("{n}", &layer_idx.to_string());
if source.has_tensor(&name) {
return true;
}
}
for template in fused_fallback_templates(role) {
let name = template.replace("{n}", &layer_idx.to_string());
if source.has_tensor(&name) {
return true;
}
}
false
}
pub(crate) fn resolve_fused<S: TensorSource>(
source: &S,
arch: &str,
layer_idx: usize,
role: FusedTensorRole,
) -> Option<Vec<f32>> {
let arch_key = normalize_architecture(arch);
for template in fused_templates(arch_key, role) {
let name = template.replace("{n}", &layer_idx.to_string());
if let Ok(t) = source.get_tensor_auto(&name) {
return Some(t);
}
}
for template in fused_fallback_templates(role) {
let name = template.replace("{n}", &layer_idx.to_string());
if let Ok(t) = source.get_tensor_auto(&name) {
return Some(t);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_phi_architecture_distinction() {
assert_eq!(normalize_architecture("PhiForCausalLM"), "phi2");
assert_eq!(normalize_architecture("Phi3ForCausalLM"), "phi");
assert_eq!(normalize_architecture("Phi3SmallForCausalLM"), "phi");
}
#[test]
fn test_unknown_architecture_defaults_to_llama() {
assert_eq!(normalize_architecture("FutureArch2027"), "llama");
assert_eq!(normalize_architecture(""), "llama");
assert_eq!(normalize_architecture("SomeRandomModel"), "llama");
}
#[test]
fn test_gpt2_global_names() {
let names = global_names("gpt2", GlobalTensorRole::Embedding);
assert!(names.contains(&"wte.weight"));
assert!(!names.iter().any(|n| n.starts_with("model.")));
}
#[test]
fn test_gpt2_fused_qkv() {
let fused = fused_templates("gpt2", FusedTensorRole::FusedQkv);
assert!(!fused.is_empty(), "GPT-2 should have fused QKV templates");
let q_templates = layer_templates("gpt2", LayerTensorRole::QProjWeight);
assert!(
q_templates.is_empty(),
"GPT-2 should not have separate Q template"
);
}
#[test]
fn test_architecture_map_completeness() {
let known_archs = [
"llama", "qwen2", "qwen3", "mistral", "gemma", "phi", "phi2", "deepseek", "gpt2",
"gpt_neox", "bert", "openelm", "falcon", "stablelm",
];
for arch in known_archs {
let embed = global_names(arch, GlobalTensorRole::Embedding);
let norm = global_names(arch, GlobalTensorRole::OutputNormWeight);
assert!(
!embed.is_empty() || !norm.is_empty(),
"Architecture '{}' has no global names defined",
arch
);
}
}
#[test]
fn test_required_roles_have_fallbacks() {
let fb = global_fallback_names(GlobalTensorRole::Embedding);
assert!(!fb.is_empty(), "Embedding must have fallback names");
let fb = global_fallback_names(GlobalTensorRole::OutputNormWeight);
assert!(!fb.is_empty(), "OutputNormWeight must have fallback names");
let fb = layer_fallback_templates(LayerTensorRole::AttnNormWeight);
assert!(
!fb.is_empty(),
"AttnNormWeight must have fallback templates"
);
let fb = layer_fallback_templates(LayerTensorRole::FfnUpWeight);
assert!(!fb.is_empty(), "FfnUpWeight must have fallback templates");
}
#[test]
fn test_llama_names_backward_compatible() {
let embed = global_names("llama", GlobalTensorRole::Embedding);
assert!(embed.contains(&"model.embed_tokens.weight"));
let norm = global_names("llama", GlobalTensorRole::OutputNormWeight);
assert!(norm.contains(&"model.norm.weight"));
let q = layer_templates("llama", LayerTensorRole::QProjWeight);
assert!(q.contains(&"model.layers.{n}.self_attn.q_proj.weight"));
}
#[test]
fn test_phi2_mlp_names() {
let up = layer_templates("phi2", LayerTensorRole::FfnUpWeight);
assert!(
up.iter().any(|t| t.contains("fc1")),
"Phi-2 should use fc1 for FFN up: {:?}",
up
);
let down = layer_templates("phi2", LayerTensorRole::FfnDownWeight);
assert!(
down.iter().any(|t| t.contains("fc2")),
"Phi-2 should use fc2 for FFN down: {:?}",
down
);
let gate = layer_templates("phi2", LayerTensorRole::FfnGateWeight);
assert!(gate.is_empty(), "Phi-2 should have no gate projection");
}
#[test]
fn test_phi2_output_norm() {
let norm = global_names("phi2", GlobalTensorRole::OutputNormWeight);
assert!(
norm.iter().any(|n| n.contains("final_layernorm")),
"Phi-2 should use final_layernorm: {:?}",
norm
);
}
#[test]
fn test_gpt_neox_fused_qkv() {
let fused = fused_templates("gpt_neox", FusedTensorRole::FusedQkv);
assert!(
fused.iter().any(|t| t.contains("query_key_value")),
"GPT-NeoX should have query_key_value fused template"
);
}
#[test]
fn test_hf_class_name_mapping() {
assert_eq!(normalize_architecture("LlamaForCausalLM"), "llama");
assert_eq!(normalize_architecture("Qwen2ForCausalLM"), "qwen2");
assert_eq!(normalize_architecture("Qwen3ForCausalLM"), "qwen3");
assert_eq!(normalize_architecture("MistralForCausalLM"), "mistral");
assert_eq!(normalize_architecture("GemmaForCausalLM"), "gemma");
assert_eq!(normalize_architecture("GPT2LMHeadModel"), "gpt2");
assert_eq!(normalize_architecture("GPTNeoXForCausalLM"), "gpt_neox");
assert_eq!(normalize_architecture("BertModel"), "bert");
assert_eq!(normalize_architecture("FalconForCausalLM"), "falcon");
}
}