1use crate::Shape;
22use crate::binding_manifest::BindingManifest;
23use crate::component::ModelComponent;
24use crate::hir::{HirModule, HirNodeId, HirOp};
25use crate::lir::LirModule;
26use crate::mir::MirModule;
27use crate::shape::DimBinding;
28
29#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct HirReflection {
32 pub name: String,
33 pub node_count: usize,
34 pub fusion_policy: String,
35 pub inputs: Vec<(String, Shape)>,
36 pub params: Vec<(String, Shape)>,
37 pub outputs: Vec<Shape>,
38 pub block_labels: Vec<String>,
39}
40
41impl HirReflection {
42 pub fn from_hir(hir: &HirModule) -> Self {
43 let mut inputs = Vec::new();
44 let mut params = Vec::new();
45 let mut block_labels = Vec::new();
46 for node in hir.nodes().iter() {
47 let label = node
48 .name
49 .clone()
50 .unwrap_or_else(|| format!("{:?}", node.op));
51 match &node.op {
52 HirOp::Input { name } => inputs.push((name.clone(), node.shape.clone())),
53 HirOp::Param { name } => params.push((name.clone(), node.shape.clone())),
54 HirOp::LlamaDecoderBlock { .. }
55 | HirOp::SwiGLU
56 | HirOp::Attention { .. }
57 | HirOp::GatedDeltaNet { .. }
58 | HirOp::Qwen35MtpHead { .. } => block_labels.push(label),
59 _ => {}
60 }
61 }
62 let outputs = hir
63 .outputs
64 .iter()
65 .map(|&id| hir.node(id).shape.clone())
66 .collect();
67 HirReflection {
68 name: hir.name.clone(),
69 node_count: hir.nodes().len(),
70 fusion_policy: format!("{:?}", hir.fusion_policy),
71 inputs,
72 params,
73 outputs,
74 block_labels,
75 }
76 }
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
81pub struct MirReflection {
82 pub name: String,
83 pub node_count: usize,
84 pub op_kinds: Vec<(String, usize)>,
85}
86
87impl MirReflection {
88 pub fn from_mir(mir: &MirModule) -> Self {
89 let g = mir.as_graph();
90 let mut counts = std::collections::HashMap::new();
91 for node in g.nodes() {
92 *counts.entry(format!("{:?}", node.op.kind())).or_default() += 1;
93 }
94 let mut op_kinds: Vec<_> = counts.into_iter().collect();
95 op_kinds.sort_by(|a, b| a.0.cmp(&b.0));
96 MirReflection {
97 name: g.name.clone(),
98 node_count: g.nodes().len(),
99 op_kinds,
100 }
101 }
102}
103
104pub fn layout_from_lir(lir: &LirModule) -> BindingManifest {
106 BindingManifest::from_lir(lir)
107}
108
109pub fn layout_for_binding(lir: &LirModule, _component: &ModelComponent) -> BindingManifest {
111 layout_from_lir(lir)
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
116pub struct ManifestDiff {
117 pub template_arena: usize,
118 pub specialized_arena: usize,
119 pub params_only_in_template: Vec<String>,
120 pub params_only_in_specialized: Vec<String>,
121}
122
123impl ManifestDiff {
124 pub fn compare(template: &BindingManifest, specialized: &BindingManifest) -> Self {
125 let t: std::collections::HashSet<_> = template.param_names().collect();
126 let s: std::collections::HashSet<_> = specialized.param_names().collect();
127 Self {
128 template_arena: template.arena_size,
129 specialized_arena: specialized.arena_size,
130 params_only_in_template: t.difference(&s).map(|x| (*x).to_string()).collect(),
131 params_only_in_specialized: s.difference(&t).map(|x| (*x).to_string()).collect(),
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum BlockSpecialization {
139 Default,
140 FusedTransformerLayer,
141 UnfusedPrimitives,
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
146pub struct SpecializeBlockRecord {
147 pub node: HirNodeId,
148 pub label: String,
149 pub choice: BlockSpecialization,
150}
151
152pub fn probe_block_specialization(
154 hir: &HirModule,
155 choice: BlockSpecialization,
156) -> Vec<SpecializeBlockRecord> {
157 hir.nodes()
158 .iter()
159 .filter_map(|node| {
160 let fused = matches!(
161 node.op,
162 HirOp::LlamaDecoderBlock { .. } | HirOp::SwiGLU | HirOp::GatedDeltaNet { .. }
163 );
164 if !fused {
165 return None;
166 }
167 let effective = choice;
168 Some(SpecializeBlockRecord {
169 node: node.id,
170 label: node
171 .name
172 .clone()
173 .unwrap_or_else(|| format!("{:?}", node.op)),
174 choice: effective,
175 })
176 })
177 .collect()
178}
179
180pub fn symbolic_layout_hint(binding: &DimBinding) -> String {
182 format!("DimBinding({:?})", binding)
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::hir::HirMut;
189 use crate::{DType, HirModule};
190
191 #[test]
192 fn hir_reflection_lists_inputs() {
193 let mut hir = HirModule::new("t");
194 let mut gb = HirMut::new(&mut hir);
195 let _x = gb.input("x", Shape::new(&[1, 4], DType::F32));
196 let _w = gb.param("w", Shape::new(&[4, 2], DType::F32));
197 let r = HirReflection::from_hir(&hir);
198 assert_eq!(r.inputs.len(), 1);
199 assert_eq!(r.params.len(), 1);
200 }
201}