Skip to main content

wave_emu/
wave.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Wave state management. A wave contains W threads sharing a program counter.
5//!
6//! Active mask (u64 bitmask) tracks which threads execute each instruction.
7//! Supports wave widths up to 64. Status tracks ready/suspended/halted state.
8
9use crate::control_flow::ControlFlowManager;
10use crate::thread::{SpecialRegisters, Thread};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum WaveStatus {
14    Ready,
15    Suspended,
16    Halted,
17}
18
19#[derive(Debug)]
20pub struct Wave {
21    pub threads: Vec<Thread>,
22    pub pc: u32,
23    pub active_mask: u64,
24    pub status: WaveStatus,
25    pub wave_width: u32,
26    pub wave_id: u32,
27    call_stack: Vec<u32>,
28    pub control_flow: ControlFlowManager,
29}
30
31impl Wave {
32    pub fn new(
33        wave_width: u32,
34        register_count: u32,
35        wave_id: u32,
36        workgroup_id: [u32; 3],
37        workgroup_size: [u32; 3],
38        grid_size: [u32; 3],
39        base_thread_index: u32,
40        total_threads_in_workgroup: u32,
41        num_waves: u32,
42    ) -> Self {
43        let mut threads = Vec::with_capacity(wave_width as usize);
44
45        for lane_id in 0..wave_width {
46            let global_thread_index = base_thread_index + lane_id;
47
48            let thread_id = Self::compute_thread_id(global_thread_index, workgroup_size);
49
50            let special = SpecialRegisters {
51                thread_id,
52                wave_id,
53                lane_id,
54                workgroup_id,
55                workgroup_size,
56                grid_size,
57                wave_width,
58                num_waves,
59            };
60
61            threads.push(Thread::with_special_registers(register_count, special));
62        }
63
64        let active_threads =
65            (total_threads_in_workgroup.saturating_sub(base_thread_index)).min(wave_width);
66        let active_mask = if active_threads >= 64 {
67            u64::MAX
68        } else {
69            (1u64 << active_threads) - 1
70        };
71
72        Self {
73            threads,
74            pc: 0,
75            active_mask,
76            status: WaveStatus::Ready,
77            wave_width,
78            wave_id,
79            call_stack: Vec::with_capacity(8),
80            control_flow: ControlFlowManager::new(),
81        }
82    }
83
84    fn compute_thread_id(linear_index: u32, workgroup_size: [u32; 3]) -> [u32; 3] {
85        let x = linear_index % workgroup_size[0];
86        let y = (linear_index / workgroup_size[0]) % workgroup_size[1];
87        let z = linear_index / (workgroup_size[0] * workgroup_size[1]);
88        [x, y, z]
89    }
90
91    pub fn is_thread_active(&self, lane: u32) -> bool {
92        if lane >= 64 {
93            return false;
94        }
95        (self.active_mask & (1u64 << lane)) != 0
96    }
97
98    pub fn active_thread_count(&self) -> u32 {
99        self.active_mask.count_ones()
100    }
101
102    pub fn set_thread_active(&mut self, lane: u32, active: bool) {
103        if lane < 64 {
104            if active {
105                self.active_mask |= 1u64 << lane;
106            } else {
107                self.active_mask &= !(1u64 << lane);
108            }
109        }
110    }
111
112    pub fn push_call(&mut self, return_pc: u32) -> Result<(), &'static str> {
113        if self.call_stack.len() >= 8 {
114            return Err("call stack overflow");
115        }
116        self.call_stack.push(return_pc);
117        Ok(())
118    }
119
120    pub fn pop_call(&mut self) -> Option<u32> {
121        self.call_stack.pop()
122    }
123
124    pub fn call_depth(&self) -> usize {
125        self.call_stack.len()
126    }
127
128    pub fn halt(&mut self) {
129        self.status = WaveStatus::Halted;
130        self.active_mask = 0;
131    }
132
133    pub fn suspend(&mut self) {
134        self.status = WaveStatus::Suspended;
135    }
136
137    pub fn resume(&mut self) {
138        if self.status == WaveStatus::Suspended {
139            self.status = WaveStatus::Ready;
140        }
141    }
142
143    pub fn is_halted(&self) -> bool {
144        self.status == WaveStatus::Halted
145    }
146
147    pub fn is_ready(&self) -> bool {
148        self.status == WaveStatus::Ready
149    }
150
151    pub fn advance_pc(&mut self, bytes: u32) {
152        self.pc = self.pc.wrapping_add(bytes);
153    }
154
155    pub fn set_pc(&mut self, pc: u32) {
156        self.pc = pc;
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_wave_new() {
166        let wave = Wave::new(32, 32, 0, [0, 0, 0], [64, 1, 1], [1, 1, 1], 0, 64, 2);
167        assert_eq!(wave.threads.len(), 32);
168        assert_eq!(wave.active_mask, 0xFFFF_FFFF);
169        assert_eq!(wave.status, WaveStatus::Ready);
170    }
171
172    #[test]
173    fn test_wave_partial_active() {
174        let wave = Wave::new(32, 32, 1, [0, 0, 0], [48, 1, 1], [1, 1, 1], 32, 48, 2);
175        assert_eq!(wave.active_mask, 0xFFFF);
176        assert_eq!(wave.active_thread_count(), 16);
177    }
178
179    #[test]
180    fn test_wave_thread_ids() {
181        let wave = Wave::new(32, 32, 0, [1, 2, 3], [8, 4, 2], [4, 4, 1], 0, 64, 2);
182
183        assert_eq!(wave.threads[0].special_registers.thread_id, [0, 0, 0]);
184        assert_eq!(wave.threads[8].special_registers.thread_id, [0, 1, 0]);
185        assert_eq!(wave.threads[16].special_registers.thread_id, [0, 2, 0]);
186    }
187
188    #[test]
189    fn test_wave_lane_ids() {
190        let wave = Wave::new(4, 32, 0, [0, 0, 0], [4, 1, 1], [1, 1, 1], 0, 4, 1);
191
192        assert_eq!(wave.threads[0].special_registers.lane_id, 0);
193        assert_eq!(wave.threads[1].special_registers.lane_id, 1);
194        assert_eq!(wave.threads[2].special_registers.lane_id, 2);
195        assert_eq!(wave.threads[3].special_registers.lane_id, 3);
196
197        assert_eq!(wave.threads[0].read_special(4), 0);
198        assert_eq!(wave.threads[1].read_special(4), 1);
199        assert_eq!(wave.threads[2].read_special(4), 2);
200        assert_eq!(wave.threads[3].read_special(4), 3);
201    }
202
203    #[test]
204    fn test_wave_call_stack() {
205        let mut wave = Wave::new(32, 32, 0, [0, 0, 0], [32, 1, 1], [1, 1, 1], 0, 32, 1);
206
207        wave.push_call(0x100).unwrap();
208        wave.push_call(0x200).unwrap();
209
210        assert_eq!(wave.call_depth(), 2);
211        assert_eq!(wave.pop_call(), Some(0x200));
212        assert_eq!(wave.pop_call(), Some(0x100));
213        assert_eq!(wave.pop_call(), None);
214    }
215
216    #[test]
217    fn test_wave_halt() {
218        let mut wave = Wave::new(32, 32, 0, [0, 0, 0], [32, 1, 1], [1, 1, 1], 0, 32, 1);
219        wave.halt();
220
221        assert!(wave.is_halted());
222        assert_eq!(wave.active_mask, 0);
223    }
224
225    #[test]
226    fn test_wave_suspend_resume() {
227        let mut wave = Wave::new(32, 32, 0, [0, 0, 0], [32, 1, 1], [1, 1, 1], 0, 32, 1);
228
229        wave.suspend();
230        assert_eq!(wave.status, WaveStatus::Suspended);
231
232        wave.resume();
233        assert!(wave.is_ready());
234    }
235
236    #[test]
237    fn test_wave_set_thread_active() {
238        let mut wave = Wave::new(32, 32, 0, [0, 0, 0], [32, 1, 1], [1, 1, 1], 0, 32, 1);
239
240        wave.set_thread_active(5, false);
241        assert!(!wave.is_thread_active(5));
242        assert!(wave.is_thread_active(4));
243
244        wave.set_thread_active(5, true);
245        assert!(wave.is_thread_active(5));
246    }
247}