1use crate::Shape;
10use crate::binding_manifest::BindingManifest;
11use crate::component::ModelComponent;
12use crate::hir::{HirModule, HirNodeId, HirOp};
13use crate::lir::LirModule;
14use crate::mir::MirModule;
15use crate::shape::DimBinding;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct HirReflection {
20 pub name: String,
21 pub node_count: usize,
22 pub fusion_policy: String,
23 pub inputs: Vec<(String, Shape)>,
24 pub params: Vec<(String, Shape)>,
25 pub outputs: Vec<Shape>,
26 pub block_labels: Vec<String>,
27}
28
29impl HirReflection {
30 pub fn from_hir(hir: &HirModule) -> Self {
31 let mut inputs = Vec::new();
32 let mut params = Vec::new();
33 let mut block_labels = Vec::new();
34 for node in hir.nodes().iter() {
35 let label = node
36 .name
37 .clone()
38 .unwrap_or_else(|| format!("{:?}", node.op));
39 match &node.op {
40 HirOp::Input { name } => inputs.push((name.clone(), node.shape.clone())),
41 HirOp::Param { name } => params.push((name.clone(), node.shape.clone())),
42 HirOp::LlamaDecoderBlock { .. }
43 | HirOp::SwiGLU
44 | HirOp::Attention { .. }
45 | HirOp::GatedDeltaNet { .. }
46 | HirOp::Qwen35MtpHead { .. } => block_labels.push(label),
47 _ => {}
48 }
49 }
50 let outputs = hir
51 .outputs
52 .iter()
53 .map(|&id| hir.node(id).shape.clone())
54 .collect();
55 HirReflection {
56 name: hir.name.clone(),
57 node_count: hir.nodes().len(),
58 fusion_policy: format!("{:?}", hir.fusion_policy),
59 inputs,
60 params,
61 outputs,
62 block_labels,
63 }
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct MirReflection {
70 pub name: String,
71 pub node_count: usize,
72 pub op_kinds: Vec<(String, usize)>,
73}
74
75impl MirReflection {
76 pub fn from_mir(mir: &MirModule) -> Self {
77 let g = mir.as_graph();
78 let mut counts = std::collections::HashMap::new();
79 for node in g.nodes() {
80 *counts.entry(format!("{:?}", node.op.kind())).or_default() += 1;
81 }
82 let mut op_kinds: Vec<_> = counts.into_iter().collect();
83 op_kinds.sort_by(|a, b| a.0.cmp(&b.0));
84 MirReflection {
85 name: g.name.clone(),
86 node_count: g.nodes().len(),
87 op_kinds,
88 }
89 }
90}
91
92pub fn layout_from_lir(lir: &LirModule) -> BindingManifest {
94 BindingManifest::from_lir(lir)
95}
96
97pub fn layout_for_binding(lir: &LirModule, _component: &ModelComponent) -> BindingManifest {
99 layout_from_lir(lir)
100}
101
102#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct ManifestDiff {
105 pub template_arena: usize,
106 pub specialized_arena: usize,
107 pub params_only_in_template: Vec<String>,
108 pub params_only_in_specialized: Vec<String>,
109}
110
111impl ManifestDiff {
112 pub fn compare(template: &BindingManifest, specialized: &BindingManifest) -> Self {
113 let t: std::collections::HashSet<_> = template.param_names().collect();
114 let s: std::collections::HashSet<_> = specialized.param_names().collect();
115 Self {
116 template_arena: template.arena_size,
117 specialized_arena: specialized.arena_size,
118 params_only_in_template: t.difference(&s).map(|x| (*x).to_string()).collect(),
119 params_only_in_specialized: s.difference(&t).map(|x| (*x).to_string()).collect(),
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum BlockSpecialization {
127 Default,
128 FusedTransformerLayer,
129 UnfusedPrimitives,
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct SpecializeBlockRecord {
135 pub node: HirNodeId,
136 pub label: String,
137 pub choice: BlockSpecialization,
138}
139
140pub fn probe_block_specialization(
142 hir: &HirModule,
143 choice: BlockSpecialization,
144) -> Vec<SpecializeBlockRecord> {
145 hir.nodes()
146 .iter()
147 .filter_map(|node| {
148 let fused = matches!(
149 node.op,
150 HirOp::LlamaDecoderBlock { .. } | HirOp::SwiGLU | HirOp::GatedDeltaNet { .. }
151 );
152 if !fused {
153 return None;
154 }
155 let effective = choice;
156 Some(SpecializeBlockRecord {
157 node: node.id,
158 label: node
159 .name
160 .clone()
161 .unwrap_or_else(|| format!("{:?}", node.op)),
162 choice: effective,
163 })
164 })
165 .collect()
166}
167
168pub fn symbolic_layout_hint(binding: &DimBinding) -> String {
170 format!("DimBinding({:?})", binding)
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::hir::HirMut;
177 use crate::{DType, HirModule};
178
179 #[test]
180 fn hir_reflection_lists_inputs() {
181 let mut hir = HirModule::new("t");
182 let mut gb = HirMut::new(&mut hir);
183 let _x = gb.input("x", Shape::new(&[1, 4], DType::F32));
184 let _w = gb.param("w", Shape::new(&[4, 2], DType::F32));
185 let r = HirReflection::from_hir(&hir);
186 assert_eq!(r.inputs.len(), 1);
187 assert_eq!(r.params.len(), 1);
188 }
189}