use std::{
collections::HashMap,
fmt,
sync::{
atomic::{AtomicU64, AtomicUsize, Ordering},
Arc, RwLock,
},
};
use imbl::HashMap as ImHashMap;
use crate::{
emulation::{engine::EmulationError, EmValue, HeapRef},
metadata::{
token::Token,
typesystem::{CilFlavor, PointerSize},
},
Result,
};
pub type SymmetricAlgorithmInfo = (Arc<str>, Option<Vec<u8>>, Option<Vec<u8>>);
pub type CryptoTransformInfo = (Arc<str>, Vec<u8>, Vec<u8>, bool);
pub type KeyDerivationInfo = (Vec<u8>, Vec<u8>, u32, Arc<str>);
pub struct HeapIter<'a> {
heap: &'a ManagedHeap,
keys: std::vec::IntoIter<u64>,
}
impl Iterator for HeapIter<'_> {
type Item = (HeapRef, HeapObject);
fn next(&mut self) -> Option<Self::Item> {
self.keys.next().and_then(|id| {
let state = self.heap.state.read().expect("heap lock poisoned");
state
.objects
.get(&id)
.map(|obj| (HeapRef::new(id), obj.clone()))
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.keys.size_hint()
}
}
impl ExactSizeIterator for HeapIter<'_> {}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum EncodingType {
Utf8,
Ascii,
Utf16Le,
Utf16Be,
Utf32,
}
#[derive(Clone, Debug)]
pub enum HeapObject {
String(Arc<str>),
Array {
element_type: CilFlavor,
elements: Vec<EmValue>,
},
MultiArray {
element_type: CilFlavor,
dimensions: Vec<usize>,
elements: Vec<EmValue>,
},
Object {
type_token: Token,
fields: HashMap<Token, EmValue>,
},
BoxedValue {
type_token: Token,
value: Box<EmValue>,
},
Delegate {
type_token: Token,
target: Option<HeapRef>,
method_token: Token,
},
Encoding {
encoding_type: EncodingType,
},
CryptoAlgorithm {
algorithm_type: Arc<str>,
},
SymmetricAlgorithm {
algorithm_type: Arc<str>,
key: Option<Vec<u8>>,
iv: Option<Vec<u8>>,
},
CryptoTransform {
algorithm: Arc<str>,
key: Vec<u8>,
iv: Vec<u8>,
is_encryptor: bool,
},
ReflectionMethod {
method_token: Token,
},
KeyDerivation {
password: Vec<u8>,
salt: Vec<u8>,
iterations: u32,
hash_algorithm: Arc<str>,
},
Stream {
data: Vec<u8>,
position: usize,
},
CryptoStream {
underlying_stream: HeapRef,
transform: HeapRef,
mode: u8,
transformed_data: Option<Vec<u8>>,
transformed_pos: usize,
write_buffer: Vec<u8>,
},
}
impl HeapObject {
#[must_use]
pub fn kind(&self) -> &'static str {
match self {
HeapObject::String(_) => "string",
HeapObject::Array { .. } => "array",
HeapObject::MultiArray { .. } => "multi-dimensional array",
HeapObject::Object { .. } => "object",
HeapObject::BoxedValue { .. } => "boxed value",
HeapObject::Delegate { .. } => "delegate",
HeapObject::Encoding { .. } => "encoding",
HeapObject::CryptoAlgorithm { .. } => "crypto algorithm",
HeapObject::SymmetricAlgorithm { .. } => "symmetric algorithm",
HeapObject::CryptoTransform { .. } => "crypto transform",
HeapObject::ReflectionMethod { .. } => "reflection method",
HeapObject::KeyDerivation { .. } => "key derivation",
HeapObject::Stream { .. } => "stream",
HeapObject::CryptoStream { .. } => "crypto stream",
}
}
#[must_use]
pub fn estimated_size(&self) -> usize {
match self {
HeapObject::String(s) => 24 + s.len() * 2, HeapObject::Array { elements, .. } => 24 + elements.len() * 8,
HeapObject::MultiArray { elements, .. } => 32 + elements.len() * 8,
HeapObject::Object { fields, .. } => 24 + fields.len() * 16,
HeapObject::BoxedValue { .. }
| HeapObject::CryptoAlgorithm { .. }
| HeapObject::ReflectionMethod { .. } => 32,
HeapObject::CryptoTransform { key, iv, .. } => 48 + key.len() + iv.len(),
HeapObject::Delegate { .. } => 48,
HeapObject::Encoding { .. } => 24,
HeapObject::SymmetricAlgorithm { key, iv, .. } => {
32 + key.as_ref().map_or(0, Vec::len) + iv.as_ref().map_or(0, Vec::len)
}
HeapObject::KeyDerivation { password, salt, .. } => 48 + password.len() + salt.len(),
HeapObject::Stream { data, .. } => 32 + data.len(),
HeapObject::CryptoStream {
transformed_data,
write_buffer,
..
} => 64 + transformed_data.as_ref().map_or(0, Vec::len) + write_buffer.len(),
}
}
}
impl fmt::Display for HeapObject {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HeapObject::String(s) => {
if s.len() > 50 {
write!(f, "\"{}...\"", &s[..47])
} else {
write!(f, "\"{s}\"")
}
}
HeapObject::Array {
element_type,
elements,
} => {
write!(f, "{:?}[{}]", element_type, elements.len())
}
HeapObject::MultiArray {
element_type,
dimensions,
..
} => {
write!(f, "{element_type:?}[")?;
for (i, dim) in dimensions.iter().enumerate() {
if i > 0 {
write!(f, ",")?;
}
write!(f, "{dim}")?;
}
write!(f, "]")
}
HeapObject::Object { type_token, .. } => {
write!(f, "object({type_token})")
}
HeapObject::BoxedValue { type_token, value } => {
write!(f, "boxed({type_token}, {value})")
}
HeapObject::Delegate {
type_token,
method_token,
..
} => {
write!(f, "delegate({type_token}, {method_token})")
}
HeapObject::Encoding { encoding_type } => {
write!(f, "encoding({encoding_type:?})")
}
HeapObject::CryptoAlgorithm { algorithm_type } => {
write!(f, "crypto_algorithm({algorithm_type})")
}
HeapObject::SymmetricAlgorithm { algorithm_type, .. } => {
write!(f, "symmetric_algorithm({algorithm_type})")
}
HeapObject::CryptoTransform {
algorithm,
is_encryptor,
key,
..
} => {
let mode = if *is_encryptor { "encrypt" } else { "decrypt" };
write!(
f,
"crypto_transform({} {} key={}B)",
algorithm,
mode,
key.len()
)
}
HeapObject::ReflectionMethod { method_token } => {
write!(f, "reflection_method(0x{:08x})", method_token.value())
}
HeapObject::KeyDerivation {
hash_algorithm,
iterations,
..
} => {
write!(
f,
"key_derivation({hash_algorithm}, {iterations} iterations)"
)
}
HeapObject::Stream { data, position } => {
write!(f, "stream({} bytes, pos={})", data.len(), position)
}
HeapObject::CryptoStream {
mode,
transformed_data,
write_buffer,
..
} => {
let mode_str = if *mode == 0 { "Read" } else { "Write" };
let cached = transformed_data.as_ref().map_or(0, Vec::len);
let buffered = write_buffer.len();
write!(
f,
"crypto_stream(mode={mode_str}, cached={cached}, buffered={buffered})"
)
}
}
}
}
#[derive(Clone, Debug)]
struct HeapState {
objects: ImHashMap<u64, HeapObject>,
}
#[derive(Debug)]
pub struct ManagedHeap {
state: RwLock<HeapState>,
next_id: AtomicU64,
current_size: AtomicUsize,
max_size: usize,
}
impl ManagedHeap {
#[must_use]
pub fn new(max_size: usize) -> Self {
ManagedHeap {
state: RwLock::new(HeapState {
objects: ImHashMap::new(),
}),
next_id: AtomicU64::new(1),
current_size: AtomicUsize::new(0),
max_size,
}
}
#[must_use]
pub fn default_size() -> Self {
Self::new(64 * 1024 * 1024)
}
fn check_allocation(&self, size: usize) -> Result<()> {
let current = self.current_size.load(Ordering::Relaxed);
if current + size > self.max_size {
return Err(EmulationError::HeapMemoryLimitExceeded {
current,
limit: self.max_size,
}
.into());
}
Ok(())
}
fn alloc_object_internal(&self, obj: HeapObject) -> Result<HeapRef> {
let size = obj.estimated_size();
self.check_allocation(size)?;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let heap_ref = HeapRef::new(id);
let mut state = self.state.write().expect("heap lock poisoned");
state.objects.insert(heap_ref.id(), obj);
self.current_size.fetch_add(size, Ordering::Relaxed);
Ok(heap_ref)
}
#[must_use]
pub fn fork(&self) -> Self {
let state = self.state.read().expect("heap lock poisoned");
ManagedHeap {
state: RwLock::new(HeapState {
objects: state.objects.clone(),
}),
next_id: AtomicU64::new(self.next_id.load(Ordering::SeqCst)),
current_size: AtomicUsize::new(self.current_size.load(Ordering::Relaxed)),
max_size: self.max_size,
}
}
pub fn alloc_string(&self, value: &str) -> Result<HeapRef> {
let arc_str: Arc<str> = value.into();
self.alloc_object_internal(HeapObject::String(arc_str))
}
pub fn alloc_array(&self, element_type: CilFlavor, length: usize) -> Result<HeapRef> {
let elements = vec![EmValue::default_for_flavor(&element_type); length];
self.alloc_object_internal(HeapObject::Array {
element_type,
elements,
})
}
pub fn alloc_array_with_values(
&self,
element_type: CilFlavor,
elements: Vec<EmValue>,
) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::Array {
element_type,
elements,
})
}
pub fn alloc_multi_array(
&self,
element_type: CilFlavor,
dimensions: Vec<usize>,
) -> Result<HeapRef> {
let total_elements: usize = dimensions.iter().product();
let elements = vec![EmValue::default_for_flavor(&element_type); total_elements];
self.alloc_object_internal(HeapObject::MultiArray {
element_type,
dimensions,
elements,
})
}
pub fn alloc_object_with_fields(
&self,
type_token: Token,
field_types: &[(Token, CilFlavor)],
) -> Result<HeapRef> {
let mut fields = HashMap::new();
for (token, cil_flavor) in field_types {
fields.insert(*token, EmValue::default_for_flavor(cil_flavor));
}
self.alloc_object_internal(HeapObject::Object { type_token, fields })
}
pub fn alloc_object(&self, type_token: Token) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::Object {
type_token,
fields: HashMap::new(),
})
}
pub fn alloc_boxed(&self, type_token: Token, value: EmValue) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::BoxedValue {
type_token,
value: Box::new(value),
})
}
pub fn alloc_delegate(
&self,
type_token: Token,
target: Option<HeapRef>,
method_token: Token,
) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::Delegate {
type_token,
target,
method_token,
})
}
pub fn get(&self, heap_ref: HeapRef) -> Result<HeapObject> {
let state = self.state.read().expect("heap lock poisoned");
state.objects.get(&heap_ref.id()).cloned().ok_or(
EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into(),
)
}
pub fn with_object_mut<F, R>(&self, heap_ref: HeapRef, f: F) -> Result<R>
where
F: FnOnce(&mut HeapObject) -> Result<R>,
{
let mut state = self.state.write().expect("heap lock poisoned");
let obj =
state
.objects
.get_mut(&heap_ref.id())
.ok_or(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
})?;
f(obj)
}
pub fn get_string(&self, heap_ref: HeapRef) -> Result<Arc<str>> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id()) {
Some(HeapObject::String(s)) => Ok(Arc::clone(s)),
Some(other) => Err(EmulationError::HeapTypeMismatch {
expected: "string",
found: other.kind(),
}
.into()),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn get_array_element(&self, heap_ref: HeapRef, index: usize) -> Result<EmValue> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id()) {
Some(HeapObject::Array { elements, .. }) => {
if index >= elements.len() {
Err(EmulationError::ArrayIndexOutOfBounds {
index: i64::try_from(index).unwrap_or(i64::MAX),
length: elements.len(),
}
.into())
} else {
Ok(elements[index].clone())
}
}
Some(other) => Err(EmulationError::HeapTypeMismatch {
expected: "array",
found: other.kind(),
}
.into()),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn set_array_element(&self, heap_ref: HeapRef, index: usize, value: EmValue) -> Result<()> {
let mut state = self.state.write().expect("heap lock poisoned");
match state.objects.get_mut(&heap_ref.id()) {
Some(HeapObject::Array { elements, .. }) => {
if index >= elements.len() {
Err(EmulationError::ArrayIndexOutOfBounds {
index: i64::try_from(index).unwrap_or(i64::MAX),
length: elements.len(),
}
.into())
} else {
elements[index] = value;
Ok(())
}
}
Some(other) => Err(EmulationError::HeapTypeMismatch {
expected: "array",
found: other.kind(),
}
.into()),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn get_array_length(&self, heap_ref: HeapRef) -> Result<usize> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id()) {
Some(HeapObject::Array { elements, .. }) => Ok(elements.len()),
Some(HeapObject::MultiArray { dimensions, .. }) => Ok(dimensions.iter().product()),
Some(other) => Err(EmulationError::HeapTypeMismatch {
expected: "array",
found: other.kind(),
}
.into()),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn get_array_element_type(&self, heap_ref: HeapRef) -> Result<CilFlavor> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id()) {
Some(
HeapObject::Array { element_type, .. }
| HeapObject::MultiArray { element_type, .. },
) => Ok(element_type.clone()),
Some(other) => Err(EmulationError::HeapTypeMismatch {
expected: "array",
found: other.kind(),
}
.into()),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn get_field(&self, heap_ref: HeapRef, field_token: Token) -> Result<EmValue> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id()) {
Some(HeapObject::Object { fields, .. }) => fields
.get(&field_token)
.cloned()
.ok_or(EmulationError::FieldNotFound { token: field_token }.into()),
Some(other) => Err(EmulationError::HeapTypeMismatch {
expected: "object",
found: other.kind(),
}
.into()),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn set_field(&self, heap_ref: HeapRef, field_token: Token, value: EmValue) -> Result<()> {
let mut state = self.state.write().expect("heap lock poisoned");
match state.objects.get_mut(&heap_ref.id()) {
Some(HeapObject::Object { fields, .. }) => {
fields.insert(field_token, value);
Ok(())
}
Some(other) => Err(EmulationError::HeapTypeMismatch {
expected: "object",
found: other.kind(),
}
.into()),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn get_type_token(&self, heap_ref: HeapRef) -> Result<Token> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id()) {
Some(
HeapObject::Object { type_token, .. }
| HeapObject::BoxedValue { type_token, .. }
| HeapObject::Delegate { type_token, .. },
) => Ok(*type_token),
Some(HeapObject::String(_)) => Ok(Token::new(0x0100_0001)), Some(HeapObject::Array { .. } | HeapObject::MultiArray { .. }) => {
Ok(Token::new(0x0100_0002))
}
Some(HeapObject::Encoding { .. }) => Ok(Token::new(0x0100_0003)), Some(HeapObject::CryptoAlgorithm { .. }) => Ok(Token::new(0x0100_0004)),
Some(HeapObject::SymmetricAlgorithm { .. }) => Ok(Token::new(0x0100_0005)),
Some(HeapObject::CryptoTransform { .. }) => Ok(Token::new(0x0100_0006)),
Some(HeapObject::ReflectionMethod { method_token }) => Ok(*method_token),
Some(HeapObject::KeyDerivation { .. }) => Ok(Token::new(0x0100_0007)),
Some(HeapObject::Stream { .. }) => Ok(Token::new(0x0100_0008)),
Some(HeapObject::CryptoStream { .. }) => Ok(Token::new(0x0100_0009)),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn unbox(&self, heap_ref: HeapRef) -> Result<EmValue> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id()) {
Some(HeapObject::BoxedValue { value, .. }) => Ok((**value).clone()),
Some(other) => Err(EmulationError::HeapTypeMismatch {
expected: "boxed value",
found: other.kind(),
}
.into()),
None => Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into()),
}
}
pub fn get_boxed_value(&self, heap_ref: HeapRef) -> Result<EmValue> {
self.unbox(heap_ref)
}
#[must_use]
pub fn contains(&self, heap_ref: HeapRef) -> bool {
let state = self.state.read().expect("heap lock poisoned");
state.objects.contains_key(&heap_ref.id())
}
#[must_use]
pub fn current_size(&self) -> usize {
self.current_size.load(Ordering::Relaxed)
}
#[must_use]
pub fn max_size(&self) -> usize {
self.max_size
}
#[must_use]
pub fn object_count(&self) -> usize {
let state = self.state.read().expect("heap lock poisoned");
state.objects.len()
}
pub fn clear(&self) {
let mut state = self.state.write().expect("heap lock poisoned");
state.objects.clear();
self.current_size.store(0, Ordering::Relaxed);
}
#[must_use]
pub fn to_vec(&self) -> Vec<(HeapRef, HeapObject)> {
let state = self.state.read().expect("heap lock poisoned");
state
.objects
.iter()
.map(|(&id, obj)| (HeapRef::new(id), obj.clone()))
.collect()
}
pub fn iter(&self) -> HeapIter<'_> {
let keys: Vec<u64> = {
let state = self.state.read().expect("heap lock poisoned");
state.objects.keys().copied().collect()
};
HeapIter {
heap: self,
keys: keys.into_iter(),
}
}
#[must_use]
pub fn object_count_estimate(&self) -> usize {
let state = self.state.read().expect("heap lock poisoned");
state.objects.len()
}
pub fn alloc_byte_array(&self, data: &[u8]) -> Result<HeapRef> {
let elements: Vec<EmValue> = data.iter().map(|&b| EmValue::I32(i32::from(b))).collect();
self.alloc_array_with_values(CilFlavor::U1, elements)
}
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn get_byte_array(&self, heap_ref: HeapRef) -> Option<Vec<u8>> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::Array { elements, .. } => {
let mut bytes = Vec::with_capacity(elements.len());
for e in elements {
match e {
EmValue::I32(n) => bytes.push(*n as u8),
_ => return None, }
}
Some(bytes)
}
_ => None,
}
}
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn get_array_as_bytes(&self, heap_ref: HeapRef, ptr_size: PointerSize) -> Option<Vec<u8>> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::Array {
elements,
element_type,
} => {
let element_size = element_type.element_size(ptr_size)?;
let mut bytes = Vec::with_capacity(elements.len() * element_size);
for e in elements {
match e {
EmValue::I32(n) => match element_size {
2 => bytes.extend_from_slice(&(*n as i16).to_le_bytes()),
4 => bytes.extend_from_slice(&n.to_le_bytes()),
_ => bytes.push(*n as u8),
},
EmValue::I64(n) => {
bytes.extend_from_slice(&n.to_le_bytes());
}
EmValue::F32(f) => {
bytes.extend_from_slice(&f.to_le_bytes());
}
EmValue::F64(f) => {
bytes.extend_from_slice(&f.to_le_bytes());
}
_ => return None, }
}
Some(bytes)
}
_ => None,
}
}
#[must_use]
pub fn get_string_opt(&self, heap_ref: HeapRef) -> Option<Arc<str>> {
self.get_string(heap_ref).ok()
}
pub fn alloc_encoding(&self, encoding_type: EncodingType) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::Encoding { encoding_type })
}
#[must_use]
pub fn get_encoding_type(&self, heap_ref: HeapRef) -> Option<EncodingType> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::Encoding { encoding_type } => Some(*encoding_type),
_ => None,
}
}
pub fn alloc_crypto_algorithm(&self, algorithm_type: &str) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::CryptoAlgorithm {
algorithm_type: algorithm_type.into(),
})
}
#[must_use]
pub fn get_crypto_algorithm_type(&self, heap_ref: HeapRef) -> Option<Arc<str>> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::CryptoAlgorithm { algorithm_type } => Some(Arc::clone(algorithm_type)),
_ => None,
}
}
pub fn alloc_symmetric_algorithm(&self, algorithm_type: &str) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::SymmetricAlgorithm {
algorithm_type: algorithm_type.into(),
key: None,
iv: None,
})
}
pub fn set_symmetric_key(&self, heap_ref: HeapRef, key: Vec<u8>) -> Result<()> {
let mut state = self.state.write().expect("heap lock poisoned");
if let Some(HeapObject::SymmetricAlgorithm { key: key_slot, .. }) =
state.objects.get_mut(&heap_ref.id())
{
*key_slot = Some(key);
return Ok(());
}
Err(EmulationError::HeapTypeMismatch {
expected: "SymmetricAlgorithm",
found: "other",
}
.into())
}
pub fn set_symmetric_iv(&self, heap_ref: HeapRef, iv: Vec<u8>) -> Result<()> {
let mut state = self.state.write().expect("heap lock poisoned");
if let Some(HeapObject::SymmetricAlgorithm { iv: iv_slot, .. }) =
state.objects.get_mut(&heap_ref.id())
{
*iv_slot = Some(iv);
return Ok(());
}
Err(EmulationError::HeapTypeMismatch {
expected: "SymmetricAlgorithm",
found: "other",
}
.into())
}
#[must_use]
pub fn get_symmetric_algorithm_info(
&self,
heap_ref: HeapRef,
) -> Option<SymmetricAlgorithmInfo> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::SymmetricAlgorithm {
algorithm_type,
key,
iv,
} => Some((algorithm_type.clone(), key.clone(), iv.clone())),
_ => None,
}
}
pub fn alloc_crypto_transform(
&self,
algorithm: &str,
key: Vec<u8>,
iv: Vec<u8>,
is_encryptor: bool,
) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::CryptoTransform {
algorithm: algorithm.into(),
key,
iv,
is_encryptor,
})
}
#[must_use]
pub fn get_crypto_transform_info(&self, heap_ref: HeapRef) -> Option<CryptoTransformInfo> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::CryptoTransform {
algorithm,
key,
iv,
is_encryptor,
} => Some((algorithm.clone(), key.clone(), iv.clone(), *is_encryptor)),
_ => None,
}
}
pub fn alloc_reflection_method(&self, method_token: Token) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::ReflectionMethod { method_token })
}
#[must_use]
pub fn get_reflection_method_token(&self, heap_ref: HeapRef) -> Option<Token> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::ReflectionMethod { method_token } => Some(*method_token),
_ => None,
}
}
pub fn replace_with_key_derivation(
&self,
heap_ref: HeapRef,
password: Vec<u8>,
salt: Vec<u8>,
iterations: u32,
hash_algorithm: &str,
) -> Result<()> {
let mut state = self.state.write().expect("heap lock poisoned");
if state.objects.contains_key(&heap_ref.id()) {
state.objects.insert(
heap_ref.id(),
HeapObject::KeyDerivation {
password,
salt,
iterations,
hash_algorithm: hash_algorithm.into(),
},
);
Ok(())
} else {
Err(EmulationError::InvalidHeapReference {
reference_id: heap_ref.id(),
}
.into())
}
}
pub fn alloc_key_derivation(
&self,
password: Vec<u8>,
salt: Vec<u8>,
iterations: u32,
hash_algorithm: &str,
) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::KeyDerivation {
password,
salt,
iterations,
hash_algorithm: hash_algorithm.into(),
})
}
#[must_use]
pub fn get_key_derivation_params(&self, heap_ref: HeapRef) -> Option<KeyDerivationInfo> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::KeyDerivation {
password,
salt,
iterations,
hash_algorithm,
} => Some((
password.clone(),
salt.clone(),
*iterations,
hash_algorithm.clone(),
)),
_ => None,
}
}
pub fn alloc_stream(&self, data: Vec<u8>) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::Stream { data, position: 0 })
}
#[must_use]
pub fn get_stream_data(&self, heap_ref: HeapRef) -> Option<(Vec<u8>, usize)> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::Stream { data, position } => Some((data.clone(), *position)),
_ => None,
}
}
pub fn set_stream_position(&self, heap_ref: HeapRef, new_position: usize) -> bool {
let mut state = self.state.write().expect("heap lock poisoned");
if let Some(HeapObject::Stream { position, .. }) = state.objects.get_mut(&heap_ref.id()) {
*position = new_position;
true
} else {
false
}
}
pub fn write_to_stream(&self, heap_ref: HeapRef, bytes: &[u8]) -> usize {
let mut state = self.state.write().expect("heap lock poisoned");
if let Some(HeapObject::Stream { data, position }) = state.objects.get_mut(&heap_ref.id()) {
let write_len = bytes.len();
let required_len = *position + write_len;
if data.len() < required_len {
data.resize(required_len, 0);
}
data[*position..*position + write_len].copy_from_slice(bytes);
*position += write_len;
write_len
} else {
0
}
}
pub fn replace_with_stream(&self, heap_ref: HeapRef, data: Vec<u8>) -> bool {
let mut state = self.state.write().expect("heap lock poisoned");
if let Some(old_obj) = state.objects.get(&heap_ref.id()) {
let old_size = old_obj.estimated_size();
let new_obj = HeapObject::Stream { data, position: 0 };
let new_size = new_obj.estimated_size();
state.objects.insert(heap_ref.id(), new_obj);
if new_size >= old_size {
self.current_size
.fetch_add(new_size - old_size, Ordering::Relaxed);
} else {
self.current_size
.fetch_sub(old_size - new_size, Ordering::Relaxed);
}
true
} else {
false
}
}
pub fn alloc_crypto_stream(
&self,
underlying_stream: HeapRef,
transform: HeapRef,
mode: u8,
) -> Result<HeapRef> {
self.alloc_object_internal(HeapObject::CryptoStream {
underlying_stream,
transform,
mode,
transformed_data: None,
transformed_pos: 0,
write_buffer: Vec::new(),
})
}
#[must_use]
pub fn get_crypto_stream_info(&self, heap_ref: HeapRef) -> Option<(HeapRef, HeapRef, u8)> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::CryptoStream {
underlying_stream,
transform,
mode,
..
} => Some((*underlying_stream, *transform, *mode)),
_ => None,
}
}
pub fn replace_with_crypto_stream(
&self,
heap_ref: HeapRef,
underlying_stream: HeapRef,
transform: HeapRef,
mode: u8,
) -> bool {
let mut state = self.state.write().expect("heap lock poisoned");
if let Some(old_obj) = state.objects.get(&heap_ref.id()) {
let old_size = old_obj.estimated_size();
let new_obj = HeapObject::CryptoStream {
underlying_stream,
transform,
mode,
transformed_data: None,
transformed_pos: 0,
write_buffer: Vec::new(),
};
let new_size = new_obj.estimated_size();
state.objects.insert(heap_ref.id(), new_obj);
if new_size >= old_size {
self.current_size
.fetch_add(new_size - old_size, Ordering::Relaxed);
} else {
self.current_size
.fetch_sub(old_size - new_size, Ordering::Relaxed);
}
true
} else {
false
}
}
#[must_use]
pub fn get_crypto_stream_transformed(&self, heap_ref: HeapRef) -> Option<(Vec<u8>, usize)> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::CryptoStream {
transformed_data: Some(data),
transformed_pos,
..
} => Some((data.clone(), *transformed_pos)),
_ => None,
}
}
pub fn set_crypto_stream_transformed(&self, heap_ref: HeapRef, data: Vec<u8>) -> bool {
let mut state = self.state.write().expect("heap lock poisoned");
let new_size = data.len();
let old_size = match state.objects.get(&heap_ref.id()) {
Some(HeapObject::CryptoStream {
transformed_data, ..
}) => transformed_data.as_ref().map_or(0, Vec::len),
_ => return false,
};
if let Some(HeapObject::CryptoStream {
transformed_data,
transformed_pos,
..
}) = state.objects.get_mut(&heap_ref.id())
{
*transformed_data = Some(data);
*transformed_pos = 0; } else {
return false;
}
if new_size >= old_size {
self.current_size
.fetch_add(new_size - old_size, Ordering::Relaxed);
} else {
self.current_size
.fetch_sub(old_size - new_size, Ordering::Relaxed);
}
true
}
pub fn read_crypto_stream(&self, heap_ref: HeapRef, count: usize) -> Option<Vec<u8>> {
let mut state = self.state.write().expect("heap lock poisoned");
if let Some(HeapObject::CryptoStream {
transformed_data: Some(data),
transformed_pos,
..
}) = state.objects.get_mut(&heap_ref.id())
{
let available = data.len().saturating_sub(*transformed_pos);
let to_read = count.min(available);
let result = data[*transformed_pos..*transformed_pos + to_read].to_vec();
*transformed_pos += to_read;
Some(result)
} else {
None
}
}
pub fn crypto_stream_append_write(&self, heap_ref: HeapRef, data: &[u8]) -> bool {
let mut state = self.state.write().expect("heap lock poisoned");
let data_len = data.len();
if let Some(HeapObject::CryptoStream { write_buffer, .. }) =
state.objects.get_mut(&heap_ref.id())
{
write_buffer.extend_from_slice(data);
self.current_size.fetch_add(data_len, Ordering::Relaxed);
true
} else {
false
}
}
#[must_use]
pub fn get_crypto_stream_write_buffer(&self, heap_ref: HeapRef) -> Option<Vec<u8>> {
let state = self.state.read().expect("heap lock poisoned");
match state.objects.get(&heap_ref.id())? {
HeapObject::CryptoStream { write_buffer, .. } => Some(write_buffer.clone()),
_ => None,
}
}
pub fn clear_crypto_stream_write_buffer(&self, heap_ref: HeapRef) -> bool {
let mut state = self.state.write().expect("heap lock poisoned");
let buffer_len = match state.objects.get(&heap_ref.id()) {
Some(HeapObject::CryptoStream { write_buffer, .. }) => write_buffer.len(),
_ => return false,
};
if let Some(HeapObject::CryptoStream { write_buffer, .. }) =
state.objects.get_mut(&heap_ref.id())
{
write_buffer.clear();
}
self.current_size.fetch_sub(buffer_len, Ordering::Relaxed);
true
}
}
impl<'a> IntoIterator for &'a ManagedHeap {
type Item = (HeapRef, HeapObject);
type IntoIter = HeapIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl Clone for ManagedHeap {
fn clone(&self) -> Self {
self.fork()
}
}
impl Default for ManagedHeap {
fn default() -> Self {
Self::default_size()
}
}
impl fmt::Display for ManagedHeap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.state.read().expect("heap lock poisoned");
write!(
f,
"Heap({} objects, {}/{} bytes)",
state.objects.len(),
self.current_size.load(Ordering::Relaxed),
self.max_size
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Error;
#[test]
fn test_heap_alloc_string() {
let heap = ManagedHeap::new(1024 * 1024);
let string_ref = heap.alloc_string("Hello, World!").unwrap();
assert!(heap.contains(string_ref));
let value = heap.get_string(string_ref).unwrap();
assert_eq!(&*value, "Hello, World!");
}
#[test]
fn test_heap_alloc_array() {
let heap = ManagedHeap::new(1024 * 1024);
let array_ref = heap.alloc_array(CilFlavor::I4, 10).unwrap();
assert!(heap.contains(array_ref));
let length = heap.get_array_length(array_ref).unwrap();
assert_eq!(length, 10);
let elem = heap.get_array_element(array_ref, 0).unwrap();
assert_eq!(elem, EmValue::I32(0));
}
#[test]
fn test_heap_array_operations() {
let heap = ManagedHeap::new(1024 * 1024);
let array_ref = heap.alloc_array(CilFlavor::I4, 5).unwrap();
heap.set_array_element(array_ref, 2, EmValue::I32(42))
.unwrap();
let elem = heap.get_array_element(array_ref, 2).unwrap();
assert_eq!(elem, EmValue::I32(42));
assert!(heap.get_array_element(array_ref, 10).is_err());
assert!(heap
.set_array_element(array_ref, 10, EmValue::I32(0))
.is_err());
}
#[test]
fn test_heap_alloc_object() {
let heap = ManagedHeap::new(1024 * 1024);
let type_token = Token::new(0x0200_0001);
let field_token = Token::new(0x0400_0001);
let obj_ref = heap
.alloc_object_with_fields(type_token, &[(field_token, CilFlavor::I4)])
.unwrap();
let field = heap.get_field(obj_ref, field_token).unwrap();
assert_eq!(field, EmValue::I32(0));
heap.set_field(obj_ref, field_token, EmValue::I32(42))
.unwrap();
let field = heap.get_field(obj_ref, field_token).unwrap();
assert_eq!(field, EmValue::I32(42));
}
#[test]
fn test_heap_boxing() {
let heap = ManagedHeap::new(1024 * 1024);
let type_token = Token::new(0x0200_0001);
let boxed_ref = heap.alloc_boxed(type_token, EmValue::I32(42)).unwrap();
let value = heap.unbox(boxed_ref).unwrap();
assert_eq!(value, EmValue::I32(42));
}
#[test]
fn test_heap_out_of_memory() {
let heap = ManagedHeap::new(100);
let large_string = "A".repeat(1000);
let result = heap.alloc_string(&large_string);
assert!(matches!(
result,
Err(Error::Emulation(ref e)) if matches!(e.as_ref(), EmulationError::HeapMemoryLimitExceeded { .. })
));
}
#[test]
fn test_heap_invalid_reference() {
let heap = ManagedHeap::new(1024 * 1024);
let fake_ref = HeapRef::new(9999);
assert!(heap.get(fake_ref).is_err());
assert!(!heap.contains(fake_ref));
}
#[test]
fn test_heap_type_mismatch() {
let heap = ManagedHeap::new(1024 * 1024);
let string_ref = heap.alloc_string("test").unwrap();
assert!(matches!(
heap.get_array_element(string_ref, 0),
Err(Error::Emulation(ref e)) if matches!(e.as_ref(), EmulationError::HeapTypeMismatch { .. })
));
}
#[test]
fn test_heap_clear() {
let heap = ManagedHeap::new(1024 * 1024);
heap.alloc_string("test").unwrap();
heap.alloc_array(CilFlavor::I4, 10).unwrap();
assert!(heap.object_count() > 0);
heap.clear();
assert_eq!(heap.object_count(), 0);
assert_eq!(heap.current_size(), 0);
}
#[test]
fn test_heap_display() {
let heap = ManagedHeap::new(1024 * 1024);
heap.alloc_string("test").unwrap();
let display = format!("{heap}");
assert!(display.contains("1 objects"));
}
#[test]
fn test_heap_object_display() {
let obj = HeapObject::String("test".into());
assert!(format!("{obj}").contains("test"));
let obj = HeapObject::Array {
element_type: CilFlavor::I4,
elements: vec![EmValue::I32(0); 5],
};
assert!(format!("{obj}").contains("5"));
}
#[test]
fn test_heap_concurrent_access() {
let heap = ManagedHeap::new(1024 * 1024);
let s1 = heap.alloc_string("first").unwrap();
let s2 = heap.alloc_string("second").unwrap();
let str1 = heap.get_string(s1).unwrap();
let str2 = heap.get_string(s2).unwrap();
assert_eq!(&*str1, "first");
assert_eq!(&*str2, "second");
let s3 = heap.alloc_string("third").unwrap();
let str3 = heap.get_string(s3).unwrap();
assert_eq!(&*str3, "third");
assert_eq!(&*str1, "first");
}
#[test]
fn test_heap_fork() {
let heap = ManagedHeap::new(1024 * 1024);
let s1 = heap.alloc_string("original").unwrap();
let arr1 = heap.alloc_array(CilFlavor::I4, 5).unwrap();
heap.set_array_element(arr1, 0, EmValue::I32(42)).unwrap();
let forked = heap.fork();
assert_eq!(heap.get_string(s1).unwrap().as_ref(), "original");
assert_eq!(forked.get_string(s1).unwrap().as_ref(), "original");
assert_eq!(heap.get_array_element(arr1, 0).unwrap(), EmValue::I32(42));
assert_eq!(forked.get_array_element(arr1, 0).unwrap(), EmValue::I32(42));
forked
.set_array_element(arr1, 0, EmValue::I32(100))
.unwrap();
let s2 = forked.alloc_string("forked").unwrap();
assert_eq!(heap.get_array_element(arr1, 0).unwrap(), EmValue::I32(42));
assert!(!heap.contains(s2));
assert_eq!(
forked.get_array_element(arr1, 0).unwrap(),
EmValue::I32(100)
);
assert!(forked.contains(s2));
assert_eq!(forked.get_string(s2).unwrap().as_ref(), "forked");
}
#[test]
fn test_heap_fork_isolation() {
let heap = ManagedHeap::new(1024 * 1024);
let s1 = heap.alloc_string("hello").unwrap();
let fork1 = heap.fork();
let fork2 = heap.fork();
let f1_str = fork1.alloc_string("fork1").unwrap();
let f2_str = fork2.alloc_string("fork2").unwrap();
assert!(fork1.contains(s1));
assert!(fork2.contains(s1));
assert_eq!(fork1.get_string(s1).unwrap().as_ref(), "hello");
assert_eq!(fork2.get_string(s1).unwrap().as_ref(), "hello");
assert_eq!(fork1.get_string(f1_str).unwrap().as_ref(), "fork1");
assert_eq!(fork2.get_string(f2_str).unwrap().as_ref(), "fork2");
assert!(heap.contains(s1));
assert!(!heap.contains(f1_str)); }
#[test]
fn test_heap_fork_cow_semantics() {
let heap = ManagedHeap::new(1024 * 1024);
let arr = heap.alloc_array(CilFlavor::I4, 3).unwrap();
heap.set_array_element(arr, 0, EmValue::I32(1)).unwrap();
heap.set_array_element(arr, 1, EmValue::I32(2)).unwrap();
heap.set_array_element(arr, 2, EmValue::I32(3)).unwrap();
let forked = heap.fork();
assert_eq!(heap.get_array_element(arr, 0).unwrap(), EmValue::I32(1));
assert_eq!(forked.get_array_element(arr, 0).unwrap(), EmValue::I32(1));
forked.set_array_element(arr, 0, EmValue::I32(100)).unwrap();
assert_eq!(heap.get_array_element(arr, 0).unwrap(), EmValue::I32(1));
assert_eq!(heap.get_array_element(arr, 1).unwrap(), EmValue::I32(2));
assert_eq!(heap.get_array_element(arr, 2).unwrap(), EmValue::I32(3));
assert_eq!(forked.get_array_element(arr, 0).unwrap(), EmValue::I32(100));
assert_eq!(forked.get_array_element(arr, 1).unwrap(), EmValue::I32(2));
assert_eq!(forked.get_array_element(arr, 2).unwrap(), EmValue::I32(3));
}
}