1use crate::specialization::SpecMap;
4
5#[derive(Debug, Clone)]
7pub struct DispatchShape {
8 pub id: &'static str,
10 pub workgroup_size: [u32; 3],
12 pub shared_memory_bytes: u32,
14 pub inputs: Vec<&'static str>,
16 pub outputs: Vec<&'static str>,
18 pub specs: SpecMap,
20}
21
22#[derive(Debug, Clone, Copy)]
24pub struct FusionCaps {
25 pub max_shared_memory_bytes: u32,
27 pub max_invocations_per_workgroup: u32,
29}
30
31impl Default for FusionCaps {
32 fn default() -> Self {
33 Self {
34 max_shared_memory_bytes: 16 * 1024,
35 max_invocations_per_workgroup: 256,
36 }
37 }
38}
39
40impl FusionCaps {
41 #[must_use]
43 pub const fn high_end() -> Self {
44 Self {
45 max_shared_memory_bytes: 128 * 1024,
46 max_invocations_per_workgroup: 1024,
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53#[non_exhaustive]
54pub enum FusionDecision {
55 Accept,
57 WorkgroupSizeMismatch {
59 upstream: [u32; 3],
61 downstream: [u32; 3],
63 },
64 SharedMemoryBudget {
66 needed: u64,
68 cap: u32,
70 },
71 OutputConsumedElsewhere,
73 NoPipelineDependency,
75}
76
77pub struct FusionPass;
79
80impl FusionPass {
81 #[must_use]
83 pub fn decide(
84 upstream: &DispatchShape,
85 downstream: &DispatchShape,
86 caps: FusionCaps,
87 other_consumers: &[&str],
88 ) -> FusionDecision {
89 if upstream.workgroup_size != downstream.workgroup_size {
90 return FusionDecision::WorkgroupSizeMismatch {
91 upstream: upstream.workgroup_size,
92 downstream: downstream.workgroup_size,
93 };
94 }
95 let invocations = u128::from(upstream.workgroup_size[0])
96 * u128::from(upstream.workgroup_size[1])
97 * u128::from(upstream.workgroup_size[2]);
98 if invocations > u128::from(caps.max_invocations_per_workgroup) {
99 return FusionDecision::WorkgroupSizeMismatch {
100 upstream: upstream.workgroup_size,
101 downstream: downstream.workgroup_size,
102 };
103 }
104 let needed =
105 u64::from(upstream.shared_memory_bytes) + u64::from(downstream.shared_memory_bytes);
106 if needed > u64::from(caps.max_shared_memory_bytes) {
107 return FusionDecision::SharedMemoryBudget {
108 needed,
109 cap: caps.max_shared_memory_bytes,
110 };
111 }
112
113 let mut has_pipeline_dependency = false;
114 for output in &upstream.outputs {
115 if !downstream.inputs.iter().any(|input| input == output) {
116 continue;
117 }
118 has_pipeline_dependency = true;
119 if other_consumers.iter().any(|consumer| consumer == output) {
120 return FusionDecision::OutputConsumedElsewhere;
121 }
122 }
123 if !has_pipeline_dependency {
124 return FusionDecision::NoPipelineDependency;
125 }
126 FusionDecision::Accept
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 fn dispatch(
135 id: &'static str,
136 inputs: &[&'static str],
137 outputs: &[&'static str],
138 ) -> DispatchShape {
139 DispatchShape {
140 id,
141 workgroup_size: [64, 1, 1],
142 shared_memory_bytes: 1024,
143 inputs: inputs.to_vec(),
144 outputs: outputs.to_vec(),
145 specs: SpecMap::new(),
146 }
147 }
148
149 #[test]
150 fn straight_producer_consumer_fuses() {
151 let up = dispatch("load", &["in"], &["stage"]);
152 let down = dispatch("xor", &["stage"], &["out"]);
153 assert_eq!(
154 FusionPass::decide(&up, &down, FusionCaps::high_end(), &[]),
155 FusionDecision::Accept
156 );
157 }
158
159 #[test]
160 fn third_consumer_rejects() {
161 let up = dispatch("a", &[], &["x"]);
162 let down = dispatch("b", &["x"], &[]);
163 assert_eq!(
164 FusionPass::decide(&up, &down, FusionCaps::high_end(), &["x"]),
165 FusionDecision::OutputConsumedElsewhere
166 );
167 }
168
169 #[test]
170 fn workgroup_invocation_overflow_rejects_instead_of_wrapping_or_clamping() {
171 let mut up = dispatch("wide-a", &["in"], &["stage"]);
172 up.workgroup_size = [u32::MAX, u32::MAX, 2];
173 let mut down = dispatch("wide-b", &["stage"], &["out"]);
174 down.workgroup_size = up.workgroup_size;
175 assert_eq!(
176 FusionPass::decide(&up, &down, FusionCaps::high_end(), &[]),
177 FusionDecision::WorkgroupSizeMismatch {
178 upstream: up.workgroup_size,
179 downstream: down.workgroup_size,
180 }
181 );
182 }
183
184 #[test]
185 fn shared_memory_overflow_rejects_instead_of_appearing_under_cap() {
186 let mut up = dispatch("smem-a", &["in"], &["stage"]);
187 up.shared_memory_bytes = u32::MAX;
188 let mut down = dispatch("smem-b", &["stage"], &["out"]);
189 down.shared_memory_bytes = 1;
190 assert_eq!(
191 FusionPass::decide(&up, &down, FusionCaps::high_end(), &[]),
192 FusionDecision::SharedMemoryBudget {
193 needed: u64::from(u32::MAX) + 1,
194 cap: FusionCaps::high_end().max_shared_memory_bytes,
195 }
196 );
197 }
198
199 #[test]
200 fn source_has_no_clamped_fusion_admission_math() {
201 let source = include_str!("fusion.rs");
202 assert!(
203 !source.contains(concat!(".", "saturating_")),
204 "fusion admission must use widened exact arithmetic, not silent clamps"
205 );
206 }
207}