use std::collections::BTreeMap;
#[derive(Debug, Default, Clone)]
pub struct Context {
pub vfs: Option<VirtualFs>,
pub registry: Option<VirtualRegistry>,
}
impl Context {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_vfs(mut self, vfs: VirtualFs) -> Self {
self.vfs = Some(vfs);
self
}
#[must_use]
pub fn with_registry(mut self, reg: VirtualRegistry) -> Self {
self.registry = Some(reg);
self
}
}
#[derive(Debug, Default, Clone)]
pub struct VirtualFs {
files: BTreeMap<String, Vec<u8>>,
open: BTreeMap<u32, FileHandle>,
next_handle: u32,
}
#[derive(Debug, Clone)]
pub struct FileHandle {
pub path: String,
pub pos: u64,
pub access: FileAccess,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FileAccess {
Read,
Write,
ReadWrite,
}
impl FileAccess {
#[must_use]
pub fn from_win32_desired_access(flags: u32) -> Self {
let read = flags & 0x8000_0000 != 0;
let write = flags & 0x4000_0000 != 0;
match (read, write) {
(true, true) => FileAccess::ReadWrite,
(false, true) => FileAccess::Write,
_ => FileAccess::Read, }
}
fn allows_read(self) -> bool {
matches!(self, FileAccess::Read | FileAccess::ReadWrite)
}
fn allows_write(self) -> bool {
matches!(self, FileAccess::Write | FileAccess::ReadWrite)
}
}
pub const HANDLE_BASE: u32 = 0x6800_0000;
impl VirtualFs {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, path: &str, bytes: Vec<u8>) {
self.files.insert(normalize_path(path), bytes);
}
#[must_use]
pub fn read(&self, path: &str) -> Option<&[u8]> {
self.files.get(&normalize_path(path)).map(Vec::as_slice)
}
#[must_use]
pub fn contains(&self, path: &str) -> bool {
self.files.contains_key(&normalize_path(path))
}
pub fn write_path(&mut self, path: &str, bytes: Vec<u8>) {
self.files.insert(normalize_path(path), bytes);
}
pub fn remove(&mut self, path: &str) -> bool {
self.files.remove(&normalize_path(path)).is_some()
}
pub fn list(&self) -> impl Iterator<Item = (&str, usize)> {
self.files.iter().map(|(k, v)| (k.as_str(), v.len()))
}
pub fn open(&mut self, path: &str, access: FileAccess) -> Option<u32> {
let key = normalize_path(path);
let exists = self.files.contains_key(&key);
if !exists {
if !access.allows_write() {
return None;
}
self.files.insert(key.clone(), Vec::new());
}
let handle = HANDLE_BASE.wrapping_add(self.next_handle);
self.next_handle = self.next_handle.wrapping_add(1);
self.open.insert(
handle,
FileHandle {
path: key,
pos: 0,
access,
},
);
Some(handle)
}
pub fn close(&mut self, handle: u32) -> bool {
self.open.remove(&handle).is_some()
}
pub fn read_handle(&mut self, handle: u32, buf: &mut [u8]) -> Option<usize> {
let fh = self.open.get_mut(&handle)?;
if !fh.access.allows_read() {
return None;
}
let file = self.files.get(&fh.path)?;
let pos = fh.pos as usize;
if pos >= file.len() {
return Some(0);
}
let n = buf.len().min(file.len() - pos);
buf[..n].copy_from_slice(&file[pos..pos + n]);
fh.pos = fh.pos.wrapping_add(n as u64);
Some(n)
}
pub fn write_handle(&mut self, handle: u32, data: &[u8]) -> Option<usize> {
let fh = self.open.get_mut(&handle)?;
if !fh.access.allows_write() {
return None;
}
let file = self.files.get_mut(&fh.path)?;
let pos = fh.pos as usize;
if pos + data.len() > file.len() {
file.resize(pos + data.len(), 0);
}
file[pos..pos + data.len()].copy_from_slice(data);
fh.pos = fh.pos.wrapping_add(data.len() as u64);
Some(data.len())
}
pub fn seek(&mut self, handle: u32, pos: u64) -> Option<u64> {
let fh = self.open.get_mut(&handle)?;
fh.pos = pos;
Some(pos)
}
#[must_use]
pub fn size(&self, handle: u32) -> Option<u64> {
let fh = self.open.get(&handle)?;
let file = self.files.get(&fh.path)?;
Some(file.len() as u64)
}
#[must_use]
pub fn owns(&self, handle: u32) -> bool {
self.open.contains_key(&handle)
}
}
fn normalize_path(path: &str) -> String {
let mut out = String::with_capacity(path.len());
for c in path.chars() {
if c == '\\' {
out.push('/');
} else {
out.extend(c.to_lowercase());
}
}
out
}
#[derive(Debug, Default, Clone)]
pub struct VirtualRegistry {
keys: BTreeMap<String, RegistryKey>,
open: BTreeMap<u32, OpenKey>,
next_handle: u32,
}
#[derive(Debug, Clone)]
pub struct OpenKey {
pub path: String,
}
#[derive(Debug, Default, Clone)]
pub struct RegistryKey {
values: BTreeMap<String, RegistryValue>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RegistryValue {
Sz(String),
ExpandSz(String),
Dword(u32),
Qword(u64),
Binary(Vec<u8>),
MultiSz(Vec<String>),
}
pub const HKEY_CLASSES_ROOT: u32 = 0x8000_0000;
pub const HKEY_CURRENT_USER: u32 = 0x8000_0001;
pub const HKEY_LOCAL_MACHINE: u32 = 0x8000_0002;
pub const HKEY_USERS: u32 = 0x8000_0003;
pub const HKCR: u32 = HKEY_CLASSES_ROOT;
pub const HKCU: u32 = HKEY_CURRENT_USER;
pub const HKLM: u32 = HKEY_LOCAL_MACHINE;
pub const HKU: u32 = HKEY_USERS;
pub const HKEY_USER_BASE: u32 = 0x6900_0000;
impl VirtualRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_value(&mut self, key_path: &str, name: &str, value: RegistryValue) {
let key = normalize_path(key_path);
let entry = self.keys.entry(key).or_default();
entry.values.insert(name.to_ascii_lowercase(), value);
}
#[must_use]
pub fn get_value(&self, key_path: &str, name: &str) -> Option<&RegistryValue> {
let key = normalize_path(key_path);
self.keys
.get(&key)
.and_then(|k| k.values.get(&name.to_ascii_lowercase()))
}
pub fn all_values(&self) -> impl Iterator<Item = (&str, &str, &RegistryValue)> {
self.keys.iter().flat_map(|(key_path, key)| {
key.values
.iter()
.map(move |(name, value)| (key_path.as_str(), name.as_str(), value))
})
}
#[must_use]
pub fn contains_key(&self, key_path: &str) -> bool {
self.keys.contains_key(&normalize_path(key_path))
}
#[must_use]
pub fn predefined_path(hkey: u32) -> Option<&'static str> {
match hkey {
HKEY_CLASSES_ROOT => Some("hkey_classes_root"),
HKEY_CURRENT_USER => Some("hkey_current_user"),
HKEY_LOCAL_MACHINE => Some("hkey_local_machine"),
HKEY_USERS => Some("hkey_users"),
_ => None,
}
}
pub fn open_key(&mut self, base_hkey: u32, subkey: &str) -> Option<u32> {
let base_path = if let Some(p) = Self::predefined_path(base_hkey) {
p.to_string()
} else {
self.open.get(&base_hkey)?.path.clone()
};
let combined = if subkey.is_empty() {
base_path
} else {
format!("{}/{}", base_path, normalize_path(subkey))
};
if !self.keys.contains_key(&combined) {
return None;
}
let h = HKEY_USER_BASE.wrapping_add(self.next_handle);
self.next_handle = self.next_handle.wrapping_add(1);
self.open.insert(h, OpenKey { path: combined });
Some(h)
}
pub fn close_key(&mut self, hkey: u32) -> bool {
if Self::predefined_path(hkey).is_some() {
return true;
}
self.open.remove(&hkey).is_some()
}
#[must_use]
pub fn owns(&self, hkey: u32) -> bool {
Self::predefined_path(hkey).is_some() || self.open.contains_key(&hkey)
}
#[must_use]
pub fn path_of(&self, hkey: u32) -> Option<&str> {
if let Some(p) = Self::predefined_path(hkey) {
return Some(p);
}
self.open.get(&hkey).map(|k| k.path.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vfs_insert_and_read_case_insensitive() {
let mut vfs = VirtualFs::new();
vfs.insert("C:\\Windows\\foo.ini", b"hello".to_vec());
assert_eq!(vfs.read("C:\\Windows\\foo.ini"), Some(&b"hello"[..]));
assert_eq!(vfs.read("c:/windows/FOO.INI"), Some(&b"hello"[..]));
assert_eq!(vfs.read("nope.txt"), None);
}
#[test]
fn vfs_open_read_close() {
let mut vfs = VirtualFs::new();
vfs.insert("a.txt", b"hello world".to_vec());
let h = vfs.open("a.txt", FileAccess::Read).expect("opens");
let mut buf = [0u8; 5];
assert_eq!(vfs.read_handle(h, &mut buf), Some(5));
assert_eq!(&buf, b"hello");
assert_eq!(vfs.read_handle(h, &mut buf), Some(5));
assert_eq!(&buf, b" worl");
let mut tail = [0u8; 5];
assert_eq!(vfs.read_handle(h, &mut tail), Some(1));
assert_eq!(&tail[..1], b"d");
assert_eq!(vfs.read_handle(h, &mut buf), Some(0));
assert!(vfs.close(h));
}
#[test]
fn vfs_write_extends_and_round_trips() {
let mut vfs = VirtualFs::new();
let h = vfs.open("new.txt", FileAccess::Write).expect("opens");
assert_eq!(vfs.write_handle(h, b"hello").unwrap(), 5);
assert_eq!(vfs.write_handle(h, b" world").unwrap(), 6);
vfs.close(h);
assert_eq!(vfs.read("new.txt"), Some(&b"hello world"[..]));
}
#[test]
fn vfs_read_only_handle_cannot_write() {
let mut vfs = VirtualFs::new();
vfs.insert("a.txt", b"hi".to_vec());
let h = vfs.open("a.txt", FileAccess::Read).unwrap();
assert!(vfs.write_handle(h, b"!").is_none());
}
#[test]
fn vfs_open_nonexistent_read_returns_none() {
let mut vfs = VirtualFs::new();
assert!(vfs.open("missing.txt", FileAccess::Read).is_none());
}
#[test]
fn registry_set_get_case_insensitive() {
let mut reg = VirtualRegistry::new();
reg.set_value(
"HKLM\\Software\\Foo",
"Version",
RegistryValue::Sz("1.2.3".into()),
);
assert_eq!(
reg.get_value("hklm/software/foo", "version"),
Some(&RegistryValue::Sz("1.2.3".into()))
);
assert_eq!(
reg.get_value("HKLM\\Software\\Foo", "VERSION"),
Some(&RegistryValue::Sz("1.2.3".into()))
);
}
#[test]
fn registry_open_close_round_trip() {
let mut reg = VirtualRegistry::new();
reg.set_value(
"hkey_local_machine/software/foo",
"x",
RegistryValue::Dword(1),
);
let h = reg.open_key(HKLM, "Software\\Foo").expect("opens");
assert!(reg.owns(h));
assert_eq!(reg.path_of(h), Some("hkey_local_machine/software/foo"));
assert!(reg.close_key(h));
}
#[test]
fn context_builders() {
let mut vfs = VirtualFs::new();
vfs.insert("a.txt", b"x".to_vec());
let mut reg = VirtualRegistry::new();
reg.set_value("hklm", "v", RegistryValue::Dword(1));
let ctx = Context::new().with_vfs(vfs).with_registry(reg);
assert!(ctx.vfs.is_some());
assert!(ctx.registry.is_some());
}
}