Skip to main content

provable_contracts/query/
registry.rs

1//! Contract tier and kernel equivalence class lookups.
2//!
3//! Derived from `docs/specifications/sub/registry.md`. Contracts not
4//! explicitly listed default to Tier 7 (performance/KAIZEN).
5
6/// Contract tier (1-7). Lower tier = more foundational.
7pub fn tier_of(stem: &str) -> u8 {
8    // Tier 1: Foundation kernels
9    const TIER1: &[&str] = &[
10        "softmax-kernel-v1",
11        "rmsnorm-kernel-v1",
12        "rope-kernel-v1",
13        "silu-kernel-v1",
14        "swiglu-kernel-v1",
15        "gelu-kernel-v1",
16        "layernorm-kernel-v1",
17        "batchnorm-kernel-v1",
18        "embedding-lookup-v1",
19        "cross-entropy-kernel-v1",
20        "linear-projection-v1",
21        "dropout-v1",
22        "activation-kernel-v1",
23        "bias-add-v1",
24        "transpose-kernel-v1",
25    ];
26    // Tier 2: Composite kernels
27    const TIER2: &[&str] = &[
28        "attention-kernel-v1",
29        "gqa-kernel-v1",
30        "matmul-kernel-v1",
31        "flash-attention-v1",
32        "sliding-window-attention-v1",
33        "qk-norm-v1",
34        "attention-scaling-v1",
35        "bidirectional-attention-v1",
36    ];
37    // Tier 3: System kernels
38    const TIER3: &[&str] = &[
39        "kv-cache-equivalence-v1",
40        "kv-cache-sizing-v1",
41        "sampling-algorithms-v1",
42        "inference-pipeline-v1",
43        "streaming-tpot-v1",
44        "backend-dispatch-v1",
45        "conversation-generation-v1",
46        "safetensors-cpu-dispatch-v1",
47        "safetensors-header-v1",
48        "weight-loading-v1",
49        "validated-tensor-v1",
50    ];
51    // Tier 4: Training kernels
52    const TIER4: &[&str] = &[
53        "adamw-kernel-v1",
54        "loss-functions-v1",
55        "lora-algebra-v1",
56        "classification-finetune-v1",
57        "optimization-v1",
58        "lbfgs-kernel-v1",
59        "cmaes-kernel-v1",
60        "gradient-clipping-v1",
61        "learning-rate-scheduling-v1",
62    ];
63    // Tier 5: Classical ML
64    const TIER5: &[&str] = &[
65        "kmeans-kernel-v1",
66        "pagerank-kernel-v1",
67        "pca-v1",
68        "svm-v1",
69        "decision-tree-v1",
70        "random-forest-v1",
71        "naive-bayes-v1",
72        "gbm-v1",
73        "arima-v1",
74        "bayesian-v1",
75        "calibration-v1",
76        "dbscan-v1",
77        "gaussian-mixture-v1",
78        "isotonic-regression-v1",
79        "knn-v1",
80        "logistic-regression-v1",
81        "linear-probe-classifier-v1",
82        "active-learning-v1",
83    ];
84    // Tier 6: Model-specific
85    const TIER6: &[&str] = &[
86        "qwen2-shapes-v1",
87        "qwen2-e2e-verification-v1",
88        "qwen3-shapes-v1",
89        "qwen3-e2e-verification-v1",
90        "qwen3moe-shapes-v1",
91        "qwen3moe-e2e-verification-v1",
92        "qwen35-shapes-v1",
93        "qwen35-hybrid-forward-v1",
94    ];
95
96    if TIER1.contains(&stem) {
97        1
98    } else if TIER2.contains(&stem) {
99        2
100    } else if TIER3.contains(&stem) {
101        3
102    } else if TIER4.contains(&stem) {
103        4
104    } else if TIER5.contains(&stem) {
105        5
106    } else if TIER6.contains(&stem) {
107        6
108    } else {
109        7
110    }
111}
112
113/// Kernel equivalence class (A-E). Returns None for contracts that
114/// don't belong to any class (training, classical ML, etc.).
115pub fn class_of(stem: &str) -> Option<char> {
116    // Class A: Llama / Mistral / Yi — GQA + RMSNorm + SiLU + SwiGLU + RoPE
117    const CLASS_A: &[&str] = &[
118        "gqa-kernel-v1",
119        "rmsnorm-kernel-v1",
120        "silu-kernel-v1",
121        "swiglu-kernel-v1",
122        "rope-kernel-v1",
123    ];
124    // Class B: GPT-2 / BERT — MHA + LayerNorm + GELU + AbsPos
125    const CLASS_B: &[&str] = &[
126        "attention-kernel-v1",
127        "layernorm-kernel-v1",
128        "gelu-kernel-v1",
129        "absolute-position-v1",
130        "bidirectional-attention-v1",
131    ];
132    // Class C: BLOOM / MPT — MHA + LayerNorm + GELU + ALiBi
133    const CLASS_C: &[&str] = &["alibi-kernel-v1"];
134    // Class D: Gemma — LayerNorm + GELU + SiLU + GQA
135    // (shares contracts with A and B; D-unique contracts are few)
136    const CLASS_D: &[&str] = &[];
137    // Class E: Qwen — RMSNorm + SwiGLU + GQA + model-specific
138    const CLASS_E: &[&str] = &[
139        "qwen2-shapes-v1",
140        "qwen2-e2e-verification-v1",
141        "qwen3-shapes-v1",
142        "qwen3-e2e-verification-v1",
143        "qwen3moe-shapes-v1",
144        "qwen3moe-e2e-verification-v1",
145        "qwen35-shapes-v1",
146        "qwen35-hybrid-forward-v1",
147    ];
148
149    // A contract can belong to multiple classes. Return the primary one.
150    if CLASS_A.contains(&stem) {
151        Some('A')
152    } else if CLASS_B.contains(&stem) {
153        Some('B')
154    } else if CLASS_C.contains(&stem) {
155        Some('C')
156    } else if CLASS_D.contains(&stem) {
157        Some('D')
158    } else if CLASS_E.contains(&stem) {
159        Some('E')
160    } else {
161        None
162    }
163}
164
165/// All classes a contract belongs to (a contract can be in multiple classes).
166pub fn classes_of(stem: &str) -> Vec<char> {
167    let mut result = Vec::new();
168    // Class A
169    if matches!(
170        stem,
171        "gqa-kernel-v1"
172            | "rmsnorm-kernel-v1"
173            | "silu-kernel-v1"
174            | "swiglu-kernel-v1"
175            | "rope-kernel-v1"
176    ) {
177        result.push('A');
178    }
179    // Class B
180    if matches!(
181        stem,
182        "attention-kernel-v1"
183            | "layernorm-kernel-v1"
184            | "gelu-kernel-v1"
185            | "absolute-position-v1"
186            | "bidirectional-attention-v1"
187    ) {
188        result.push('B');
189    }
190    // Class C (shares attention, layernorm, gelu with B)
191    if matches!(
192        stem,
193        "attention-kernel-v1" | "layernorm-kernel-v1" | "gelu-kernel-v1" | "alibi-kernel-v1"
194    ) {
195        result.push('C');
196    }
197    // Class D (shares layernorm, gelu, silu, gqa with A/B)
198    if matches!(
199        stem,
200        "layernorm-kernel-v1" | "gelu-kernel-v1" | "silu-kernel-v1" | "gqa-kernel-v1"
201    ) {
202        result.push('D');
203    }
204    // Class E
205    if matches!(
206        stem,
207        "rmsnorm-kernel-v1"
208            | "swiglu-kernel-v1"
209            | "gqa-kernel-v1"
210            | "qwen2-shapes-v1"
211            | "qwen2-e2e-verification-v1"
212            | "qwen3-shapes-v1"
213            | "qwen3-e2e-verification-v1"
214            | "qwen3moe-shapes-v1"
215            | "qwen3moe-e2e-verification-v1"
216            | "qwen35-shapes-v1"
217            | "qwen35-hybrid-forward-v1"
218    ) {
219        result.push('E');
220    }
221    result
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn tier1_foundation() {
230        assert_eq!(tier_of("softmax-kernel-v1"), 1);
231        assert_eq!(tier_of("rmsnorm-kernel-v1"), 1);
232        assert_eq!(tier_of("dropout-v1"), 1);
233    }
234
235    #[test]
236    fn tier2_composite() {
237        assert_eq!(tier_of("attention-kernel-v1"), 2);
238        assert_eq!(tier_of("flash-attention-v1"), 2);
239    }
240
241    #[test]
242    fn tier3_system() {
243        assert_eq!(tier_of("kv-cache-equivalence-v1"), 3);
244        assert_eq!(tier_of("sampling-algorithms-v1"), 3);
245    }
246
247    #[test]
248    fn tier4_training() {
249        assert_eq!(tier_of("adamw-kernel-v1"), 4);
250        assert_eq!(tier_of("lora-algebra-v1"), 4);
251    }
252
253    #[test]
254    fn tier5_classical() {
255        assert_eq!(tier_of("kmeans-kernel-v1"), 5);
256        assert_eq!(tier_of("pagerank-kernel-v1"), 5);
257    }
258
259    #[test]
260    fn tier6_model_specific() {
261        assert_eq!(tier_of("qwen2-shapes-v1"), 6);
262        assert_eq!(tier_of("qwen35-shapes-v1"), 6);
263    }
264
265    #[test]
266    fn tier7_default() {
267        assert_eq!(tier_of("some-unknown-contract-v1"), 7);
268        assert_eq!(tier_of("encoder-forward-v1"), 7);
269    }
270
271    #[test]
272    fn class_a_llama() {
273        assert_eq!(class_of("gqa-kernel-v1"), Some('A'));
274        assert_eq!(class_of("rmsnorm-kernel-v1"), Some('A'));
275        assert_eq!(class_of("rope-kernel-v1"), Some('A'));
276    }
277
278    #[test]
279    fn class_b_gpt2() {
280        assert_eq!(class_of("attention-kernel-v1"), Some('B'));
281        assert_eq!(class_of("layernorm-kernel-v1"), Some('B'));
282    }
283
284    #[test]
285    fn class_c_bloom() {
286        assert_eq!(class_of("alibi-kernel-v1"), Some('C'));
287    }
288
289    #[test]
290    fn class_e_qwen() {
291        assert_eq!(class_of("qwen2-shapes-v1"), Some('E'));
292        assert_eq!(class_of("qwen35-shapes-v1"), Some('E'));
293    }
294
295    #[test]
296    fn class_none_for_non_arch() {
297        assert_eq!(class_of("adamw-kernel-v1"), None);
298        assert_eq!(class_of("kmeans-kernel-v1"), None);
299    }
300
301    #[test]
302    fn multi_class_membership() {
303        // gqa-kernel-v1 is in A, D, E
304        let classes = classes_of("gqa-kernel-v1");
305        assert!(classes.contains(&'A'));
306        assert!(classes.contains(&'D'));
307        assert!(classes.contains(&'E'));
308    }
309
310    #[test]
311    fn multi_class_layernorm() {
312        // layernorm is in B, C, D
313        let classes = classes_of("layernorm-kernel-v1");
314        assert!(classes.contains(&'B'));
315        assert!(classes.contains(&'C'));
316        assert!(classes.contains(&'D'));
317    }
318}