1use crate::control_flow::ControlFlowManager;
10use crate::thread::{SpecialRegisters, Thread};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum WaveStatus {
14 Ready,
15 Suspended,
16 Halted,
17}
18
19#[derive(Debug)]
20pub struct Wave {
21 pub threads: Vec<Thread>,
22 pub pc: u32,
23 pub active_mask: u64,
24 pub status: WaveStatus,
25 pub wave_width: u32,
26 pub wave_id: u32,
27 call_stack: Vec<u32>,
28 pub control_flow: ControlFlowManager,
29}
30
31impl Wave {
32 pub fn new(
33 wave_width: u32,
34 register_count: u32,
35 wave_id: u32,
36 workgroup_id: [u32; 3],
37 workgroup_size: [u32; 3],
38 grid_size: [u32; 3],
39 base_thread_index: u32,
40 total_threads_in_workgroup: u32,
41 num_waves: u32,
42 ) -> Self {
43 let mut threads = Vec::with_capacity(wave_width as usize);
44
45 for lane_id in 0..wave_width {
46 let global_thread_index = base_thread_index + lane_id;
47
48 let thread_id = Self::compute_thread_id(global_thread_index, workgroup_size);
49
50 let special = SpecialRegisters {
51 thread_id,
52 wave_id,
53 lane_id,
54 workgroup_id,
55 workgroup_size,
56 grid_size,
57 wave_width,
58 num_waves,
59 };
60
61 threads.push(Thread::with_special_registers(register_count, special));
62 }
63
64 let active_threads =
65 (total_threads_in_workgroup.saturating_sub(base_thread_index)).min(wave_width);
66 let active_mask = if active_threads >= 64 {
67 u64::MAX
68 } else {
69 (1u64 << active_threads) - 1
70 };
71
72 Self {
73 threads,
74 pc: 0,
75 active_mask,
76 status: WaveStatus::Ready,
77 wave_width,
78 wave_id,
79 call_stack: Vec::with_capacity(8),
80 control_flow: ControlFlowManager::new(),
81 }
82 }
83
84 fn compute_thread_id(linear_index: u32, workgroup_size: [u32; 3]) -> [u32; 3] {
85 let x = linear_index % workgroup_size[0];
86 let y = (linear_index / workgroup_size[0]) % workgroup_size[1];
87 let z = linear_index / (workgroup_size[0] * workgroup_size[1]);
88 [x, y, z]
89 }
90
91 pub fn is_thread_active(&self, lane: u32) -> bool {
92 if lane >= 64 {
93 return false;
94 }
95 (self.active_mask & (1u64 << lane)) != 0
96 }
97
98 pub fn active_thread_count(&self) -> u32 {
99 self.active_mask.count_ones()
100 }
101
102 pub fn set_thread_active(&mut self, lane: u32, active: bool) {
103 if lane < 64 {
104 if active {
105 self.active_mask |= 1u64 << lane;
106 } else {
107 self.active_mask &= !(1u64 << lane);
108 }
109 }
110 }
111
112 pub fn push_call(&mut self, return_pc: u32) -> Result<(), &'static str> {
113 if self.call_stack.len() >= 8 {
114 return Err("call stack overflow");
115 }
116 self.call_stack.push(return_pc);
117 Ok(())
118 }
119
120 pub fn pop_call(&mut self) -> Option<u32> {
121 self.call_stack.pop()
122 }
123
124 pub fn call_depth(&self) -> usize {
125 self.call_stack.len()
126 }
127
128 pub fn halt(&mut self) {
129 self.status = WaveStatus::Halted;
130 self.active_mask = 0;
131 }
132
133 pub fn suspend(&mut self) {
134 self.status = WaveStatus::Suspended;
135 }
136
137 pub fn resume(&mut self) {
138 if self.status == WaveStatus::Suspended {
139 self.status = WaveStatus::Ready;
140 }
141 }
142
143 pub fn is_halted(&self) -> bool {
144 self.status == WaveStatus::Halted
145 }
146
147 pub fn is_ready(&self) -> bool {
148 self.status == WaveStatus::Ready
149 }
150
151 pub fn advance_pc(&mut self, bytes: u32) {
152 self.pc = self.pc.wrapping_add(bytes);
153 }
154
155 pub fn set_pc(&mut self, pc: u32) {
156 self.pc = pc;
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_wave_new() {
166 let wave = Wave::new(32, 32, 0, [0, 0, 0], [64, 1, 1], [1, 1, 1], 0, 64, 2);
167 assert_eq!(wave.threads.len(), 32);
168 assert_eq!(wave.active_mask, 0xFFFF_FFFF);
169 assert_eq!(wave.status, WaveStatus::Ready);
170 }
171
172 #[test]
173 fn test_wave_partial_active() {
174 let wave = Wave::new(32, 32, 1, [0, 0, 0], [48, 1, 1], [1, 1, 1], 32, 48, 2);
175 assert_eq!(wave.active_mask, 0xFFFF);
176 assert_eq!(wave.active_thread_count(), 16);
177 }
178
179 #[test]
180 fn test_wave_thread_ids() {
181 let wave = Wave::new(32, 32, 0, [1, 2, 3], [8, 4, 2], [4, 4, 1], 0, 64, 2);
182
183 assert_eq!(wave.threads[0].special_registers.thread_id, [0, 0, 0]);
184 assert_eq!(wave.threads[8].special_registers.thread_id, [0, 1, 0]);
185 assert_eq!(wave.threads[16].special_registers.thread_id, [0, 2, 0]);
186 }
187
188 #[test]
189 fn test_wave_lane_ids() {
190 let wave = Wave::new(4, 32, 0, [0, 0, 0], [4, 1, 1], [1, 1, 1], 0, 4, 1);
191
192 assert_eq!(wave.threads[0].special_registers.lane_id, 0);
193 assert_eq!(wave.threads[1].special_registers.lane_id, 1);
194 assert_eq!(wave.threads[2].special_registers.lane_id, 2);
195 assert_eq!(wave.threads[3].special_registers.lane_id, 3);
196
197 assert_eq!(wave.threads[0].read_special(4), 0);
198 assert_eq!(wave.threads[1].read_special(4), 1);
199 assert_eq!(wave.threads[2].read_special(4), 2);
200 assert_eq!(wave.threads[3].read_special(4), 3);
201 }
202
203 #[test]
204 fn test_wave_call_stack() {
205 let mut wave = Wave::new(32, 32, 0, [0, 0, 0], [32, 1, 1], [1, 1, 1], 0, 32, 1);
206
207 wave.push_call(0x100).unwrap();
208 wave.push_call(0x200).unwrap();
209
210 assert_eq!(wave.call_depth(), 2);
211 assert_eq!(wave.pop_call(), Some(0x200));
212 assert_eq!(wave.pop_call(), Some(0x100));
213 assert_eq!(wave.pop_call(), None);
214 }
215
216 #[test]
217 fn test_wave_halt() {
218 let mut wave = Wave::new(32, 32, 0, [0, 0, 0], [32, 1, 1], [1, 1, 1], 0, 32, 1);
219 wave.halt();
220
221 assert!(wave.is_halted());
222 assert_eq!(wave.active_mask, 0);
223 }
224
225 #[test]
226 fn test_wave_suspend_resume() {
227 let mut wave = Wave::new(32, 32, 0, [0, 0, 0], [32, 1, 1], [1, 1, 1], 0, 32, 1);
228
229 wave.suspend();
230 assert_eq!(wave.status, WaveStatus::Suspended);
231
232 wave.resume();
233 assert!(wave.is_ready());
234 }
235
236 #[test]
237 fn test_wave_set_thread_active() {
238 let mut wave = Wave::new(32, 32, 0, [0, 0, 0], [32, 1, 1], [1, 1, 1], 0, 32, 1);
239
240 wave.set_thread_active(5, false);
241 assert!(!wave.is_thread_active(5));
242 assert!(wave.is_thread_active(4));
243
244 wave.set_thread_active(5, true);
245 assert!(wave.is_thread_active(5));
246 }
247}