impl ModelFamily for DynModelFamily {
fn family_name(&self) -> &str {
&self.config.family
}
fn display_name(&self) -> &str {
&self.config.display_name
}
fn config(&self) -> &ModelFamilyConfig {
&self.config
}
fn size_config(&self, size: &str) -> Option<&ModelSizeConfig> {
self.config.size_variants.get(size)
}
fn detect_size(&self, hidden_dim: usize, num_layers: usize) -> Option<String> {
for (name, variant) in &self.config.size_variants {
if variant.hidden_dim == hidden_dim && variant.num_layers == num_layers {
return Some(name.clone());
}
}
None
}
fn constraints(&self) -> &ModelConstraints {
&self.config.constraints
}
fn expected_tensor_count(&self, size: &str) -> Option<usize> {
let variant = self.config.size_variants.get(size)?;
let num_layers = variant.num_layers;
let mut count = 0usize;
if !self.config.tensor_template.embedding.is_empty() {
count += 1;
}
if self.config.tensor_template.lm_head.is_some() {
count += 1;
}
if self.config.tensor_template.final_norm.is_some() {
count += 1;
}
let tensors_per_layer = self
.config
.tensor_template
.per_layer
.values()
.filter(|v| v.is_some())
.count();
count += tensors_per_layer * num_layers;
Some(count)
}
fn validate_tensor_names(
&self,
names: &[&str],
size: &str,
) -> std::result::Result<(), ContractError> {
let variant = self
.config
.size_variants
.get(size)
.ok_or_else(|| ContractError {
family: self.config.family.clone(),
message: format!("Unknown size variant: {size}"),
})?;
let mut expected: Vec<String> = Vec::new();
expected.push(self.config.tensor_template.embedding.clone());
if let Some(lm_head) = &self.config.tensor_template.lm_head {
expected.push(lm_head.clone());
}
if let Some(final_norm) = &self.config.tensor_template.final_norm {
expected.push(final_norm.clone());
}
for layer_idx in 0..variant.num_layers {
for pat in self.config.tensor_template.per_layer.values().flatten() {
expected.push(pat.replace("{n}", &layer_idx.to_string()));
}
}
let expected_set: std::collections::HashSet<&str> =
expected.iter().map(String::as_str).collect();
let actual_set: std::collections::HashSet<&str> = names.iter().copied().collect();
let missing: Vec<&str> = expected_set.difference(&actual_set).copied().collect();
let unexpected: Vec<&str> = actual_set.difference(&expected_set).copied().collect();
if !missing.is_empty() || !unexpected.is_empty() {
let mut msg = String::new();
if !missing.is_empty() {
msg.push_str(&format!("Missing tensors: {}", missing.join(", ")));
}
if !unexpected.is_empty() {
if !msg.is_empty() {
msg.push_str("; ");
}
msg.push_str(&format!("Unexpected tensors: {}", unexpected.join(", ")));
}
return Err(ContractError {
family: self.config.family.clone(),
message: msg,
});
}
Ok(())
}
}
#[derive(Debug)]
pub struct FamilyRegistry {
families: Vec<Box<dyn ModelFamily>>,
}
impl FamilyRegistry {
#[must_use]
pub fn new() -> Self {
Self {
families: Vec::new(),
}
}
pub fn register(&mut self, family: Box<dyn ModelFamily>) {
self.families.push(family);
}
#[must_use]
pub fn family_names(&self) -> Vec<&str> {
self.families.iter().map(|f| f.family_name()).collect()
}
#[must_use]
pub fn get(&self, family_name: &str) -> Option<&dyn ModelFamily> {
self.families
.iter()
.find(|f| f.family_name() == family_name)
.map(|f| f.as_ref())
}
#[must_use]
pub fn detect_family(&self, tensor_names: &[&str]) -> Option<&dyn ModelFamily> {
let mut best: Option<(usize, &dyn ModelFamily)> = None;
for family in &self.families {
let config = family.config();
if !tensor_names.contains(&config.tensor_template.embedding.as_str()) {
continue;
}
let mut score = 1usize;
for pattern in config.tensor_template.per_layer.values().flatten() {
let layer0 = pattern.replace("{n}", "0");
if tensor_names.contains(&layer0.as_str()) {
score += 1;
}
}
if score <= 1 {
continue;
}
match best {
None => best = Some((score, family.as_ref())),
Some((best_score, _)) if score > best_score => {
best = Some((score, family.as_ref()));
}
_ => {}
}
}
best.map(|(_, family)| family)
}
#[must_use]
pub fn detect_from_model_type(&self, model_type: &str) -> Option<&dyn ModelFamily> {
let model_type_lower = model_type.to_lowercase();
for family in &self.families {
if family.config().family == model_type_lower {
return Some(family.as_ref());
}
}
for family in &self.families {
let config = family.config();
for arch in &config.architectures {
if arch.to_lowercase().contains(&model_type_lower)
|| model_type_lower.contains(&config.family)
{
return Some(family.as_ref());
}
}
}
None
}
#[must_use]
pub fn len(&self) -> usize {
self.families.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.families.is_empty()
}
}
impl Default for FamilyRegistry {
fn default() -> Self {
Self::new()
}
}
include!(concat!(env!("OUT_DIR"), "/model_families_generated.rs"));