use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
use std::thread;
use std::time::Duration;
use windows::Win32::Foundation::HANDLE;
use crate::memory::{read_memory_bytes, write_memory_bytes, MemoryError};
use crate::memory_resolver::MemoryAddress;
use super::utils::SendableHandle;
#[derive(Clone)]
pub enum AddressSource {
Static(usize),
Dynamic(MemoryAddress),
}
impl AddressSource {
pub fn resolve(&self, handle: HANDLE, pid: u32) -> Result<usize, MemoryError> {
match self {
AddressSource::Static(addr) => Ok(*addr),
AddressSource::Dynamic(mem_addr) => {
mem_addr.resolve_address(handle, pid)
.map_err(|e| MemoryError::ReadFailed(
format!("Failed to resolve dynamic address: {}", e)
))
}
}
}
}
pub struct MemoryLock {
handle: Option<SendableHandle>,
address_source: Option<AddressSource>,
pid: Option<u32>,
size: usize,
locked_value: Vec<u8>,
scan_interval: Duration,
stop_flag: Arc<AtomicBool>,
worker_thread: Option<thread::JoinHandle<()>>,
}
impl MemoryLock {
pub fn builder() -> MemoryLockBuilder {
MemoryLockBuilder::new()
}
pub fn set_scan_interval(&mut self, interval: Duration) {
self.scan_interval = interval;
}
pub fn lock_value<T: Copy + AsBytes>(&mut self, value: T) -> Result<(), MemoryError> {
let handle = self.handle.as_ref().ok_or_else(|| {
MemoryError::WriteFailed("handle must be set. Call .handle(handle) before lock_value().".to_string())
})?;
let address_source = self.address_source.as_ref().ok_or_else(|| {
MemoryError::WriteFailed("address must be set. Call .address(addr) or .address_from_resolver(addr) before lock_value().".to_string())
})?;
let pid = self.pid.ok_or_else(|| {
MemoryError::WriteFailed("PID must be set for address resolution. Call .pid(pid) before lock_value().".to_string())
})?;
let address = match address_source.resolve(handle.0, pid) {
Ok(addr) => addr,
Err(e) => {
#[cfg(debug_assertions)]
println!("[MemoryLock] âš Initial address resolution failed ({}), background thread will retry", e);
let bytes = value.as_bytes();
self.size = bytes.len();
self.locked_value = bytes.to_vec();
return self.start_monitoring(0); }
};
let bytes = value.as_bytes();
self.size = bytes.len();
self.locked_value = bytes.to_vec();
write_memory_bytes(handle.0, address, &self.locked_value)?;
self.start_monitoring(address)?;
Ok(())
}
pub fn lock_bytes(&mut self, bytes: &[u8]) -> Result<(), MemoryError> {
let handle = self.handle.as_ref().ok_or_else(|| {
MemoryError::WriteFailed("handle must be set. Call .handle(handle) before lock_bytes().".to_string())
})?;
let address_source = self.address_source.as_ref().ok_or_else(|| {
MemoryError::WriteFailed("address must be set. Call .address(addr) or .address_from_resolver(addr) before lock_bytes().".to_string())
})?;
let pid = self.pid.ok_or_else(|| {
MemoryError::WriteFailed("PID must be set for address resolution. Call .pid(pid) before lock_bytes().".to_string())
})?;
let address = match address_source.resolve(handle.0, pid) {
Ok(addr) => addr,
Err(e) => {
#[cfg(debug_assertions)]
println!("[MemoryLock] âš Initial address resolution failed ({}), background thread will retry", e);
self.size = bytes.len();
self.locked_value = bytes.to_vec();
return self.start_monitoring(0); }
};
self.size = bytes.len();
self.locked_value = bytes.to_vec();
write_memory_bytes(handle.0, address, &self.locked_value)?;
self.start_monitoring(address)?;
Ok(())
}
pub fn unlock(&mut self) -> Result<(), MemoryError> {
if let Some(thread) = self.worker_thread.take() {
self.stop_flag.store(true, Ordering::Relaxed);
if let Err(e) = thread.join() {
return Err(MemoryError::ReadFailed(
format!("Failed to join worker thread: {:?}", e)
));
}
}
Ok(())
}
pub fn reset(&mut self) {
if self.is_locked() {
let _ = self.unlock();
}
self.handle = None;
self.pid = None;
self.address_source = None;
self.size = 0;
self.locked_value.clear();
self.scan_interval = Duration::from_millis(10); self.stop_flag.store(false, Ordering::Relaxed);
self.worker_thread = None;
}
pub fn is_locked(&self) -> bool {
self.worker_thread.is_some()
}
pub fn get_locked_value(&self) -> &[u8] {
&self.locked_value
}
fn start_monitoring(&mut self, initial_address: usize) -> Result<(), MemoryError> {
if self.worker_thread.is_some() {
self.unlock()?;
}
self.stop_flag.store(false, Ordering::Relaxed);
let handle = self.handle.as_ref().unwrap();
let address_source = self.address_source.clone().unwrap();
let pid = self.pid.unwrap();
let handle_int = handle.0 .0 as isize; let locked_value = self.locked_value.clone();
let size = self.size;
let interval = self.scan_interval;
let stop_flag: Arc<AtomicBool> = Arc::clone(&self.stop_flag);
let thread_handle = thread::spawn(move || {
let handle = HANDLE(handle_int as *mut std::ffi::c_void);
let mut current_address = initial_address;
while !stop_flag.load(Ordering::Relaxed) {
if let AddressSource::Dynamic(_) = address_source {
match address_source.resolve(handle, pid) {
Ok(new_addr) => {
current_address = new_addr;
}
Err(_) => {
thread::sleep(interval);
continue;
}
}
}
match read_memory_bytes(handle, current_address, size) {
Ok(current_value) => {
if current_value != locked_value {
let _ = write_memory_bytes(handle, current_address, &locked_value);
}
}
Err(_) => {
thread::sleep(interval);
continue;
}
}
thread::sleep(interval);
}
});
self.worker_thread = Some(thread_handle);
Ok(())
}
}
impl Drop for MemoryLock {
fn drop(&mut self) {
let _ = self.unlock();
}
}
pub trait AsBytes {
fn as_bytes(&self) -> &[u8];
}
macro_rules! impl_as_bytes {
($($t:ty),*) => {
$(
impl AsBytes for $t {
fn as_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self as *const $t as *const u8,
std::mem::size_of::<$t>()
)
}
}
}
)*
};
}
impl_as_bytes!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
#[derive(Clone)]
pub struct MemoryLockBuilder {
handle: Option<HANDLE>,
pid: Option<u32>,
address_source: Option<AddressSource>,
locked_value: Option<Vec<u8>>,
scan_interval: Duration,
}
impl MemoryLockBuilder {
pub fn new() -> Self {
Self {
handle: None,
pid: None,
address_source: None,
locked_value: None,
scan_interval: Duration::from_millis(10), }
}
pub fn handle(mut self, handle: HANDLE) -> Self {
self.handle = Some(handle);
self
}
pub fn pid(mut self, pid: u32) -> Self {
self.pid = Some(pid);
self
}
pub fn address(mut self, addr: usize) -> Self {
self.address_source = Some(AddressSource::Static(addr));
self
}
pub fn address_from_resolver(mut self, mem_addr: MemoryAddress) -> Self {
self.address_source = Some(AddressSource::Dynamic(mem_addr));
self
}
pub fn value<T: Copy + AsBytes>(mut self, val: T) -> Self {
self.locked_value = Some(val.as_bytes().to_vec());
self
}
pub fn bytes(mut self, data: Vec<u8>) -> Self {
self.locked_value = Some(data);
self
}
pub fn scan_interval_ms(mut self, ms: u64) -> Self {
self.scan_interval = Duration::from_millis(ms);
self
}
pub fn scan_interval(mut self, interval: Duration) -> Self {
self.scan_interval = interval;
self
}
pub fn build(self) -> Result<MemoryLock, MemoryError> {
let handle = self.handle.ok_or_else(|| {
MemoryError::WriteFailed(
"handle must be set. Call .handle(handle) before build().".to_string()
)
})?;
let address_source = self.address_source.ok_or_else(|| {
MemoryError::WriteFailed(
"address must be set. Call .address(addr) or .address_from_resolver(addr) before build().".to_string()
)
})?;
let locked_value = self.locked_value.ok_or_else(|| {
MemoryError::WriteFailed(
"locked_value must be set. Call .value(val) or .bytes(data) before build().".to_string()
)
})?;
let size = locked_value.len();
if matches!(address_source, AddressSource::Dynamic(_)) && self.pid.is_none() {
return Err(MemoryError::WriteFailed(
"PID must be set when using dynamic address resolution. Call .pid(pid) before build().".to_string()
));
}
Ok(MemoryLock {
handle: Some(SendableHandle(handle)),
pid: self.pid,
address_source: Some(address_source),
size,
locked_value,
scan_interval: self.scan_interval,
stop_flag: Arc::new(AtomicBool::new(false)),
worker_thread: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_as_bytes_u32() {
let value: u32 = 0x12345678;
let bytes = value.as_bytes();
assert_eq!(bytes.len(), 4);
assert_eq!(bytes[0], 0x78);
assert_eq!(bytes[3], 0x12);
}
#[test]
fn test_builder_clone_support() {
let builder = MemoryLock::builder()
.address(0x1000)
.value(100u32)
.scan_interval_ms(5);
let lock1 = builder.clone()
.handle(HANDLE::default())
.build();
let lock2 = builder.clone()
.handle(HANDLE::default())
.build();
assert!(lock1.is_ok());
assert!(lock2.is_ok());
}
#[test]
fn test_builder_validation_missing_handle() {
let result = MemoryLock::builder()
.address(0x1000)
.value(100u32)
.build();
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("handle must be set"));
}
}
#[test]
fn test_builder_validation_missing_address() {
let result = MemoryLock::builder()
.handle(HANDLE::default())
.value(100u32)
.build();
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("address must be set"));
}
}
#[test]
fn test_builder_validation_missing_value() {
let result = MemoryLock::builder()
.handle(HANDLE::default())
.address(0x1000)
.build();
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("locked_value must be set"));
}
}
}