use std::sync::OnceLock;
use serde::{Deserialize, Serialize};
use crate::quantization::QuantizationLevel;
use crate::registry::AcceleratorRegistry;
const MODELS_JSON: &str = include_str!("../data/models.json");
static PARSED_MODELS: OnceLock<Vec<ModelProfile>> = OnceLock::new();
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelProfile {
pub name: String,
pub family: String,
pub params_billions: f64,
pub default_dtype: String,
#[serde(default)]
pub formats: Vec<String>,
#[serde(default)]
pub context_lengths: Vec<u64>,
}
impl ModelProfile {
#[must_use]
#[inline]
pub fn param_count(&self) -> u64 {
(self.params_billions * 1_000_000_000.0) as u64
}
#[must_use]
#[inline]
pub fn memory_bytes(&self, quant: &QuantizationLevel) -> u64 {
AcceleratorRegistry::estimate_memory(self.param_count(), quant)
}
#[must_use]
#[inline]
pub fn memory_gb(&self, quant: &QuantizationLevel) -> f64 {
self.memory_bytes(quant) as f64 / crate::units::BYTES_PER_GIB
}
}
#[derive(Debug, Clone)]
pub struct CompatResult<'a> {
pub model: &'a ModelProfile,
pub memory_required_bytes: u64,
pub memory_available_bytes: u64,
pub headroom_pct: f64,
}
#[must_use]
pub fn all_models() -> &'static [ModelProfile] {
#[derive(Deserialize)]
struct ModelData {
#[serde(default)]
models: Vec<ModelProfile>,
}
PARSED_MODELS.get_or_init(|| {
serde_json::from_str::<ModelData>(MODELS_JSON)
.map(|d| d.models)
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "failed to parse embedded model catalogue");
Vec::new()
})
})
}
#[must_use]
#[inline]
pub fn can_run(
model: &ModelProfile,
quant: &QuantizationLevel,
available_memory_bytes: u64,
) -> bool {
model.memory_bytes(quant) <= available_memory_bytes
}
#[must_use]
pub fn compatible_models(
quant: &QuantizationLevel,
available_memory_bytes: u64,
) -> Vec<CompatResult<'static>> {
let mut results: Vec<CompatResult<'static>> = all_models()
.iter()
.filter_map(|model| {
let needed = model.memory_bytes(quant);
if needed <= available_memory_bytes {
let headroom = if available_memory_bytes == 0 {
0.0
} else {
(available_memory_bytes - needed) as f64 / available_memory_bytes as f64 * 100.0
};
Some(CompatResult {
model,
memory_required_bytes: needed,
memory_available_bytes: available_memory_bytes,
headroom_pct: headroom,
})
} else {
None
}
})
.collect();
results.sort_by(|a, b| {
b.model
.params_billions
.partial_cmp(&a.model.params_billions)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
#[must_use]
pub fn find_model(name: &str) -> Option<&'static ModelProfile> {
all_models().iter().find(|m| m.name == name)
}
#[must_use]
pub fn models_by_family(family: &str) -> Vec<&'static ModelProfile> {
all_models()
.iter()
.filter(|m| m.family.eq_ignore_ascii_case(family))
.collect()
}
#[must_use]
pub fn compatible_with_registry(
registry: &AcceleratorRegistry,
quant: &QuantizationLevel,
) -> Vec<CompatResult<'static>> {
let total_memory = registry.total_accelerator_memory();
compatible_models(quant, total_memory)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_models_not_empty() {
assert!(!all_models().is_empty());
}
#[test]
fn all_models_have_valid_params() {
for model in all_models() {
assert!(
model.params_billions > 0.0,
"model {} has zero params",
model.name
);
assert!(
model.param_count() > 0,
"model {} param_count is zero",
model.name
);
}
}
#[test]
fn find_model_exists() {
assert!(find_model("Llama 3.1 70B").is_some());
assert!(find_model("Nonexistent Model").is_none());
}
#[test]
fn models_by_family_llama() {
let llamas = models_by_family("llama");
assert!(llamas.len() >= 3);
assert!(llamas.iter().all(|m| m.family == "llama"));
}
#[test]
fn can_run_small_model_large_memory() {
let model = find_model("Llama 3.2 1B").unwrap();
assert!(can_run(
model,
&QuantizationLevel::Int4,
24 * 1024 * 1024 * 1024
));
}
#[test]
fn cannot_run_huge_model_small_memory() {
let model = find_model("Llama 3.1 405B").unwrap();
assert!(!can_run(
model,
&QuantizationLevel::None,
24 * 1024 * 1024 * 1024
));
}
#[test]
fn compatible_models_24gb_fp16() {
let results = compatible_models(&QuantizationLevel::Float16, 24 * 1024 * 1024 * 1024);
assert!(!results.is_empty());
assert!(!results.iter().any(|r| r.model.params_billions >= 70.0));
for w in results.windows(2) {
assert!(w[0].model.params_billions >= w[1].model.params_billions);
}
}
#[test]
fn compatible_models_headroom_is_valid() {
let results = compatible_models(&QuantizationLevel::Int4, 80 * 1024 * 1024 * 1024);
for r in &results {
assert!(r.headroom_pct >= 0.0);
assert!(r.headroom_pct <= 100.0);
}
}
#[test]
fn memory_gb_reasonable() {
let model = find_model("Llama 3.1 8B").unwrap();
let gb = model.memory_gb(&QuantizationLevel::Float16);
assert!(gb > 15.0 && gb < 25.0, "8B FP16 memory: {gb} GB");
}
#[test]
fn model_profile_serde_roundtrip() {
let model = find_model("Mistral 7B").unwrap();
let json = serde_json::to_string(model).unwrap();
let back: ModelProfile = serde_json::from_str(&json).unwrap();
assert_eq!(model.name, back.name);
assert_eq!(model.param_count(), back.param_count());
}
}