Skip to main content

wave_emu/
executor.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Instruction executor. Decodes instructions from binary, dispatches to appropriate
5//!
6//! handlers, and executes operations across active threads in a wave. Handles ALU
7//! operations, memory access, control flow, wave operations, and atomics.
8
9// ControlFlowManager is now in Wave, not Executor
10use crate::decoder::{
11    AtomicOp, BitOpType, CmpOp, ControlOp, CvtType, DecodedInstruction, Decoder, F16Op,
12    F16PackedOp, F64DivSqrtOp, F64Op, FUnaryOp, MemWidth, MiscOp, Opcode, SyncOp, WaveOpType,
13    WaveReduceType,
14};
15use crate::memory::{DeviceMemory, LocalMemory};
16use crate::shuffle;
17use crate::stats::{ExecutionStats, InstructionCategory, TraceWriter};
18use crate::wave::Wave;
19use crate::EmulatorError;
20use half::f16;
21
22pub struct Executor<'a> {
23    decoder: Decoder<'a>,
24    trace: TraceWriter,
25    workgroup_id: [u32; 3],
26}
27
28impl<'a> Executor<'a> {
29    pub fn new(code: &'a [u8], trace_enabled: bool, workgroup_id: [u32; 3]) -> Self {
30        Self {
31            decoder: Decoder::new(code),
32            trace: TraceWriter::new(trace_enabled),
33            workgroup_id,
34        }
35    }
36
37    pub fn step(
38        &mut self,
39        wave: &mut Wave,
40        local_memory: &mut LocalMemory,
41        device_memory: &mut DeviceMemory,
42        stats: &mut ExecutionStats,
43    ) -> Result<StepResult, EmulatorError> {
44        if wave.is_halted() {
45            return Ok(StepResult::Halted);
46        }
47
48        if wave.active_mask == 0 && wave.control_flow.is_empty() {
49            wave.halt();
50            return Ok(StepResult::Halted);
51        }
52
53        let inst = self.decoder.decode_at(wave.pc)?;
54
55        if self.trace.is_enabled() {
56            let disasm = self.decoder.disassemble(&inst);
57            self.trace
58                .trace_instruction(self.workgroup_id, wave.wave_id, wave.pc, &disasm);
59        }
60
61        let result = self.execute_instruction(wave, &inst, local_memory, device_memory, stats)?;
62
63        match result {
64            ExecuteResult::Continue => {
65                wave.advance_pc(inst.size);
66                Ok(StepResult::Continue)
67            }
68            ExecuteResult::Jump(target) => {
69                wave.set_pc(target);
70                Ok(StepResult::Continue)
71            }
72            ExecuteResult::Halt => {
73                wave.halt();
74                Ok(StepResult::Halted)
75            }
76            ExecuteResult::Barrier => Ok(StepResult::Barrier),
77        }
78    }
79
80    fn execute_instruction(
81        &mut self,
82        wave: &mut Wave,
83        inst: &DecodedInstruction,
84        local_memory: &mut LocalMemory,
85        device_memory: &mut DeviceMemory,
86        stats: &mut ExecutionStats,
87    ) -> Result<ExecuteResult, EmulatorError> {
88        let original_mask = wave.active_mask;
89        let is_control_sync = inst.opcode == Opcode::Control && inst.is_sync_op();
90        let is_halt = is_control_sync && inst.modifier == SyncOp::Halt as u8;
91
92        if inst.is_predicated() {
93            let pred_mask = self.compute_predicate_mask(wave, inst.pred_reg, inst.pred_neg);
94
95            if is_halt {
96                wave.active_mask &= !pred_mask;
97                if wave.active_mask == 0 {
98                    return Ok(ExecuteResult::Halt);
99                }
100                return Ok(ExecuteResult::Continue);
101            } else if is_control_sync {
102                if pred_mask == 0 {
103                    return Ok(ExecuteResult::Continue);
104                }
105            } else {
106                wave.active_mask &= pred_mask;
107            }
108        }
109
110        let result = match inst.opcode {
111            Opcode::Iadd
112            | Opcode::Isub
113            | Opcode::Imul
114            | Opcode::ImulHi
115            | Opcode::Idiv
116            | Opcode::Imod
117            | Opcode::Ineg
118            | Opcode::Iabs
119            | Opcode::Imin
120            | Opcode::Imax => {
121                self.execute_integer_op(wave, inst);
122                stats.record_instruction(InstructionCategory::Integer);
123                Ok(ExecuteResult::Continue)
124            }
125            Opcode::Imad | Opcode::Iclamp => {
126                self.execute_integer_extended(wave, inst);
127                stats.record_instruction(InstructionCategory::Integer);
128                Ok(ExecuteResult::Continue)
129            }
130            Opcode::And
131            | Opcode::Or
132            | Opcode::Xor
133            | Opcode::Not
134            | Opcode::Shl
135            | Opcode::Shr
136            | Opcode::Sar => {
137                self.execute_bitwise_op(wave, inst);
138                stats.record_instruction(InstructionCategory::Integer);
139                Ok(ExecuteResult::Continue)
140            }
141            Opcode::BitOps => {
142                self.execute_bit_ops(wave, inst);
143                stats.record_instruction(InstructionCategory::Integer);
144                Ok(ExecuteResult::Continue)
145            }
146            Opcode::Fadd
147            | Opcode::Fsub
148            | Opcode::Fmul
149            | Opcode::Fdiv
150            | Opcode::Fneg
151            | Opcode::Fabs
152            | Opcode::Fmin
153            | Opcode::Fmax
154            | Opcode::Fsqrt => {
155                self.execute_float_op(wave, inst);
156                stats.record_instruction(InstructionCategory::Float);
157                Ok(ExecuteResult::Continue)
158            }
159            Opcode::Fma | Opcode::Fclamp => {
160                self.execute_float_extended(wave, inst);
161                stats.record_instruction(InstructionCategory::Float);
162                Ok(ExecuteResult::Continue)
163            }
164            Opcode::FUnaryOps => {
165                self.execute_float_unary(wave, inst);
166                stats.record_instruction(InstructionCategory::Float);
167                Ok(ExecuteResult::Continue)
168            }
169            Opcode::F16Ops | Opcode::F16PackedOps => {
170                self.execute_f16_op(wave, inst);
171                stats.record_instruction(InstructionCategory::Float);
172                Ok(ExecuteResult::Continue)
173            }
174            Opcode::F64Ops | Opcode::F64DivSqrt => {
175                self.execute_f64_op(wave, inst);
176                stats.record_instruction(InstructionCategory::Float);
177                Ok(ExecuteResult::Continue)
178            }
179            Opcode::Icmp | Opcode::Ucmp | Opcode::Fcmp => {
180                self.execute_compare(wave, inst);
181                stats.record_instruction(InstructionCategory::Integer);
182                Ok(ExecuteResult::Continue)
183            }
184            Opcode::Select => {
185                self.execute_select(wave, inst);
186                stats.record_instruction(InstructionCategory::Integer);
187                Ok(ExecuteResult::Continue)
188            }
189            Opcode::Cvt => {
190                self.execute_convert(wave, inst);
191                stats.record_instruction(InstructionCategory::Float);
192                Ok(ExecuteResult::Continue)
193            }
194            Opcode::LocalLoad => {
195                self.execute_local_load(wave, inst, local_memory, stats)?;
196                Ok(ExecuteResult::Continue)
197            }
198            Opcode::LocalStore => {
199                self.execute_local_store(wave, inst, local_memory, stats)?;
200                Ok(ExecuteResult::Continue)
201            }
202            Opcode::DeviceLoad => {
203                self.execute_device_load(wave, inst, device_memory, stats)?;
204                Ok(ExecuteResult::Continue)
205            }
206            Opcode::DeviceStore => {
207                self.execute_device_store(wave, inst, device_memory, stats)?;
208                Ok(ExecuteResult::Continue)
209            }
210            Opcode::LocalAtomic => {
211                self.execute_local_atomic(wave, inst, local_memory, stats)?;
212                Ok(ExecuteResult::Continue)
213            }
214            Opcode::DeviceAtomic => {
215                self.execute_device_atomic(wave, inst, device_memory, stats)?;
216                Ok(ExecuteResult::Continue)
217            }
218            Opcode::WaveOp => {
219                self.execute_wave_op(wave, inst);
220                stats.record_instruction(InstructionCategory::WaveOp);
221                Ok(ExecuteResult::Continue)
222            }
223            Opcode::Control => self.execute_control(wave, inst, stats),
224        };
225
226        if inst.is_predicated() && !is_control_sync {
227            wave.active_mask = original_mask;
228        }
229
230        result
231    }
232
233    fn compute_predicate_mask(&self, wave: &Wave, pred_reg: u8, negated: bool) -> u64 {
234        let mut mask: u64 = 0;
235        for lane in 0..wave.wave_width {
236            if wave.is_thread_active(lane) {
237                let pred = wave.threads[lane as usize].read_predicate(pred_reg);
238                let value = if negated { !pred } else { pred };
239                if value {
240                    mask |= 1u64 << lane;
241                }
242            }
243        }
244        mask
245    }
246
247    fn execute_integer_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
248        for lane in 0..wave.wave_width {
249            if !wave.is_thread_active(lane) {
250                continue;
251            }
252
253            let thread = &mut wave.threads[lane as usize];
254            let rs1 = thread.read_register(inst.rs1);
255            let rs2 = thread.read_register(inst.rs2);
256
257            let result = match inst.opcode {
258                Opcode::Iadd => rs1.wrapping_add(rs2),
259                Opcode::Isub => rs1.wrapping_sub(rs2),
260                Opcode::Imul => rs1.wrapping_mul(rs2),
261                Opcode::ImulHi => {
262                    let wide = (rs1 as i64).wrapping_mul(rs2 as i64);
263                    (wide >> 32) as u32
264                }
265                Opcode::Idiv => {
266                    if rs2 == 0 {
267                        0
268                    } else {
269                        (rs1 as i32).wrapping_div(rs2 as i32) as u32
270                    }
271                }
272                Opcode::Imod => {
273                    if rs2 == 0 {
274                        0
275                    } else {
276                        (rs1 as i32).wrapping_rem(rs2 as i32) as u32
277                    }
278                }
279                Opcode::Ineg => (-(rs1 as i32)) as u32,
280                Opcode::Iabs => (rs1 as i32).unsigned_abs(),
281                Opcode::Imin => (rs1 as i32).min(rs2 as i32) as u32,
282                Opcode::Imax => (rs1 as i32).max(rs2 as i32) as u32,
283                _ => 0,
284            };
285
286            thread.write_register(inst.rd, result);
287        }
288    }
289
290    fn execute_integer_extended(&self, wave: &mut Wave, inst: &DecodedInstruction) {
291        for lane in 0..wave.wave_width {
292            if !wave.is_thread_active(lane) {
293                continue;
294            }
295
296            let thread = &mut wave.threads[lane as usize];
297            let rs1 = thread.read_register(inst.rs1);
298            let rs2 = thread.read_register(inst.rs2);
299            let rs3 = thread.read_register(inst.rs3);
300
301            let result = match inst.opcode {
302                Opcode::Imad => rs1.wrapping_mul(rs2).wrapping_add(rs3),
303                Opcode::Iclamp => {
304                    let val = rs1 as i32;
305                    let lo = rs2 as i32;
306                    let hi = rs3 as i32;
307                    val.clamp(lo, hi) as u32
308                }
309                _ => 0,
310            };
311
312            thread.write_register(inst.rd, result);
313        }
314    }
315
316    fn execute_bitwise_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
317        for lane in 0..wave.wave_width {
318            if !wave.is_thread_active(lane) {
319                continue;
320            }
321
322            let thread = &mut wave.threads[lane as usize];
323            let rs1 = thread.read_register(inst.rs1);
324            let rs2 = thread.read_register(inst.rs2);
325
326            let result = match inst.opcode {
327                Opcode::And => rs1 & rs2,
328                Opcode::Or => rs1 | rs2,
329                Opcode::Xor => rs1 ^ rs2,
330                Opcode::Not => !rs1,
331                Opcode::Shl => rs1.wrapping_shl(rs2 & 0x1F),
332                Opcode::Shr => rs1.wrapping_shr(rs2 & 0x1F),
333                Opcode::Sar => ((rs1 as i32).wrapping_shr(rs2 & 0x1F)) as u32,
334                _ => 0,
335            };
336
337            thread.write_register(inst.rd, result);
338        }
339    }
340
341    fn execute_bit_ops(&self, wave: &mut Wave, inst: &DecodedInstruction) {
342        for lane in 0..wave.wave_width {
343            if !wave.is_thread_active(lane) {
344                continue;
345            }
346
347            let thread = &mut wave.threads[lane as usize];
348            let rs1 = thread.read_register(inst.rs1);
349            let rs2 = thread.read_register(inst.rs2);
350            let rs3 = thread.read_register(inst.rs3);
351            let rs4 = thread.read_register(inst.rs4);
352
353            let result = match inst.modifier {
354                m if m == BitOpType::Bitcount as u8 => rs1.count_ones(),
355                m if m == BitOpType::Bitfind as u8 => {
356                    if rs1 == 0 {
357                        u32::MAX
358                    } else {
359                        rs1.leading_zeros()
360                    }
361                }
362                m if m == BitOpType::Bitrev as u8 => rs1.reverse_bits(),
363                m if m == BitOpType::Bfe as u8 => {
364                    let offset = rs2 & 0x1F;
365                    let width = rs3 & 0x1F;
366                    if width == 0 {
367                        0
368                    } else {
369                        (rs1 >> offset) & ((1 << width) - 1)
370                    }
371                }
372                m if m == BitOpType::Bfi as u8 => {
373                    let offset = rs3 & 0x1F;
374                    let width = rs4 & 0x1F;
375                    if width == 0 {
376                        rs1
377                    } else {
378                        let mask = ((1u32 << width) - 1) << offset;
379                        (rs1 & !mask) | ((rs2 << offset) & mask)
380                    }
381                }
382                _ => 0,
383            };
384
385            thread.write_register(inst.rd, result);
386        }
387    }
388
389    fn execute_float_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
390        for lane in 0..wave.wave_width {
391            if !wave.is_thread_active(lane) {
392                continue;
393            }
394
395            let thread = &mut wave.threads[lane as usize];
396            let rs1 = f32::from_bits(thread.read_register(inst.rs1));
397            let rs2 = f32::from_bits(thread.read_register(inst.rs2));
398
399            let result = match inst.opcode {
400                Opcode::Fadd => rs1 + rs2,
401                Opcode::Fsub => rs1 - rs2,
402                Opcode::Fmul => rs1 * rs2,
403                Opcode::Fdiv => {
404                    if rs2 == 0.0 {
405                        f32::INFINITY
406                    } else {
407                        rs1 / rs2
408                    }
409                }
410                Opcode::Fneg => -rs1,
411                Opcode::Fabs => rs1.abs(),
412                Opcode::Fmin => rs1.min(rs2),
413                Opcode::Fmax => rs1.max(rs2),
414                Opcode::Fsqrt => rs1.sqrt(),
415                _ => 0.0,
416            };
417
418            thread.write_register(inst.rd, result.to_bits());
419        }
420    }
421
422    fn execute_float_extended(&self, wave: &mut Wave, inst: &DecodedInstruction) {
423        for lane in 0..wave.wave_width {
424            if !wave.is_thread_active(lane) {
425                continue;
426            }
427
428            let thread = &mut wave.threads[lane as usize];
429            let rs1 = f32::from_bits(thread.read_register(inst.rs1));
430            let rs2 = f32::from_bits(thread.read_register(inst.rs2));
431            let rs3 = f32::from_bits(thread.read_register(inst.rs3));
432
433            let result = match inst.opcode {
434                Opcode::Fma => rs1.mul_add(rs2, rs3),
435                Opcode::Fclamp => rs1.clamp(rs2, rs3),
436                _ => 0.0,
437            };
438
439            thread.write_register(inst.rd, result.to_bits());
440        }
441    }
442
443    fn execute_float_unary(&self, wave: &mut Wave, inst: &DecodedInstruction) {
444        for lane in 0..wave.wave_width {
445            if !wave.is_thread_active(lane) {
446                continue;
447            }
448
449            let thread = &mut wave.threads[lane as usize];
450            let rs1 = f32::from_bits(thread.read_register(inst.rs1));
451
452            let result = match inst.modifier {
453                m if m == FUnaryOp::Frsqrt as u8 => 1.0 / rs1.sqrt(),
454                m if m == FUnaryOp::Frcp as u8 => 1.0 / rs1,
455                m if m == FUnaryOp::Ffloor as u8 => rs1.floor(),
456                m if m == FUnaryOp::Fceil as u8 => rs1.ceil(),
457                m if m == FUnaryOp::Fround as u8 => rs1.round(),
458                m if m == FUnaryOp::Ftrunc as u8 => rs1.trunc(),
459                m if m == FUnaryOp::Ffract as u8 => rs1.fract(),
460                m if m == FUnaryOp::Fsat as u8 => rs1.clamp(0.0, 1.0),
461                m if m == FUnaryOp::Fsin as u8 => rs1.sin(),
462                m if m == FUnaryOp::Fcos as u8 => rs1.cos(),
463                m if m == FUnaryOp::Fexp2 as u8 => rs1.exp2(),
464                m if m == FUnaryOp::Flog2 as u8 => rs1.log2(),
465                _ => 0.0,
466            };
467
468            thread.write_register(inst.rd, result.to_bits());
469        }
470    }
471
472    fn execute_f16_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
473        for lane in 0..wave.wave_width {
474            if !wave.is_thread_active(lane) {
475                continue;
476            }
477
478            let thread = &mut wave.threads[lane as usize];
479            let rs1_bits = thread.read_register(inst.rs1);
480            let rs2_bits = thread.read_register(inst.rs2);
481            let rs3_bits = thread.read_register(inst.rs3);
482
483            let result = if inst.opcode == Opcode::F16Ops {
484                let a = f16::from_bits(rs1_bits as u16);
485                let b = f16::from_bits(rs2_bits as u16);
486                let c = f16::from_bits(rs3_bits as u16);
487
488                let r = match inst.modifier {
489                    m if m == F16Op::Hadd as u8 => f16::from_f32(a.to_f32() + b.to_f32()),
490                    m if m == F16Op::Hsub as u8 => f16::from_f32(a.to_f32() - b.to_f32()),
491                    m if m == F16Op::Hmul as u8 => f16::from_f32(a.to_f32() * b.to_f32()),
492                    m if m == F16Op::Hma as u8 => {
493                        f16::from_f32(a.to_f32().mul_add(b.to_f32(), c.to_f32()))
494                    }
495                    _ => f16::ZERO,
496                };
497                u32::from(r.to_bits())
498            } else {
499                let a_lo = f16::from_bits(rs1_bits as u16);
500                let a_hi = f16::from_bits((rs1_bits >> 16) as u16);
501                let b_lo = f16::from_bits(rs2_bits as u16);
502                let b_hi = f16::from_bits((rs2_bits >> 16) as u16);
503                let c_lo = f16::from_bits(rs3_bits as u16);
504                let c_hi = f16::from_bits((rs3_bits >> 16) as u16);
505
506                let (r_lo, r_hi) = match inst.modifier {
507                    m if m == F16PackedOp::Hadd2 as u8 => (
508                        f16::from_f32(a_lo.to_f32() + b_lo.to_f32()),
509                        f16::from_f32(a_hi.to_f32() + b_hi.to_f32()),
510                    ),
511                    m if m == F16PackedOp::Hmul2 as u8 => (
512                        f16::from_f32(a_lo.to_f32() * b_lo.to_f32()),
513                        f16::from_f32(a_hi.to_f32() * b_hi.to_f32()),
514                    ),
515                    m if m == F16PackedOp::Hma2 as u8 => (
516                        f16::from_f32(a_lo.to_f32().mul_add(b_lo.to_f32(), c_lo.to_f32())),
517                        f16::from_f32(a_hi.to_f32().mul_add(b_hi.to_f32(), c_hi.to_f32())),
518                    ),
519                    _ => (f16::ZERO, f16::ZERO),
520                };
521                u32::from(r_lo.to_bits()) | (u32::from(r_hi.to_bits()) << 16)
522            };
523
524            thread.write_register(inst.rd, result);
525        }
526    }
527
528    fn execute_f64_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
529        for lane in 0..wave.wave_width {
530            if !wave.is_thread_active(lane) {
531                continue;
532            }
533
534            let thread = &mut wave.threads[lane as usize];
535
536            let rs1_lo = thread.read_register(inst.rs1);
537            let rs1_hi = thread.read_register(inst.rs1 + 1);
538            let a = f64::from_bits((u64::from(rs1_hi) << 32) | u64::from(rs1_lo));
539
540            let rs2_lo = thread.read_register(inst.rs2);
541            let rs2_hi = thread.read_register(inst.rs2 + 1);
542            let b = f64::from_bits((u64::from(rs2_hi) << 32) | u64::from(rs2_lo));
543
544            let result = if inst.opcode == Opcode::F64Ops {
545                let rs3_lo = thread.read_register(inst.rs3);
546                let rs3_hi = thread.read_register(inst.rs3 + 1);
547                let c = f64::from_bits((u64::from(rs3_hi) << 32) | u64::from(rs3_lo));
548
549                match inst.modifier {
550                    m if m == F64Op::Dadd as u8 => a + b,
551                    m if m == F64Op::Dsub as u8 => a - b,
552                    m if m == F64Op::Dmul as u8 => a * b,
553                    m if m == F64Op::Dma as u8 => a.mul_add(b, c),
554                    _ => 0.0,
555                }
556            } else {
557                match inst.modifier {
558                    m if m == F64DivSqrtOp::Ddiv as u8 => a / b,
559                    m if m == F64DivSqrtOp::Dsqrt as u8 => a.sqrt(),
560                    _ => 0.0,
561                }
562            };
563
564            let bits = result.to_bits();
565            thread.write_register(inst.rd, bits as u32);
566            thread.write_register(inst.rd + 1, (bits >> 32) as u32);
567        }
568    }
569
570    fn execute_compare(&self, wave: &mut Wave, inst: &DecodedInstruction) {
571        for lane in 0..wave.wave_width {
572            if !wave.is_thread_active(lane) {
573                continue;
574            }
575
576            let thread = &mut wave.threads[lane as usize];
577            let rs1 = thread.read_register(inst.rs1);
578            let rs2 = thread.read_register(inst.rs2);
579
580            let result = match inst.opcode {
581                Opcode::Icmp => {
582                    let a = rs1 as i32;
583                    let b = rs2 as i32;
584                    match inst.modifier {
585                        m if m == CmpOp::Eq as u8 => a == b,
586                        m if m == CmpOp::Ne as u8 => a != b,
587                        m if m == CmpOp::Lt as u8 => a < b,
588                        m if m == CmpOp::Le as u8 => a <= b,
589                        m if m == CmpOp::Gt as u8 => a > b,
590                        m if m == CmpOp::Ge as u8 => a >= b,
591                        _ => false,
592                    }
593                }
594                Opcode::Ucmp => match inst.modifier {
595                    m if m == CmpOp::Lt as u8 => rs1 < rs2,
596                    m if m == CmpOp::Le as u8 => rs1 <= rs2,
597                    _ => false,
598                },
599                Opcode::Fcmp => {
600                    let a = f32::from_bits(rs1);
601                    let b = f32::from_bits(rs2);
602                    match inst.modifier {
603                        m if m == CmpOp::Eq as u8 => a == b,
604                        m if m == CmpOp::Ne as u8 => a != b,
605                        m if m == CmpOp::Lt as u8 => a < b,
606                        m if m == CmpOp::Le as u8 => a <= b,
607                        m if m == CmpOp::Gt as u8 => a > b,
608                        m if m == CmpOp::Ord as u8 => !a.is_nan() && !b.is_nan(),
609                        m if m == CmpOp::Unord as u8 => a.is_nan() || b.is_nan(),
610                        _ => false,
611                    }
612                }
613                _ => false,
614            };
615
616            thread.write_predicate(inst.rd, result);
617        }
618    }
619
620    fn execute_select(&self, wave: &mut Wave, inst: &DecodedInstruction) {
621        for lane in 0..wave.wave_width {
622            if !wave.is_thread_active(lane) {
623                continue;
624            }
625
626            let thread = &mut wave.threads[lane as usize];
627            let pred = thread.read_predicate(inst.modifier);
628            let rs1 = thread.read_register(inst.rs1);
629            let rs2 = thread.read_register(inst.rs2);
630
631            let result = if pred { rs1 } else { rs2 };
632            thread.write_register(inst.rd, result);
633        }
634    }
635
636    fn execute_convert(&self, wave: &mut Wave, inst: &DecodedInstruction) {
637        for lane in 0..wave.wave_width {
638            if !wave.is_thread_active(lane) {
639                continue;
640            }
641
642            let thread = &mut wave.threads[lane as usize];
643            let rs1 = thread.read_register(inst.rs1);
644
645            let result = match inst.modifier {
646                m if m == CvtType::F32I32 as u8 => f32::from_bits(rs1) as i32 as u32, // f32 → i32
647                m if m == CvtType::F32U32 as u8 => f32::from_bits(rs1) as u32,        // f32 → u32
648                m if m == CvtType::I32F32 as u8 => ((rs1 as i32) as f32).to_bits(),   // i32 → f32
649                m if m == CvtType::U32F32 as u8 => (rs1 as f32).to_bits(),            // u32 → f32
650                m if m == CvtType::F32F16 as u8 => f16::from_bits(rs1 as u16).to_f32().to_bits(),
651                m if m == CvtType::F16F32 as u8 => {
652                    u32::from(f16::from_f32(f32::from_bits(rs1)).to_bits())
653                }
654                m if m == CvtType::F32F64 as u8 => {
655                    let rs1_hi = thread.read_register(inst.rs1 + 1);
656                    let d = f64::from_bits((u64::from(rs1_hi) << 32) | u64::from(rs1));
657                    (d as f32).to_bits()
658                }
659                m if m == CvtType::F64F32 as u8 => {
660                    let f = f32::from_bits(rs1);
661                    let d = f64::from(f);
662                    let bits = d.to_bits();
663                    thread.write_register(inst.rd + 1, (bits >> 32) as u32);
664                    bits as u32
665                }
666                _ => 0,
667            };
668
669            thread.write_register(inst.rd, result);
670        }
671    }
672
673    fn execute_local_load(
674        &self,
675        wave: &mut Wave,
676        inst: &DecodedInstruction,
677        local_memory: &mut LocalMemory,
678        stats: &mut ExecutionStats,
679    ) -> Result<(), EmulatorError> {
680        let width = match inst.modifier {
681            m if m == MemWidth::U8 as u8 => 1,
682            m if m == MemWidth::U16 as u8 => 2,
683            m if m == MemWidth::U32 as u8 => 4,
684            m if m == MemWidth::U64 as u8 => 8,
685            _ => 4,
686        };
687
688        for lane in 0..wave.wave_width {
689            if !wave.is_thread_active(lane) {
690                continue;
691            }
692
693            let thread = &mut wave.threads[lane as usize];
694            let addr = thread.read_register(inst.rs1);
695
696            let value = match width {
697                1 => u32::from(local_memory.read_u8(addr)?),
698                2 => u32::from(local_memory.read_u16(addr)?),
699                4 => local_memory.read_u32(addr)?,
700                8 => {
701                    let val = local_memory.read_u64(addr)?;
702                    thread.write_register(inst.rd + 1, (val >> 32) as u32);
703                    val as u32
704                }
705                _ => 0,
706            };
707
708            thread.write_register(inst.rd, value);
709            stats.record_local_load(width as u64);
710        }
711
712        stats.record_instruction(InstructionCategory::Memory);
713        Ok(())
714    }
715
716    fn execute_local_store(
717        &self,
718        wave: &mut Wave,
719        inst: &DecodedInstruction,
720        local_memory: &mut LocalMemory,
721        stats: &mut ExecutionStats,
722    ) -> Result<(), EmulatorError> {
723        let width = match inst.modifier {
724            m if m == MemWidth::U8 as u8 => 1,
725            m if m == MemWidth::U16 as u8 => 2,
726            m if m == MemWidth::U32 as u8 => 4,
727            m if m == MemWidth::U64 as u8 => 8,
728            _ => 4,
729        };
730
731        for lane in 0..wave.wave_width {
732            if !wave.is_thread_active(lane) {
733                continue;
734            }
735
736            let thread = &wave.threads[lane as usize];
737            let addr = thread.read_register(inst.rs1);
738            let value = thread.read_register(inst.rs2);
739
740            match width {
741                1 => local_memory.write_u8(addr, value as u8)?,
742                2 => local_memory.write_u16(addr, value as u16)?,
743                4 => local_memory.write_u32(addr, value)?,
744                8 => {
745                    let hi = thread.read_register(inst.rs2 + 1);
746                    let val = (u64::from(hi) << 32) | u64::from(value);
747                    local_memory.write_u64(addr, val)?;
748                }
749                _ => {}
750            }
751
752            stats.record_local_store(width as u64);
753        }
754
755        stats.record_instruction(InstructionCategory::Memory);
756        Ok(())
757    }
758
759    fn execute_device_load(
760        &self,
761        wave: &mut Wave,
762        inst: &DecodedInstruction,
763        device_memory: &mut DeviceMemory,
764        stats: &mut ExecutionStats,
765    ) -> Result<(), EmulatorError> {
766        let width = match inst.modifier {
767            m if m == MemWidth::U8 as u8 => 1,
768            m if m == MemWidth::U16 as u8 => 2,
769            m if m == MemWidth::U32 as u8 => 4,
770            m if m == MemWidth::U64 as u8 => 8,
771            m if m == MemWidth::U128 as u8 => 16,
772            _ => 4,
773        };
774
775        for lane in 0..wave.wave_width {
776            if !wave.is_thread_active(lane) {
777                continue;
778            }
779
780            let thread = &mut wave.threads[lane as usize];
781            let addr = u64::from(thread.read_register(inst.rs1));
782
783            match width {
784                1 => {
785                    let val = device_memory.read_u8(addr)?;
786                    thread.write_register(inst.rd, u32::from(val));
787                }
788                2 => {
789                    let val = device_memory.read_u16(addr)?;
790                    thread.write_register(inst.rd, u32::from(val));
791                }
792                4 => {
793                    let val = device_memory.read_u32(addr)?;
794                    thread.write_register(inst.rd, val);
795                }
796                8 => {
797                    let val = device_memory.read_u64(addr)?;
798                    thread.write_register(inst.rd, val as u32);
799                    thread.write_register(inst.rd + 1, (val >> 32) as u32);
800                }
801                16 => {
802                    let val = device_memory.read_u128(addr)?;
803                    thread.write_register(inst.rd, val as u32);
804                    thread.write_register(inst.rd + 1, (val >> 32) as u32);
805                    thread.write_register(inst.rd + 2, (val >> 64) as u32);
806                    thread.write_register(inst.rd + 3, (val >> 96) as u32);
807                }
808                _ => {}
809            }
810
811            stats.record_device_load(width as u64);
812        }
813
814        stats.record_instruction(InstructionCategory::Memory);
815        Ok(())
816    }
817
818    fn execute_device_store(
819        &self,
820        wave: &mut Wave,
821        inst: &DecodedInstruction,
822        device_memory: &mut DeviceMemory,
823        stats: &mut ExecutionStats,
824    ) -> Result<(), EmulatorError> {
825        let width = match inst.modifier {
826            m if m == MemWidth::U8 as u8 => 1,
827            m if m == MemWidth::U16 as u8 => 2,
828            m if m == MemWidth::U32 as u8 => 4,
829            m if m == MemWidth::U64 as u8 => 8,
830            m if m == MemWidth::U128 as u8 => 16,
831            _ => 4,
832        };
833
834        for lane in 0..wave.wave_width {
835            if !wave.is_thread_active(lane) {
836                continue;
837            }
838
839            let thread = &wave.threads[lane as usize];
840            let addr = u64::from(thread.read_register(inst.rs1));
841            let value = thread.read_register(inst.rs2);
842
843            match width {
844                1 => device_memory.write_u8(addr, value as u8)?,
845                2 => device_memory.write_u16(addr, value as u16)?,
846                4 => device_memory.write_u32(addr, value)?,
847                8 => {
848                    let hi = thread.read_register(inst.rs2 + 1);
849                    let val = (u64::from(hi) << 32) | u64::from(value);
850                    device_memory.write_u64(addr, val)?;
851                }
852                16 => {
853                    let w0 = value;
854                    let w1 = thread.read_register(inst.rs2 + 1);
855                    let w2 = thread.read_register(inst.rs2 + 2);
856                    let w3 = thread.read_register(inst.rs2 + 3);
857                    let val = u128::from(w0)
858                        | (u128::from(w1) << 32)
859                        | (u128::from(w2) << 64)
860                        | (u128::from(w3) << 96);
861                    device_memory.write_u128(addr, val)?;
862                }
863                _ => {}
864            }
865
866            stats.record_device_store(width as u64);
867        }
868
869        stats.record_instruction(InstructionCategory::Memory);
870        Ok(())
871    }
872
873    fn execute_local_atomic(
874        &self,
875        wave: &mut Wave,
876        inst: &DecodedInstruction,
877        local_memory: &mut LocalMemory,
878        stats: &mut ExecutionStats,
879    ) -> Result<(), EmulatorError> {
880        let non_returning = inst.is_non_returning_atomic();
881
882        for lane in 0..wave.wave_width {
883            if !wave.is_thread_active(lane) {
884                continue;
885            }
886
887            let thread = &mut wave.threads[lane as usize];
888            let addr = thread.read_register(inst.rs1);
889            let value = thread.read_register(inst.rs2);
890
891            let old = match inst.modifier {
892                m if m == AtomicOp::Add as u8 => local_memory.atomic_add(addr, value)?,
893                m if m == AtomicOp::Sub as u8 => local_memory.atomic_sub(addr, value)?,
894                m if m == AtomicOp::Min as u8 => local_memory.atomic_min(addr, value)?,
895                m if m == AtomicOp::Max as u8 => local_memory.atomic_max(addr, value)?,
896                m if m == AtomicOp::And as u8 => local_memory.atomic_and(addr, value)?,
897                m if m == AtomicOp::Or as u8 => local_memory.atomic_or(addr, value)?,
898                m if m == AtomicOp::Xor as u8 => local_memory.atomic_xor(addr, value)?,
899                m if m == AtomicOp::Exchange as u8 => local_memory.atomic_exchange(addr, value)?,
900                _ => {
901                    let expected = thread.read_register(inst.rs3);
902                    local_memory.atomic_cas(addr, expected, value)?
903                }
904            };
905
906            if !non_returning {
907                thread.write_register(inst.rd, old);
908            }
909
910            stats.atomic_ops += 1;
911        }
912
913        stats.record_instruction(InstructionCategory::Atomic);
914        Ok(())
915    }
916
917    fn execute_device_atomic(
918        &self,
919        wave: &mut Wave,
920        inst: &DecodedInstruction,
921        device_memory: &mut DeviceMemory,
922        stats: &mut ExecutionStats,
923    ) -> Result<(), EmulatorError> {
924        let non_returning = inst.is_non_returning_atomic();
925
926        for lane in 0..wave.wave_width {
927            if !wave.is_thread_active(lane) {
928                continue;
929            }
930
931            let thread = &mut wave.threads[lane as usize];
932            let addr = u64::from(thread.read_register(inst.rs1));
933            let value = thread.read_register(inst.rs2);
934
935            let old = match inst.modifier {
936                m if m == AtomicOp::Add as u8 => device_memory.atomic_add(addr, value)?,
937                m if m == AtomicOp::Sub as u8 => device_memory.atomic_sub(addr, value)?,
938                m if m == AtomicOp::Min as u8 => device_memory.atomic_min(addr, value)?,
939                m if m == AtomicOp::Max as u8 => device_memory.atomic_max(addr, value)?,
940                m if m == AtomicOp::And as u8 => device_memory.atomic_and(addr, value)?,
941                m if m == AtomicOp::Or as u8 => device_memory.atomic_or(addr, value)?,
942                m if m == AtomicOp::Xor as u8 => device_memory.atomic_xor(addr, value)?,
943                m if m == AtomicOp::Exchange as u8 => device_memory.atomic_exchange(addr, value)?,
944                _ => {
945                    let expected = thread.read_register(inst.rs3);
946                    device_memory.atomic_cas(addr, expected, value)?
947                }
948            };
949
950            if !non_returning {
951                thread.write_register(inst.rd, old);
952            }
953
954            stats.atomic_ops += 1;
955        }
956
957        stats.record_instruction(InstructionCategory::Atomic);
958        Ok(())
959    }
960
961    fn execute_wave_op(&self, wave: &mut Wave, inst: &DecodedInstruction) {
962        if inst.is_wave_reduce() {
963            let reduce_mod = inst.modifier - 8;
964            match reduce_mod {
965                m if m == WaveReduceType::PrefixSum as u8 => {
966                    shuffle::wave_prefix_sum(wave, inst.rd, inst.rs1);
967                }
968                m if m == WaveReduceType::ReduceAdd as u8 => {
969                    shuffle::wave_reduce_add(wave, inst.rd, inst.rs1);
970                }
971                m if m == WaveReduceType::ReduceMin as u8 => {
972                    shuffle::wave_reduce_min(wave, inst.rd, inst.rs1);
973                }
974                m if m == WaveReduceType::ReduceMax as u8 => {
975                    shuffle::wave_reduce_max(wave, inst.rd, inst.rs1);
976                }
977                _ => {}
978            }
979        } else {
980            match inst.modifier {
981                m if m == WaveOpType::Shuffle as u8 => {
982                    shuffle::wave_shuffle(wave, inst.rd, inst.rs1, inst.rs2);
983                }
984                m if m == WaveOpType::ShuffleUp as u8 => {
985                    shuffle::wave_shuffle_up(wave, inst.rd, inst.rs1, inst.rs2);
986                }
987                m if m == WaveOpType::ShuffleDown as u8 => {
988                    shuffle::wave_shuffle_down(wave, inst.rd, inst.rs1, inst.rs2);
989                }
990                m if m == WaveOpType::ShuffleXor as u8 => {
991                    shuffle::wave_shuffle_xor(wave, inst.rd, inst.rs1, inst.rs2);
992                }
993                m if m == WaveOpType::Broadcast as u8 => {
994                    shuffle::wave_broadcast(wave, inst.rd, inst.rs1, inst.rs2);
995                }
996                m if m == WaveOpType::Ballot as u8 => {
997                    shuffle::wave_ballot(wave, inst.rd, inst.rs1);
998                }
999                m if m == WaveOpType::Any as u8 => {
1000                    shuffle::wave_any(wave, inst.rd, inst.rs1);
1001                }
1002                m if m == WaveOpType::All as u8 => {
1003                    shuffle::wave_all(wave, inst.rd, inst.rs1);
1004                }
1005                _ => {}
1006            }
1007        }
1008    }
1009
1010    fn execute_control(
1011        &mut self,
1012        wave: &mut Wave,
1013        inst: &DecodedInstruction,
1014        stats: &mut ExecutionStats,
1015    ) -> Result<ExecuteResult, EmulatorError> {
1016        stats.record_instruction(InstructionCategory::Control);
1017
1018        if inst.is_sync_op() {
1019            return self.execute_sync_op(wave, inst);
1020        }
1021
1022        if inst.is_misc_op() {
1023            return self.execute_misc_op(wave, inst);
1024        }
1025
1026        self.execute_control_flow(wave, inst, stats)
1027    }
1028
1029    fn execute_sync_op(
1030        &self,
1031        wave: &mut Wave,
1032        inst: &DecodedInstruction,
1033    ) -> Result<ExecuteResult, EmulatorError> {
1034        match inst.modifier {
1035            m if m == SyncOp::Return as u8 => {
1036                if let Some(return_pc) = wave.pop_call() {
1037                    Ok(ExecuteResult::Jump(return_pc))
1038                } else {
1039                    Ok(ExecuteResult::Halt)
1040                }
1041            }
1042            m if m == SyncOp::Halt as u8 => Ok(ExecuteResult::Halt),
1043            m if m == SyncOp::Barrier as u8 => Ok(ExecuteResult::Barrier),
1044            m if m == SyncOp::Nop as u8 || m == SyncOp::Wait as u8 => Ok(ExecuteResult::Continue),
1045            _ => Ok(ExecuteResult::Continue),
1046        }
1047    }
1048
1049    fn execute_misc_op(
1050        &self,
1051        wave: &mut Wave,
1052        inst: &DecodedInstruction,
1053    ) -> Result<ExecuteResult, EmulatorError> {
1054        match inst.modifier {
1055            m if m == MiscOp::Mov as u8 => {
1056                for lane in 0..wave.wave_width {
1057                    if wave.is_thread_active(lane) {
1058                        let thread = &mut wave.threads[lane as usize];
1059                        let value = thread.read_register(inst.rs1);
1060                        thread.write_register(inst.rd, value);
1061                    }
1062                }
1063            }
1064            m if m == MiscOp::MovImm as u8 => {
1065                for lane in 0..wave.wave_width {
1066                    if wave.is_thread_active(lane) {
1067                        wave.threads[lane as usize].write_register(inst.rd, inst.immediate);
1068                    }
1069                }
1070            }
1071            m if m == MiscOp::MovSr as u8 => {
1072                for lane in 0..wave.wave_width {
1073                    if wave.is_thread_active(lane) {
1074                        let thread = &mut wave.threads[lane as usize];
1075                        let value = thread.read_special(inst.rs1);
1076                        thread.write_register(inst.rd, value);
1077                    }
1078                }
1079            }
1080            _ => {}
1081        }
1082        Ok(ExecuteResult::Continue)
1083    }
1084
1085    fn execute_control_flow(
1086        &mut self,
1087        wave: &mut Wave,
1088        inst: &DecodedInstruction,
1089        stats: &mut ExecutionStats,
1090    ) -> Result<ExecuteResult, EmulatorError> {
1091        match inst.modifier {
1092            m if m == ControlOp::If as u8 => {
1093                let mut pred_mask: u64 = 0;
1094                for lane in 0..wave.wave_width {
1095                    if wave.is_thread_active(lane) {
1096                        if wave.threads[lane as usize].read_predicate(inst.rs1) {
1097                            pred_mask |= 1u64 << lane;
1098                        }
1099                    }
1100                }
1101
1102                let then_mask = wave.active_mask & pred_mask;
1103                let else_mask = wave.active_mask & !pred_mask;
1104
1105                if then_mask != wave.active_mask && else_mask != 0 {
1106                    stats.record_divergent_branch();
1107                }
1108
1109                let (new_mask, _) = wave.control_flow.handle_if(wave.active_mask, pred_mask)?;
1110                wave.active_mask = new_mask;
1111                Ok(ExecuteResult::Continue)
1112            }
1113            m if m == ControlOp::Else as u8 => {
1114                let (new_mask, _) = wave.control_flow.handle_else(wave.active_mask)?;
1115                wave.active_mask = new_mask;
1116                Ok(ExecuteResult::Continue)
1117            }
1118            m if m == ControlOp::Endif as u8 => {
1119                let new_mask = wave.control_flow.handle_endif()?;
1120                wave.active_mask = new_mask;
1121                Ok(ExecuteResult::Continue)
1122            }
1123            m if m == ControlOp::Loop as u8 => {
1124                let body_start = wave.pc + inst.size;
1125                let new_mask = wave
1126                    .control_flow
1127                    .handle_loop(wave.active_mask, body_start)?;
1128                wave.active_mask = new_mask;
1129                Ok(ExecuteResult::Continue)
1130            }
1131            m if m == ControlOp::Break as u8 => {
1132                let mut pred_mask: u64 = 0;
1133                for lane in 0..wave.wave_width {
1134                    if wave.is_thread_active(lane) {
1135                        if wave.threads[lane as usize].read_predicate(inst.rs1) {
1136                            pred_mask |= 1u64 << lane;
1137                        }
1138                    }
1139                }
1140
1141                if self.trace.is_enabled() {
1142                    eprintln!(
1143                        "  BREAK: active_mask=0x{:x}, pred_mask=0x{:x}, pred_reg=p{}",
1144                        wave.active_mask, pred_mask, inst.rs1
1145                    );
1146                }
1147
1148                let (new_mask, jump) = wave
1149                    .control_flow
1150                    .handle_break(wave.active_mask, pred_mask)?;
1151                wave.active_mask = new_mask;
1152
1153                if self.trace.is_enabled() {
1154                    eprintln!("  BREAK: new_active_mask=0x{new_mask:x}, jump={jump:?}");
1155                }
1156
1157                Ok(ExecuteResult::Continue)
1158            }
1159            m if m == ControlOp::Continue as u8 => {
1160                let mut pred_mask: u64 = 0;
1161                for lane in 0..wave.wave_width {
1162                    if wave.is_thread_active(lane) {
1163                        if wave.threads[lane as usize].read_predicate(inst.rs1) {
1164                            pred_mask |= 1u64 << lane;
1165                        }
1166                    }
1167                }
1168
1169                let (new_mask, jump) = wave
1170                    .control_flow
1171                    .handle_continue(wave.active_mask, pred_mask)?;
1172                wave.active_mask = new_mask;
1173                if let Some(target) = jump {
1174                    Ok(ExecuteResult::Jump(target))
1175                } else {
1176                    Ok(ExecuteResult::Continue)
1177                }
1178            }
1179            m if m == ControlOp::Endloop as u8 => {
1180                if self.trace.is_enabled() {
1181                    eprintln!("  ENDLOOP: active_mask=0x{:x}", wave.active_mask);
1182                }
1183
1184                let (new_mask, jump) = wave.control_flow.handle_endloop(wave.active_mask)?;
1185                wave.active_mask = new_mask;
1186
1187                if self.trace.is_enabled() {
1188                    eprintln!("  ENDLOOP: new_active_mask=0x{new_mask:x}, jump={jump:?}");
1189                }
1190
1191                if let Some(target) = jump {
1192                    Ok(ExecuteResult::Jump(target))
1193                } else {
1194                    Ok(ExecuteResult::Continue)
1195                }
1196            }
1197            m if m == ControlOp::Call as u8 => {
1198                let return_pc = wave.pc + inst.size;
1199                wave.push_call(return_pc)
1200                    .map_err(|_| EmulatorError::StackOverflow {
1201                        kind: "call".into(),
1202                    })?;
1203                Ok(ExecuteResult::Jump(inst.immediate))
1204            }
1205            _ => Ok(ExecuteResult::Continue),
1206        }
1207    }
1208}
1209
1210#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1211pub enum StepResult {
1212    Continue,
1213    Halted,
1214    Barrier,
1215}
1216
1217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1218enum ExecuteResult {
1219    Continue,
1220    Jump(u32),
1221    Halt,
1222    Barrier,
1223}
1224
1225#[cfg(test)]
1226mod tests {
1227    use super::*;
1228    use crate::decoder::MISC_OP_FLAG;
1229
1230    fn encode_base(opcode: u8, rd: u8, rs1: u8, rs2: u8, modifier: u8, flags: u8) -> Vec<u8> {
1231        let word = ((u32::from(opcode) & 0x3F) << 26)
1232            | ((u32::from(rd) & 0x1F) << 21)
1233            | ((u32::from(rs1) & 0x1F) << 16)
1234            | ((u32::from(rs2) & 0x1F) << 11)
1235            | ((u32::from(modifier) & 0x0F) << 7)
1236            | (u32::from(flags) & 0x03);
1237        word.to_le_bytes().to_vec()
1238    }
1239
1240    fn encode_extended(
1241        opcode: u8,
1242        rd: u8,
1243        rs1: u8,
1244        rs2: u8,
1245        modifier: u8,
1246        flags: u8,
1247        imm: u32,
1248    ) -> Vec<u8> {
1249        let word0 = ((u32::from(opcode) & 0x3F) << 26)
1250            | ((u32::from(rd) & 0x1F) << 21)
1251            | ((u32::from(rs1) & 0x1F) << 16)
1252            | ((u32::from(rs2) & 0x1F) << 11)
1253            | ((u32::from(modifier) & 0x0F) << 7)
1254            | (u32::from(flags) & 0x03);
1255        let mut code = word0.to_le_bytes().to_vec();
1256        code.extend_from_slice(&imm.to_le_bytes());
1257        code
1258    }
1259
1260    #[test]
1261    fn test_executor_iadd() {
1262        let code = encode_base(0x00, 3, 1, 2, 0, 0);
1263        let mut wave = Wave::new(4, 32, 0, [0, 0, 0], [4, 1, 1], [1, 1, 1], 0, 4, 1);
1264
1265        for i in 0..4 {
1266            wave.threads[i].write_register(1, 10);
1267            wave.threads[i].write_register(2, 20);
1268        }
1269
1270        let mut executor = Executor::new(&code, false, [0, 0, 0]);
1271        let mut local_memory = LocalMemory::new(1024);
1272        let mut device_memory = DeviceMemory::new(1024);
1273        let mut stats = ExecutionStats::new();
1274
1275        executor
1276            .step(&mut wave, &mut local_memory, &mut device_memory, &mut stats)
1277            .unwrap();
1278
1279        for i in 0..4 {
1280            assert_eq!(wave.threads[i].read_register(3), 30);
1281        }
1282    }
1283
1284    #[test]
1285    fn test_executor_mov_imm() {
1286        let code = encode_extended(0x3F, 5, 0, 0, 1, MISC_OP_FLAG as u8, 0xDEADBEEF);
1287        let mut wave = Wave::new(4, 32, 0, [0, 0, 0], [4, 1, 1], [1, 1, 1], 0, 4, 1);
1288
1289        let mut executor = Executor::new(&code, false, [0, 0, 0]);
1290        let mut local_memory = LocalMemory::new(1024);
1291        let mut device_memory = DeviceMemory::new(1024);
1292        let mut stats = ExecutionStats::new();
1293
1294        executor
1295            .step(&mut wave, &mut local_memory, &mut device_memory, &mut stats)
1296            .unwrap();
1297
1298        for i in 0..4 {
1299            assert_eq!(wave.threads[i].read_register(5), 0xDEADBEEF);
1300        }
1301    }
1302
1303    #[test]
1304    fn test_executor_respects_active_mask() {
1305        let code = encode_base(0x00, 3, 1, 2, 0, 0);
1306        let mut wave = Wave::new(4, 32, 0, [0, 0, 0], [4, 1, 1], [1, 1, 1], 0, 4, 1);
1307
1308        wave.active_mask = 0b0101;
1309
1310        for i in 0..4 {
1311            wave.threads[i].write_register(1, 10);
1312            wave.threads[i].write_register(2, 20);
1313            wave.threads[i].write_register(3, 0);
1314        }
1315
1316        let mut executor = Executor::new(&code, false, [0, 0, 0]);
1317        let mut local_memory = LocalMemory::new(1024);
1318        let mut device_memory = DeviceMemory::new(1024);
1319        let mut stats = ExecutionStats::new();
1320
1321        executor
1322            .step(&mut wave, &mut local_memory, &mut device_memory, &mut stats)
1323            .unwrap();
1324
1325        assert_eq!(wave.threads[0].read_register(3), 30);
1326        assert_eq!(wave.threads[1].read_register(3), 0);
1327        assert_eq!(wave.threads[2].read_register(3), 30);
1328        assert_eq!(wave.threads[3].read_register(3), 0);
1329    }
1330}