pe_assembler/helpers/pe_writer/
mod.rs

1//! PE 文件写入器的通用 trait
2//!
3//! 此模块提供 PE 文件写入的通用接口,类似于 pe_reader 模块的设计
4
5use crate::types::{DataDirectory, DosHeader, NtHeader, OptionalHeader, PeHeader, PeProgram, PeSection, SubsystemType};
6use byteorder::{LittleEndian, WriteBytesExt};
7use gaia_types::GaiaError;
8use std::io::{Seek, Write};
9
10/// PE 文件写入器的通用 trait
11pub trait PeWriter<W: Write + Seek> {
12    /// 获取写入器的可变引用
13    fn get_writer(&mut self) -> &mut W;
14
15    /// 获取当前流位置
16    fn stream_position(&mut self) -> Result<u64, GaiaError> {
17        Ok(self.get_writer().stream_position()?)
18    }
19
20    /// 将 PE 程序写入流(通用实现)
21    fn write_program(&mut self, program: &PeProgram) -> Result<(), GaiaError> {
22        // 写入 DOS 头
23        self.write_dos_header(&program.header.dos_header)?;
24
25        // 写入 DOS stub(简单的 DOS 程序)
26        self.write_dos_stub()?;
27
28        // 对齐到 PE 头位置
29        let pe_header_offset = program.header.dos_header.e_lfanew as u64;
30        self.pad_to_offset(pe_header_offset)?;
31
32        // 写入 NT 头
33        self.write_nt_header(&program.header.nt_header)?;
34
35        // 写入 COFF 头
36        self.write_coff_header(&program.header.coff_header)?;
37
38        // 写入可选头
39        self.write_optional_header(&program.header.optional_header)?;
40
41        // 写入节头
42        for section in &program.sections {
43            self.write_section_header(section)?;
44        }
45
46        // 对齐整个头部区域到文件对齐边界
47        let file_alignment = program.header.optional_header.file_alignment;
48        self.align_to_boundary(file_alignment)?;
49
50        // 写入节数据
51        for section in &program.sections {
52            if !section.data.is_empty() {
53                // 对齐到节的文件偏移
54                self.pad_to_offset(section.pointer_to_raw_data as u64)?;
55                self.get_writer().write_all(&section.data)?;
56
57                // 对齐到文件对齐边界
58                self.align_to_boundary(file_alignment)?;
59            }
60        }
61
62        // 写入导入表(如果存在)
63        let pointer_size: usize = if program.header.optional_header.magic == 0x020B { 8 } else { 4 };
64        self.write_import_table(&program.imports, &program.sections, pointer_size)?;
65
66        Ok(())
67    }
68
69    /// 写入 DOS 头(通用实现)
70    fn write_dos_header(&mut self, dos_header: &DosHeader) -> Result<(), GaiaError> {
71        let writer = self.get_writer();
72        writer.write_u16::<LittleEndian>(dos_header.e_magic)?;
73        writer.write_u16::<LittleEndian>(dos_header.e_cblp)?;
74        writer.write_u16::<LittleEndian>(dos_header.e_cp)?;
75        writer.write_u16::<LittleEndian>(dos_header.e_crlc)?;
76        writer.write_u16::<LittleEndian>(dos_header.e_cparhdr)?;
77        writer.write_u16::<LittleEndian>(dos_header.e_min_allocate)?;
78        writer.write_u16::<LittleEndian>(dos_header.e_max_allocate)?;
79        writer.write_u16::<LittleEndian>(dos_header.e_ss)?;
80        writer.write_u16::<LittleEndian>(dos_header.e_sp)?;
81        writer.write_u16::<LittleEndian>(dos_header.e_check_sum)?;
82        writer.write_u16::<LittleEndian>(dos_header.e_ip)?;
83        writer.write_u16::<LittleEndian>(dos_header.e_cs)?;
84        writer.write_u16::<LittleEndian>(dos_header.e_lfarlc)?;
85        writer.write_u16::<LittleEndian>(dos_header.e_ovno)?;
86        for &res in &dos_header.e_res {
87            writer.write_u16::<LittleEndian>(res)?;
88        }
89        writer.write_u16::<LittleEndian>(dos_header.e_oem_id)?;
90        writer.write_u16::<LittleEndian>(dos_header.e_oem_info)?;
91        for &res in &dos_header.e_res2 {
92            writer.write_u16::<LittleEndian>(res)?;
93        }
94        writer.write_u32::<LittleEndian>(dos_header.e_lfanew)?;
95        Ok(())
96    }
97
98    /// 写入 DOS stub(通用实现)
99    fn write_dos_stub(&mut self) -> Result<(), GaiaError> {
100        // 简单的 DOS stub 程序
101        let dos_stub = b"This program cannot be run in DOS mode.\r\n$";
102        self.get_writer().write_all(dos_stub)?;
103        // 填充到 PE 头位置
104        while self.stream_position()? < 0x80 {
105            self.get_writer().write_u8(0)?;
106        }
107        Ok(())
108    }
109
110    /// 写入 NT 头(通用实现)
111    fn write_nt_header(&mut self, nt_header: &NtHeader) -> Result<(), GaiaError> {
112        self.get_writer().write_u32::<LittleEndian>(nt_header.signature)?;
113        Ok(())
114    }
115
116    /// 写入 COFF 头(通用实现)
117    fn write_coff_header(&mut self, coff_header: &crate::types::coff::CoffHeader) -> Result<(), GaiaError> {
118        let writer = self.get_writer();
119        writer.write_u16::<LittleEndian>(coff_header.machine)?;
120        writer.write_u16::<LittleEndian>(coff_header.number_of_sections)?;
121        writer.write_u32::<LittleEndian>(coff_header.time_date_stamp)?;
122        writer.write_u32::<LittleEndian>(coff_header.pointer_to_symbol_table)?;
123        writer.write_u32::<LittleEndian>(coff_header.number_of_symbols)?;
124        writer.write_u16::<LittleEndian>(coff_header.size_of_optional_header)?;
125        writer.write_u16::<LittleEndian>(coff_header.characteristics)?;
126        Ok(())
127    }
128
129    /// 写入可选头(通用实现)
130    fn write_optional_header(&mut self, optional_header: &OptionalHeader) -> Result<(), GaiaError> {
131        let writer = self.get_writer();
132        writer.write_u16::<LittleEndian>(optional_header.magic)?;
133        writer.write_u8(optional_header.major_linker_version)?;
134        writer.write_u8(optional_header.minor_linker_version)?;
135        writer.write_u32::<LittleEndian>(optional_header.size_of_code)?;
136        writer.write_u32::<LittleEndian>(optional_header.size_of_initialized_data)?;
137        writer.write_u32::<LittleEndian>(optional_header.size_of_uninitialized_data)?;
138        writer.write_u32::<LittleEndian>(optional_header.address_of_entry_point)?;
139        writer.write_u32::<LittleEndian>(optional_header.base_of_code)?;
140
141        // 根据架构写入不同的字段
142        if optional_header.magic == 0x20b {
143            // PE32+
144            writer.write_u64::<LittleEndian>(optional_header.image_base)?;
145        } else {
146            // PE32
147            let base_of_data = optional_header.base_of_data.unwrap_or(0);
148            writer.write_u32::<LittleEndian>(base_of_data)?;
149            writer.write_u32::<LittleEndian>(optional_header.image_base as u32)?;
150        }
151
152        writer.write_u32::<LittleEndian>(optional_header.section_alignment)?;
153        writer.write_u32::<LittleEndian>(optional_header.file_alignment)?;
154        writer.write_u16::<LittleEndian>(optional_header.major_operating_system_version)?;
155        writer.write_u16::<LittleEndian>(optional_header.minor_operating_system_version)?;
156        writer.write_u16::<LittleEndian>(optional_header.major_image_version)?;
157        writer.write_u16::<LittleEndian>(optional_header.minor_image_version)?;
158        writer.write_u16::<LittleEndian>(optional_header.major_subsystem_version)?;
159        writer.write_u16::<LittleEndian>(optional_header.minor_subsystem_version)?;
160        writer.write_u32::<LittleEndian>(optional_header.win32_version_value)?;
161        writer.write_u32::<LittleEndian>(optional_header.size_of_image)?;
162        writer.write_u32::<LittleEndian>(optional_header.size_of_headers)?;
163        writer.write_u32::<LittleEndian>(optional_header.checksum)?;
164
165        // 写入子系统
166        let subsystem_value = match optional_header.subsystem {
167            SubsystemType::Console => 3,
168            SubsystemType::Windows => 2,
169            SubsystemType::Native => 1,
170            _ => 3, // 默认为控制台
171        };
172        writer.write_u16::<LittleEndian>(subsystem_value)?;
173
174        writer.write_u16::<LittleEndian>(optional_header.dll_characteristics)?;
175
176        // 根据架构写入不同大小的字段
177        if optional_header.magic == 0x20b {
178            // PE32+
179            writer.write_u64::<LittleEndian>(optional_header.size_of_stack_reserve)?;
180            writer.write_u64::<LittleEndian>(optional_header.size_of_stack_commit)?;
181            writer.write_u64::<LittleEndian>(optional_header.size_of_heap_reserve)?;
182            writer.write_u64::<LittleEndian>(optional_header.size_of_heap_commit)?;
183        } else {
184            // PE32
185            writer.write_u32::<LittleEndian>(optional_header.size_of_stack_reserve as u32)?;
186            writer.write_u32::<LittleEndian>(optional_header.size_of_stack_commit as u32)?;
187            writer.write_u32::<LittleEndian>(optional_header.size_of_heap_reserve as u32)?;
188            writer.write_u32::<LittleEndian>(optional_header.size_of_heap_commit as u32)?;
189        }
190
191        writer.write_u32::<LittleEndian>(optional_header.loader_flags)?;
192        writer.write_u32::<LittleEndian>(optional_header.number_of_rva_and_sizes)?;
193
194        // 写入数据目录
195        for data_dir in &optional_header.data_directories {
196            self.write_data_directory(data_dir)?;
197        }
198
199        Ok(())
200    }
201
202    /// 写入数据目录(通用实现)
203    fn write_data_directory(&mut self, data_dir: &DataDirectory) -> Result<(), GaiaError> {
204        let writer = self.get_writer();
205        writer.write_u32::<LittleEndian>(data_dir.virtual_address)?;
206        writer.write_u32::<LittleEndian>(data_dir.size)?;
207        Ok(())
208    }
209
210    /// 写入节头(通用实现)
211    fn write_section_header(&mut self, section: &PeSection) -> Result<(), GaiaError> {
212        let writer = self.get_writer();
213        // 写入节名(8字节,不足补0)
214        let mut name_bytes = [0u8; 8];
215        let name_len = section.name.len().min(8);
216        name_bytes[..name_len].copy_from_slice(&section.name.as_bytes()[..name_len]);
217        writer.write_all(&name_bytes)?;
218
219        writer.write_u32::<LittleEndian>(section.virtual_size)?;
220        writer.write_u32::<LittleEndian>(section.virtual_address)?;
221        writer.write_u32::<LittleEndian>(section.size_of_raw_data)?;
222        writer.write_u32::<LittleEndian>(section.pointer_to_raw_data)?;
223        writer.write_u32::<LittleEndian>(section.pointer_to_relocations)?;
224        writer.write_u32::<LittleEndian>(section.pointer_to_line_numbers)?;
225        writer.write_u16::<LittleEndian>(section.number_of_relocations)?;
226        writer.write_u16::<LittleEndian>(section.number_of_line_numbers)?;
227        writer.write_u32::<LittleEndian>(section.characteristics)?;
228        Ok(())
229    }
230
231    /// 填充到指定偏移(通用实现)
232    fn pad_to_offset(&mut self, target_offset: u64) -> Result<(), GaiaError> {
233        let current_pos = self.stream_position()?;
234        if current_pos < target_offset {
235            let padding_size = target_offset - current_pos;
236            for _ in 0..padding_size {
237                self.get_writer().write_u8(0)?;
238            }
239        }
240        Ok(())
241    }
242
243    /// 对齐到边界(通用实现)
244    fn align_to_boundary(&mut self, alignment: u32) -> Result<(), GaiaError> {
245        let current_pos = self.stream_position()?;
246        let remainder = current_pos % alignment as u64;
247        if remainder != 0 {
248            let padding = alignment as u64 - remainder;
249            for _ in 0..padding {
250                self.get_writer().write_u8(0)?;
251            }
252        }
253        Ok(())
254    }
255
256    /// 写入导入表(通用实现)
257    fn write_import_table(
258        &mut self,
259        imports: &crate::types::tables::ImportTable,
260        sections: &[PeSection],
261        pointer_size: usize,
262    ) -> Result<(), GaiaError> {
263        // 如果没有导入,直接返回
264        if imports.entries.is_empty() {
265            return Ok(());
266        }
267
268        // 查找 .idata 节
269        let idata_section = sections.iter().find(|s| s.name == ".idata");
270        if let Some(section) = idata_section {
271            // 移动到 .idata 节的文件偏移
272            self.pad_to_offset(section.pointer_to_raw_data as u64)?;
273
274            let base_rva = section.virtual_address;
275            let mut current_rva = base_rva + ((imports.entries.len() + 1) * 20) as u32; // 跳过导入描述符表
276
277            // 计算 DLL 名称的 RVA
278            let mut dll_name_rvas = Vec::new();
279            for entry in &imports.entries {
280                dll_name_rvas.push(current_rva);
281                current_rva += (entry.dll_name.len() + 1) as u32; // 包括空终止符
282            }
283
284            // 对齐到 2 字节边界(名称通常按字对齐)
285            if current_rva % 2 != 0 {
286                current_rva += 1;
287            }
288
289            // 计算函数名称的 RVA(Hint+Name)
290            let mut function_name_rvas = Vec::new();
291            for entry in &imports.entries {
292                let mut entry_function_rvas = Vec::new();
293                for function in &entry.functions {
294                    entry_function_rvas.push(current_rva);
295                    current_rva += (2 + function.len() + 1) as u32; // Hint(2字节) + 函数名 + 空终止符
296                }
297                function_name_rvas.push(entry_function_rvas);
298            }
299
300            // 对齐到 2 字节边界(IMAGE_IMPORT_BY_NAME 推荐字对齐)
301            if current_rva % 2 != 0 {
302                current_rva += 1;
303            }
304
305            // 计算 INT 与 IAT 的 RVA(在函数名称之后)
306            // 先分配 INT(OriginalFirstThunk 指向的名称指针数组),再分配 IAT(FirstThunk 指向的地址数组)
307            // 对齐到 pointer_size 字节边界
308            if current_rva % (pointer_size as u32) != 0 {
309                current_rva = (current_rva + (pointer_size as u32) - 1) & !((pointer_size as u32) - 1);
310            }
311
312            let mut int_rvas = Vec::new();
313            for entry in &imports.entries {
314                int_rvas.push(current_rva);
315                current_rva += ((entry.functions.len() as u32) * (pointer_size as u32) + (pointer_size as u32));
316            }
317
318            // 对齐到 pointer_size 字节边界
319            if current_rva % (pointer_size as u32) != 0 {
320                current_rva = (current_rva + (pointer_size as u32) - 1) & !((pointer_size as u32) - 1);
321            }
322
323            let mut iat_rvas = Vec::new();
324            for entry in &imports.entries {
325                iat_rvas.push(current_rva);
326                current_rva += ((entry.functions.len() as u32) * (pointer_size as u32) + (pointer_size as u32));
327            }
328
329            // 写入导入描述符表
330            for (i, _entry) in imports.entries.iter().enumerate() {
331                let writer = self.get_writer();
332                // x64:使用经典布局(OFT 指向 INT,IAT 初始为 0)
333                // x86:为了兼容更多加载器,采用 OFT=0(不使用 INT),IAT 初始填入 Hint/Name 的 RVA
334                if pointer_size == 8 {
335                    writer.write_u32::<LittleEndian>(int_rvas[i])?; // OriginalFirstThunk (INT)
336                } else {
337                    writer.write_u32::<LittleEndian>(0)?; // OriginalFirstThunk = 0(使用 IAT 作为查找表)
338                }
339                writer.write_u32::<LittleEndian>(0)?; // TimeDateStamp
340                writer.write_u32::<LittleEndian>(0)?; // ForwarderChain
341                writer.write_u32::<LittleEndian>(dll_name_rvas[i])?; // Name RVA
342                writer.write_u32::<LittleEndian>(iat_rvas[i])?; // FirstThunk 指向 IAT(地址数组)
343            }
344
345            // 写入终止符(全零的导入描述符)
346            {
347                let writer = self.get_writer();
348                for _ in 0..5 {
349                    writer.write_u32::<LittleEndian>(0)?;
350                }
351            }
352
353            // 写入 DLL 名称字符串
354            for entry in &imports.entries {
355                let writer = self.get_writer();
356                writer.write_all(entry.dll_name.as_bytes())?;
357                writer.write_u8(0)?; // 空终止符
358            }
359
360            // 按 2 字节对齐
361            if self.stream_position()? % 2 != 0 {
362                self.get_writer().write_u8(0)?;
363            }
364
365            // 写入函数名称字符串
366            for (_i, entry) in imports.entries.iter().enumerate() {
367                for (_j, function) in entry.functions.iter().enumerate() {
368                    let writer = self.get_writer();
369                    writer.write_u16::<LittleEndian>(0)?; // Hint
370                    writer.write_all(function.as_bytes())?;
371                    writer.write_u8(0)?; // 空终止符
372                }
373            }
374
375            // 按 2 字节对齐
376            if self.stream_position()? % 2 != 0 {
377                self.get_writer().write_u8(0)?;
378            }
379
380            // 写入 INT(OriginalFirstThunk 指向的名称指针数组)
381            // 按 pointer_size 字节对齐
382            while self.stream_position()? % (pointer_size as u64) != 0 {
383                self.get_writer().write_u8(0)?;
384            }
385            for (i, entry) in imports.entries.iter().enumerate() {
386                for j in 0..entry.functions.len() {
387                    let writer = self.get_writer();
388                    if pointer_size == 8 {
389                        writer.write_u64::<LittleEndian>(function_name_rvas[i][j] as u64)?;
390                    } else {
391                        writer.write_u32::<LittleEndian>(function_name_rvas[i][j])?;
392                    }
393                }
394                let writer = self.get_writer();
395                if pointer_size == 8 {
396                    writer.write_u64::<LittleEndian>(0)?;
397                } else {
398                    writer.write_u32::<LittleEndian>(0)?;
399                }
400            }
401
402            // 写入 IAT(FirstThunk 指向的地址数组)。
403            // 经典模式下 IAT 初始为 0,加载器解析后填入实际地址。
404            // 按 pointer_size 字节对齐
405            while self.stream_position()? % (pointer_size as u64) != 0 {
406                self.get_writer().write_u8(0)?;
407            }
408            for (i, entry) in imports.entries.iter().enumerate() {
409                for _j in 0..entry.functions.len() {
410                    let writer = self.get_writer();
411                    if pointer_size == 8 {
412                        // x64:IAT 初始为 0,由加载器填充
413                        writer.write_u64::<LittleEndian>(0)?;
414                    } else {
415                        // x86:IAT 初始填入 Hint/Name 的 RVA,加载器解析后覆盖
416                        writer.write_u32::<LittleEndian>(function_name_rvas[i][_j])?;
417                    }
418                }
419                // 终止符 0
420                let writer = self.get_writer();
421                if pointer_size == 8 {
422                    writer.write_u64::<LittleEndian>(0)?;
423                } else {
424                    writer.write_u32::<LittleEndian>(0)?;
425                }
426            }
427
428            // 对齐到节的大小
429            self.align_to_boundary(section.size_of_raw_data)?;
430        }
431
432        Ok(())
433    }
434}