Skip to main content

amx_sys/instructions/
modes.rs

1//! AMX instruction parameter modes and encoding
2//!
3//! This module defines the various mode bits and parameters used by AMX instructions.
4
5/// Write mode - controls which elements are written
6#[repr(u8)]
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum WriteMode {
9    /// Mode 0: write to all (val=0), odd lanes only (val=1), even lanes only (val=2), or no lanes (val=3)
10    Mode0 = 0,
11    /// Mode 1: write to only lane #N (or broadcast Y lane #N for vector ops)
12    Mode1 = 1,
13    /// Mode 2: write only first N lanes (or all when N=0)
14    Mode2 = 2,
15    /// Mode 3: write only last N lanes (or all when N=0)
16    Mode3 = 3,
17    /// Mode 4: write first N lanes (no lanes when N=0)
18    Mode4 = 4,
19}
20
21/// Shuffle mode for vector operations
22#[repr(u8)]
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ShuffleMode {
25    S0 = 0,  // Identity (no shuffle)
26    S1 = 1,  // Interleave: move lane 1 to lane 2
27    S2 = 2,  // move lane 1 to lane 4
28    S3 = 3,  // move lane 1 to lane 8
29}
30
31/// Element size for operations
32#[repr(u8)]
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum ElementSize {
35    B8 = 0,   // 8-bit
36    B16 = 1,  // 16-bit
37    B32 = 2,  // 32-bit
38    B64 = 3,  // 64-bit
39}
40
41impl ElementSize {
42    pub fn bytes(self) -> usize {
43        1 << (self as usize)
44    }
45}
46
47/// Rounding mode for floating-point operations
48#[repr(u8)]
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum RoundingMode {
51    /// Round to nearest, ties to even
52    RNE = 0,
53    /// Round towards zero
54    RZ = 1,
55    /// Round towards positive infinity
56    RP = 2,
57    /// Round towards negative infinity
58    RM = 3,
59}
60
61/// Data type for operations
62#[repr(u8)]
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum DataType {
65    SignedInt = 0,
66    UnsignedInt = 1,
67    Float = 2,
68}
69
70/// Instruction encoding parameters - extracted from the immediate field
71#[derive(Debug, Clone, Copy)]
72pub struct InstructionParams {
73    pub mode: u8,
74    pub value: u8,
75    pub shuffle: ShuffleMode,
76    pub element_size: ElementSize,
77    pub data_type: DataType,
78}
79
80impl InstructionParams {
81    /// Decode instruction parameters from the immediate field
82    pub fn from_immediate(imm: u64) -> Self {
83        let mode = ((imm >> 0) & 0xF) as u8;
84        let value = ((imm >> 4) & 0xFF) as u8;
85        let shuffle = match (imm >> 12) & 0x3 {
86            0 => ShuffleMode::S0,
87            1 => ShuffleMode::S1,
88            2 => ShuffleMode::S2,
89            3 => ShuffleMode::S3,
90            _ => ShuffleMode::S0,
91        };
92        let element_size = match (imm >> 14) & 0x3 {
93            0 => ElementSize::B8,
94            1 => ElementSize::B16,
95            2 => ElementSize::B32,
96            3 => ElementSize::B64,
97            _ => ElementSize::B8,
98        };
99        let data_type = match (imm >> 16) & 0x3 {
100            0 => DataType::SignedInt,
101            1 => DataType::UnsignedInt,
102            2 => DataType::Float,
103            _ => DataType::SignedInt,
104        };
105
106        InstructionParams {
107            mode,
108            value,
109            shuffle,
110            element_size,
111            data_type,
112        }
113    }
114
115    /// Encode to immediate field
116    pub fn to_immediate(self) -> u64 {
117        let mut imm = 0u64;
118        imm |= (self.mode as u64) << 0;
119        imm |= (self.value as u64) << 4;
120        imm |= ((self.shuffle as u8) as u64) << 12;
121        imm |= ((self.element_size as u8) as u64) << 14;
122        imm |= ((self.data_type as u8) as u64) << 16;
123        imm
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn element_size_bytes() {
133        assert_eq!(ElementSize::B8.bytes(), 1);
134        assert_eq!(ElementSize::B16.bytes(), 2);
135        assert_eq!(ElementSize::B32.bytes(), 4);
136        assert_eq!(ElementSize::B64.bytes(), 8);
137    }
138
139    #[test]
140    fn params_roundtrip() {
141        let params = InstructionParams {
142            mode: 2,
143            value: 42,
144            shuffle: ShuffleMode::S2,
145            element_size: ElementSize::B32,
146            data_type: DataType::Float,
147        };
148
149        let imm = params.to_immediate();
150        let decoded = InstructionParams::from_immediate(imm);
151
152        assert_eq!(decoded.mode, 2);
153        assert_eq!(decoded.value, 42);
154        assert_eq!(decoded.shuffle, ShuffleMode::S2);
155        assert_eq!(decoded.element_size, ElementSize::B32);
156        assert_eq!(decoded.data_type, DataType::Float);
157    }
158}