use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
use core::ptr::null_mut;
use core::sync::atomic::{AtomicPtr, AtomicU64, Ordering};
use spin::Mutex;
use vck_common::{types::Guid, EncryptedOffsetStore, SectorIo, VckResult, VolumeCipher};
use wdk_sys::{DEVICE_OBJECT, DRIVER_OBJECT};
use crate::{offset::engine::EncryptionEngine, provider::IoConfig};
#[derive(Clone)]
pub struct HandoverInfo {
pub partition_guid: Guid,
pub vmk: Vec<u8>,
}
static GLOBAL_REGISTRY: AtomicPtr<VolumeAttachRegistry> = AtomicPtr::new(null_mut());
pub fn set_global_registry(registry: &'static VolumeAttachRegistry) {
GLOBAL_REGISTRY.store(
registry as *const VolumeAttachRegistry as *mut VolumeAttachRegistry,
Ordering::Release,
);
}
pub fn global_registry() -> Option<&'static VolumeAttachRegistry> {
let ptr = GLOBAL_REGISTRY.load(Ordering::Acquire);
if ptr.is_null() {
None
} else {
Some(unsafe { &*ptr })
}
}
pub struct PdoFilterEntry {
pub pdo: *mut DEVICE_OBJECT,
pub filter_do: *mut DEVICE_OBJECT,
pub lower_do: *mut DEVICE_OBJECT,
pub pdo_name: alloc::string::String,
}
unsafe impl Send for PdoFilterEntry {}
unsafe impl Sync for PdoFilterEntry {}
pub struct VolumeAttachRegistry {
entries: Mutex<BTreeMap<String, Arc<AttachedVolume>>>,
driver_object: AtomicPtr<DRIVER_OBJECT>,
pdo_filters: Mutex<alloc::vec::Vec<PdoFilterEntry>>,
handover: Mutex<Option<HandoverInfo>>,
}
pub struct AttachedVolume {
pub volume_path: String,
pub sector_size: u32,
pub io_config: IoConfig,
pub encryption: Mutex<EncryptionEngine>,
pub offset_store: Arc<dyn EncryptedOffsetStore>,
pub attach_source: AttachSource,
pub filter_device: AtomicPtr<DEVICE_OBJECT>,
pub sweep_io: Mutex<Arc<dyn SectorIo>>,
pub encrypted_boundary: AtomicU64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AttachSource {
Handover,
Ioctl,
}
impl VolumeAttachRegistry {
pub fn new() -> Self {
Self {
entries: Mutex::new(BTreeMap::new()),
driver_object: AtomicPtr::new(null_mut()),
pdo_filters: Mutex::new(alloc::vec::Vec::new()),
handover: Mutex::new(None),
}
}
pub fn set_handover(&self, info: HandoverInfo) {
crate::vck_log!(
"registry: handover set partition_guid={}",
info.partition_guid
);
*self.handover.lock() = Some(info);
}
pub fn handover(&self) -> Option<HandoverInfo> {
self.handover.lock().clone()
}
pub fn add_pdo_filter(
&self,
pdo: *mut DEVICE_OBJECT,
filter_do: *mut DEVICE_OBJECT,
lower_do: *mut DEVICE_OBJECT,
pdo_name: alloc::string::String,
) {
crate::vck_log!("add_pdo_filter: name={} filter={:p}", pdo_name, filter_do);
self.pdo_filters.lock().push(PdoFilterEntry {
pdo,
filter_do,
lower_do,
pdo_name,
});
}
pub fn pdo_name_for_filter(&self, filter_do: *mut DEVICE_OBJECT) -> Option<String> {
let map = self.pdo_filters.lock();
for e in map.iter() {
if e.filter_do == filter_do && !e.pdo_name.is_empty() {
return Some(e.pdo_name.clone());
}
}
None
}
pub fn find_pdo_filter(
&self,
dev: *mut DEVICE_OBJECT,
) -> Option<(*mut DEVICE_OBJECT, *mut DEVICE_OBJECT)> {
let map = self.pdo_filters.lock();
for e in map.iter() {
if e.pdo == dev || e.lower_do == dev {
return Some((e.filter_do, e.lower_do));
}
}
None
}
pub fn find_pdo_filter_by_name(
&self,
nt_path: &str,
) -> Option<(*mut DEVICE_OBJECT, *mut DEVICE_OBJECT)> {
let query = nt_path.trim_end_matches('\\').to_ascii_lowercase();
let map = self.pdo_filters.lock();
for e in map.iter() {
let name = e.pdo_name.trim_end_matches('\\').to_ascii_lowercase();
if name == query {
return Some((e.filter_do, e.lower_do));
}
}
None
}
pub fn set_driver_object(&self, driver: *mut DRIVER_OBJECT) {
self.driver_object.store(driver, Ordering::Release);
}
pub fn driver_object(&self) -> *mut DRIVER_OBJECT {
self.driver_object.load(Ordering::Acquire)
}
pub fn insert(&self, volume: Arc<AttachedVolume>) {
self.entries
.lock()
.insert(volume.volume_path.clone(), volume);
}
pub fn get(&self, volume_path: &str) -> Option<Arc<AttachedVolume>> {
self.entries.lock().get(volume_path).cloned()
}
pub fn get_by_filter(&self, filter_do: *mut DEVICE_OBJECT) -> Option<Arc<AttachedVolume>> {
if filter_do.is_null() {
return None;
}
self.entries
.lock()
.values()
.find(|v| v.filter_device.load(Ordering::Acquire) == filter_do)
.cloned()
}
pub fn remove(&self, volume_path: &str) -> Option<Arc<AttachedVolume>> {
self.entries.lock().remove(volume_path)
}
pub fn all(&self) -> Vec<Arc<AttachedVolume>> {
self.entries.lock().values().cloned().collect()
}
pub fn has_encrypted_os_volume(&self) -> bool {
self.entries
.lock()
.values()
.any(|v| v.is_os_volume() && v.has_encrypted_data())
}
}
impl AttachedVolume {
pub fn is_os_volume(&self) -> bool {
matches!(self.attach_source, AttachSource::Handover)
}
pub fn has_encrypted_data(&self) -> bool {
self.encrypted_boundary.load(Ordering::Acquire) > 0
}
pub fn offset_sector(&self) -> u64 {
self.io_config.offset_sector()
}
pub fn data_sectors(&self) -> u64 {
match &self.io_config {
IoConfig::Passthrough => 0,
IoConfig::Encrypted {
encrypted_offset, ..
}
| IoConfig::Custom {
encrypted_offset, ..
} => encrypted_offset.total_sectors,
}
}
pub fn cipher(&self) -> Option<&dyn VolumeCipher> {
match &self.io_config {
IoConfig::Encrypted {
cipher: Some(c), ..
} => Some(&**c),
_ => None,
}
}
pub fn sync_boundary(&self) {
let boundary = self.encryption.lock().encrypted_boundary();
self.encrypted_boundary.store(boundary, Ordering::Release);
}
pub fn sweep_step(&self, batch_sectors: u64) -> VckResult<bool> {
let cipher = match self.cipher() {
Some(cipher) => cipher,
None => return Ok(false),
};
let io = self.sweep_io.lock().clone();
let result = {
let mut engine = self.encryption.lock();
engine.progress_step(
io.as_ref(),
cipher,
self.offset_store.as_ref(),
batch_sectors,
)
}; if result.is_ok() {
self.sync_boundary();
}
result
}
}
impl Default for VolumeAttachRegistry {
fn default() -> Self {
Self::new()
}
}