use std::collections::HashMap;
use std::sync::mpsc as std_mpsc;
use tracing::{debug, error, info, trace, warn};
use windows::Win32::System::RemoteDesktop::{
IWTSListenerCallback, IWTSPlugin, IWTSVirtualChannel, IWTSVirtualChannelCallback, IWTSVirtualChannelManager,
};
use windows_core::{BOOL, BSTR};
use crate::com::{ActiveChannel, OnWriteDvc, VirtualChannel};
pub(crate) enum ComCommand {
ChannelOpened {
channel_name: String,
channel_id: u32,
accept_tx: std_mpsc::SyncSender<bool>,
},
DataReceived { channel_id: u32, data: Vec<u8> },
ChannelClosed { channel_id: u32 },
Connected,
Shutdown,
}
pub(crate) fn run_com_worker(
plugin: IWTSPlugin,
_manager: IWTSVirtualChannelManager,
listeners: HashMap<String, IWTSListenerCallback>,
command_rx: std_mpsc::Receiver<ComCommand>,
on_write_dvc_rx: std_mpsc::Receiver<OnWriteDvc>,
) {
info!("COM worker thread started");
let mut active_channels: HashMap<u32, ActiveChannel> = HashMap::new();
loop {
let cmd = match command_rx.recv() {
Ok(cmd) => cmd,
Err(_) => {
debug!("Command channel closed, shutting down COM worker");
break;
}
};
match cmd {
ComCommand::Connected => {
debug!("Notifying plugin: Connected");
let result = unsafe { plugin.Connected() };
if let Err(e) = result {
warn!("IWTSPlugin::Connected returned error (non-fatal): {e}");
}
}
ComCommand::ChannelOpened {
channel_name,
channel_id,
accept_tx,
} => {
debug!(channel_name = %channel_name, channel_id, "Opening DVC channel via COM plugin");
let listener_callback = match listeners.get(&channel_name) {
Some(cb) => cb,
None => {
warn!(channel_name = %channel_name, "No listener registered for channel");
let _ = accept_tx.send(false);
continue;
}
};
let on_write: OnWriteDvc = match on_write_dvc_rx.try_recv() {
Ok(cb) => cb,
Err(_) => {
error!("No write callback for channel {channel_name}");
let _ = accept_tx.send(false);
continue;
}
};
let virtual_channel: IWTSVirtualChannel = VirtualChannel::new(channel_id, on_write).into();
let mut accept = BOOL::default();
let mut channel_callback: Option<IWTSVirtualChannelCallback> = None;
let result = unsafe {
listener_callback.OnNewChannelConnection(
&virtual_channel,
&BSTR::default(),
&mut accept,
&mut channel_callback,
)
};
match result {
Ok(()) if accept.as_bool() => {
if let Some(callback) = channel_callback {
info!(channel_name = %channel_name, channel_id, "Plugin accepted DVC channel");
active_channels.insert(
channel_id,
ActiveChannel {
callback,
_channel: virtual_channel,
},
);
let _ = accept_tx.send(true);
} else {
warn!(
channel_name = %channel_name, channel_id,
"Plugin accepted channel but returned no callback"
);
let _ = accept_tx.send(false);
}
}
Ok(()) => {
debug!(channel_name = %channel_name, channel_id, "Plugin rejected DVC channel");
let _ = accept_tx.send(false);
}
Err(e) => {
warn!(
channel_name = %channel_name, channel_id,
"OnNewChannelConnection failed: {e}"
);
let _ = accept_tx.send(false);
}
}
}
ComCommand::DataReceived { channel_id, data } => {
trace!(channel_id, size = data.len(), "Forwarding data to COM plugin");
if let Some(active) = active_channels.get(&channel_id) {
let result = unsafe { active.callback.OnDataReceived(&data) };
if let Err(e) = result {
warn!(channel_id, "OnDataReceived failed: {e}");
}
} else {
warn!(channel_id, "Data received for unknown channel");
}
}
ComCommand::ChannelClosed { channel_id } => {
debug!(channel_id, "Closing DVC channel in COM plugin");
if let Some(active) = active_channels.remove(&channel_id) {
let result = unsafe { active.callback.OnClose() };
if let Err(e) = result {
warn!(channel_id, "OnClose failed: {e}");
}
}
}
ComCommand::Shutdown => {
info!("Shutting down COM plugin");
for (channel_id, active) in active_channels.drain() {
let result = unsafe { active.callback.OnClose() };
if let Err(e) = result {
warn!(channel_id, "OnClose during shutdown failed: {e}");
}
}
unsafe {
let result = plugin.Disconnected(0);
if let Err(e) = result {
warn!("IWTSPlugin::Disconnected failed: {e}");
}
let result = plugin.Terminated();
if let Err(e) = result {
warn!("IWTSPlugin::Terminated failed: {e}");
}
}
break;
}
}
}
info!("COM worker thread exiting");
}