use std::path::{Path, PathBuf};
use memmap2::{MmapMut, MmapOptions};
use crate::{utils::write_compressed_uint, Error, Result};
enum OutputBacking {
File {
mmap: MmapMut,
target_path: PathBuf,
},
Memory {
data: Vec<u8>,
},
}
pub struct Output {
backing: OutputBacking,
finalized: bool,
}
impl Output {
pub fn create<P: AsRef<Path>>(target_path: P, size: u64) -> Result<Self> {
let target_path = target_path.as_ref().to_path_buf();
let file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&target_path)
.map_err(|e| Error::MmapFailed(format!("Failed to create target file: {e}")))?;
file.set_len(size)
.map_err(|e| Error::MmapFailed(format!("Failed to set file size: {e}")))?;
let mmap = unsafe {
MmapOptions::new()
.map_mut(&file)
.map_err(|e| Error::MmapFailed(format!("Failed to create memory mapping: {e}")))?
};
Ok(Self {
backing: OutputBacking::File { mmap, target_path },
finalized: false,
})
}
pub fn create_in_memory(size: u64) -> Result<Self> {
let size_usize = usize::try_from(size).map_err(|_| {
Error::MmapFailed(format!("Size {size} too large for target architecture"))
})?;
let data = vec![0u8; size_usize];
Ok(Self {
backing: OutputBacking::Memory { data },
finalized: false,
})
}
#[must_use]
pub fn is_in_memory(&self) -> bool {
matches!(self.backing, OutputBacking::Memory { .. })
}
pub fn as_slice(&self) -> &[u8] {
match &self.backing {
OutputBacking::File { mmap, .. } => &mmap[..],
OutputBacking::Memory { data } => &data[..],
}
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
match &mut self.backing {
OutputBacking::File { mmap, .. } => &mut mmap[..],
OutputBacking::Memory { data } => &mut data[..],
}
}
pub fn get_mut_range(&mut self, start: usize, end: usize) -> Result<&mut [u8]> {
let len = self.size();
if end > len {
return Err(Error::MmapFailed(format!(
"Range end {end} exceeds buffer size {len}"
)));
}
if start > end {
return Err(Error::MmapFailed(format!(
"Range start {start} is greater than end {end}"
)));
}
Ok(&mut self.as_mut_slice()[start..end])
}
pub fn get_mut_slice(&mut self, start: usize, size: usize) -> Result<&mut [u8]> {
let end = start + size;
let len = self.size();
if end > len {
return Err(Error::MmapFailed(format!(
"Write would exceed buffer size: start={start}, size={size}, end={end}, buffer_size={len}"
)));
}
self.get_mut_range(start, end)
}
pub fn write_at(&mut self, offset: u64, data: &[u8]) -> Result<()> {
let start = usize::try_from(offset).map_err(|_| {
Error::MmapFailed(format!("Offset {offset} too large for target architecture"))
})?;
let end = start + data.len();
let len = self.size();
if end > len {
return Err(Error::MmapFailed(format!(
"Write would exceed buffer size: offset={}, len={}, buffer_size={}",
offset,
data.len(),
len
)));
}
self.as_mut_slice()[start..end].copy_from_slice(data);
Ok(())
}
pub fn read_at(&self, offset: u64, buffer: &mut [u8]) -> Result<()> {
let start = usize::try_from(offset).map_err(|_| {
Error::MmapFailed(format!("Offset {offset} too large for target architecture"))
})?;
let end = start + buffer.len();
let len = self.size();
if end > len {
return Err(Error::MmapFailed(format!(
"Read would exceed buffer size: offset={}, len={}, buffer_size={}",
offset,
buffer.len(),
len
)));
}
buffer.copy_from_slice(&self.as_slice()[start..end]);
Ok(())
}
pub fn copy_range(&mut self, source_offset: u64, target_offset: u64, size: u64) -> Result<()> {
let source_start = usize::try_from(source_offset).map_err(|_| {
Error::MmapFailed(format!(
"Source offset {source_offset} too large for target architecture"
))
})?;
let target_start = usize::try_from(target_offset).map_err(|_| {
Error::MmapFailed(format!(
"Target offset {target_offset} too large for target architecture"
))
})?;
let copy_size = usize::try_from(size).map_err(|_| {
Error::MmapFailed(format!("Size {size} too large for target architecture"))
})?;
let source_end = source_start + copy_size;
let target_end = target_start + copy_size;
let len = self.size();
if source_end > len {
return Err(Error::MmapFailed(format!(
"Source range would exceed buffer size: {source_start}..{source_end} (buffer size: {len})"
)));
}
if target_end > len {
return Err(Error::MmapFailed(format!(
"Target range would exceed buffer size: {target_start}..{target_end} (buffer size: {len})"
)));
}
self.as_mut_slice()
.copy_within(source_start..source_end, target_start);
Ok(())
}
pub fn zero_range(&mut self, offset: u64, size: u64) -> Result<()> {
let start = usize::try_from(offset).map_err(|_| {
Error::MmapFailed(format!("Offset {offset} too large for target architecture"))
})?;
let zero_size = usize::try_from(size).map_err(|_| {
Error::MmapFailed(format!("Size {size} too large for target architecture"))
})?;
let slice = self.get_mut_slice(start, zero_size)?;
slice.fill(0);
Ok(())
}
pub fn write_byte_at(&mut self, offset: u64, byte: u8) -> Result<()> {
let index = usize::try_from(offset).map_err(|_| {
Error::MmapFailed(format!("Offset {offset} too large for target architecture"))
})?;
let len = self.size();
if index >= len {
return Err(Error::MmapFailed(format!(
"Byte write would exceed buffer size: offset={offset}, buffer_size={len}"
)));
}
self.as_mut_slice()[index] = byte;
Ok(())
}
pub fn write_u16_le_at(&mut self, offset: u64, value: u16) -> Result<()> {
self.write_at(offset, &value.to_le_bytes())
}
pub fn write_u32_le_at(&mut self, offset: u64, value: u32) -> Result<()> {
self.write_at(offset, &value.to_le_bytes())
}
pub fn write_u64_le_at(&mut self, offset: u64, value: u64) -> Result<()> {
self.write_at(offset, &value.to_le_bytes())
}
pub fn write_compressed_uint_at(&mut self, offset: u64, value: u32) -> Result<u64> {
let mut buffer = Vec::new();
write_compressed_uint(value, &mut buffer);
self.write_at(offset, &buffer)?;
Ok(offset + buffer.len() as u64)
}
pub fn write_aligned_data(&mut self, offset: u64, data: &[u8]) -> Result<u64> {
self.write_at(offset, data)?;
let data_end = offset + data.len() as u64;
let padding_needed = (4 - (data.len() % 4)) % 4;
if padding_needed > 0 {
let padding_slice = self.get_mut_slice(
usize::try_from(data_end).map_err(|_| {
Error::MmapFailed(format!(
"Data end offset {data_end} too large for target architecture"
))
})?,
padding_needed,
)?;
padding_slice.fill(0xFF);
}
Ok(data_end + padding_needed as u64)
}
pub fn write_and_advance(&mut self, position: &mut usize, data: &[u8]) -> Result<()> {
let slice = self.get_mut_slice(*position, data.len())?;
slice.copy_from_slice(data);
*position += data.len();
Ok(())
}
pub fn fill_region(&mut self, offset: u64, size: usize, fill_byte: u8) -> Result<()> {
let slice = self.get_mut_slice(
usize::try_from(offset).map_err(|_| {
Error::MmapFailed(format!("Offset {offset} too large for target architecture"))
})?,
size,
)?;
slice.fill(fill_byte);
Ok(())
}
pub fn add_heap_padding(&mut self, current_pos: usize, heap_start: usize) -> Result<()> {
let bytes_written = current_pos - heap_start;
let padding_needed = (4 - (bytes_written % 4)) % 4;
if padding_needed > 0 {
self.fill_region(current_pos as u64, padding_needed, 0xFF)?;
}
Ok(())
}
#[must_use]
pub fn size(&self) -> usize {
match &self.backing {
OutputBacking::File { mmap, .. } => mmap.len(),
OutputBacking::Memory { data } => data.len(),
}
}
pub fn flush(&mut self) -> Result<()> {
match &mut self.backing {
OutputBacking::File { mmap, .. } => mmap
.flush()
.map_err(|e| Error::MmapFailed(format!("Failed to flush memory mapping: {e}"))),
OutputBacking::Memory { .. } => {
Ok(())
}
}
}
pub fn finalize(mut self, actual_size: Option<u64>) -> Result<()> {
if self.finalized {
return Err(Error::FinalizationFailed(
"Output has already been finalized".to_string(),
));
}
let backing = std::mem::replace(
&mut self.backing,
OutputBacking::Memory { data: Vec::new() }, );
match backing {
OutputBacking::File { mmap, target_path } => {
mmap.flush().map_err(|e| {
Error::FinalizationFailed(format!("Failed to flush memory mapping: {e}"))
})?;
if let Some(size) = actual_size {
drop(mmap);
let file = std::fs::OpenOptions::new()
.write(true)
.open(&target_path)
.map_err(|e| {
Error::FinalizationFailed(format!(
"Failed to reopen file for truncation: {e}"
))
})?;
file.set_len(size).map_err(|e| {
Error::FinalizationFailed(format!(
"Failed to truncate file to {size} bytes: {e}"
))
})?;
}
self.finalized = true;
Ok(())
}
OutputBacking::Memory { .. } => Err(Error::FinalizationFailed(
"Cannot finalize in-memory output to file; use into_vec() instead".to_string(),
)),
}
}
pub fn into_vec(mut self, actual_size: Option<u64>) -> Result<Vec<u8>> {
if self.finalized {
return Err(Error::FinalizationFailed(
"Output has already been finalized".to_string(),
));
}
let backing = std::mem::replace(
&mut self.backing,
OutputBacking::Memory { data: Vec::new() }, );
let mut data = match backing {
OutputBacking::Memory { data } => {
data
}
OutputBacking::File { mmap, .. } => {
mmap[..].to_vec()
}
};
if let Some(size) = actual_size {
let size_usize = usize::try_from(size).map_err(|_| {
Error::FinalizationFailed(format!(
"Requested size {size} too large for target architecture"
))
})?;
if size_usize > data.len() {
return Err(Error::FinalizationFailed(format!(
"Requested size {} exceeds buffer size {}",
size_usize,
data.len()
)));
}
data.truncate(size_usize);
}
self.finalized = true;
Ok(data)
}
pub fn target_path(&self) -> Option<&Path> {
match &self.backing {
OutputBacking::File { target_path, .. } => Some(target_path.as_path()),
OutputBacking::Memory { .. } => None,
}
}
}
impl Drop for Output {
fn drop(&mut self) {
if !self.finalized {
let _ = self.flush();
if let OutputBacking::File { target_path, .. } = &self.backing {
let _ = std::fs::remove_file(target_path);
}
}
}
}
pub struct OutputWriter<'a> {
output: &'a mut Output,
position: u64,
}
impl<'a> OutputWriter<'a> {
pub fn new(output: &'a mut Output, offset: u64) -> Self {
Self {
output,
position: offset,
}
}
#[must_use]
pub fn position(&self) -> u64 {
self.position
}
pub fn set_position(&mut self, position: u64) {
self.position = position;
}
#[must_use]
pub fn bytes_written_since(&self, start_offset: u64) -> u64 {
self.position.saturating_sub(start_offset)
}
}
impl std::io::Write for OutputWriter<'_> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.output
.write_at(self.position, buf)
.map_err(|e| std::io::Error::other(e.to_string()))?;
self.position += buf.len() as u64;
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
self.output
.flush()
.map_err(|e| std::io::Error::other(e.to_string()))
}
}
impl Output {
pub fn writer_at(&mut self, offset: u64) -> OutputWriter<'_> {
OutputWriter::new(self, offset)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::file::pe::DosHeader;
use std::{
fs::File,
io::{Read, Write},
};
use tempfile::tempdir;
#[test]
fn test_mmap_file_creation() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
let mmap_file = Output::create(&target_path, 1024).unwrap();
assert_eq!(mmap_file.size(), 1024);
assert!(!mmap_file.finalized);
}
#[test]
fn test_write_operations() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
let mut mmap_file = Output::create(&target_path, 1024).unwrap();
mmap_file.write_byte_at(0, 0x42).unwrap();
mmap_file.write_u32_le_at(4, 0x12345678).unwrap();
mmap_file.write_at(8, b"Hello, World!").unwrap();
let slice = mmap_file.as_mut_slice();
assert_eq!(slice[0], 0x42);
assert_eq!(&slice[4..8], &[0x78, 0x56, 0x34, 0x12]); assert_eq!(&slice[8..21], b"Hello, World!");
}
#[test]
fn test_copy_range() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
let mut mmap_file = Output::create(&target_path, 1024).unwrap();
mmap_file.write_at(0, b"Hello, World!").unwrap();
mmap_file.copy_range(0, 100, 13).unwrap();
let slice = mmap_file.as_mut_slice();
assert_eq!(&slice[100..113], b"Hello, World!");
}
#[test]
fn test_zero_range() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
let mut mmap_file = Output::create(&target_path, 1024).unwrap();
mmap_file.write_at(0, b"Hello, World!").unwrap();
mmap_file.zero_range(5, 5).unwrap();
let slice = mmap_file.as_mut_slice();
assert_eq!(&slice[0..5], b"Hello");
assert_eq!(&slice[5..10], &[0, 0, 0, 0, 0]);
assert_eq!(&slice[10..13], b"ld!");
}
#[test]
fn test_finalization() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
{
let mut mmap_file = Output::create(&target_path, 16).unwrap();
mmap_file.write_at(0, b"Test content").unwrap();
mmap_file.finalize(None).unwrap();
}
assert!(target_path.exists());
let mut file = File::open(&target_path).unwrap();
let mut contents = Vec::new();
file.read_to_end(&mut contents).unwrap();
assert_eq!(&contents[0..12], b"Test content");
}
#[test]
fn test_bounds_checking() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
let mut mmap_file = Output::create(&target_path, 10).unwrap();
assert!(mmap_file.write_at(8, b"too long").is_err());
assert!(mmap_file.write_byte_at(10, 0x42).is_err());
}
#[test]
fn test_output_writer() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
let mut output = Output::create(&target_path, 1024).unwrap();
{
let mut writer = output.writer_at(0x10);
assert_eq!(writer.position(), 0x10);
writer.write_all(b"Hello").unwrap();
assert_eq!(writer.position(), 0x15);
writer.write_all(b", World!").unwrap();
assert_eq!(writer.position(), 0x1D);
assert_eq!(writer.bytes_written_since(0x10), 0x0D); }
let slice = output.as_mut_slice();
assert_eq!(&slice[0x10..0x1D], b"Hello, World!");
}
#[test]
fn test_output_writer_with_pe_header() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
let mut output = Output::create(&target_path, 1024).unwrap();
{
let mut writer = output.writer_at(0);
DosHeader::write_standard(&mut writer).unwrap();
assert_eq!(writer.position(), 128); }
let slice = output.as_mut_slice();
assert_eq!(slice[0], 0x4D); assert_eq!(slice[1], 0x5A); assert_eq!(&slice[0x3C..0x40], &[0x80, 0x00, 0x00, 0x00]);
}
#[test]
fn test_finalize_with_truncation() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("truncate_test.bin");
let initial_size = 1024 * 1024; let actual_content_size = 100u64;
{
let mut output = Output::create(&target_path, initial_size).unwrap();
output.write_at(0, &[0x42u8; 100]).unwrap();
output.finalize(Some(actual_content_size)).unwrap();
}
let metadata = std::fs::metadata(&target_path).unwrap();
assert_eq!(metadata.len(), actual_content_size);
let mut file = File::open(&target_path).unwrap();
let mut contents = Vec::new();
file.read_to_end(&mut contents).unwrap();
assert_eq!(contents.len(), 100);
assert!(contents.iter().all(|&b| b == 0x42));
}
#[test]
fn test_finalize_without_truncation() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("no_truncate_test.bin");
let initial_size = 1024u64;
{
let mut output = Output::create(&target_path, initial_size).unwrap();
output.write_at(0, b"Hello").unwrap();
output.finalize(None).unwrap();
}
let metadata = std::fs::metadata(&target_path).unwrap();
assert_eq!(metadata.len(), initial_size);
}
#[test]
fn test_create_in_memory() {
let output = Output::create_in_memory(1024).unwrap();
assert_eq!(output.size(), 1024);
assert!(output.is_in_memory());
assert!(output.target_path().is_none());
assert!(!output.finalized);
}
#[test]
fn test_in_memory_write_operations() {
let mut output = Output::create_in_memory(1024).unwrap();
output.write_byte_at(0, 0x4D).unwrap(); output.write_byte_at(1, 0x5A).unwrap();
output.write_u32_le_at(4, 0xDEADBEEF).unwrap();
output.write_at(8, b"Hello, Memory!").unwrap();
let slice = output.as_slice();
assert_eq!(slice[0], 0x4D);
assert_eq!(slice[1], 0x5A);
assert_eq!(&slice[4..8], &[0xEF, 0xBE, 0xAD, 0xDE]); assert_eq!(&slice[8..22], b"Hello, Memory!");
}
#[test]
fn test_into_vec_basic() {
let mut output = Output::create_in_memory(1024).unwrap();
output.write_at(0, b"Test data").unwrap();
let data = output.into_vec(None).unwrap();
assert_eq!(data.len(), 1024);
assert_eq!(&data[0..9], b"Test data");
}
#[test]
fn test_into_vec_with_truncation() {
let mut output = Output::create_in_memory(1024).unwrap();
output.write_at(0, b"Test data here").unwrap();
let data = output.into_vec(Some(14)).unwrap();
assert_eq!(data.len(), 14);
assert_eq!(&data, b"Test data here");
}
#[test]
fn test_into_vec_size_validation() {
let output = Output::create_in_memory(100).unwrap();
let result = output.into_vec(Some(200));
assert!(result.is_err());
}
#[test]
fn test_in_memory_finalize_fails() {
let output = Output::create_in_memory(1024).unwrap();
let result = output.finalize(None);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("in-memory output"));
}
#[test]
fn test_file_backed_is_not_in_memory() {
let temp_dir = tempdir().unwrap();
let target_path = temp_dir.path().join("test.bin");
let output = Output::create(&target_path, 1024).unwrap();
assert!(!output.is_in_memory());
assert!(output.target_path().is_some());
assert_eq!(output.target_path().unwrap(), target_path);
}
#[test]
fn test_in_memory_copy_range() {
let mut output = Output::create_in_memory(256).unwrap();
output.write_at(0, b"Source Data").unwrap();
output.copy_range(0, 100, 11).unwrap();
let slice = output.as_slice();
assert_eq!(&slice[0..11], b"Source Data");
assert_eq!(&slice[100..111], b"Source Data");
}
#[test]
fn test_in_memory_zero_range() {
let mut output = Output::create_in_memory(64).unwrap();
output.fill_region(0, 32, 0xFF).unwrap();
output.zero_range(8, 8).unwrap();
let slice = output.as_slice();
assert!(slice[0..8].iter().all(|&b| b == 0xFF));
assert!(slice[8..16].iter().all(|&b| b == 0x00));
assert!(slice[16..32].iter().all(|&b| b == 0xFF));
}
}