1use anyhow::Result;
6use std::fmt::Write as FmtWrite;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SimdRegister {
11 XMM0,
13 XMM1,
14 XMM2,
15 XMM3,
16 XMM4,
17 XMM5,
18 XMM6,
19 XMM7,
20 XMM8,
21 XMM9,
22 XMM10,
23 XMM11,
24 XMM12,
25 XMM13,
26 XMM14,
27 XMM15,
28
29 YMM0,
31 YMM1,
32 YMM2,
33 YMM3,
34 YMM4,
35 YMM5,
36 YMM6,
37 YMM7,
38 YMM8,
39 YMM9,
40 YMM10,
41 YMM11,
42 YMM12,
43 YMM13,
44 YMM14,
45 YMM15,
46
47 ZMM0,
49 ZMM1,
50 ZMM2,
51 ZMM3,
52 ZMM4,
53 ZMM5,
54 ZMM6,
55 ZMM7,
56}
57
58impl SimdRegister {
59 pub fn name(&self) -> &str {
61 match self {
62 SimdRegister::XMM0 => "xmm0",
63 SimdRegister::XMM1 => "xmm1",
64 SimdRegister::XMM2 => "xmm2",
65 SimdRegister::XMM3 => "xmm3",
66 SimdRegister::XMM4 => "xmm4",
67 SimdRegister::XMM5 => "xmm5",
68 SimdRegister::XMM6 => "xmm6",
69 SimdRegister::XMM7 => "xmm7",
70 SimdRegister::YMM0 => "ymm0",
71 SimdRegister::YMM1 => "ymm1",
72 SimdRegister::ZMM0 => "zmm0",
73 _ => "xmm0",
74 }
75 }
76
77 pub fn width(&self) -> usize {
79 match self {
80 SimdRegister::XMM0
81 | SimdRegister::XMM1
82 | SimdRegister::XMM2
83 | SimdRegister::XMM3
84 | SimdRegister::XMM4
85 | SimdRegister::XMM5
86 | SimdRegister::XMM6
87 | SimdRegister::XMM7
88 | SimdRegister::XMM8
89 | SimdRegister::XMM9
90 | SimdRegister::XMM10
91 | SimdRegister::XMM11
92 | SimdRegister::XMM12
93 | SimdRegister::XMM13
94 | SimdRegister::XMM14
95 | SimdRegister::XMM15 => 128,
96
97 SimdRegister::YMM0
98 | SimdRegister::YMM1
99 | SimdRegister::YMM2
100 | SimdRegister::YMM3
101 | SimdRegister::YMM4
102 | SimdRegister::YMM5
103 | SimdRegister::YMM6
104 | SimdRegister::YMM7
105 | SimdRegister::YMM8
106 | SimdRegister::YMM9
107 | SimdRegister::YMM10
108 | SimdRegister::YMM11
109 | SimdRegister::YMM12
110 | SimdRegister::YMM13
111 | SimdRegister::YMM14
112 | SimdRegister::YMM15 => 256,
113
114 SimdRegister::ZMM0
115 | SimdRegister::ZMM1
116 | SimdRegister::ZMM2
117 | SimdRegister::ZMM3
118 | SimdRegister::ZMM4
119 | SimdRegister::ZMM5
120 | SimdRegister::ZMM6
121 | SimdRegister::ZMM7 => 512,
122 }
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum SimdOp {
129 AddPacked,
131 SubPacked,
132 MulPacked,
133 DivPacked,
134
135 AndPacked,
137 OrPacked,
138 XorPacked,
139
140 CmpEqPacked,
142 CmpLtPacked,
143 CmpGtPacked,
144
145 MovePacked,
147 LoadPacked,
148 StorePacked,
149
150 Shuffle,
152 Permute,
153 Broadcast,
154}
155
156pub struct SimdCodeGen {
158 output: String,
159 target_features: SimdFeatures,
160}
161
162#[derive(Debug, Clone)]
163pub struct SimdFeatures {
164 pub sse: bool,
165 pub sse2: bool,
166 pub sse3: bool,
167 pub sse4_1: bool,
168 pub avx: bool,
169 pub avx2: bool,
170 pub avx512: bool,
171}
172
173impl SimdFeatures {
174 pub fn baseline() -> Self {
176 Self {
177 sse: true,
178 sse2: true,
179 sse3: false,
180 sse4_1: false,
181 avx: false,
182 avx2: false,
183 avx512: false,
184 }
185 }
186
187 pub fn avx2() -> Self {
189 Self {
190 sse: true,
191 sse2: true,
192 sse3: true,
193 sse4_1: true,
194 avx: true,
195 avx2: true,
196 avx512: false,
197 }
198 }
199}
200
201impl SimdCodeGen {
202 pub fn new(features: SimdFeatures) -> Self {
204 Self {
205 output: String::new(),
206 target_features: features,
207 }
208 }
209
210 pub fn emit_simd(
212 &mut self,
213 op: SimdOp,
214 dest: SimdRegister,
215 src1: SimdRegister,
216 src2: Option<SimdRegister>,
217 ) -> Result<()> {
218 let instr = match op {
219 SimdOp::AddPacked => {
220 if self.target_features.avx {
221 "vaddps"
222 } else {
223 "addps"
224 }
225 }
226 SimdOp::SubPacked => {
227 if self.target_features.avx {
228 "vsubps"
229 } else {
230 "subps"
231 }
232 }
233 SimdOp::MulPacked => {
234 if self.target_features.avx {
235 "vmulps"
236 } else {
237 "mulps"
238 }
239 }
240 SimdOp::DivPacked => {
241 if self.target_features.avx {
242 "vdivps"
243 } else {
244 "divps"
245 }
246 }
247 SimdOp::MovePacked => {
248 if self.target_features.avx {
249 "vmovaps"
250 } else {
251 "movaps"
252 }
253 }
254 SimdOp::LoadPacked => "movups",
255 SimdOp::StorePacked => "movups",
256 _ => "movaps",
257 };
258
259 if let Some(src2) = src2 {
260 writeln!(
261 &mut self.output,
262 " {} {}, {}, {}",
263 instr,
264 dest.name(),
265 src1.name(),
266 src2.name()
267 )?;
268 } else {
269 writeln!(
270 &mut self.output,
271 " {} {}, {}",
272 instr,
273 dest.name(),
274 src1.name()
275 )?;
276 }
277
278 Ok(())
279 }
280
281 pub fn vectorize_loop(&mut self, iterations: usize, element_size: usize) -> Result<String> {
283 let vector_width = if self.target_features.avx2 { 256 } else { 128 };
284 let elements_per_vector = vector_width / (element_size * 8);
285
286 let mut code = String::new();
287 writeln!(
288 &mut code,
289 " # Vectorized loop ({} elements per iteration)",
290 elements_per_vector
291 )?;
292 writeln!(
293 &mut code,
294 " mov rcx, {}",
295 iterations / elements_per_vector
296 )?;
297 writeln!(&mut code, ".Lvector_loop:")?;
298
299 writeln!(&mut code, " movups xmm0, [rsi]")?;
301 writeln!(&mut code, " movups xmm1, [rdi]")?;
302
303 writeln!(&mut code, " addps xmm0, xmm1")?;
305
306 writeln!(&mut code, " movups [rdx], xmm0")?;
308
309 writeln!(
311 &mut code,
312 " add rsi, {}",
313 elements_per_vector * element_size
314 )?;
315 writeln!(
316 &mut code,
317 " add rdi, {}",
318 elements_per_vector * element_size
319 )?;
320 writeln!(
321 &mut code,
322 " add rdx, {}",
323 elements_per_vector * element_size
324 )?;
325
326 writeln!(&mut code, " dec rcx")?;
327 writeln!(&mut code, " jnz .Lvector_loop")?;
328
329 let remainder = iterations % elements_per_vector;
331 if remainder > 0 {
332 writeln!(&mut code, " # Handle {} remaining elements", remainder)?;
333 }
334
335 Ok(code)
336 }
337
338 pub fn get_code(&self) -> &str {
340 &self.output
341 }
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq)]
346pub enum CallingConvention {
347 SystemV,
349 Win64,
351 Custom,
353}
354
355impl CallingConvention {
356 pub fn param_registers(&self) -> Vec<&'static str> {
358 match self {
359 CallingConvention::SystemV => {
360 vec!["rdi", "rsi", "rdx", "rcx", "r8", "r9"]
361 }
362 CallingConvention::Win64 => {
363 vec!["rcx", "rdx", "r8", "r9"]
364 }
365 CallingConvention::Custom => {
366 vec!["rdi", "rsi", "rdx", "rcx"]
367 }
368 }
369 }
370
371 pub fn return_register(&self) -> &'static str {
373 "rax"
374 }
375
376 pub fn callee_saved_registers(&self) -> Vec<&'static str> {
378 match self {
379 CallingConvention::SystemV => {
380 vec!["rbx", "r12", "r13", "r14", "r15", "rbp"]
381 }
382 CallingConvention::Win64 => {
383 vec!["rbx", "rbp", "rdi", "rsi", "r12", "r13", "r14", "r15"]
384 }
385 CallingConvention::Custom => {
386 vec!["rbx", "r12", "r13", "r14", "r15"]
387 }
388 }
389 }
390
391 pub fn requires_stack_alignment(&self) -> bool {
393 true }
395
396 pub fn stack_alignment(&self) -> usize {
398 16
399 }
400}
401
402pub struct CallGenerator {
404 convention: CallingConvention,
405}
406
407impl CallGenerator {
408 pub fn new(convention: CallingConvention) -> Self {
410 Self { convention }
411 }
412
413 pub fn generate_prologue(&self, stack_size: usize) -> String {
415 let mut code = String::new();
416
417 code.push_str(" push rbp\n");
419 code.push_str(" mov rbp, rsp\n");
420
421 let aligned_size = if self.convention.requires_stack_alignment() {
423 (stack_size + self.convention.stack_alignment() - 1)
424 & !(self.convention.stack_alignment() - 1)
425 } else {
426 stack_size
427 };
428
429 if aligned_size > 0 {
430 code.push_str(&format!(" sub rsp, {}\n", aligned_size));
431 }
432
433 for reg in self.convention.callee_saved_registers() {
435 code.push_str(&format!(" push {}\n", reg));
436 }
437
438 code
439 }
440
441 pub fn generate_epilogue(&self) -> String {
443 let mut code = String::new();
444
445 for reg in self.convention.callee_saved_registers().iter().rev() {
447 code.push_str(&format!(" pop {}\n", reg));
448 }
449
450 code.push_str(" mov rsp, rbp\n");
452 code.push_str(" pop rbp\n");
453 code.push_str(" ret\n");
454
455 code
456 }
457
458 pub fn generate_call(&self, func_name: &str, args: &[String]) -> String {
460 let mut code = String::new();
461 let param_regs = self.convention.param_registers();
462
463 for (i, arg) in args.iter().enumerate() {
465 if i < param_regs.len() {
466 code.push_str(&format!(" mov {}, {}\n", param_regs[i], arg));
467 } else {
468 code.push_str(&format!(" push {}\n", arg));
470 }
471 }
472
473 if self.convention.requires_stack_alignment() {
475 code.push_str(" and rsp, -16\n");
476 }
477
478 code.push_str(&format!(" call {}\n", func_name));
480
481 let stack_args = args.len().saturating_sub(param_regs.len());
483 if stack_args > 0 {
484 code.push_str(&format!(" add rsp, {}\n", stack_args * 8));
485 }
486
487 code
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_simd_codegen() {
497 let mut codegen = SimdCodeGen::new(SimdFeatures::baseline());
498
499 codegen
500 .emit_simd(
501 SimdOp::AddPacked,
502 SimdRegister::XMM0,
503 SimdRegister::XMM1,
504 Some(SimdRegister::XMM2),
505 )
506 .unwrap();
507
508 let code = codegen.get_code();
509 assert!(code.contains("addps"));
510 }
511
512 #[test]
513 fn test_calling_convention() {
514 let sysv = CallingConvention::SystemV;
515 let params = sysv.param_registers();
516
517 assert_eq!(params[0], "rdi");
518 assert_eq!(params[1], "rsi");
519 }
520
521 #[test]
522 fn test_call_generator() {
523 let generator = CallGenerator::new(CallingConvention::SystemV);
524 let prologue = generator.generate_prologue(32);
525
526 assert!(prologue.contains("push rbp"));
527 assert!(prologue.contains("mov rbp, rsp"));
528 }
529}