Skip to main content

ras/assembler/
core.rs

1//! Core assembler implementation
2//!
3//! This module contains the main RasAssembler struct and basic operations
4//! that are architecture-independent. Uses two-pass assembly for jmp/call
5//! with label resolution.
6
7use crate::encoder::traits::InstructionEncoder;
8use crate::error::RasError;
9use crate::object::{ExternalReloc, ObjectSymbol, ObjectWriteOptions, ObjectWriteRequest, ObjectWriter};
10use crate::parser::{AssemblyParser, Line};
11use lamina_platform::{TargetArchitecture, TargetOperatingSystem};
12use std::collections::HashMap;
13
14/// Return type of the two-pass encoding pass.
15type EncodeResult = Result<(Vec<u8>, Vec<ObjectSymbol>, Vec<ExternalReloc>), RasError>;
16
17#[cfg(windows)]
18mod windows_loader {
19    use std::ffi::{c_char, c_void};
20
21    unsafe extern "system" {
22        pub fn GetModuleHandleA(module_name: *const c_char) -> *mut c_void;
23        pub fn GetProcAddress(module: *mut c_void, proc_name: *const c_char) -> *mut c_void;
24    }
25}
26
27/// Assembler: converts assembly text to object files
28pub struct RasAssembler {
29    pub(crate) target_arch: TargetArchitecture,
30    pub(crate) target_os: TargetOperatingSystem,
31    encoder: Box<dyn InstructionEncoder>,
32    object_writer: Box<dyn ObjectWriter>,
33    object_write_options: ObjectWriteOptions,
34    pub(crate) function_pointers: std::collections::HashMap<String, u64>, // Function name -> address
35    #[cfg(feature = "encoder")]
36    pub(crate) current_module: Option<*const lamina_mir::Module>, // Current module being compiled (for internal call detection)
37}
38
39enum PatchPoint {
40    X86Rel32 {
41        offset: usize,
42        target: String,
43    },
44    X86RipRel32 {
45        offset: usize,
46        target: String,
47    },
48    Arx64Jal {
49        offset: usize,
50        target: String,
51        rd: u8,
52    },
53    Arx64Branch {
54        offset: usize,
55        target: String,
56        rs1: u8,
57        rs2: u8,
58        funct3: u8,
59    },
60}
61
62impl RasAssembler {
63    /// Create a new assembler for the given target architecture and OS.
64    pub fn new(
65        target_arch: TargetArchitecture,
66        target_os: TargetOperatingSystem,
67    ) -> Result<Self, RasError> {
68        Self::with_object_write_options(target_arch, target_os, ObjectWriteOptions::default())
69    }
70
71    /// Create a new assembler with custom object-file write options.
72    pub fn with_object_write_options(
73        target_arch: TargetArchitecture,
74        target_os: TargetOperatingSystem,
75        object_write_options: ObjectWriteOptions,
76    ) -> Result<Self, RasError> {
77        // Create encoder based on target architecture
78        let encoder: Box<dyn InstructionEncoder> = match target_arch {
79            TargetArchitecture::X86_64 => Box::new(crate::encoder::x86_64::X86_64Encoder::new()),
80            TargetArchitecture::Aarch64 => Box::new(crate::encoder::aarch64::AArch64Encoder::new()),
81            TargetArchitecture::Arx64 => Box::new(crate::encoder::arx64::Arx64Encoder::new()),
82            TargetArchitecture::Riscv32 => {
83                Box::new(crate::encoder::riscv::RiscVEncoder::new(false))
84            }
85            TargetArchitecture::Riscv64 => Box::new(crate::encoder::riscv::RiscVEncoder::new(true)),
86            _ => {
87                return Err(RasError::UnsupportedTarget(
88                    crate::target::unsupported_target_hint(target_arch, target_os),
89                ));
90            }
91        };
92
93        let object_writer = match crate::object::object_writer_for_os(target_os) {
94            Ok(w) => w,
95            Err(_) => {
96                return Err(RasError::UnsupportedTarget(format!(
97                    "Unsupported OS for cross-compilation: {:?}. {}",
98                    target_os,
99                    crate::target::unsupported_target_hint(target_arch, target_os)
100                )));
101            }
102        };
103
104        Ok(Self {
105            target_arch,
106            target_os,
107            encoder,
108            object_writer,
109            object_write_options,
110            function_pointers: std::collections::HashMap::new(),
111            #[cfg(feature = "encoder")]
112            current_module: None,
113        })
114    }
115
116    /// Replace the object-file write options used by subsequent assembly calls.
117    pub fn set_object_write_options(&mut self, opts: ObjectWriteOptions) {
118        self.object_write_options = opts;
119    }
120
121    /// Assemble `asm_text` and write a relocatable object file to `output_path`.
122    ///
123    /// Uses two-pass assembly for `jmp`/`call` instructions with forward label
124    /// resolution.
125    pub fn assemble_text_to_object(
126        &mut self,
127        asm_text: &str,
128        output_path: &std::path::Path,
129    ) -> Result<(), RasError> {
130        let parsed = AssemblyParser::new()
131            .parse(asm_text)
132            .map_err(|e| RasError::ParseError(e.to_string()))?;
133
134        let (code, symbols, relocs) = self.encode_lines_two_pass(&parsed.lines)?;
135
136        self.object_writer
137            .write_object_file(
138                output_path,
139                &ObjectWriteRequest {
140                    code: &code,
141                    sections: &parsed.sections,
142                    symbols: &symbols,
143                    relocations: &relocs,
144                    target_arch: self.target_arch,
145                    target_os: self.target_os,
146                    opts: &self.object_write_options,
147                },
148            )
149            .map_err(|e| RasError::ObjectError(e.to_string()))?;
150
151        Ok(())
152    }
153
154    fn encode_lines_two_pass(&mut self, lines: &[Line]) -> EncodeResult {
155        let mut symbol_offsets: HashMap<String, usize> = HashMap::new();
156        let mut patch_points: Vec<PatchPoint> = Vec::new();
157        let mut code = Vec::new();
158        let mut current_offset = 0usize;
159
160        for line in lines {
161            match line {
162                Line::Label(sym) => {
163                    symbol_offsets.insert(sym.name.clone(), current_offset);
164                }
165                Line::Data(bytes) => {
166                    code.extend_from_slice(bytes);
167                    current_offset += bytes.len();
168                }
169                Line::Instruction(inst) => {
170                    let opcode = inst.opcode.to_lowercase();
171                    let is_jmp_call = opcode == "jmp" || opcode == "jmpq" || opcode == "call";
172
173                    if self.target_arch == TargetArchitecture::Arx64
174                        && let Some(patch) =
175                            arx64_label_patch(&opcode, &inst.operands, current_offset)?
176                        {
177                            code.extend_from_slice(&[0u8; 4]);
178                            patch_points.push(patch);
179                            current_offset += 4;
180                            continue;
181                        }
182
183                    // Handle leaq/lea with RIP-relative label: "leaq label(%rip), %reg"
184                    if self.target_arch == TargetArchitecture::X86_64
185                        && (opcode == "leaq" || opcode == "lea")
186                        && inst.operands.len() == 2
187                        && let Some(label) = extract_rip_label(inst.operands[0].trim())
188                    {
189                        let reg = parse_x86_reg(inst.operands[1].trim()).map_err(|e| {
190                            RasError::EncodingError(e.to_string())
191                        })?;
192                        // REX.W=1, REX.R set if reg>=8; opcode=0x8D; ModRM Mod=00 R/M=101(RIP-rel)
193                        let rex: u8 = 0x48 | ((reg >> 3) << 2);
194                        let modrm: u8 = ((reg & 7) << 3) | 5;
195                        code.extend_from_slice(&[rex, 0x8D, modrm, 0, 0, 0, 0]);
196                        patch_points.push(PatchPoint::X86RipRel32 {
197                            offset: current_offset + 3,
198                            target: label.to_string(),
199                        });
200                        current_offset += 7;
201                        continue;
202                    }
203
204                    // x86_64 conditional jumps: 2-byte opcode (0x0F 8x) + rel32
205                    if self.target_arch == TargetArchitecture::X86_64
206                        && inst.operands.len() == 1
207                        && let Some(cc_byte) = x86_cond_jmp_byte(&opcode)
208                    {
209                        let target = inst.operands[0].trim();
210                        code.extend_from_slice(&[0x0F, cc_byte, 0, 0, 0, 0]);
211                        patch_points.push(PatchPoint::X86Rel32 {
212                            offset: current_offset + 2,
213                            target: target.to_string(),
214                        });
215                        current_offset += 6;
216                        continue;
217                    }
218
219                    if is_jmp_call && inst.operands.len() == 1 {
220                        let target = inst.operands[0].trim();
221                        if self.target_arch == lamina_platform::TargetArchitecture::X86_64 {
222                            let is_call = opcode == "call";
223                            let opcode_byte: u8 = if is_call { 0xe8 } else { 0xe9 };
224                            code.push(opcode_byte);
225                            code.extend_from_slice(&[0u8; 4]);
226                            patch_points.push(PatchPoint::X86Rel32 {
227                                offset: current_offset + 1,
228                                target: target.to_string(),
229                            });
230                            current_offset += 5;
231                        } else {
232                            let bytes = self
233                                .encoder
234                                .encode_instruction(inst)
235                                .map_err(|e| RasError::EncodingError(e.to_string()))?;
236                            code.extend_from_slice(&bytes);
237                            current_offset += bytes.len();
238                        }
239                    } else {
240                        let bytes = self
241                            .encoder
242                            .encode_instruction(inst)
243                            .map_err(|e| RasError::EncodingError(e.to_string()))?;
244                        code.extend_from_slice(&bytes);
245                        current_offset += bytes.len();
246                    }
247                }
248            }
249        }
250
251        let mut external_relocs: Vec<ExternalReloc> = Vec::new();
252
253        for patch in &patch_points {
254            let target = match patch {
255                PatchPoint::X86Rel32 { target, .. }
256                | PatchPoint::X86RipRel32 { target, .. }
257                | PatchPoint::Arx64Jal { target, .. }
258                | PatchPoint::Arx64Branch { target, .. } => target,
259            };
260            if let Some(&target_offset) = symbol_offsets.get(target) {
261                match patch {
262                    PatchPoint::X86Rel32 { offset, .. }
263                    | PatchPoint::X86RipRel32 { offset, .. } => {
264                        let rel32 = (target_offset as i64) - (*offset as i64 + 4);
265                        let rel32_bytes = (rel32 as i32).to_le_bytes();
266                        code[*offset..*offset + 4].copy_from_slice(&rel32_bytes);
267                    }
268                    PatchPoint::Arx64Jal { offset, rd, .. } => {
269                        let rel = (target_offset as i64) - (*offset as i64);
270                        let word = arx64_j_type(rel as i32, *rd);
271                        code[*offset..*offset + 4].copy_from_slice(&word.to_le_bytes());
272                    }
273                    PatchPoint::Arx64Branch {
274                        offset,
275                        rs1,
276                        rs2,
277                        funct3,
278                        ..
279                    } => {
280                        let rel = (target_offset as i64) - (*offset as i64);
281                        let word = arx64_b_type(rel as i32, *rs2, *rs1, *funct3);
282                        code[*offset..*offset + 4].copy_from_slice(&word.to_le_bytes());
283                    }
284                }
285            } else {
286                // Unresolved symbol → external relocation for the linker.
287                match patch {
288                    PatchPoint::X86Rel32 { offset, target }
289                    | PatchPoint::X86RipRel32 { offset, target } => {
290                        external_relocs.push(ExternalReloc {
291                            offset: *offset,
292                            symbol: target.clone(),
293                        });
294                    }
295                    _ => {
296                        return Err(RasError::EncodingError(format!(
297                            "Undefined label: {}",
298                            target
299                        )));
300                    }
301                }
302            }
303        }
304
305        let symbols = lines
306            .iter()
307            .filter_map(|l| match l {
308                Line::Label(s) => Some(ObjectSymbol {
309                    name: s.name.clone(),
310                    global: s.global,
311                    section: s.section.clone(),
312                    value: symbol_offsets.get(&s.name).copied().unwrap_or(0) as u64,
313                }),
314                _ => None,
315            })
316            .collect();
317
318        Ok((code, symbols, external_relocs))
319    }
320
321    /// Register a function pointer for runtime calls.
322    ///
323    /// Resolves the named symbol using `dlsym` (Unix) or `GetProcAddress` (Windows)
324    /// and stores its address for use in generated code.
325    pub fn register_function(&mut self, name: &str) -> Result<(), RasError> {
326        #[cfg(unix)]
327        {
328            use std::ffi::CString;
329
330            let symbol = CString::new(name)
331                .map_err(|e| RasError::EncodingError(format!("Invalid function name: {}", e)))?;
332
333            // Try to resolve using RTLD_DEFAULT first (searches already loaded libraries)
334            // This is safer and doesn't require opening/closing handles
335            let ptr = unsafe { libc::dlsym(libc::RTLD_DEFAULT, symbol.as_ptr()) };
336
337            if ptr.is_null() {
338                // Fallback: try opening libc explicitly
339                // Clear any previous error
340                unsafe {
341                    libc::dlerror();
342                }
343
344                let handle = unsafe { libc::dlopen(std::ptr::null(), libc::RTLD_LAZY) };
345                if handle.is_null() {
346                    let err_msg = unsafe {
347                        let err_ptr = libc::dlerror();
348                        if err_ptr.is_null() {
349                            "unknown error (dlerror returned null)"
350                        } else {
351                            std::ffi::CStr::from_ptr(err_ptr)
352                                .to_str()
353                                .unwrap_or("unknown error")
354                        }
355                    };
356                    return Err(RasError::EncodingError(format!(
357                        "Failed to open libc: {}",
358                        err_msg
359                    )));
360                }
361
362                // Clear error before dlsym
363                unsafe {
364                    libc::dlerror();
365                }
366
367                let ptr2 = unsafe { libc::dlsym(handle, symbol.as_ptr()) };
368                if ptr2.is_null() {
369                    let err_msg = unsafe {
370                        let err_ptr = libc::dlerror();
371                        if err_ptr.is_null() {
372                            "symbol not found"
373                        } else {
374                            std::ffi::CStr::from_ptr(err_ptr)
375                                .to_str()
376                                .unwrap_or("unknown error")
377                        }
378                    };
379                    unsafe { libc::dlclose(handle) };
380                    return Err(RasError::EncodingError(format!(
381                        "Failed to resolve symbol {}: {}",
382                        name, err_msg
383                    )));
384                }
385
386                self.function_pointers.insert(name.to_string(), ptr2 as u64);
387                unsafe { libc::dlclose(handle) };
388            } else {
389                self.function_pointers.insert(name.to_string(), ptr as u64);
390            }
391
392            Ok(())
393        }
394
395        #[cfg(windows)]
396        {
397            use std::ffi::CString;
398            use windows_loader::{GetModuleHandleA, GetProcAddress};
399
400            let module = unsafe { GetModuleHandleA(c"msvcrt.dll".as_ptr() as *const i8) };
401            if module.is_null() {
402                return Err(RasError::EncodingError(
403                    "Failed to get msvcrt.dll handle".to_string(),
404                ));
405            }
406
407            let symbol = CString::new(name)
408                .map_err(|e| RasError::EncodingError(format!("Invalid function name: {}", e)))?;
409
410            let ptr = unsafe { GetProcAddress(module, symbol.as_ptr()) };
411            if ptr.is_null() {
412                return Err(RasError::EncodingError(format!(
413                    "Failed to resolve symbol {}",
414                    name
415                )));
416            }
417
418            self.function_pointers.insert(name.to_string(), ptr as u64);
419            Ok(())
420        }
421
422        #[cfg(not(any(unix, windows)))]
423        {
424            Err(RasError::EncodingError(
425                "Runtime function resolution not supported on this platform".to_string(),
426            ))
427        }
428    }
429
430    /// Compile all functions in a MIR module to machine code and return the raw bytes.
431    ///
432    /// Equivalent to calling [`compile_mir_to_binary_function`] with `function_name = None`.
433    /// Requires the `encoder` feature.
434    ///
435    /// [`compile_mir_to_binary_function`]: Self::compile_mir_to_binary_function
436    #[cfg(feature = "encoder")]
437    pub fn compile_mir_to_binary(
438        &mut self,
439        module: &lamina_mir::Module,
440    ) -> Result<Vec<u8>, RasError> {
441        let (code, _) = self.compile_mir_to_binary_function(module, None)?;
442        Ok(code)
443    }
444
445    /// Compile a specific function from a MIR module to binary.
446    ///
447    /// If `function_name` is `None`, all functions in the module are compiled.
448    /// Returns `(binary_code, function_offsets)` where `function_offsets` maps
449    /// each function name to its byte offset within `binary_code`.
450    #[cfg(feature = "encoder")]
451    pub fn compile_mir_to_binary_function(
452        &mut self,
453        module: &lamina_mir::Module,
454        function_name: Option<&str>,
455    ) -> Result<(Vec<u8>, std::collections::HashMap<String, usize>), RasError> {
456        // Store module reference for checking internal vs external calls
457        self.current_module = Some(module);
458        // Reuse register allocation and ABI from mir_codegen
459        match self.target_arch {
460            TargetArchitecture::X86_64 => {
461                crate::assembler::x86_64::compile_mir_x86_64_function(self, module, function_name)
462            }
463            TargetArchitecture::Aarch64 => {
464                crate::assembler::aarch64::compile_mir_aarch64_function(self, module, function_name)
465            }
466            TargetArchitecture::Riscv64 => crate::assembler::riscv::compile_mir_riscv_function(
467                self,
468                module,
469                function_name,
470                true,
471            ),
472            TargetArchitecture::Riscv32 => crate::assembler::riscv::compile_mir_riscv_function(
473                self,
474                module,
475                function_name,
476                false,
477            ),
478            _ => Err(RasError::UnsupportedTarget(format!(
479                "MIR compilation not supported for architecture: {:?}",
480                self.target_arch
481            ))),
482        }
483    }
484}
485
486fn arx64_label_patch(
487    opcode: &str,
488    operands: &[String],
489    offset: usize,
490) -> Result<Option<PatchPoint>, RasError> {
491    match opcode {
492        "j" if operands.len() == 1 && !is_numeric(&operands[0]) => Ok(Some(PatchPoint::Arx64Jal {
493            offset,
494            target: operands[0].trim().to_string(),
495            rd: 0,
496        })),
497        "call" if operands.len() == 1 && !is_numeric(&operands[0]) => {
498            Ok(Some(PatchPoint::Arx64Jal {
499                offset,
500                target: operands[0].trim().to_string(),
501                rd: 1,
502            }))
503        }
504        "jal" if operands.len() == 2 && !is_numeric(&operands[1]) => {
505            Ok(Some(PatchPoint::Arx64Jal {
506                offset,
507                target: operands[1].trim().to_string(),
508                rd: parse_arx64_reg(&operands[0])?,
509            }))
510        }
511        "beq" | "bne" | "blt" | "bge" | "bltu" | "bgeu"
512            if operands.len() == 3 && !is_numeric(&operands[2]) =>
513        {
514            Ok(Some(PatchPoint::Arx64Branch {
515                offset,
516                target: operands[2].trim().to_string(),
517                rs1: parse_arx64_reg(&operands[0])?,
518                rs2: parse_arx64_reg(&operands[1])?,
519                funct3: match opcode {
520                    "beq" => 0x0,
521                    "bne" => 0x1,
522                    "blt" => 0x4,
523                    "bge" => 0x5,
524                    "bltu" => 0x6,
525                    "bgeu" => 0x7,
526                    _ => unreachable!(),
527                },
528            }))
529        }
530        _ => Ok(None),
531    }
532}
533
534fn is_numeric(value: &str) -> bool {
535    let value = value.trim();
536    value.parse::<i64>().is_ok()
537        || value
538            .strip_prefix("0x")
539            .or_else(|| value.strip_prefix("0X"))
540            .is_some_and(|hex| i64::from_str_radix(hex, 16).is_ok())
541}
542
543fn parse_arx64_reg(value: &str) -> Result<u8, RasError> {
544    let value = value.trim().trim_start_matches('%');
545    let raw = match value {
546        "zero" => 0,
547        "ra" | "lr" => 1,
548        "sp" => 2,
549        _ => value
550            .strip_prefix('r')
551            .or_else(|| value.strip_prefix('x'))
552            .ok_or_else(|| RasError::EncodingError(format!("Unknown ARX64 register: {}", value)))?
553            .parse::<u8>()
554            .map_err(|_| RasError::EncodingError(format!("Unknown ARX64 register: {}", value)))?,
555    };
556    if raw < 32 {
557        Ok(raw)
558    } else {
559        Err(RasError::EncodingError(format!(
560            "ARX64 register out of range: {}",
561            value
562        )))
563    }
564}
565
566fn arx64_j_type(offset: i32, rd: u8) -> u32 {
567    let o = offset as u32;
568    (((o >> 20) & 0x1) << 31)
569        | (((o >> 1) & 0x03ff) << 21)
570        | (((o >> 11) & 0x1) << 20)
571        | (((o >> 12) & 0xff) << 12)
572        | ((rd as u32) << 7)
573        | 0x6f
574}
575
576fn arx64_b_type(offset: i32, rs2: u8, rs1: u8, funct3: u8) -> u32 {
577    let o = offset as u32;
578    (((o >> 12) & 0x1) << 31)
579        | (((o >> 5) & 0x3f) << 25)
580        | ((rs2 as u32) << 20)
581        | ((rs1 as u32) << 15)
582        | ((funct3 as u32) << 12)
583        | (((o >> 1) & 0x0f) << 8)
584        | (((o >> 11) & 0x1) << 7)
585        | 0x63
586}
587
588/// Maps x86 conditional-jump mnemonic to the second opcode byte of the `0F 8x rel32` encoding.
589/// Returns `None` for non-conditional-jump mnemonics.
590fn x86_cond_jmp_byte(opcode: &str) -> Option<u8> {
591    match opcode {
592        "jo" => Some(0x80),
593        "jno" => Some(0x81),
594        "jb" | "jnae" | "jc" => Some(0x82),
595        "jnb" | "jae" | "jnc" => Some(0x83),
596        "je" | "jz" => Some(0x84),
597        "jne" | "jnz" => Some(0x85),
598        "jbe" | "jna" => Some(0x86),
599        "ja" | "jnbe" => Some(0x87),
600        "js" => Some(0x88),
601        "jns" => Some(0x89),
602        "jp" | "jpe" => Some(0x8A),
603        "jnp" | "jpo" => Some(0x8B),
604        "jl" | "jnge" => Some(0x8C),
605        "jge" | "jnl" => Some(0x8D),
606        "jle" | "jng" => Some(0x8E),
607        "jg" | "jnle" => Some(0x8F),
608        _ => None,
609    }
610}
611
612/// Returns the label name if `op` has the form `label(%rip)` (RIP-relative addressing).
613fn extract_rip_label(op: &str) -> Option<&str> {
614    let paren = op.find('(')?;
615    let close = op.find(')')?;
616    if close <= paren {
617        return None;
618    }
619    let base = op[paren + 1..close].trim().trim_start_matches('%');
620    if !base.eq_ignore_ascii_case("rip") {
621        return None;
622    }
623    let label = op[..paren].trim();
624    // Must be a non-empty label name (not a numeric displacement)
625    if label.is_empty() || label.starts_with(|c: char| c.is_ascii_digit() || c == '-') {
626        return None;
627    }
628    Some(label)
629}
630
631fn parse_x86_reg(s: &str) -> Result<u8, crate::error::RasError> {
632    let s = s.trim().trim_start_matches('%');
633    match s {
634        "rax" | "eax" => Ok(0),
635        "rcx" | "ecx" => Ok(1),
636        "rdx" | "edx" => Ok(2),
637        "rbx" | "ebx" => Ok(3),
638        "rsp" | "esp" => Ok(4),
639        "rbp" | "ebp" => Ok(5),
640        "rsi" | "esi" => Ok(6),
641        "rdi" | "edi" => Ok(7),
642        "r8" => Ok(8),
643        "r9" => Ok(9),
644        "r10" => Ok(10),
645        "r11" => Ok(11),
646        "r12" => Ok(12),
647        "r13" => Ok(13),
648        "r14" => Ok(14),
649        "r15" => Ok(15),
650        _ => Err(crate::error::RasError::EncodingError(format!(
651            "Unknown x86 register: {}",
652            s
653        ))),
654    }
655}