use core::cell::Cell;
use core::ffi::c_void;
use std::collections::HashMap;
use std::os::windows::ffi::OsStrExt as _;
use std::path::Path;
use std::sync::{Arc, mpsc as std_mpsc};
use std::thread;
use ironrdp_core::impl_as_any;
use ironrdp_dvc::{DvcClientProcessor, DvcMessage, DvcProcessor};
use ironrdp_pdu::{PduResult, pdu_other_err};
use ironrdp_svc::SvcMessage;
use tracing::{debug, error, info, warn};
use windows::Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryW};
use windows::Win32::System::RemoteDesktop::{IWTSListenerCallback, IWTSPlugin, IWTSVirtualChannelManager};
use windows::core::{HRESULT, PCSTR, PCWSTR};
use windows_core::{GUID, Interface as _};
use crate::com::{ChannelManager, OnWriteDvc};
use crate::worker::{ComCommand, run_com_worker};
type VirtualChannelGetInstanceFn =
unsafe extern "system" fn(refiid: *const GUID, pnumobjs: *mut u32, ppobjarray: *mut *mut c_void) -> HRESULT;
pub struct DvcComChannel {
channel_name: String,
command_tx: std_mpsc::Sender<ComCommand>,
on_write_dvc_tx: std_mpsc::Sender<OnWriteDvc>,
on_write_dvc_factory: Arc<dyn Fn() -> OnWriteDvcMessage + Send + Sync>,
needs_connected: bool,
_worker_handle: Option<thread::JoinHandle<()>>,
}
impl_as_any!(DvcComChannel);
impl DvcProcessor for DvcComChannel {
fn channel_name(&self) -> &str {
&self.channel_name
}
fn start(&mut self, channel_id: u32) -> PduResult<Vec<DvcMessage>> {
info!(
channel_name = %self.channel_name,
channel_id,
"DVC COM channel start"
);
if self.needs_connected {
self.needs_connected = false;
let _ = self.command_tx.send(ComCommand::Connected);
}
let write_cb = (self.on_write_dvc_factory)();
let _ = self.on_write_dvc_tx.send(write_cb);
let (accept_tx, accept_rx) = std_mpsc::sync_channel(1);
self.command_tx
.send(ComCommand::ChannelOpened {
channel_name: self.channel_name.clone(),
channel_id,
accept_tx,
})
.map_err(|_| pdu_other_err!("COM worker thread is gone"))?;
let accepted = accept_rx.recv().unwrap_or(false);
if accepted {
info!(
channel_name = %self.channel_name,
channel_id,
"COM plugin accepted DVC channel"
);
} else {
warn!(
channel_name = %self.channel_name,
channel_id,
"COM plugin rejected DVC channel"
);
}
Ok(vec![])
}
fn process(&mut self, channel_id: u32, payload: &[u8]) -> PduResult<Vec<DvcMessage>> {
self.command_tx
.send(ComCommand::DataReceived {
channel_id,
data: payload.to_vec(),
})
.map_err(|_| pdu_other_err!("COM worker thread is gone"))?;
Ok(vec![])
}
fn close(&mut self, channel_id: u32) {
debug!(
channel_name = %self.channel_name,
channel_id,
"DVC COM channel close"
);
let _ = self.command_tx.send(ComCommand::ChannelClosed { channel_id });
}
}
impl DvcClientProcessor for DvcComChannel {}
impl Drop for DvcComChannel {
fn drop(&mut self) {
let _ = self.command_tx.send(ComCommand::Shutdown);
}
}
pub(crate) type OnWriteDvcMessage = Box<dyn Fn(u32, Vec<SvcMessage>) -> PduResult<()> + Send + 'static>;
pub fn load_dvc_plugin<F>(dll_path: &Path, on_write_dvc_factory: F) -> PduResult<Vec<DvcComChannel>>
where
F: Fn() -> OnWriteDvcMessage + Send + Sync + 'static,
{
info!(dll = %dll_path.display(), "Loading DVC COM plugin");
let (command_tx, command_rx) = std_mpsc::channel();
let (on_write_dvc_tx, on_write_dvc_rx) = std_mpsc::channel();
let (init_tx, init_rx) = std_mpsc::sync_channel::<Result<Vec<String>, String>>(1);
let dll_path_owned = dll_path.to_path_buf();
let _on_write_dvc_tx_clone = on_write_dvc_tx.clone();
let _worker_handle = thread::Builder::new()
.name("dvc-com-worker".into())
.spawn(move || {
match initialize_plugin_on_thread(&dll_path_owned) {
Ok((plugin, manager, listeners)) => {
let channel_names: Vec<String> = listeners.keys().cloned().collect();
info!(
channels = ?channel_names,
"Plugin initialized, registered {} listener(s)",
channel_names.len()
);
let _ = init_tx.send(Ok(channel_names));
run_com_worker(plugin, manager, listeners, command_rx, on_write_dvc_rx);
}
Err(e) => {
error!(error = %e, "Failed to initialize DVC COM plugin");
let _ = init_tx.send(Err(e));
}
}
})
.expect("spawn COM worker thread");
let channel_names = init_rx
.recv()
.map_err(|_| pdu_other_err!("COM worker thread died during initialization"))?
.map_err(|e| pdu_other_err!("plugin initialization failed").with_source(std::io::Error::other(e)))?;
if channel_names.is_empty() {
warn!(dll = %dll_path.display(), "Plugin registered no listeners");
}
let mut channels = Vec::with_capacity(channel_names.len());
let is_first = Cell::new(true);
let factory: Arc<dyn Fn() -> OnWriteDvcMessage + Send + Sync> = Arc::new(on_write_dvc_factory);
for name in channel_names {
debug!(channel_name = %name, "Creating DvcComChannel");
channels.push(DvcComChannel {
channel_name: name,
command_tx: command_tx.clone(),
on_write_dvc_tx: on_write_dvc_tx.clone(),
on_write_dvc_factory: Arc::clone(&factory),
needs_connected: is_first.get(),
_worker_handle: None,
});
is_first.set(false);
}
Ok(channels)
}
fn initialize_plugin_on_thread(
dll_path: &Path,
) -> Result<
(
IWTSPlugin,
IWTSVirtualChannelManager,
HashMap<String, IWTSListenerCallback>,
),
String,
> {
let dll_path_wide: Vec<u16> = dll_path.as_os_str().encode_wide().chain(core::iter::once(0)).collect();
let dll_path_pcwstr = PCWSTR(dll_path_wide.as_ptr());
let hmodule = unsafe { LoadLibraryW(dll_path_pcwstr) }.map_err(|e| format!("LoadLibraryW failed: {e}"))?;
info!(dll = %dll_path.display(), "DLL loaded successfully");
let proc_name = PCSTR::from_raw(c"VirtualChannelGetInstance".as_ptr().cast::<u8>());
let proc_addr = unsafe { GetProcAddress(hmodule, proc_name) }
.ok_or_else(|| "VirtualChannelGetInstance export not found in DLL".to_owned())?;
let get_instance: VirtualChannelGetInstanceFn = unsafe { core::mem::transmute(proc_addr) };
info!("VirtualChannelGetInstance export found");
let iid = IWTSPlugin::IID;
let mut num_objs: u32 = 0;
let hr = unsafe { get_instance(&iid, &mut num_objs, core::ptr::null_mut()) };
if hr.is_err() {
return Err(format!(
"VirtualChannelGetInstance phase 1 failed: HRESULT 0x{:08X}",
hr.0
));
}
info!(count = num_objs, "Plugin reports {} object(s)", num_objs);
if num_objs == 0 {
return Err("plugin returned 0 objects".to_owned());
}
let mut obj_array: Vec<*mut c_void> =
vec![core::ptr::null_mut(); usize::try_from(num_objs).expect("u32 fits in usize")];
let hr = unsafe { get_instance(&iid, &mut num_objs, obj_array.as_mut_ptr()) };
if hr.is_err() {
return Err(format!(
"VirtualChannelGetInstance phase 2 failed: HRESULT 0x{:08X}",
hr.0
));
}
let plugin_ptr = obj_array[0];
if plugin_ptr.is_null() {
return Err("VirtualChannelGetInstance returned null plugin pointer".to_owned());
}
let plugin: IWTSPlugin = unsafe { IWTSPlugin::from_raw(plugin_ptr) };
info!("Got IWTSPlugin COM object");
let listeners_rc = std::rc::Rc::new(core::cell::RefCell::new(HashMap::new()));
let channel_manager_impl = ChannelManager::new(std::rc::Rc::clone(&listeners_rc));
let manager: IWTSVirtualChannelManager = channel_manager_impl.into();
unsafe { plugin.Initialize(&manager) }.map_err(|e| format!("IWTSPlugin::Initialize failed: {e}"))?;
info!("IWTSPlugin::Initialize succeeded");
let listeners: HashMap<String, IWTSListenerCallback> = listeners_rc.borrow().clone();
Ok((plugin, manager, listeners))
}