use std::{
error::Error,
mem,
os::windows::prelude::AsRawHandle,
sync::{
atomic::{self, AtomicBool},
mpsc, Arc,
},
thread::{self, JoinHandle},
};
use log::{debug, info, trace, warn};
use parking_lot::Mutex;
use windows::{
Foundation::AsyncActionCompletedHandler,
Win32::{
Foundation::{HANDLE, LPARAM, WPARAM},
System::{
Threading::{GetCurrentThreadId, GetThreadId},
WinRT::{
CreateDispatcherQueueController, DispatcherQueueOptions, RoInitialize,
RoUninitialize, DQTAT_COM_NONE, DQTYPE_THREAD_CURRENT, RO_INIT_MULTITHREADED,
},
},
UI::WindowsAndMessaging::{
DispatchMessageW, GetMessageW, PostQuitMessage, PostThreadMessageW, TranslateMessage,
MSG, WM_QUIT,
},
},
};
use crate::{
frame::Frame,
graphics_capture_api::{GraphicsCaptureApi, InternalCaptureControl, RESULT},
settings::WindowsCaptureSettings,
};
#[derive(thiserror::Error, Eq, PartialEq, Clone, Copy, Debug)]
pub enum CaptureControlError {
#[error("Failed To Join Thread")]
FailedToJoinThread,
#[error("Thread Handle Is Taken Out Of Struct")]
ThreadHandleIsTaken,
}
pub struct CaptureControl<T: WindowsCaptureHandler + Send + 'static> {
thread_handle: Option<JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>>,
halt_handle: Arc<AtomicBool>,
callback: Arc<Mutex<T>>,
}
impl<T: WindowsCaptureHandler + Send + 'static> CaptureControl<T> {
#[must_use]
pub fn new(
thread_handle: JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>,
halt_handle: Arc<AtomicBool>,
callback: Arc<Mutex<T>>,
) -> Self {
Self {
thread_handle: Some(thread_handle),
halt_handle,
callback,
}
}
#[must_use]
pub fn is_finished(&self) -> bool {
self.thread_handle
.as_ref()
.map_or(true, |thread_handle| thread_handle.is_finished())
}
#[must_use]
pub fn into_thread_handle(self) -> JoinHandle<Result<(), Box<dyn Error + Send + Sync>>> {
self.thread_handle.unwrap()
}
#[must_use]
pub fn halt_handle(&self) -> Arc<AtomicBool> {
self.halt_handle.clone()
}
#[must_use]
pub fn callback(&self) -> Arc<Mutex<T>> {
self.callback.clone()
}
pub fn wait(mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
if let Some(thread_handle) = self.thread_handle.take() {
match thread_handle.join() {
Ok(result) => result?,
Err(_) => {
return Err(Box::new(CaptureControlError::FailedToJoinThread));
}
}
} else {
return Err(Box::new(CaptureControlError::ThreadHandleIsTaken));
}
Ok(())
}
pub fn stop(mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
self.halt_handle.store(true, atomic::Ordering::Relaxed);
if let Some(thread_handle) = self.thread_handle.take() {
let handle = thread_handle.as_raw_handle();
let handle = HANDLE(handle as isize);
let therad_id = unsafe { GetThreadId(handle) };
loop {
match unsafe {
PostThreadMessageW(therad_id, WM_QUIT, WPARAM::default(), LPARAM::default())
} {
Ok(_) => break,
Err(e) => {
if thread_handle.is_finished() {
break;
}
if e.code().0 == -2147023452 {
warn!("Thread Is Not In Message Loop Yet");
} else {
Err(e)?;
}
}
}
}
match thread_handle.join() {
Ok(result) => result?,
Err(_) => {
return Err(Box::new(CaptureControlError::FailedToJoinThread));
}
}
} else {
return Err(Box::new(CaptureControlError::ThreadHandleIsTaken));
}
Ok(())
}
}
pub trait WindowsCaptureHandler: Sized {
type Flags;
fn start(
settings: WindowsCaptureSettings<Self::Flags>,
) -> Result<(), Box<dyn Error + Send + Sync>>
where
Self: Send + 'static,
<Self as WindowsCaptureHandler>::Flags: Send,
{
trace!("Initializing WinRT");
unsafe { RoInitialize(RO_INIT_MULTITHREADED)? };
trace!("Creating A Dispatcher Queue For Capture Thread");
let options = DispatcherQueueOptions {
dwSize: mem::size_of::<DispatcherQueueOptions>() as u32,
threadType: DQTYPE_THREAD_CURRENT,
apartmentType: DQTAT_COM_NONE,
};
let controller = unsafe { CreateDispatcherQueueController(options)? };
let thread_id = unsafe { GetCurrentThreadId() };
debug!("Thread ID: {thread_id}");
info!("Starting Capture Thread");
let callback = Arc::new(Mutex::new(Self::new(settings.flags)?));
let mut capture = GraphicsCaptureApi::new(
settings.item,
callback,
settings.capture_cursor,
settings.draw_border,
settings.color_format,
thread_id,
)?;
capture.start_capture()?;
trace!("Entering Message Loop");
let mut message = MSG::default();
unsafe {
while GetMessageW(&mut message, None, 0, 0).as_bool() {
TranslateMessage(&message);
DispatchMessageW(&message);
}
}
trace!("Shutting Down Dispatcher Queue");
let async_action = controller.ShutdownQueueAsync()?;
async_action.SetCompleted(&AsyncActionCompletedHandler::new(
move |_, _| -> Result<(), windows::core::Error> {
unsafe { PostQuitMessage(0) };
Ok(())
},
))?;
trace!("Entering Final Message Loop");
let mut message = MSG::default();
unsafe {
while GetMessageW(&mut message, None, 0, 0).as_bool() {
TranslateMessage(&message);
DispatchMessageW(&message);
}
}
info!("Stopping Capture Thread");
capture.stop_capture();
trace!("Uninitializing WinRT");
unsafe { RoUninitialize() };
trace!("Checking RESULT");
let result = RESULT.take().expect("Failed To Take RESULT");
result?;
Ok(())
}
fn start_free_threaded(
settings: WindowsCaptureSettings<Self::Flags>,
) -> Result<CaptureControl<Self>, Box<dyn Error + Send + Sync>>
where
Self: Send + 'static,
<Self as WindowsCaptureHandler>::Flags: Send,
{
let (halt_sender, halt_receiver) = mpsc::channel::<Arc<AtomicBool>>();
let (callback_sender, callback_receiver) = mpsc::channel::<Arc<Mutex<Self>>>();
let thread_handle = thread::spawn(move || -> Result<(), Box<dyn Error + Send + Sync>> {
trace!("Initializing WinRT");
unsafe { RoInitialize(RO_INIT_MULTITHREADED)? };
trace!("Creating A Dispatcher Queue For Capture Thread");
let options = DispatcherQueueOptions {
dwSize: mem::size_of::<DispatcherQueueOptions>() as u32,
threadType: DQTYPE_THREAD_CURRENT,
apartmentType: DQTAT_COM_NONE,
};
let controller = unsafe { CreateDispatcherQueueController(options)? };
let thread_id = unsafe { GetCurrentThreadId() };
debug!("Thread ID: {thread_id}");
info!("Starting Capture Thread");
let callback = Arc::new(Mutex::new(Self::new(settings.flags)?));
let mut capture = GraphicsCaptureApi::new(
settings.item,
callback.clone(),
settings.capture_cursor,
settings.draw_border,
settings.color_format,
thread_id,
)?;
capture.start_capture()?;
trace!("Sending Halt Handle");
let halt_handle = capture.halt_handle();
halt_sender.send(halt_handle)?;
trace!("Sending Callback");
callback_sender.send(callback)?;
trace!("Entering Message Loop");
let mut message = MSG::default();
unsafe {
while GetMessageW(&mut message, None, 0, 0).as_bool() {
TranslateMessage(&message);
DispatchMessageW(&message);
}
}
trace!("Shutting Down Dispatcher Queue");
let async_action = controller.ShutdownQueueAsync()?;
async_action.SetCompleted(&AsyncActionCompletedHandler::new(
move |_, _| -> Result<(), windows::core::Error> {
unsafe { PostQuitMessage(0) };
Ok(())
},
))?;
trace!("Entering Final Message Loop");
let mut message = MSG::default();
unsafe {
while GetMessageW(&mut message, None, 0, 0).as_bool() {
TranslateMessage(&message);
DispatchMessageW(&message);
}
}
info!("Stopping Capture Thread");
capture.stop_capture();
trace!("Uninitializing WinRT");
unsafe { RoUninitialize() };
trace!("Checking RESULT");
let result = RESULT.take().expect("Failed To Take RESULT");
result?;
Ok(())
});
let halt_handle = match halt_receiver.recv() {
Ok(halt_handle) => halt_handle,
Err(_) => match thread_handle.join() {
Ok(result) => return Err(result.err().unwrap()),
Err(_) => {
return Err(Box::new(CaptureControlError::FailedToJoinThread));
}
},
};
let callback = match callback_receiver.recv() {
Ok(callback_handle) => callback_handle,
Err(_) => match thread_handle.join() {
Ok(result) => return Err(result.err().unwrap()),
Err(_) => {
return Err(Box::new(CaptureControlError::FailedToJoinThread));
}
},
};
Ok(CaptureControl::new(thread_handle, halt_handle, callback))
}
fn new(flags: Self::Flags) -> Result<Self, Box<dyn Error + Send + Sync>>;
fn on_frame_arrived(
&mut self,
frame: &mut Frame,
capture_control: InternalCaptureControl,
) -> Result<(), Box<dyn Error + Send + Sync>>;
fn on_closed(&mut self) -> Result<(), Box<dyn Error + Send + Sync>>;
}