Skip to main content

gaia_assembler/backends/x86/
mod.rs

1//! Native x86_64 backend compiler
2
3use crate::{
4    config::GaiaConfig,
5    instruction::{CoreInstruction, DomainInstruction, GaiaInstruction},
6    program::{GaiaConstant, GaiaFunction, GaiaModule},
7    Backend, GeneratedFiles,
8};
9use gaia_types::{
10    helpers::{AbiCompatible, ApiCompatible, Architecture, CompilationTarget},
11    GaiaError, Result,
12};
13use std::collections::HashMap;
14
15/// Native x86_64 Backend implementation
16#[derive(Default)]
17pub struct X86Backend {}
18
19impl Backend for X86Backend {
20    fn name(&self) -> &'static str {
21        "Native x86_64"
22    }
23
24    fn primary_target(&self) -> CompilationTarget {
25        CompilationTarget { build: Architecture::X86_64, host: AbiCompatible::PE, target: ApiCompatible::MicrosoftVisualC }
26    }
27
28    fn match_score(&self, target: &CompilationTarget) -> f32 {
29        if target.build == Architecture::X86_64 && target.host == AbiCompatible::PE {
30            if target.target == ApiCompatible::MicrosoftVisualC {
31                return 100.0; // Perfect match for native x86_64 on Windows
32            }
33            return 80.0;
34        }
35        0.0
36    }
37
38    fn generate(&self, program: &GaiaModule, _config: &GaiaConfig) -> Result<GeneratedFiles> {
39        let mut code = Vec::new();
40        let mut external_call_positions = HashMap::new();
41
42        // 1. Generate entry point (stub)
43        // sub rsp, 32 (shadow space for Win64 calls)
44        code.extend_from_slice(&[0x48, 0x83, 0xEC, 0x20]);
45
46        // call <main>
47        let call_main_pos = code.len();
48        code.extend_from_slice(&[0xE8, 0x00, 0x00, 0x00, 0x00]);
49
50        // mov rcx, rax (return value as exit code)
51        code.extend_from_slice(&[0x48, 0x89, 0xC1]);
52
53        // call [rip + <iat_offset>] (ExitProcess)
54        // Note: ExitProcess is usually required for EXEs, but we'll try to find it in imports
55        let call_exit_pos = code.len();
56        code.extend_from_slice(&[0xFF, 0x15, 0x00, 0x00, 0x00, 0x00]);
57
58        // 2. Generate functions
59        let mut function_offsets = HashMap::new();
60        for function in &program.functions {
61            function_offsets.insert(function.name.clone(), code.len());
62            self.generate_function(function, &mut code, &mut external_call_positions)?;
63        }
64
65        // 3. Patch main call
66        let main_name = if function_offsets.contains_key("main") {
67            "main"
68        }
69        else {
70            program.functions.first().map(|f| f.name.as_str()).unwrap_or("")
71        };
72
73        if !main_name.is_empty() {
74            let main_offset = function_offsets[main_name];
75            let relative_offset = (main_offset as i32) - (call_main_pos as i32 + 5);
76            code[call_main_pos + 1..call_main_pos + 5].copy_from_slice(&relative_offset.to_le_bytes());
77        }
78
79        let mut files = HashMap::new();
80        files.insert("main.bin".to_string(), code.clone());
81
82        let pe_bytes = self.create_pe_exe(&code, program, call_exit_pos, &external_call_positions)?;
83        files.insert("main.exe".to_string(), pe_bytes);
84
85        Ok(GeneratedFiles { files, diagnostics: vec![] })
86    }
87}
88
89impl X86Backend {
90    fn generate_function(
91        &self,
92        function: &GaiaFunction,
93        code: &mut Vec<u8>,
94        external_call_positions: &mut HashMap<String, Vec<usize>>,
95    ) -> Result<()> {
96        let mut labels = HashMap::new();
97        let mut jump_patches = Vec::new();
98        let _function_start = code.len();
99
100        // Function prologue
101        code.push(0x55); // push rbp
102        code.extend_from_slice(&[0x48, 0x89, 0xE5]); // mov rbp, rsp
103
104        // Calculate stack space for locals and shadow space (32 bytes for Win64)
105        // Ensure 16-byte alignment
106        let locals_count = function
107            .blocks
108            .iter()
109            .flat_map(|b| &b.instructions)
110            .filter(|i| matches!(i, GaiaInstruction::Core(CoreInstruction::Alloca(_, _))))
111            .count();
112        let locals_size = locals_count * 8;
113        let shadow_space = 32;
114        let total_stack_size = (locals_size + shadow_space + 15) & !15;
115
116        if total_stack_size > 0 {
117            if total_stack_size <= 127 {
118                code.extend_from_slice(&[0x48, 0x83, 0xEC]); // sub rsp, imm8
119                code.push(total_stack_size as u8);
120            }
121            else {
122                code.extend_from_slice(&[0x48, 0x81, 0xEC]); // sub rsp, imm32
123                code.extend_from_slice(&(total_stack_size as i32).to_le_bytes());
124            }
125        }
126
127        for block in &function.blocks {
128            labels.insert(block.label.clone(), code.len());
129
130            for inst in &block.instructions {
131                match inst {
132                    GaiaInstruction::Core(core_inst) => match core_inst {
133                        CoreInstruction::PushConstant(constant) => {
134                            match constant {
135                                GaiaConstant::I8(v) => {
136                                    code.push(0x6A); // push (imm8)
137                                    code.push(*v as u8);
138                                }
139                                GaiaConstant::I16(v) => {
140                                    code.push(0x68); // push (imm32, sign-extended)
141                                    code.extend_from_slice(&(*v as i32).to_le_bytes());
142                                }
143                                GaiaConstant::I32(v) => {
144                                    code.push(0x68); // push (imm32)
145                                    code.extend_from_slice(&v.to_le_bytes());
146                                }
147                                GaiaConstant::I64(v) => {
148                                    // mov rax, imm64; push rax
149                                    code.extend_from_slice(&[0x48, 0xB8]);
150                                    code.extend_from_slice(&v.to_le_bytes());
151                                    code.push(0x50); // push rax
152                                }
153                                _ => return Err(GaiaError::custom_error("Unsupported constant type for x86 backend")),
154                            }
155                        }
156                        CoreInstruction::Add(_) => {
157                            // pop rbx; pop rax; add rax, rbx; push rax
158                            code.push(0x5B); // pop rbx
159                            code.push(0x58); // pop rax
160                            code.extend_from_slice(&[0x48, 0x01, 0xD8]); // add rax, rbx
161                            code.push(0x50); // push rax
162                        }
163                        CoreInstruction::Sub(_) => {
164                            // pop rbx; pop rax; sub rax, rbx; push rax
165                            code.push(0x5B); // pop rbx
166                            code.push(0x58); // pop rax
167                            code.extend_from_slice(&[0x48, 0x29, 0xD8]); // sub rax, rbx
168                            code.push(0x50); // push rax
169                        }
170                        CoreInstruction::Mul(_) => {
171                            // pop rbx; pop rax; imul rax, rbx; push rax
172                            code.push(0x5B); // pop rbx
173                            code.push(0x58); // pop rax
174                            code.extend_from_slice(&[0x48, 0x0F, 0xAF, 0xC3]); // imul rax, rbx
175                            code.push(0x50); // push rax
176                        }
177                        CoreInstruction::Div(_) => {
178                            // pop rbx; pop rax; cqo; idiv rbx; push rax
179                            code.push(0x5B); // pop rbx
180                            code.push(0x58); // pop rax
181                            code.extend_from_slice(&[0x48, 0x99]); // cqo
182                            code.extend_from_slice(&[0x48, 0xF7, 0xFB]); // idiv rbx
183                            code.push(0x50); // push rax
184                        }
185                        CoreInstruction::Shl(_) => {
186                            // pop rcx; pop rax; shl rax, cl; push rax
187                            code.push(0x59); // pop rcx
188                            code.push(0x58); // pop rax
189                            code.extend_from_slice(&[0x48, 0xD3, 0xE0]); // shl rax, cl
190                            code.push(0x50); // push rax
191                        }
192                        CoreInstruction::Shr(_) => {
193                            // pop rcx; pop rax; shr rax, cl; push rax
194                            code.push(0x59); // pop rcx
195                            code.push(0x58); // pop rax
196                            code.extend_from_slice(&[0x48, 0xD3, 0xE8]); // shr rax, cl
197                            code.push(0x50); // push rax
198                        }
199                        CoreInstruction::And(_) => {
200                            // pop rbx; pop rax; and rax, rbx; push rax
201                            code.push(0x5B); // pop rbx
202                            code.push(0x58); // pop rax
203                            code.extend_from_slice(&[0x48, 0x21, 0xD8]); // and rax, rbx
204                            code.push(0x50); // push rax
205                        }
206                        CoreInstruction::Or(_) => {
207                            // pop rbx; pop rax; or rax, rbx; push rax
208                            code.push(0x5B); // pop rbx
209                            code.push(0x58); // pop rax
210                            code.extend_from_slice(&[0x48, 0x09, 0xD8]); // or rax, rbx
211                            code.push(0x50); // push rax
212                        }
213                        CoreInstruction::Xor(_) => {
214                            // pop rbx; pop rax; xor rax, rbx; push rax
215                            code.push(0x5B); // pop rbx
216                            code.push(0x58); // pop rax
217                            code.extend_from_slice(&[0x48, 0x31, 0xD8]); // xor rax, rbx
218                            code.push(0x50); // push rax
219                        }
220                        CoreInstruction::Alloca(_, _) => {
221                            // Space already reserved in prologue
222                        }
223                        CoreInstruction::LoadLocal(index, _) => {
224                            // mov rax, [rbp - offset]; push rax
225                            let offset = 32 + (index + 1) * 8;
226                            code.extend_from_slice(&[0x48, 0x8B, 0x85]); // mov rax, [rbp - imm32]
227                            code.extend_from_slice(&(-(offset as i32)).to_le_bytes());
228                            code.push(0x50); // push rax
229                        }
230                        CoreInstruction::StoreLocal(index, _) => {
231                            // pop rax; mov [rbp - offset], rax
232                            let offset = 32 + (index + 1) * 8;
233                            code.push(0x58); // pop rax
234                            code.extend_from_slice(&[0x48, 0x89, 0x85]); // mov [rbp - imm32], rax
235                            code.extend_from_slice(&(-(offset as i32)).to_le_bytes());
236                        }
237                        CoreInstruction::Label(name) => {
238                            labels.insert(name.clone(), code.len());
239                        }
240                        CoreInstruction::Br(target) => {
241                            code.push(0xE9); // jmp rel32
242                            jump_patches.push((code.len(), target.clone()));
243                            code.extend_from_slice(&[0, 0, 0, 0]);
244                        }
245                        CoreInstruction::BrTrue(target) => {
246                            // pop rax; test rax, rax; jnz rel32
247                            code.push(0x58); // pop rax
248                            code.extend_from_slice(&[0x48, 0x85, 0xC0]); // test rax, rax
249                            code.extend_from_slice(&[0x0F, 0x85]); // jnz rel32
250                            jump_patches.push((code.len(), target.clone()));
251                            code.extend_from_slice(&[0, 0, 0, 0]);
252                        }
253                        CoreInstruction::BrFalse(target) => {
254                            // pop rax; test rax, rax; jz rel32
255                            code.push(0x58); // pop rax
256                            code.extend_from_slice(&[0x48, 0x85, 0xC0]); // test rax, rax
257                            code.extend_from_slice(&[0x0F, 0x84]); // jz rel32
258                            jump_patches.push((code.len(), target.clone()));
259                            code.extend_from_slice(&[0, 0, 0, 0]);
260                        }
261                        CoreInstruction::Cmp(cond, _) => {
262                            use crate::instruction::CmpCondition;
263                            // pop rbx; pop rax; cmp rax, rbx; set<cond> al; movzx rax, al; push rax
264                            code.push(0x5B); // pop rbx
265                            code.push(0x58); // pop rax
266                            code.extend_from_slice(&[0x48, 0x39, 0xD8]); // cmp rax, rbx
267                            match cond {
268                                CmpCondition::Eq => code.extend_from_slice(&[0x0F, 0x94, 0xC0]), // sete al
269                                CmpCondition::Ne => code.extend_from_slice(&[0x0F, 0x95, 0xC0]), // setne al
270                                CmpCondition::Lt => code.extend_from_slice(&[0x0F, 0x9C, 0xC0]), // setl al
271                                CmpCondition::Le => code.extend_from_slice(&[0x0F, 0x9E, 0xC0]), // setle al
272                                CmpCondition::Gt => code.extend_from_slice(&[0x0F, 0x9F, 0xC0]), // setg al
273                                CmpCondition::Ge => code.extend_from_slice(&[0x0F, 0x9D, 0xC0]), // setge al
274                            }
275                            code.extend_from_slice(&[0x48, 0x0F, 0xB6, 0xC0]); // movzx rax, al
276                            code.push(0x50); // push rax
277                        }
278                        CoreInstruction::Call(name, argc) => {
279                            // Win64 ABI: RCX, RDX, R8, R9
280                            // Gaia Call: arguments are on stack in order.
281                            // So the last argument is at the top of the stack.
282
283                            // 1. Pop arguments into registers
284                            if *argc >= 1 {
285                                code.push(0x59);
286                            } // pop rcx (1st arg if only 1, or temporary)
287                            if *argc >= 2 {
288                                code.push(0x5A); // pop rdx (2nd arg)
289                                                 // swap rcx, rdx to get order right if we popped them in reverse
290                                                 // Actually, if stack is [arg1, arg2], pop rdx gives arg2, pop rcx gives arg1. Correct.
291                            }
292                            if *argc >= 3 {
293                                code.push(0x41);
294                                code.push(0x58); // pop r8
295                                                 // Now we have: R8=arg3, RDX=arg2, RCX=arg1. Correct.
296                            }
297                            if *argc >= 4 {
298                                code.push(0x41);
299                                code.push(0x59); // pop r9
300                            }
301                            // argc > 4 would need more work (stack arguments)
302
303                            // 2. Call the function
304                            external_call_positions.entry(name.clone()).or_default().push(code.len());
305                            // call [rip + offset] (IAT style)
306                            code.extend_from_slice(&[0xFF, 0x15, 0x00, 0x00, 0x00, 0x00]);
307
308                            // 3. Push return value (rax) back to stack
309                            code.push(0x50); // push rax
310                        }
311                        _ => return Err(GaiaError::custom_error(format!("Unsupported core instruction: {:?}", core_inst))),
312                    },
313                    GaiaInstruction::Domain(domain_inst) => match domain_inst {
314                        DomainInstruction::Neural(node) => {
315                            // 调用 matmul (专用加速路径)
316                            if let gaia_types::neural::NeuralNode::MatMul(_) = node {
317                                let symbol = "gaia_matmul".to_string();
318                                external_call_positions.entry(symbol).or_default().push(code.len());
319                                // 这里插入一个占位符,后续在 PE 生成时打补丁到导入表或静态库
320                                code.extend_from_slice(&[0xFF, 0x15, 0x00, 0x00, 0x00, 0x00]);
321                            }
322                        }
323                        _ => return Err(GaiaError::custom_error(format!("Unsupported domain instruction: {:?}", domain_inst))),
324                    },
325                    _ => return Err(GaiaError::custom_error(format!("Unsupported instruction tier for x86: {:?}", inst))),
326                }
327            }
328
329            // Handle terminator
330            match &block.terminator {
331                crate::program::GaiaTerminator::Jump(target) => {
332                    code.push(0xE9); // jmp rel32
333                    jump_patches.push((code.len(), target.clone()));
334                    code.extend_from_slice(&[0, 0, 0, 0]);
335                }
336                crate::program::GaiaTerminator::Branch { true_label, false_label } => {
337                    // pop rax; test rax, rax; jnz true; jmp false
338                    code.push(0x58); // pop rax
339                    code.extend_from_slice(&[0x48, 0x85, 0xC0]); // test rax, rax
340                    code.extend_from_slice(&[0x0F, 0x85]); // jnz rel32
341                    jump_patches.push((code.len(), true_label.clone()));
342                    code.extend_from_slice(&[0, 0, 0, 0]);
343
344                    code.push(0xE9); // jmp rel32
345                    jump_patches.push((code.len(), false_label.clone()));
346                    code.extend_from_slice(&[0, 0, 0, 0]);
347                }
348                crate::program::GaiaTerminator::Return => {
349                    // Function epilogue
350                    code.extend_from_slice(&[0x48, 0x89, 0xEC]); // mov rsp, rbp
351                    code.push(0x5D); // pop rbp
352                    code.push(0xC3); // ret
353                }
354                _ => {}
355            }
356        }
357
358        for (pos, name) in jump_patches {
359            if let Some(&label_pos) = labels.get(&name) {
360                let relative_offset = (label_pos as i32) - (pos as i32 + 4);
361                code[pos..pos + 4].copy_from_slice(&relative_offset.to_le_bytes());
362            }
363        }
364
365        Ok(())
366    }
367
368    fn create_pe_exe(
369        &self,
370        code: &[u8],
371        program: &GaiaModule,
372        call_exit_pos: usize,
373        external_call_positions: &HashMap<String, Vec<usize>>,
374    ) -> Result<Vec<u8>> {
375        let mut imports = pe_assembler::types::ImportTable::new();
376
377        // Group imports by library
378        let mut lib_imports: HashMap<String, Vec<String>> = HashMap::new();
379        for imp in &program.imports {
380            lib_imports.entry(imp.library.clone()).or_default().push(imp.symbol.clone());
381        }
382
383        // Ensure kernel32.dll!ExitProcess is present if we are an EXE and it's not provided
384        if !lib_imports.values().any(|funcs| funcs.contains(&"ExitProcess".to_string())) {
385            lib_imports.entry("kernel32.dll".to_string()).or_default().push("ExitProcess".to_string());
386        }
387
388        for (lib, funcs) in lib_imports {
389            imports.entries.push(pe_assembler::types::ImportEntry { dll_name: lib, functions: funcs });
390        }
391
392        let mut pe_program = pe_assembler::types::PeProgram::create_executable(code.to_vec()).with_imports(imports);
393
394        // Ensure some critical fields are set correctly for native x64
395        pe_program.header.optional_header.image_base = 0x400000;
396        pe_program.header.optional_header.section_alignment = 0x1000;
397        pe_program.header.optional_header.file_alignment = 0x200;
398        pe_program.header.optional_header.major_operating_system_version = 6;
399        pe_program.header.optional_header.minor_operating_system_version = 0;
400        pe_program.header.optional_header.major_subsystem_version = 6;
401        pe_program.header.optional_header.minor_subsystem_version = 0;
402
403        // DYNAMIC_BASE | NX_COMPAT | NO_SEH | TERMINAL_SERVER_AWARE
404        pe_program.header.optional_header.dll_characteristics = 0x8160;
405
406        // Recalculate size of image and headers
407        let text_size_aligned = (pe_program.sections[0].data.len() as u32 + 0xFFF) & !0xFFF;
408        let idata_size_aligned =
409            if pe_program.sections.len() > 1 { (pe_program.sections[1].virtual_size + 0xFFF) & !0xFFF } else { 0 };
410        pe_program.header.optional_header.size_of_image = 0x1000 + text_size_aligned + idata_size_aligned;
411        pe_program.header.optional_header.size_of_headers = 0x200;
412
413        // Patch calls to imported functions
414        if pe_program.sections.len() > 1 {
415            let iat_rva = pe_program.header.optional_header.data_directories[12].virtual_address;
416            let code_data = &mut pe_program.sections[0].data;
417
418            // Find ExitProcess in IAT
419            let mut exit_process_iat_rva = 0;
420            let mut current_iat_offset = 0;
421
422            for entry in &pe_program.imports.entries {
423                for (_i, func) in entry.functions.iter().enumerate() {
424                    let rva = iat_rva + current_iat_offset;
425
426                    if func == "ExitProcess" {
427                        exit_process_iat_rva = rva;
428                    }
429
430                    if let Some(positions) = external_call_positions.get(func) {
431                        for &pos in positions {
432                            let next_rip_rva = 0x1000 + pos as u32 + 6;
433                            let relative_offset = (rva as i32) - (next_rip_rva as i32);
434                            code_data[pos + 2..pos + 6].copy_from_slice(&relative_offset.to_le_bytes());
435                        }
436                    }
437
438                    current_iat_offset += 8; // x64 IAT entry size
439                }
440                current_iat_offset += 8; // Null terminator for DLL
441            }
442
443            // Patch the hardcoded ExitProcess call in entry point
444            if exit_process_iat_rva != 0 {
445                let next_rip_rva_exit = 0x1000 + call_exit_pos as u32 + 6;
446                let relative_offset_exit = (exit_process_iat_rva as i32) - (next_rip_rva_exit as i32);
447                code_data[call_exit_pos + 2..call_exit_pos + 6].copy_from_slice(&relative_offset_exit.to_le_bytes());
448            }
449        }
450
451        let mut buffer = Vec::new();
452        let mut cursor = std::io::Cursor::new(&mut buffer);
453        let mut writer = pe_assembler::formats::exe::writer::ExeWriter::new(&mut cursor);
454        use pe_assembler::helpers::PeWriter;
455        writer.write_program(&pe_program).map_err(|e| GaiaError::custom_error(format!("PE write error: {}", e)))?;
456        Ok(buffer)
457    }
458}