pe_assembler/helpers/pe_writer/
mod.rs

1//! PE 文件写入器的通用 trait
2//!
3//! 此模块提供 PE 文件写入的通用接口,类似于 pe_reader 模块的设计
4
5use crate::types::{DataDirectory, DosHeader, NtHeader, OptionalHeader, PeProgram, PeSection, SubsystemType};
6use byteorder::{LittleEndian, WriteBytesExt};
7use gaia_types::GaiaError;
8use std::io::{Seek, Write};
9
10/// PE 文件写入器的通用 trait
11pub trait PeWriter<W: Write + Seek> {
12    /// 获取写入器的可变引用
13    fn get_writer(&mut self) -> &mut W;
14
15    /// 获取当前流位置
16    fn stream_position(&mut self) -> Result<u64, GaiaError> {
17        Ok(self.get_writer().stream_position()?)
18    }
19
20    /// 将 PE 程序写入流(通用实现)
21    fn write_program(&mut self, program: &PeProgram) -> Result<(), GaiaError> {
22        // 写入 DOS 头
23        self.write_dos_header(&program.header.dos_header)?;
24
25        // 写入 DOS stub(简单的 DOS 程序)
26        self.write_dos_stub()?;
27
28        // 对齐到 PE 头位置
29        let pe_header_offset = program.header.dos_header.e_lfanew as u64;
30        self.pad_to_offset(pe_header_offset)?;
31
32        // 写入 NT 头
33        self.write_nt_header(&program.header.nt_header)?;
34
35        // 写入 COFF 头
36        self.write_coff_header(&program.header.coff_header)?;
37
38        // 写入可选头
39        self.write_optional_header(&program.header.optional_header)?;
40
41        // 写入节头
42        for section in &program.sections {
43            self.write_section_header(section)?;
44        }
45
46        // 对齐整个头部区域到文件对齐边界
47        let file_alignment = program.header.optional_header.file_alignment;
48        self.align_to_boundary(file_alignment)?;
49
50        // 写入节数据
51        for section in &program.sections {
52            if !section.data.is_empty() {
53                // 对齐到节的文件偏移
54                self.pad_to_offset(section.pointer_to_raw_data as u64)?;
55                self.get_writer().write_all(&section.data)?;
56
57                // 对齐到文件对齐边界
58                self.align_to_boundary(file_alignment)?;
59            }
60        }
61
62        // 写入导入表(如果存在)
63        let pointer_size: usize = if program.header.optional_header.magic == 0x020B { 8 } else { 4 };
64        self.write_import_table(&program.imports, &program.sections, pointer_size)?;
65
66        Ok(())
67    }
68
69    /// 写入 DOS 头(通用实现)
70    fn write_dos_header(&mut self, dos_header: &DosHeader) -> Result<(), GaiaError> {
71        let writer = self.get_writer();
72        writer.write_u16::<LittleEndian>(dos_header.e_magic)?;
73        writer.write_u16::<LittleEndian>(dos_header.e_cblp)?;
74        writer.write_u16::<LittleEndian>(dos_header.e_cp)?;
75        writer.write_u16::<LittleEndian>(dos_header.e_crlc)?;
76        writer.write_u16::<LittleEndian>(dos_header.e_cparhdr)?;
77        writer.write_u16::<LittleEndian>(dos_header.e_min_allocate)?;
78        writer.write_u16::<LittleEndian>(dos_header.e_max_allocate)?;
79        writer.write_u16::<LittleEndian>(dos_header.e_ss)?;
80        writer.write_u16::<LittleEndian>(dos_header.e_sp)?;
81        writer.write_u16::<LittleEndian>(dos_header.e_check_sum)?;
82        writer.write_u16::<LittleEndian>(dos_header.e_ip)?;
83        writer.write_u16::<LittleEndian>(dos_header.e_cs)?;
84        writer.write_u16::<LittleEndian>(dos_header.e_lfarlc)?;
85        writer.write_u16::<LittleEndian>(dos_header.e_ovno)?;
86        for &res in &dos_header.e_res {
87            writer.write_u16::<LittleEndian>(res)?;
88        }
89        writer.write_u16::<LittleEndian>(dos_header.e_oem_id)?;
90        writer.write_u16::<LittleEndian>(dos_header.e_oem_info)?;
91        for &res in &dos_header.e_res2 {
92            writer.write_u16::<LittleEndian>(res)?;
93        }
94        writer.write_u32::<LittleEndian>(dos_header.e_lfanew)?;
95        Ok(())
96    }
97
98    /// 写入 DOS stub(通用实现)
99    fn write_dos_stub(&mut self) -> Result<(), GaiaError> {
100        // 简单的 DOS stub 程序
101        let dos_stub = b"This program cannot be run in DOS mode.\r\n$";
102        self.get_writer().write_all(dos_stub)?;
103        // 填充到 PE 头位置
104        while self.stream_position()? < 0x80 {
105            self.get_writer().write_u8(0)?;
106        }
107        Ok(())
108    }
109
110    /// 写入 NT 头(通用实现)
111    fn write_nt_header(&mut self, nt_header: &NtHeader) -> Result<(), GaiaError> {
112        self.get_writer().write_u32::<LittleEndian>(nt_header.signature)?;
113        Ok(())
114    }
115
116    /// 写入 COFF 头(通用实现)
117    fn write_coff_header(&mut self, coff_header: &crate::types::coff::CoffHeader) -> Result<(), GaiaError> {
118        let writer = self.get_writer();
119        writer.write_u16::<LittleEndian>(coff_header.machine)?;
120        writer.write_u16::<LittleEndian>(coff_header.number_of_sections)?;
121        writer.write_u32::<LittleEndian>(coff_header.time_date_stamp)?;
122        writer.write_u32::<LittleEndian>(coff_header.pointer_to_symbol_table)?;
123        writer.write_u32::<LittleEndian>(coff_header.number_of_symbols)?;
124        writer.write_u16::<LittleEndian>(coff_header.size_of_optional_header)?;
125        writer.write_u16::<LittleEndian>(coff_header.characteristics)?;
126        Ok(())
127    }
128
129    /// 写入可选头(通用实现)
130    fn write_optional_header(&mut self, optional_header: &OptionalHeader) -> Result<(), GaiaError> {
131        let writer = self.get_writer();
132        writer.write_u16::<LittleEndian>(optional_header.magic)?;
133        writer.write_u8(optional_header.major_linker_version)?;
134        writer.write_u8(optional_header.minor_linker_version)?;
135        writer.write_u32::<LittleEndian>(optional_header.size_of_code)?;
136        writer.write_u32::<LittleEndian>(optional_header.size_of_initialized_data)?;
137        writer.write_u32::<LittleEndian>(optional_header.size_of_uninitialized_data)?;
138        writer.write_u32::<LittleEndian>(optional_header.address_of_entry_point)?;
139        writer.write_u32::<LittleEndian>(optional_header.base_of_code)?;
140
141        // 根据架构写入不同的字段
142        if optional_header.magic == 0x20b {
143            // PE32+
144            writer.write_u64::<LittleEndian>(optional_header.image_base)?;
145        }
146        else {
147            // PE32
148            let base_of_data = optional_header.base_of_data.unwrap_or(0);
149            writer.write_u32::<LittleEndian>(base_of_data)?;
150            writer.write_u32::<LittleEndian>(optional_header.image_base as u32)?;
151        }
152
153        writer.write_u32::<LittleEndian>(optional_header.section_alignment)?;
154        writer.write_u32::<LittleEndian>(optional_header.file_alignment)?;
155        writer.write_u16::<LittleEndian>(optional_header.major_operating_system_version)?;
156        writer.write_u16::<LittleEndian>(optional_header.minor_operating_system_version)?;
157        writer.write_u16::<LittleEndian>(optional_header.major_image_version)?;
158        writer.write_u16::<LittleEndian>(optional_header.minor_image_version)?;
159        writer.write_u16::<LittleEndian>(optional_header.major_subsystem_version)?;
160        writer.write_u16::<LittleEndian>(optional_header.minor_subsystem_version)?;
161        writer.write_u32::<LittleEndian>(optional_header.win32_version_value)?;
162        writer.write_u32::<LittleEndian>(optional_header.size_of_image)?;
163        writer.write_u32::<LittleEndian>(optional_header.size_of_headers)?;
164        writer.write_u32::<LittleEndian>(optional_header.checksum)?;
165
166        // 写入子系统
167        let subsystem_value = match optional_header.subsystem {
168            SubsystemType::Console => 3,
169            SubsystemType::Windows => 2,
170            SubsystemType::Native => 1,
171            _ => 3, // 默认为控制台
172        };
173        writer.write_u16::<LittleEndian>(subsystem_value)?;
174
175        writer.write_u16::<LittleEndian>(optional_header.dll_characteristics)?;
176
177        // 根据架构写入不同大小的字段
178        if optional_header.magic == 0x20b {
179            // PE32+
180            writer.write_u64::<LittleEndian>(optional_header.size_of_stack_reserve)?;
181            writer.write_u64::<LittleEndian>(optional_header.size_of_stack_commit)?;
182            writer.write_u64::<LittleEndian>(optional_header.size_of_heap_reserve)?;
183            writer.write_u64::<LittleEndian>(optional_header.size_of_heap_commit)?;
184        }
185        else {
186            // PE32
187            writer.write_u32::<LittleEndian>(optional_header.size_of_stack_reserve as u32)?;
188            writer.write_u32::<LittleEndian>(optional_header.size_of_stack_commit as u32)?;
189            writer.write_u32::<LittleEndian>(optional_header.size_of_heap_reserve as u32)?;
190            writer.write_u32::<LittleEndian>(optional_header.size_of_heap_commit as u32)?;
191        }
192
193        writer.write_u32::<LittleEndian>(optional_header.loader_flags)?;
194        writer.write_u32::<LittleEndian>(optional_header.number_of_rva_and_sizes)?;
195
196        // 写入数据目录
197        for data_dir in &optional_header.data_directories {
198            self.write_data_directory(data_dir)?;
199        }
200
201        Ok(())
202    }
203
204    /// 写入数据目录(通用实现)
205    fn write_data_directory(&mut self, data_dir: &DataDirectory) -> Result<(), GaiaError> {
206        let writer = self.get_writer();
207        writer.write_u32::<LittleEndian>(data_dir.virtual_address)?;
208        writer.write_u32::<LittleEndian>(data_dir.size)?;
209        Ok(())
210    }
211
212    /// 写入节头(通用实现)
213    fn write_section_header(&mut self, section: &PeSection) -> Result<(), GaiaError> {
214        let writer = self.get_writer();
215        // 写入节名(8字节,不足补0)
216        let mut name_bytes = [0u8; 8];
217        let name_len = section.name.len().min(8);
218        name_bytes[..name_len].copy_from_slice(&section.name.as_bytes()[..name_len]);
219        writer.write_all(&name_bytes)?;
220
221        writer.write_u32::<LittleEndian>(section.virtual_size)?;
222        writer.write_u32::<LittleEndian>(section.virtual_address)?;
223        writer.write_u32::<LittleEndian>(section.size_of_raw_data)?;
224        writer.write_u32::<LittleEndian>(section.pointer_to_raw_data)?;
225        writer.write_u32::<LittleEndian>(section.pointer_to_relocations)?;
226        writer.write_u32::<LittleEndian>(section.pointer_to_line_numbers)?;
227        writer.write_u16::<LittleEndian>(section.number_of_relocations)?;
228        writer.write_u16::<LittleEndian>(section.number_of_line_numbers)?;
229        writer.write_u32::<LittleEndian>(section.characteristics)?;
230        Ok(())
231    }
232
233    /// 填充到指定偏移(通用实现)
234    fn pad_to_offset(&mut self, target_offset: u64) -> Result<(), GaiaError> {
235        let current_pos = self.stream_position()?;
236        if current_pos < target_offset {
237            let padding_size = target_offset - current_pos;
238            for _ in 0..padding_size {
239                self.get_writer().write_u8(0)?;
240            }
241        }
242        Ok(())
243    }
244
245    /// 对齐到边界(通用实现)
246    fn align_to_boundary(&mut self, alignment: u32) -> Result<(), GaiaError> {
247        let current_pos = self.stream_position()?;
248        let remainder = current_pos % alignment as u64;
249        if remainder != 0 {
250            let padding = alignment as u64 - remainder;
251            for _ in 0..padding {
252                self.get_writer().write_u8(0)?;
253            }
254        }
255        Ok(())
256    }
257
258    /// 写入导入表(通用实现)
259    fn write_import_table(
260        &mut self,
261        imports: &crate::types::tables::ImportTable,
262        sections: &[PeSection],
263        pointer_size: usize,
264    ) -> Result<(), GaiaError> {
265        // 如果没有导入,直接返回
266        if imports.entries.is_empty() {
267            return Ok(());
268        }
269
270        // 查找 .idata 节
271        let idata_section = sections.iter().find(|s| s.name == ".idata");
272        if let Some(section) = idata_section {
273            // 移动到 .idata 节的文件偏移
274            self.pad_to_offset(section.pointer_to_raw_data as u64)?;
275
276            let base_rva = section.virtual_address;
277            let mut current_rva = base_rva + ((imports.entries.len() + 1) * 20) as u32; // 跳过导入描述符表
278
279            // 计算 DLL 名称的 RVA
280            let mut dll_name_rvas = Vec::new();
281            for entry in &imports.entries {
282                dll_name_rvas.push(current_rva);
283                current_rva += (entry.dll_name.len() + 1) as u32; // 包括空终止符
284            }
285
286            // 对齐到 2 字节边界(名称通常按字对齐)
287            if current_rva % 2 != 0 {
288                current_rva += 1;
289            }
290
291            // 计算函数名称的 RVA(Hint+Name)
292            let mut function_name_rvas = Vec::new();
293            for entry in &imports.entries {
294                let mut entry_function_rvas = Vec::new();
295                for function in &entry.functions {
296                    entry_function_rvas.push(current_rva);
297                    current_rva += (2 + function.len() + 1) as u32; // Hint(2字节) + 函数名 + 空终止符
298                }
299                function_name_rvas.push(entry_function_rvas);
300            }
301
302            // 对齐到 2 字节边界(IMAGE_IMPORT_BY_NAME 推荐字对齐)
303            if current_rva % 2 != 0 {
304                current_rva += 1;
305            }
306
307            // 计算 INT 与 IAT 的 RVA(在函数名称之后)
308            // 先分配 INT(OriginalFirstThunk 指向的名称指针数组),再分配 IAT(FirstThunk 指向的地址数组)
309            // 对齐到 pointer_size 字节边界
310            if current_rva % (pointer_size as u32) != 0 {
311                current_rva = (current_rva + (pointer_size as u32) - 1) & !((pointer_size as u32) - 1);
312            }
313
314            let mut int_rvas = Vec::new();
315            for entry in &imports.entries {
316                int_rvas.push(current_rva);
317                current_rva += (entry.functions.len() as u32) * (pointer_size as u32) + (pointer_size as u32);
318            }
319
320            // 对齐到 pointer_size 字节边界
321            if current_rva % (pointer_size as u32) != 0 {
322                current_rva = (current_rva + (pointer_size as u32) - 1) & !((pointer_size as u32) - 1);
323            }
324
325            let mut iat_rvas = Vec::new();
326            for entry in &imports.entries {
327                iat_rvas.push(current_rva);
328                current_rva += (entry.functions.len() as u32) * (pointer_size as u32) + (pointer_size as u32);
329            }
330
331            // 写入导入描述符表
332            for (i, _entry) in imports.entries.iter().enumerate() {
333                let writer = self.get_writer();
334                // x64:使用经典布局(OFT 指向 INT,IAT 初始为 0)
335                // x86:为了兼容更多加载器,采用 OFT=0(不使用 INT),IAT 初始填入 Hint/Name 的 RVA
336                if pointer_size == 8 {
337                    writer.write_u32::<LittleEndian>(int_rvas[i])?; // OriginalFirstThunk (INT)
338                }
339                else {
340                    writer.write_u32::<LittleEndian>(0)?; // OriginalFirstThunk = 0(使用 IAT 作为查找表)
341                }
342                writer.write_u32::<LittleEndian>(0)?; // TimeDateStamp
343                writer.write_u32::<LittleEndian>(0)?; // ForwarderChain
344                writer.write_u32::<LittleEndian>(dll_name_rvas[i])?; // Name RVA
345                writer.write_u32::<LittleEndian>(iat_rvas[i])?; // FirstThunk 指向 IAT(地址数组)
346            }
347
348            // 写入终止符(全零的导入描述符)
349            {
350                let writer = self.get_writer();
351                for _ in 0..5 {
352                    writer.write_u32::<LittleEndian>(0)?;
353                }
354            }
355
356            // 写入 DLL 名称字符串
357            for entry in &imports.entries {
358                let writer = self.get_writer();
359                writer.write_all(entry.dll_name.as_bytes())?;
360                writer.write_u8(0)?; // 空终止符
361            }
362
363            // 按 2 字节对齐
364            if self.stream_position()? % 2 != 0 {
365                self.get_writer().write_u8(0)?;
366            }
367
368            // 写入函数名称字符串
369            for (_i, entry) in imports.entries.iter().enumerate() {
370                for (_j, function) in entry.functions.iter().enumerate() {
371                    let writer = self.get_writer();
372                    writer.write_u16::<LittleEndian>(0)?; // Hint
373                    writer.write_all(function.as_bytes())?;
374                    writer.write_u8(0)?; // 空终止符
375                }
376            }
377
378            // 按 2 字节对齐
379            if self.stream_position()? % 2 != 0 {
380                self.get_writer().write_u8(0)?;
381            }
382
383            // 写入 INT(OriginalFirstThunk 指向的名称指针数组)
384            // 按 pointer_size 字节对齐
385            while self.stream_position()? % (pointer_size as u64) != 0 {
386                self.get_writer().write_u8(0)?;
387            }
388            for (i, entry) in imports.entries.iter().enumerate() {
389                for j in 0..entry.functions.len() {
390                    let writer = self.get_writer();
391                    if pointer_size == 8 {
392                        writer.write_u64::<LittleEndian>(function_name_rvas[i][j] as u64)?;
393                    }
394                    else {
395                        writer.write_u32::<LittleEndian>(function_name_rvas[i][j])?;
396                    }
397                }
398                let writer = self.get_writer();
399                if pointer_size == 8 {
400                    writer.write_u64::<LittleEndian>(0)?;
401                }
402                else {
403                    writer.write_u32::<LittleEndian>(0)?;
404                }
405            }
406
407            // 写入 IAT(FirstThunk 指向的地址数组)。
408            // 经典模式下 IAT 初始为 0,加载器解析后填入实际地址。
409            // 按 pointer_size 字节对齐
410            while self.stream_position()? % (pointer_size as u64) != 0 {
411                self.get_writer().write_u8(0)?;
412            }
413            for (i, entry) in imports.entries.iter().enumerate() {
414                for _j in 0..entry.functions.len() {
415                    let writer = self.get_writer();
416                    if pointer_size == 8 {
417                        // x64:IAT 初始填入 Hint/Name 的 RVA(与 x86 一致)
418                        // 加载器解析后会覆盖为实际地址。
419                        writer.write_u64::<LittleEndian>(function_name_rvas[i][_j] as u64)?;
420                    }
421                    else {
422                        // x86:IAT 初始填入 Hint/Name 的 RVA,加载器解析后覆盖
423                        writer.write_u32::<LittleEndian>(function_name_rvas[i][_j])?;
424                    }
425                }
426                // 终止符 0
427                let writer = self.get_writer();
428                if pointer_size == 8 {
429                    writer.write_u64::<LittleEndian>(0)?;
430                }
431                else {
432                    writer.write_u32::<LittleEndian>(0)?;
433                }
434            }
435
436            // 对齐到节的大小
437            self.align_to_boundary(section.size_of_raw_data)?;
438        }
439
440        Ok(())
441    }
442}