use serde::Serialize;
pub mod config;
pub mod family;
pub mod kernel_ops;
pub mod output;
pub mod proof;
pub mod resolve;
#[cfg(test)]
mod tests;
macro_rules! embed_family {
($name:expr, $path:expr) => {
(
$name,
include_str!(concat!("../../../contracts/model-families/", $path)),
)
};
}
const FAMILY_YAMLS: &[(&str, &str)] = &[
embed_family!("bert", "bert.yaml"),
embed_family!("deepseek", "deepseek.yaml"),
embed_family!("falcon_h1", "falcon_h1.yaml"),
embed_family!("gemma", "gemma.yaml"),
embed_family!("gpt2", "gpt2.yaml"),
embed_family!("llama", "llama.yaml"),
embed_family!("mamba", "mamba.yaml"),
embed_family!("mistral", "mistral.yaml"),
embed_family!("moonshine", "moonshine.yaml"),
embed_family!("openelm", "openelm.yaml"),
embed_family!("phi", "phi.yaml"),
embed_family!("qwen2", "qwen2.yaml"),
embed_family!("qwen3", "qwen3.yaml"),
embed_family!("qwen3_5", "qwen3_5.yaml"),
embed_family!("rwkv7", "rwkv7.yaml"),
embed_family!("whisper", "whisper.yaml"),
];
macro_rules! embed_contract {
($name:expr, $path:expr) => {
($name, include_str!(concat!("../../../contracts/", $path)))
};
}
const KERNEL_CONTRACTS: &[(&str, &str)] = &[
embed_contract!("matvec-kernel-v1", "matvec-kernel-v1.yaml"),
embed_contract!("rope-kernel-v1", "rope-kernel-v1.yaml"),
embed_contract!("normalization-kernel-v1", "normalization-kernel-v1.yaml"),
embed_contract!("element-wise-ops-v1", "element-wise-ops-v1.yaml"),
embed_contract!("softmax-kernel-v1", "softmax-kernel-v1.yaml"),
embed_contract!("kernel-fusion-v1", "kernel-fusion-v1.yaml"),
embed_contract!("tensor-layout-v1", "tensor-layout-v1.yaml"),
embed_contract!("quantized-dot-product-v1", "quantized-dot-product-v1.yaml"),
embed_contract!("transpose-kernel-v1", "transpose-kernel-v1.yaml"),
];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum KernelClass {
A, B, C, D, E, F, Ssm, Linear, Unknown,
}
impl KernelClass {
pub fn label(self) -> &'static str {
match self {
Self::A => "A (GQA + RMSNorm + SiLU + SwiGLU + RoPE)",
Self::B => "B (MHA + LayerNorm + GELU)",
Self::C => "C (MQA + LayerNorm + GELU + ALiBi)",
Self::D => "D (GQA + LayerNorm + GELU/SiLU)",
Self::E => "E (MoE + GQA + RMSNorm + SwiGLU)",
Self::F => "F (RMSNorm + GELU + GatedMlp + RoPE)",
Self::Ssm => "SSM (State Space Model + RMSNorm + SiLU)",
Self::Linear => "Linear (WKV Recurrence + LayerNorm + GELU)",
Self::Unknown => "Unknown",
}
}
pub fn letter(self) -> &'static str {
match self {
Self::A => "A",
Self::B => "B",
Self::C => "C",
Self::D => "D",
Self::E => "E",
Self::F => "F",
Self::Ssm => "SSM",
Self::Linear => "Linear",
Self::Unknown => "Unknown",
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct KernelOp {
pub op: &'static str,
pub kernel: &'static str,
pub contract: &'static str,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct Constraints {
pub attention_type: String,
pub activation: String,
pub norm_type: String,
pub mlp_type: String,
pub positional_encoding: String,
pub has_bias: bool,
pub tied_embeddings: bool,
}
#[derive(Debug, Clone, Serialize)]
pub struct FamilyInfo {
pub family: String,
pub display_name: String,
pub architectures: Vec<String>,
pub constraints: Constraints,
pub kernel_class: KernelClass,
}
#[derive(Debug, Clone, Serialize)]
pub struct ConfigField {
pub value: String,
pub rationale: String,
}
const FAMILY_ALIASES: &[(&str, &str)] = &[
("olmo", "llama"), ("olmo2", "llama"), ("granite", "llama"), ("internlm2", "llama"), ("phi3", "llama"), ("phi4", "llama"), ("codellama", "llama"),
("tinyllama", "llama"), ("stablelm", "llama"),
("yi", "llama"),
("baichuan", "llama"),
("gpt_neo", "bert"), ("gptneo", "bert"), ("gpt_neox", "bert"), ("gptneox", "bert"), ("gpt_j", "bert"), ("gptj", "bert"), ("gpt_bigcode", "bert"), ("gptbigcode", "bert"), ("starcoder1", "bert"), ("codegen", "bert"), ("xglm", "bert"), ("opt", "bert"), ("galactica", "bert"), ("roberta", "bert"), ("deberta", "bert"), ("electra", "bert"), ("distilbert", "bert"), ("phi3small", "phi"), ("codegemma", "gemma"), ("gemma2", "gemma"), ("gemma3", "gemma"), ("starcoder2", "qwen2"), ("qwen2_moe", "qwen2"), ("qwen2moe", "qwen2"), ("qwen3_moe", "mistral"), ("qwen3moe", "mistral"), ("qwen3_next", "mistral"), ("qwen3next", "mistral"), ("deepseek_v2", "deepseek"), ("deepseekv2", "deepseek"), ("mixtral", "mistral"), ("codestral", "mistral"), ("mathstral", "mistral"), ("pixtral", "mistral"), ("zephyr", "mistral"), ("openchat", "mistral"), ("openhermes", "mistral"), ("nemotron", "llama"), ("solar", "llama"), ("vicuna", "llama"), ("qwq", "qwen2"), ("smollm", "llama"), ("smollm2", "llama"), ("falcon", "bert"), ("bloom", "bert"), ("bloomz", "bert"), ("bloom_560m", "bert"), ("bigscience", "bert"), ];
pub use config::{
extract_config_mapping,
extract_json_string,
};
pub use family::load_families;
pub use output::{build_json_output, print_human_output};
pub use resolve::{family_aliases, resolve_family, resolve_from_config_json};