use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, RwLock,
};
use imbl::HashMap as ImHashMap;
use crate::{
emulation::{
memory::{
region::{MemoryProtection, MemoryRegion, SectionInfo},
statics::StaticFieldStorage,
},
EmValue, EmulationError, HeapRef, ManagedHeap,
},
metadata::token::Token,
Error, Result,
};
#[derive(Clone, Debug)]
pub struct SharedHeap {
inner: Arc<ManagedHeap>,
}
impl SharedHeap {
#[must_use]
pub fn new(max_size: usize) -> Self {
Self {
inner: Arc::new(ManagedHeap::new(max_size)),
}
}
pub fn from_heap(heap: ManagedHeap) -> Self {
Self {
inner: Arc::new(heap),
}
}
#[must_use]
pub fn heap(&self) -> &ManagedHeap {
&self.inner
}
#[must_use]
pub fn ref_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
#[must_use]
pub fn is_unique(&self) -> bool {
Arc::strong_count(&self.inner) == 1
}
#[must_use]
pub fn fork(&self) -> Self {
Self {
inner: Arc::new(self.inner.fork()),
}
}
}
impl Default for SharedHeap {
fn default() -> Self {
Self::new(64 * 1024 * 1024) }
}
impl std::ops::Deref for SharedHeap {
type Target = ManagedHeap;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[derive(Debug)]
pub struct AddressSpace {
heap: SharedHeap,
regions: RwLock<Vec<MemoryRegion>>,
statics: StaticFieldStorage,
next_address: AtomicU64,
size: u64,
protection_overrides: RwLock<ImHashMap<u64, MemoryProtection>>,
}
impl AddressSpace {
const PAGE_SIZE: u64 = 0x1000;
#[must_use]
pub fn new() -> Self {
Self::with_config(64 * 1024 * 1024, 0x1_0000_0000) }
#[must_use]
pub fn with_config(heap_size: usize, address_space_size: u64) -> Self {
Self {
heap: SharedHeap::new(heap_size),
regions: RwLock::new(Vec::new()),
statics: StaticFieldStorage::new(),
next_address: AtomicU64::new(0x1000_0000), size: address_space_size,
protection_overrides: RwLock::new(ImHashMap::new()),
}
}
#[must_use]
pub fn with_heap(heap: SharedHeap) -> Self {
Self {
heap,
regions: RwLock::new(Vec::new()),
statics: StaticFieldStorage::new(),
next_address: AtomicU64::new(0x1000_0000),
size: 0x1_0000_0000,
protection_overrides: RwLock::new(ImHashMap::new()),
}
}
#[must_use]
pub fn heap(&self) -> &SharedHeap {
&self.heap
}
#[must_use]
pub fn managed_heap(&self) -> &ManagedHeap {
self.heap.heap()
}
#[must_use]
pub fn statics(&self) -> &StaticFieldStorage {
&self.statics
}
pub fn map_at(&self, address: u64, region: MemoryRegion) -> Result<()> {
let mut regions = self.regions.write().map_err(|_| {
Error::from(EmulationError::InternalError {
description: "region lock poisoned".to_string(),
})
})?;
for existing in regions.iter() {
if Self::regions_overlap(existing, ®ion) {
return Err(EmulationError::InvalidAddress {
address,
reason: "region overlaps with existing mapping".to_string(),
}
.into());
}
}
regions.push(region);
Ok(())
}
pub fn map(&self, region: MemoryRegion) -> Result<u64> {
let size = region.size();
let aligned_size = (size + 0xFFF) & !0xFFF;
let base = self
.next_address
.fetch_add(aligned_size as u64, Ordering::SeqCst);
if region.is_pe_image() {
return Err(EmulationError::InternalError {
description: "PE images must use map_at with explicit base address".to_string(),
}
.into());
}
let region = region.with_base(base);
self.map_at(base, region)?;
Ok(base)
}
pub fn unmap(&self, base: u64) -> Result<()> {
let mut regions = self.regions.write().map_err(|_| {
Error::from(EmulationError::InternalError {
description: "region lock poisoned".to_string(),
})
})?;
if let Some(pos) = regions.iter().position(|r| r.base() == base) {
regions.remove(pos);
Ok(())
} else {
Err(EmulationError::InvalidAddress {
address: base,
reason: "no region at this address".to_string(),
}
.into())
}
}
pub fn read(&self, address: u64, len: usize) -> Result<Vec<u8>> {
let regions = self.regions.read().map_err(|_| {
Error::from(EmulationError::InternalError {
description: "region lock poisoned".to_string(),
})
})?;
for region in regions.iter() {
if region.contains_range(address, len) {
return region.read(address, len).ok_or_else(|| {
EmulationError::InvalidAddress {
address,
reason: "read failed".to_string(),
}
.into()
});
}
}
Err(EmulationError::InvalidAddress {
address,
reason: "address not mapped".to_string(),
}
.into())
}
pub fn write(&self, address: u64, data: &[u8]) -> Result<()> {
let regions = self.regions.read().map_err(|_| {
Error::from(EmulationError::InternalError {
description: "region lock poisoned".to_string(),
})
})?;
for region in regions.iter() {
if region.contains_range(address, data.len()) {
if region.write(address, data) {
return Ok(());
}
return Err(EmulationError::InvalidAddress {
address,
reason: "write failed (possibly read-only)".to_string(),
}
.into());
}
}
Err(EmulationError::InvalidAddress {
address,
reason: "address not mapped".to_string(),
}
.into())
}
#[must_use]
pub fn is_valid(&self, address: u64) -> bool {
let Ok(regions) = self.regions.read() else {
return false;
};
regions.iter().any(|r| r.contains(address))
}
#[must_use]
pub fn get_region(&self, address: u64) -> Option<MemoryRegion> {
let regions = self.regions.read().ok()?;
regions.iter().find(|r| r.contains(address)).cloned()
}
#[must_use]
pub fn get_protection(&self, address: u64) -> Option<MemoryProtection> {
let page_addr = address & !(Self::PAGE_SIZE - 1);
if let Ok(overrides) = self.protection_overrides.read() {
if let Some(&prot) = overrides.get(&page_addr) {
return Some(prot);
}
}
let regions = self.regions.read().ok()?;
regions
.iter()
.find(|r| r.contains(address))
.map(|r| r.protection_at(address))
}
pub fn set_protection(
&self,
address: u64,
size: usize,
new_protection: MemoryProtection,
) -> Option<MemoryProtection> {
let start_page = address & !(Self::PAGE_SIZE - 1);
let old_protection = if let Ok(overrides) = self.protection_overrides.read() {
if overrides.contains_key(&start_page) {
drop(overrides);
self.get_protection(address)?
} else {
drop(overrides);
let region_prot = self.get_protection(address)?;
if region_prot.contains(MemoryProtection::EXECUTE) {
MemoryProtection::READ_EXECUTE
} else {
region_prot
}
}
} else {
self.get_protection(address)?
};
let end_addr = address.saturating_add(size as u64);
let end_page = (end_addr + Self::PAGE_SIZE - 1) & !(Self::PAGE_SIZE - 1);
if let Ok(mut overrides) = self.protection_overrides.write() {
let mut page = start_page;
while page < end_page {
overrides.insert(page, new_protection);
page += Self::PAGE_SIZE;
}
}
Some(old_protection)
}
#[must_use]
pub fn get_static(&self, field_token: Token) -> Option<EmValue> {
self.statics.get(field_token)
}
pub fn set_static(&self, field_token: Token, value: EmValue) {
self.statics.set(field_token, value);
}
pub fn alloc_unmanaged(&self, size: usize) -> Result<u64> {
let region = MemoryRegion::unmanaged_alloc(0, size);
self.map(region)
}
pub fn free_unmanaged(&self, address: u64) -> Result<()> {
let regions = self.regions.read().map_err(|_| {
Error::from(EmulationError::InternalError {
description: "region lock poisoned".to_string(),
})
})?;
let is_unmanaged = regions
.iter()
.any(|r| r.base() == address && r.is_unmanaged_alloc());
drop(regions);
if is_unmanaged {
self.unmap(address)
} else {
Err(EmulationError::InvalidAddress {
address,
reason: "not an unmanaged allocation".to_string(),
}
.into())
}
}
pub fn copy_block(&self, dest: u64, src: u64, size: usize) -> Result<()> {
if size == 0 {
return Ok(());
}
let src_data = self.read(src, size)?;
self.write(dest, &src_data)
}
pub fn init_block(&self, address: u64, value: u8, size: usize) -> Result<()> {
if size == 0 {
return Ok(());
}
let data = vec![value; size];
self.write(address, &data)
}
pub fn map_pe_image(
&self,
data: &[u8],
preferred_base: u64,
sections: Vec<SectionInfo>,
name: impl Into<String>,
) -> Result<u64> {
let region = MemoryRegion::pe_image(preferred_base, data, sections, name);
self.map_at(preferred_base, region)?;
Ok(preferred_base)
}
pub fn map_data(&self, address: u64, data: &[u8], label: impl Into<String>) -> Result<()> {
let region = MemoryRegion::mapped_data(address, data, label, MemoryProtection::READ_WRITE);
self.map_at(address, region)
}
#[must_use]
pub fn regions(&self) -> Vec<(u64, usize, String)> {
match self.regions.read() {
Ok(regions) => regions
.iter()
.map(|r| (r.base(), r.size(), r.label().to_string()))
.collect(),
Err(_) => Vec::new(),
}
}
#[must_use]
pub fn mapped_size(&self) -> usize {
match self.regions.read() {
Ok(regions) => regions.iter().map(MemoryRegion::size).sum(),
Err(_) => 0,
}
}
fn regions_overlap(a: &MemoryRegion, b: &MemoryRegion) -> bool {
let a_start = a.base();
let a_end = a.end();
let b_start = b.base();
let b_end = b.end();
a_start < b_end && b_start < a_end
}
pub fn alloc_string(&self, value: &str) -> Result<HeapRef> {
self.heap.alloc_string(value)
}
pub fn get_string(&self, heap_ref: HeapRef) -> Result<std::sync::Arc<str>> {
self.heap.get_string(heap_ref)
}
pub fn alloc_object(&self, type_token: Token) -> Result<HeapRef> {
self.heap.alloc_object(type_token)
}
pub fn get_field(&self, heap_ref: HeapRef, field_token: Token) -> Result<EmValue> {
self.heap.get_field(heap_ref, field_token)
}
pub fn set_field(&self, heap_ref: HeapRef, field_token: Token, value: EmValue) -> Result<()> {
self.heap.set_field(heap_ref, field_token, value)
}
}
impl Default for AddressSpace {
fn default() -> Self {
Self::new()
}
}
impl Clone for AddressSpace {
fn clone(&self) -> Self {
let regions = match self.regions.read() {
Ok(r) => r.clone(),
Err(_) => Vec::new(),
};
let protection_overrides = match self.protection_overrides.read() {
Ok(p) => p.clone(), Err(_) => ImHashMap::new(),
};
Self {
heap: self.heap.clone(), regions: RwLock::new(regions),
statics: self.statics.clone(),
next_address: AtomicU64::new(self.next_address.load(Ordering::SeqCst)),
size: self.size,
protection_overrides: RwLock::new(protection_overrides),
}
}
}
impl AddressSpace {
#[must_use]
pub fn spawn_fresh(&self) -> Self {
let regions = match self.regions.read() {
Ok(r) => r.clone(),
Err(_) => Vec::new(),
};
Self {
heap: SharedHeap::default(),
regions: RwLock::new(regions),
statics: StaticFieldStorage::new(),
next_address: AtomicU64::new(self.next_address.load(Ordering::SeqCst)),
size: self.size,
protection_overrides: RwLock::new(ImHashMap::new()),
}
}
#[must_use]
pub fn fork(&self) -> Self {
let regions = match self.regions.read() {
Ok(r) => r.iter().filter_map(|region| region.fork().ok()).collect(),
Err(_) => Vec::new(),
};
let protection_overrides = match self.protection_overrides.read() {
Ok(p) => p.clone(),
Err(_) => ImHashMap::new(),
};
Self {
heap: self.heap.fork(),
regions: RwLock::new(regions),
statics: self.statics.fork(),
next_address: AtomicU64::new(self.next_address.load(Ordering::SeqCst)),
size: self.size,
protection_overrides: RwLock::new(protection_overrides),
}
}
}
#[cfg(test)]
mod tests {
use crate::{
emulation::{
memory::{
addressspace::{AddressSpace, SharedHeap},
region::MemoryProtection,
},
EmValue,
},
metadata::token::Token,
};
#[test]
fn test_address_space_creation() {
let space = AddressSpace::new();
assert!(space.regions().is_empty());
}
#[test]
fn test_map_and_read_data() {
let space = AddressSpace::new();
let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
space.map_data(0x1000, &data, "test").unwrap();
let read = space.read(0x1000, 4).unwrap();
assert_eq!(read, data);
}
#[test]
fn test_write_data() {
let space = AddressSpace::new();
space.map_data(0x1000, &[0u8; 16], "test").unwrap();
space.write(0x1000, &[0xCA, 0xFE]).unwrap();
let read = space.read(0x1000, 2).unwrap();
assert_eq!(read, vec![0xCA, 0xFE]);
}
#[test]
fn test_static_fields() {
let space = AddressSpace::new();
let field = Token::new(0x04000001);
assert!(space.get_static(field).is_none());
space.set_static(field, EmValue::I32(42));
assert_eq!(space.get_static(field), Some(EmValue::I32(42)));
}
#[test]
fn test_shared_heap() {
let space1 = AddressSpace::new();
let str_ref = space1.alloc_string("Hello").unwrap();
let space2 = space1.clone();
let s1 = space1.get_string(str_ref).unwrap();
let s2 = space2.get_string(str_ref).unwrap();
assert_eq!(&*s1, "Hello");
assert_eq!(&*s2, "Hello");
let str_ref2 = space2.alloc_string("World").unwrap();
let s3 = space1.get_string(str_ref2).unwrap();
assert_eq!(&*s3, "World");
}
#[test]
fn test_unmanaged_alloc() {
let space = AddressSpace::new();
let addr = space.alloc_unmanaged(256).unwrap();
assert!(space.is_valid(addr));
space.write(addr, &[1, 2, 3, 4]).unwrap();
let data = space.read(addr, 4).unwrap();
assert_eq!(data, vec![1, 2, 3, 4]);
space.free_unmanaged(addr).unwrap();
assert!(!space.is_valid(addr));
}
#[test]
fn test_heap_delegation() {
let space = AddressSpace::new();
let str_ref = space.alloc_string("Test").unwrap();
let s = space.get_string(str_ref).unwrap();
assert_eq!(&*s, "Test");
let type_token = Token::new(0x02000001);
let field_token = Token::new(0x04000001);
let obj_ref = space.alloc_object(type_token).unwrap();
space
.set_field(obj_ref, field_token, EmValue::I32(100))
.unwrap();
let value = space.get_field(obj_ref, field_token).unwrap();
assert_eq!(value, EmValue::I32(100));
}
#[test]
fn test_fork_memory_isolation() {
let space = AddressSpace::new();
space.map_data(0x1000, &[1, 2, 3, 4], "test").unwrap();
let forked = space.fork();
assert_eq!(space.read(0x1000, 4).unwrap(), vec![1, 2, 3, 4]);
assert_eq!(forked.read(0x1000, 4).unwrap(), vec![1, 2, 3, 4]);
forked.write(0x1000, &[0xFF, 0xFE]).unwrap();
assert_eq!(space.read(0x1000, 4).unwrap(), vec![1, 2, 3, 4]);
assert_eq!(forked.read(0x1000, 4).unwrap(), vec![0xFF, 0xFE, 3, 4]);
}
#[test]
fn test_fork_heap_isolation() {
let space = AddressSpace::new();
let str_ref = space.alloc_string("Original").unwrap();
let forked = space.fork();
assert_eq!(&*space.get_string(str_ref).unwrap(), "Original");
assert_eq!(&*forked.get_string(str_ref).unwrap(), "Original");
let new_ref = forked.alloc_string("Forked").unwrap();
assert_eq!(&*forked.get_string(new_ref).unwrap(), "Forked");
assert!(space.get_string(new_ref).is_err());
}
#[test]
fn test_fork_statics_isolation() {
let space = AddressSpace::new();
let field = Token::new(0x04000001);
space.set_static(field, EmValue::I32(42));
let forked = space.fork();
assert_eq!(space.get_static(field), Some(EmValue::I32(42)));
assert_eq!(forked.get_static(field), Some(EmValue::I32(42)));
forked.set_static(field, EmValue::I32(100));
assert_eq!(space.get_static(field), Some(EmValue::I32(42)));
assert_eq!(forked.get_static(field), Some(EmValue::I32(100)));
}
#[test]
fn test_fork_protection_isolation() {
let space = AddressSpace::new();
space.map_data(0x1000, &vec![0u8; 0x2000], "test").unwrap();
space.set_protection(0x1000, 0x1000, MemoryProtection::READ_EXECUTE);
let forked = space.fork();
assert_eq!(
space.get_protection(0x1000),
Some(MemoryProtection::READ_EXECUTE)
);
assert_eq!(
forked.get_protection(0x1000),
Some(MemoryProtection::READ_EXECUTE)
);
forked.set_protection(0x1000, 0x1000, MemoryProtection::READ_WRITE);
assert_eq!(
space.get_protection(0x1000),
Some(MemoryProtection::READ_EXECUTE)
);
assert_eq!(
forked.get_protection(0x1000),
Some(MemoryProtection::READ_WRITE)
);
}
#[test]
fn test_multiple_forks_isolation() {
let space = AddressSpace::new();
let field = Token::new(0x04000001);
space.set_static(field, EmValue::I32(1));
let fork1 = space.fork();
let fork2 = space.fork();
fork1.set_static(field, EmValue::I32(10));
fork2.set_static(field, EmValue::I32(20));
assert_eq!(space.get_static(field), Some(EmValue::I32(1)));
assert_eq!(fork1.get_static(field), Some(EmValue::I32(10)));
assert_eq!(fork2.get_static(field), Some(EmValue::I32(20)));
}
#[test]
fn test_shared_heap_fork() {
let heap = SharedHeap::new(1024 * 1024);
let str_ref = heap.alloc_string("Hello").unwrap();
let forked = heap.fork();
assert_eq!(&*heap.get_string(str_ref).unwrap(), "Hello");
assert_eq!(&*forked.get_string(str_ref).unwrap(), "Hello");
let new_ref = forked.alloc_string("World").unwrap();
assert_eq!(&*forked.get_string(new_ref).unwrap(), "World");
assert!(heap.get_string(new_ref).is_err());
}
}