Skip to main content

pe_assembler/helpers/pe_writer/
mod.rs

1//! PE 文件写入器的通用 trait
2//!
3//! 此模块提供 PE 文件写入的通用接口,类似于 pe_reader 模块的设计
4
5use crate::types::{DataDirectory, DosHeader, NtHeader, OptionalHeader, PeProgram, PeSection, SubsystemType};
6use byteorder::{LittleEndian, WriteBytesExt};
7use gaia_types::GaiaError;
8use std::io::{Seek, Write};
9
10/// PE 文件写入器的通用 trait
11pub trait PeWriter<W: Write + Seek> {
12    /// 获取写入器的可变引用
13    fn get_writer(&mut self) -> &mut W;
14
15    /// 获取当前流位置
16    fn stream_position(&mut self) -> Result<u64, GaiaError> {
17        Ok(self.get_writer().stream_position()?)
18    }
19
20    /// 将 PE 程序写入流(通用实现)
21    fn write_program(&mut self, program: &PeProgram) -> Result<(), GaiaError> {
22        // 写入 DOS 头
23        self.write_dos_header(&program.header.dos_header)?;
24
25        // 写入 DOS stub(简单的 DOS 程序)
26        self.write_dos_stub()?;
27
28        // 对齐到 PE 头位置
29        let pe_header_offset = program.header.dos_header.e_lfanew as u64;
30        self.pad_to_offset(pe_header_offset)?;
31
32        // 写入 NT 头
33        self.write_nt_header(&program.header.nt_header)?;
34
35        // 写入 COFF 头
36        self.write_coff_header(&program.header.coff_header)?;
37
38        // 写入可选头
39        self.write_optional_header(&program.header.optional_header)?;
40
41        // 写入节头
42        for section in &program.sections {
43            self.write_section_header(section)?;
44        }
45
46        // 对齐整个头部区域到文件对齐边界
47        let file_alignment = program.header.optional_header.file_alignment;
48        self.align_to_boundary(file_alignment)?;
49
50        // 写入节数据
51        for section in &program.sections {
52            if section.name == ".idata" && !program.imports.entries.is_empty() {
53                // 对齐到节的文件偏移
54                self.pad_to_offset(section.pointer_to_raw_data as u64)?;
55                self.write_import_table(program)?;
56
57                // 对齐到文件对齐边界
58                self.align_to_boundary(file_alignment)?;
59            }
60            else if !section.data.is_empty() {
61                // 对齐到节的文件偏移
62                self.pad_to_offset(section.pointer_to_raw_data as u64)?;
63                self.get_writer().write_all(&section.data)?;
64
65                // 对齐到文件对齐边界
66                self.align_to_boundary(file_alignment)?;
67            }
68        }
69
70        Ok(())
71    }
72
73    /// 写入导入表(通用实现)
74    fn write_import_table(&mut self, program: &PeProgram) -> Result<(), GaiaError> {
75        let arch_magic = program.header.optional_header.magic;
76        let is_64 = arch_magic == 0x20b;
77        let pointer_size = if is_64 { 8 } else { 4 };
78
79        // 查找 .idata 节的 RVA
80        let idata_section = program
81            .sections
82            .iter()
83            .find(|s| s.name == ".idata")
84            .ok_or_else(|| GaiaError::syntax_error("Missing .idata section", gaia_types::SourceLocation::default()))?;
85        let import_rva_base = idata_section.virtual_address;
86
87        // 计算各个部分的 RVA
88        // 1. Import Directory Table (array of 20-byte descriptors, null-terminated)
89        let mut current_rva = import_rva_base + ((program.imports.entries.len() + 1) * 20) as u32;
90
91        // 2. DLL Names
92        let mut dll_name_rvas = Vec::new();
93        for entry in &program.imports.entries {
94            dll_name_rvas.push(current_rva);
95            current_rva += (entry.dll_name.len() as u32) + 1;
96        }
97        if current_rva % 2 != 0 {
98            current_rva += 1;
99        }
100
101        // 3. Hint/Name Table
102        let mut hint_name_rvas = Vec::new();
103        for entry in &program.imports.entries {
104            let mut entry_hint_name_rvas = Vec::new();
105            for func in &entry.functions {
106                // 名称按 2 字节对齐
107                if current_rva % 2 != 0 {
108                    current_rva += 1;
109                }
110                entry_hint_name_rvas.push(current_rva);
111                current_rva += 2 + (func.len() as u32) + 1;
112            }
113            hint_name_rvas.push(entry_hint_name_rvas);
114        }
115
116        // 4. INT (Import Lookup Table)
117        if current_rva % pointer_size != 0 {
118            current_rva = (current_rva + pointer_size - 1) & !(pointer_size - 1);
119        }
120        let mut int_rvas = Vec::new();
121        for entry in &program.imports.entries {
122            int_rvas.push(current_rva);
123            current_rva += ((entry.functions.len() + 1) as u32) * pointer_size;
124        }
125
126        // 5. IAT (Import Address Table)
127        if current_rva % pointer_size != 0 {
128            current_rva = (current_rva + pointer_size - 1) & !(pointer_size - 1);
129        }
130        let mut iat_rvas = Vec::new();
131        for entry in &program.imports.entries {
132            iat_rvas.push(current_rva);
133            current_rva += ((entry.functions.len() + 1) as u32) * pointer_size;
134        }
135
136        // 写入数据
137        // 1. 写入 Descriptors
138        for i in 0..program.imports.entries.len() {
139            let writer = self.get_writer();
140            writer.write_u32::<LittleEndian>(int_rvas[i])?; // OriginalFirstThunk (INT)
141            writer.write_u32::<LittleEndian>(0)?; // TimeDateStamp
142            writer.write_u32::<LittleEndian>(0)?; // ForwarderChain
143            writer.write_u32::<LittleEndian>(dll_name_rvas[i])?; // Name RVA
144            writer.write_u32::<LittleEndian>(iat_rvas[i])?; // FirstThunk (IAT)
145        }
146        // Null terminator
147        {
148            let writer = self.get_writer();
149            for _ in 0..5 {
150                writer.write_u32::<LittleEndian>(0)?;
151            }
152        }
153
154        // 2. 写入 DLL Names
155        for entry in &program.imports.entries {
156            let writer = self.get_writer();
157            writer.write_all(entry.dll_name.as_bytes())?;
158            writer.write_u8(0)?;
159        }
160        {
161            let current_pos = self.stream_position()?;
162            if (current_pos - idata_section.pointer_to_raw_data as u64) % 2 != 0 {
163                self.get_writer().write_u8(0)?;
164            }
165        }
166
167        // 3. 写入 Hint/Name Table
168        for entry in &program.imports.entries {
169            for func in &entry.functions {
170                // 对齐到 2 字节
171                let current_pos = self.stream_position()?;
172                if (current_pos - idata_section.pointer_to_raw_data as u64) % 2 != 0 {
173                    self.get_writer().write_u8(0)?;
174                }
175
176                let writer = self.get_writer();
177                writer.write_u16::<LittleEndian>(0)?; // Hint
178                writer.write_all(func.as_bytes())?;
179                writer.write_u8(0)?;
180            }
181        }
182
183        // 4. 写入 INT
184        self.pad_to_offset(idata_section.pointer_to_raw_data as u64 + (int_rvas[0] - import_rva_base) as u64)?;
185        for (i, entry) in program.imports.entries.iter().enumerate() {
186            for j in 0..entry.functions.len() {
187                let writer = self.get_writer();
188                if is_64 {
189                    writer.write_u64::<LittleEndian>(hint_name_rvas[i][j] as u64)?;
190                }
191                else {
192                    writer.write_u32::<LittleEndian>(hint_name_rvas[i][j])?;
193                }
194            }
195            let writer = self.get_writer();
196            if is_64 {
197                writer.write_u64::<LittleEndian>(0)?;
198            }
199            else {
200                writer.write_u32::<LittleEndian>(0)?;
201            }
202        }
203
204        // 5. 写入 IAT
205        self.pad_to_offset(idata_section.pointer_to_raw_data as u64 + (iat_rvas[0] - import_rva_base) as u64)?;
206        for (i, entry) in program.imports.entries.iter().enumerate() {
207            for j in 0..entry.functions.len() {
208                // 兼容模式:IAT 初始填入名称 RVA
209                let writer = self.get_writer();
210                if is_64 {
211                    writer.write_u64::<LittleEndian>(hint_name_rvas[i][j] as u64)?;
212                }
213                else {
214                    writer.write_u32::<LittleEndian>(hint_name_rvas[i][j])?;
215                }
216            }
217            let writer = self.get_writer();
218            if is_64 {
219                writer.write_u64::<LittleEndian>(0)?;
220            }
221            else {
222                writer.write_u32::<LittleEndian>(0)?;
223            }
224        }
225
226        Ok(())
227    }
228
229    /// 写入 DOS 头(通用实现)
230    fn write_dos_header(&mut self, dos_header: &DosHeader) -> Result<(), GaiaError> {
231        let writer = self.get_writer();
232        writer.write_u16::<LittleEndian>(dos_header.e_magic)?;
233        writer.write_u16::<LittleEndian>(dos_header.e_cblp)?;
234        writer.write_u16::<LittleEndian>(dos_header.e_cp)?;
235        writer.write_u16::<LittleEndian>(dos_header.e_crlc)?;
236        writer.write_u16::<LittleEndian>(dos_header.e_cparhdr)?;
237        writer.write_u16::<LittleEndian>(dos_header.e_min_allocate)?;
238        writer.write_u16::<LittleEndian>(dos_header.e_max_allocate)?;
239        writer.write_u16::<LittleEndian>(dos_header.e_ss)?;
240        writer.write_u16::<LittleEndian>(dos_header.e_sp)?;
241        writer.write_u16::<LittleEndian>(dos_header.e_check_sum)?;
242        writer.write_u16::<LittleEndian>(dos_header.e_ip)?;
243        writer.write_u16::<LittleEndian>(dos_header.e_cs)?;
244        writer.write_u16::<LittleEndian>(dos_header.e_lfarlc)?;
245        writer.write_u16::<LittleEndian>(dos_header.e_ovno)?;
246        for &res in &dos_header.e_res {
247            writer.write_u16::<LittleEndian>(res)?;
248        }
249        writer.write_u16::<LittleEndian>(dos_header.e_oem_id)?;
250        writer.write_u16::<LittleEndian>(dos_header.e_oem_info)?;
251        for &res in &dos_header.e_res2 {
252            writer.write_u16::<LittleEndian>(res)?;
253        }
254        writer.write_u32::<LittleEndian>(dos_header.e_lfanew)?;
255        Ok(())
256    }
257
258    /// 写入 DOS stub(通用实现)
259    fn write_dos_stub(&mut self) -> Result<(), GaiaError> {
260        // 简单的 DOS stub 程序
261        let dos_stub = b"This program cannot be run in DOS mode.\r\n$";
262        self.get_writer().write_all(dos_stub)?;
263        // 填充到 PE 头位置
264        while self.stream_position()? < 0x80 {
265            self.get_writer().write_u8(0)?;
266        }
267        Ok(())
268    }
269
270    /// 写入 NT 头(通用实现)
271    fn write_nt_header(&mut self, nt_header: &NtHeader) -> Result<(), GaiaError> {
272        self.get_writer().write_u32::<LittleEndian>(nt_header.signature)?;
273        Ok(())
274    }
275
276    /// 写入 COFF 头(通用实现)
277    fn write_coff_header(&mut self, coff_header: &crate::types::coff::CoffHeader) -> Result<(), GaiaError> {
278        let writer = self.get_writer();
279        writer.write_u16::<LittleEndian>(coff_header.machine)?;
280        writer.write_u16::<LittleEndian>(coff_header.number_of_sections)?;
281        writer.write_u32::<LittleEndian>(coff_header.time_date_stamp)?;
282        writer.write_u32::<LittleEndian>(coff_header.pointer_to_symbol_table)?;
283        writer.write_u32::<LittleEndian>(coff_header.number_of_symbols)?;
284        writer.write_u16::<LittleEndian>(coff_header.size_of_optional_header)?;
285        writer.write_u16::<LittleEndian>(coff_header.characteristics)?;
286        Ok(())
287    }
288
289    /// 写入可选头(通用实现)
290    fn write_optional_header(&mut self, optional_header: &OptionalHeader) -> Result<(), GaiaError> {
291        let writer = self.get_writer();
292        writer.write_u16::<LittleEndian>(optional_header.magic)?;
293        writer.write_u8(optional_header.major_linker_version)?;
294        writer.write_u8(optional_header.minor_linker_version)?;
295        writer.write_u32::<LittleEndian>(optional_header.size_of_code)?;
296        writer.write_u32::<LittleEndian>(optional_header.size_of_initialized_data)?;
297        writer.write_u32::<LittleEndian>(optional_header.size_of_uninitialized_data)?;
298        writer.write_u32::<LittleEndian>(optional_header.address_of_entry_point)?;
299        writer.write_u32::<LittleEndian>(optional_header.base_of_code)?;
300
301        // 根据架构写入不同的字段
302        if optional_header.magic == 0x20b {
303            // PE32+
304            writer.write_u64::<LittleEndian>(optional_header.image_base)?;
305        }
306        else {
307            // PE32
308            let base_of_data = optional_header.base_of_data.unwrap_or(0);
309            writer.write_u32::<LittleEndian>(base_of_data)?;
310            writer.write_u32::<LittleEndian>(optional_header.image_base as u32)?;
311        }
312
313        writer.write_u32::<LittleEndian>(optional_header.section_alignment)?;
314        writer.write_u32::<LittleEndian>(optional_header.file_alignment)?;
315        writer.write_u16::<LittleEndian>(optional_header.major_operating_system_version)?;
316        writer.write_u16::<LittleEndian>(optional_header.minor_operating_system_version)?;
317        writer.write_u16::<LittleEndian>(optional_header.major_image_version)?;
318        writer.write_u16::<LittleEndian>(optional_header.minor_image_version)?;
319        writer.write_u16::<LittleEndian>(optional_header.major_subsystem_version)?;
320        writer.write_u16::<LittleEndian>(optional_header.minor_subsystem_version)?;
321        writer.write_u32::<LittleEndian>(optional_header.win32_version_value)?;
322        writer.write_u32::<LittleEndian>(optional_header.size_of_image)?;
323        writer.write_u32::<LittleEndian>(optional_header.size_of_headers)?;
324        writer.write_u32::<LittleEndian>(optional_header.checksum)?;
325
326        // 写入子系统
327        let subsystem_value = match optional_header.subsystem {
328            SubsystemType::Console => 3,
329            SubsystemType::Windows => 2,
330            SubsystemType::Native => 1,
331            _ => 3, // 默认为控制台
332        };
333        writer.write_u16::<LittleEndian>(subsystem_value)?;
334
335        writer.write_u16::<LittleEndian>(optional_header.dll_characteristics)?;
336
337        // 根据架构写入不同大小的字段
338        if optional_header.magic == 0x20b {
339            // PE32+
340            writer.write_u64::<LittleEndian>(optional_header.size_of_stack_reserve)?;
341            writer.write_u64::<LittleEndian>(optional_header.size_of_stack_commit)?;
342            writer.write_u64::<LittleEndian>(optional_header.size_of_heap_reserve)?;
343            writer.write_u64::<LittleEndian>(optional_header.size_of_heap_commit)?;
344        }
345        else {
346            // PE32
347            writer.write_u32::<LittleEndian>(optional_header.size_of_stack_reserve as u32)?;
348            writer.write_u32::<LittleEndian>(optional_header.size_of_stack_commit as u32)?;
349            writer.write_u32::<LittleEndian>(optional_header.size_of_heap_reserve as u32)?;
350            writer.write_u32::<LittleEndian>(optional_header.size_of_heap_commit as u32)?;
351        }
352
353        writer.write_u32::<LittleEndian>(optional_header.loader_flags)?;
354        writer.write_u32::<LittleEndian>(optional_header.number_of_rva_and_sizes)?;
355
356        // 写入数据目录
357        for data_dir in &optional_header.data_directories {
358            self.write_data_directory(data_dir)?;
359        }
360
361        Ok(())
362    }
363
364    /// 写入数据目录(通用实现)
365    fn write_data_directory(&mut self, data_dir: &DataDirectory) -> Result<(), GaiaError> {
366        let writer = self.get_writer();
367        writer.write_u32::<LittleEndian>(data_dir.virtual_address)?;
368        writer.write_u32::<LittleEndian>(data_dir.size)?;
369        Ok(())
370    }
371
372    /// 写入节头(通用实现)
373    fn write_section_header(&mut self, section: &PeSection) -> Result<(), GaiaError> {
374        let writer = self.get_writer();
375        // 写入节名(8字节,不足补0)
376        let mut name_bytes = [0u8; 8];
377        let name_len = section.name.len().min(8);
378        name_bytes[..name_len].copy_from_slice(&section.name.as_bytes()[..name_len]);
379        writer.write_all(&name_bytes)?;
380
381        writer.write_u32::<LittleEndian>(section.virtual_size)?;
382        writer.write_u32::<LittleEndian>(section.virtual_address)?;
383        writer.write_u32::<LittleEndian>(section.size_of_raw_data)?;
384        writer.write_u32::<LittleEndian>(section.pointer_to_raw_data)?;
385        writer.write_u32::<LittleEndian>(section.pointer_to_relocations)?;
386        writer.write_u32::<LittleEndian>(section.pointer_to_line_numbers)?;
387        writer.write_u16::<LittleEndian>(section.number_of_relocations)?;
388        writer.write_u16::<LittleEndian>(section.number_of_line_numbers)?;
389        writer.write_u32::<LittleEndian>(section.characteristics)?;
390        Ok(())
391    }
392
393    /// 填充到指定偏移(通用实现)
394    fn pad_to_offset(&mut self, target_offset: u64) -> Result<(), GaiaError> {
395        let current_pos = self.stream_position()?;
396        if current_pos < target_offset {
397            let padding_size = target_offset - current_pos;
398            for _ in 0..padding_size {
399                self.get_writer().write_u8(0)?;
400            }
401        }
402        Ok(())
403    }
404
405    /// 对齐到边界(通用实现)
406    fn align_to_boundary(&mut self, alignment: u32) -> Result<(), GaiaError> {
407        let current_pos = self.stream_position()?;
408        let remainder = current_pos % alignment as u64;
409        if remainder != 0 {
410            let padding = alignment as u64 - remainder;
411            for _ in 0..padding {
412                self.get_writer().write_u8(0)?;
413            }
414        }
415        Ok(())
416    }
417}