Skip to main content

pascal/
simd.rs

1//! SIMD instruction support
2//!
3//! Vectorization and SIMD code generation
4
5use anyhow::Result;
6use std::fmt::Write as FmtWrite;
7
8/// SIMD register types
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SimdRegister {
11    // SSE registers (128-bit)
12    XMM0,
13    XMM1,
14    XMM2,
15    XMM3,
16    XMM4,
17    XMM5,
18    XMM6,
19    XMM7,
20    XMM8,
21    XMM9,
22    XMM10,
23    XMM11,
24    XMM12,
25    XMM13,
26    XMM14,
27    XMM15,
28
29    // AVX registers (256-bit)
30    YMM0,
31    YMM1,
32    YMM2,
33    YMM3,
34    YMM4,
35    YMM5,
36    YMM6,
37    YMM7,
38    YMM8,
39    YMM9,
40    YMM10,
41    YMM11,
42    YMM12,
43    YMM13,
44    YMM14,
45    YMM15,
46
47    // AVX-512 registers (512-bit)
48    ZMM0,
49    ZMM1,
50    ZMM2,
51    ZMM3,
52    ZMM4,
53    ZMM5,
54    ZMM6,
55    ZMM7,
56}
57
58impl SimdRegister {
59    /// Get register name
60    pub fn name(&self) -> &str {
61        match self {
62            SimdRegister::XMM0 => "xmm0",
63            SimdRegister::XMM1 => "xmm1",
64            SimdRegister::XMM2 => "xmm2",
65            SimdRegister::XMM3 => "xmm3",
66            SimdRegister::XMM4 => "xmm4",
67            SimdRegister::XMM5 => "xmm5",
68            SimdRegister::XMM6 => "xmm6",
69            SimdRegister::XMM7 => "xmm7",
70            SimdRegister::YMM0 => "ymm0",
71            SimdRegister::YMM1 => "ymm1",
72            SimdRegister::ZMM0 => "zmm0",
73            _ => "xmm0",
74        }
75    }
76
77    /// Get register width in bits
78    pub fn width(&self) -> usize {
79        match self {
80            SimdRegister::XMM0
81            | SimdRegister::XMM1
82            | SimdRegister::XMM2
83            | SimdRegister::XMM3
84            | SimdRegister::XMM4
85            | SimdRegister::XMM5
86            | SimdRegister::XMM6
87            | SimdRegister::XMM7
88            | SimdRegister::XMM8
89            | SimdRegister::XMM9
90            | SimdRegister::XMM10
91            | SimdRegister::XMM11
92            | SimdRegister::XMM12
93            | SimdRegister::XMM13
94            | SimdRegister::XMM14
95            | SimdRegister::XMM15 => 128,
96
97            SimdRegister::YMM0
98            | SimdRegister::YMM1
99            | SimdRegister::YMM2
100            | SimdRegister::YMM3
101            | SimdRegister::YMM4
102            | SimdRegister::YMM5
103            | SimdRegister::YMM6
104            | SimdRegister::YMM7
105            | SimdRegister::YMM8
106            | SimdRegister::YMM9
107            | SimdRegister::YMM10
108            | SimdRegister::YMM11
109            | SimdRegister::YMM12
110            | SimdRegister::YMM13
111            | SimdRegister::YMM14
112            | SimdRegister::YMM15 => 256,
113
114            SimdRegister::ZMM0
115            | SimdRegister::ZMM1
116            | SimdRegister::ZMM2
117            | SimdRegister::ZMM3
118            | SimdRegister::ZMM4
119            | SimdRegister::ZMM5
120            | SimdRegister::ZMM6
121            | SimdRegister::ZMM7 => 512,
122        }
123    }
124}
125
126/// SIMD operation type
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum SimdOp {
129    // Arithmetic
130    AddPacked,
131    SubPacked,
132    MulPacked,
133    DivPacked,
134
135    // Logical
136    AndPacked,
137    OrPacked,
138    XorPacked,
139
140    // Comparison
141    CmpEqPacked,
142    CmpLtPacked,
143    CmpGtPacked,
144
145    // Data movement
146    MovePacked,
147    LoadPacked,
148    StorePacked,
149
150    // Shuffle
151    Shuffle,
152    Permute,
153    Broadcast,
154}
155
156/// SIMD code generator
157pub struct SimdCodeGen {
158    output: String,
159    target_features: SimdFeatures,
160}
161
162#[derive(Debug, Clone)]
163pub struct SimdFeatures {
164    pub sse: bool,
165    pub sse2: bool,
166    pub sse3: bool,
167    pub sse4_1: bool,
168    pub avx: bool,
169    pub avx2: bool,
170    pub avx512: bool,
171}
172
173impl SimdFeatures {
174    /// Create with SSE2 support (baseline)
175    pub fn baseline() -> Self {
176        Self {
177            sse: true,
178            sse2: true,
179            sse3: false,
180            sse4_1: false,
181            avx: false,
182            avx2: false,
183            avx512: false,
184        }
185    }
186
187    /// Create with AVX2 support
188    pub fn avx2() -> Self {
189        Self {
190            sse: true,
191            sse2: true,
192            sse3: true,
193            sse4_1: true,
194            avx: true,
195            avx2: true,
196            avx512: false,
197        }
198    }
199}
200
201impl SimdCodeGen {
202    /// Create a new SIMD code generator
203    pub fn new(features: SimdFeatures) -> Self {
204        Self {
205            output: String::new(),
206            target_features: features,
207        }
208    }
209
210    /// Generate SIMD instruction
211    pub fn emit_simd(
212        &mut self,
213        op: SimdOp,
214        dest: SimdRegister,
215        src1: SimdRegister,
216        src2: Option<SimdRegister>,
217    ) -> Result<()> {
218        let instr = match op {
219            SimdOp::AddPacked => {
220                if self.target_features.avx {
221                    "vaddps"
222                } else {
223                    "addps"
224                }
225            }
226            SimdOp::SubPacked => {
227                if self.target_features.avx {
228                    "vsubps"
229                } else {
230                    "subps"
231                }
232            }
233            SimdOp::MulPacked => {
234                if self.target_features.avx {
235                    "vmulps"
236                } else {
237                    "mulps"
238                }
239            }
240            SimdOp::DivPacked => {
241                if self.target_features.avx {
242                    "vdivps"
243                } else {
244                    "divps"
245                }
246            }
247            SimdOp::MovePacked => {
248                if self.target_features.avx {
249                    "vmovaps"
250                } else {
251                    "movaps"
252                }
253            }
254            SimdOp::LoadPacked => "movups",
255            SimdOp::StorePacked => "movups",
256            _ => "movaps",
257        };
258
259        if let Some(src2) = src2 {
260            writeln!(
261                &mut self.output,
262                "    {} {}, {}, {}",
263                instr,
264                dest.name(),
265                src1.name(),
266                src2.name()
267            )?;
268        } else {
269            writeln!(
270                &mut self.output,
271                "    {} {}, {}",
272                instr,
273                dest.name(),
274                src1.name()
275            )?;
276        }
277
278        Ok(())
279    }
280
281    /// Vectorize a loop
282    pub fn vectorize_loop(&mut self, iterations: usize, element_size: usize) -> Result<String> {
283        let vector_width = if self.target_features.avx2 { 256 } else { 128 };
284        let elements_per_vector = vector_width / (element_size * 8);
285
286        let mut code = String::new();
287        writeln!(
288            &mut code,
289            "    # Vectorized loop ({} elements per iteration)",
290            elements_per_vector
291        )?;
292        writeln!(
293            &mut code,
294            "    mov rcx, {}",
295            iterations / elements_per_vector
296        )?;
297        writeln!(&mut code, ".Lvector_loop:")?;
298
299        // Load vectors
300        writeln!(&mut code, "    movups xmm0, [rsi]")?;
301        writeln!(&mut code, "    movups xmm1, [rdi]")?;
302
303        // Perform operation
304        writeln!(&mut code, "    addps xmm0, xmm1")?;
305
306        // Store result
307        writeln!(&mut code, "    movups [rdx], xmm0")?;
308
309        // Advance pointers
310        writeln!(
311            &mut code,
312            "    add rsi, {}",
313            elements_per_vector * element_size
314        )?;
315        writeln!(
316            &mut code,
317            "    add rdi, {}",
318            elements_per_vector * element_size
319        )?;
320        writeln!(
321            &mut code,
322            "    add rdx, {}",
323            elements_per_vector * element_size
324        )?;
325
326        writeln!(&mut code, "    dec rcx")?;
327        writeln!(&mut code, "    jnz .Lvector_loop")?;
328
329        // Handle remainder
330        let remainder = iterations % elements_per_vector;
331        if remainder > 0 {
332            writeln!(&mut code, "    # Handle {} remaining elements", remainder)?;
333        }
334
335        Ok(code)
336    }
337
338    /// Get generated code
339    pub fn get_code(&self) -> &str {
340        &self.output
341    }
342}
343
344/// Calling convention
345#[derive(Debug, Clone, Copy, PartialEq, Eq)]
346pub enum CallingConvention {
347    /// System V AMD64 ABI (Linux, macOS)
348    SystemV,
349    /// Microsoft x64 calling convention (Windows)
350    Win64,
351    /// Custom calling convention
352    Custom,
353}
354
355impl CallingConvention {
356    /// Get parameter registers
357    pub fn param_registers(&self) -> Vec<&'static str> {
358        match self {
359            CallingConvention::SystemV => {
360                vec!["rdi", "rsi", "rdx", "rcx", "r8", "r9"]
361            }
362            CallingConvention::Win64 => {
363                vec!["rcx", "rdx", "r8", "r9"]
364            }
365            CallingConvention::Custom => {
366                vec!["rdi", "rsi", "rdx", "rcx"]
367            }
368        }
369    }
370
371    /// Get return register
372    pub fn return_register(&self) -> &'static str {
373        "rax"
374    }
375
376    /// Get callee-saved registers
377    pub fn callee_saved_registers(&self) -> Vec<&'static str> {
378        match self {
379            CallingConvention::SystemV => {
380                vec!["rbx", "r12", "r13", "r14", "r15", "rbp"]
381            }
382            CallingConvention::Win64 => {
383                vec!["rbx", "rbp", "rdi", "rsi", "r12", "r13", "r14", "r15"]
384            }
385            CallingConvention::Custom => {
386                vec!["rbx", "r12", "r13", "r14", "r15"]
387            }
388        }
389    }
390
391    /// Check if stack alignment is required
392    pub fn requires_stack_alignment(&self) -> bool {
393        true // Both conventions require 16-byte alignment
394    }
395
396    /// Get stack alignment
397    pub fn stack_alignment(&self) -> usize {
398        16
399    }
400}
401
402/// Function call generator with calling convention support
403pub struct CallGenerator {
404    convention: CallingConvention,
405}
406
407impl CallGenerator {
408    /// Create a new call generator
409    pub fn new(convention: CallingConvention) -> Self {
410        Self { convention }
411    }
412
413    /// Generate function prologue
414    pub fn generate_prologue(&self, stack_size: usize) -> String {
415        let mut code = String::new();
416
417        // Save frame pointer
418        code.push_str("    push rbp\n");
419        code.push_str("    mov rbp, rsp\n");
420
421        // Align stack if needed
422        let aligned_size = if self.convention.requires_stack_alignment() {
423            (stack_size + self.convention.stack_alignment() - 1)
424                & !(self.convention.stack_alignment() - 1)
425        } else {
426            stack_size
427        };
428
429        if aligned_size > 0 {
430            code.push_str(&format!("    sub rsp, {}\n", aligned_size));
431        }
432
433        // Save callee-saved registers
434        for reg in self.convention.callee_saved_registers() {
435            code.push_str(&format!("    push {}\n", reg));
436        }
437
438        code
439    }
440
441    /// Generate function epilogue
442    pub fn generate_epilogue(&self) -> String {
443        let mut code = String::new();
444
445        // Restore callee-saved registers (in reverse order)
446        for reg in self.convention.callee_saved_registers().iter().rev() {
447            code.push_str(&format!("    pop {}\n", reg));
448        }
449
450        // Restore stack
451        code.push_str("    mov rsp, rbp\n");
452        code.push_str("    pop rbp\n");
453        code.push_str("    ret\n");
454
455        code
456    }
457
458    /// Generate function call
459    pub fn generate_call(&self, func_name: &str, args: &[String]) -> String {
460        let mut code = String::new();
461        let param_regs = self.convention.param_registers();
462
463        // Move arguments to registers
464        for (i, arg) in args.iter().enumerate() {
465            if i < param_regs.len() {
466                code.push_str(&format!("    mov {}, {}\n", param_regs[i], arg));
467            } else {
468                // Push to stack for extra arguments
469                code.push_str(&format!("    push {}\n", arg));
470            }
471        }
472
473        // Align stack before call if needed
474        if self.convention.requires_stack_alignment() {
475            code.push_str("    and rsp, -16\n");
476        }
477
478        // Call function
479        code.push_str(&format!("    call {}\n", func_name));
480
481        // Clean up stack for extra arguments
482        let stack_args = args.len().saturating_sub(param_regs.len());
483        if stack_args > 0 {
484            code.push_str(&format!("    add rsp, {}\n", stack_args * 8));
485        }
486
487        code
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_simd_codegen() {
497        let mut codegen = SimdCodeGen::new(SimdFeatures::baseline());
498
499        codegen
500            .emit_simd(
501                SimdOp::AddPacked,
502                SimdRegister::XMM0,
503                SimdRegister::XMM1,
504                Some(SimdRegister::XMM2),
505            )
506            .unwrap();
507
508        let code = codegen.get_code();
509        assert!(code.contains("addps"));
510    }
511
512    #[test]
513    fn test_calling_convention() {
514        let sysv = CallingConvention::SystemV;
515        let params = sysv.param_registers();
516
517        assert_eq!(params[0], "rdi");
518        assert_eq!(params[1], "rsi");
519    }
520
521    #[test]
522    fn test_call_generator() {
523        let generator = CallGenerator::new(CallingConvention::SystemV);
524        let prologue = generator.generate_prologue(32);
525
526        assert!(prologue.contains("push rbp"));
527        assert!(prologue.contains("mov rbp, rsp"));
528    }
529}