use std::{
mem,
sync::{Arc, Once, atomic, mpsc},
time::Duration,
};
use bon::bon;
use tracing::{debug, error, instrument, trace, warn};
use widestring::u16cstr;
use windows::{
Win32::{
Foundation::{HINSTANCE, HWND, LPARAM, LRESULT, WPARAM},
System::DataExchange::COPYDATASTRUCT,
UI::WindowsAndMessaging::{
CreateWindowExW, DefWindowProcW, DispatchMessageW, GWL_USERDATA, GetMessageW,
GetWindowLongPtrW, MSG, PostMessageW, RegisterClassW, ReplyMessage, SendMessageW,
SetWindowLongPtrW, WINDOW_EX_STYLE, WINDOW_STYLE, WM_APP, WM_COPYDATA, WM_QUIT,
WNDCLASSW,
},
},
core::PCWSTR,
};
use crate::{IpcWindow, windows::get_current_module_handle};
mod types;
pub use types::*;
mod ext;
#[derive(Debug, thiserror::Error)]
pub enum IpcError {
#[error("IPC window not found")]
NoIpcWindow,
#[error("failed to create reply window")]
CreateReplyWindow,
#[error("failed to send query to Everything")]
Send,
#[error("query timed out")]
Timeout,
#[error("query: {0}")]
Query(&'static str),
}
const WINDOW_CLASS_NAME: &widestring::U16CStr = u16cstr!("everything_ipc::wm");
static CLASS_REGISTERED: Once = Once::new();
fn register_window_class() {
CLASS_REGISTERED.call_once(|| unsafe {
let wnd_class = WNDCLASSW {
lpfnWndProc: Some(reply_window_wndproc),
hInstance: get_current_module_handle().into(),
lpszClassName: PCWSTR(WINDOW_CLASS_NAME.as_ptr()),
style: Default::default(),
cbClsExtra: 0,
cbWndExtra: 0,
hIcon: Default::default(),
hCursor: Default::default(),
hbrBackground: Default::default(),
lpszMenuName: Default::default(),
};
let class_atom = RegisterClassW(&wnd_class);
if class_atom == 0 {
error!("Failed to register window class");
} else {
debug!(
"Registered window class {}",
WINDOW_CLASS_NAME.to_string_lossy()
);
}
});
}
#[derive(Debug)]
struct ReplyWindow {
hwnd: HWND,
_thread: std::thread::JoinHandle<()>,
}
unsafe impl Send for ReplyWindow {}
unsafe impl Sync for ReplyWindow {}
#[derive(Debug)]
struct MessageLoopResult {
hwnd_usize: usize,
}
impl ReplyWindow {
pub fn new(inner_ptr: *mut ClientInner) -> Result<Self, IpcError> {
register_window_class();
let (tx, rx) = mpsc::channel::<MessageLoopResult>();
let inner_ptr_usize = inner_ptr as usize;
let thread = std::thread::spawn(move || {
let hwnd = unsafe {
CreateWindowExW(
WINDOW_EX_STYLE(0),
PCWSTR(WINDOW_CLASS_NAME.as_ptr()),
None,
WINDOW_STYLE(0),
0,
0,
0,
0,
None,
None,
Some(HINSTANCE::default()),
None,
)
};
let hwnd = match hwnd.ok() {
Some(h) => h,
None => {
debug!("Failed to create window in message loop thread");
let _ = tx.send(MessageLoopResult { hwnd_usize: 0 });
return;
}
};
if let Err(_) = tx.send(MessageLoopResult {
hwnd_usize: hwnd.0 as usize,
}) {
return;
}
debug!(?hwnd, "Created reply window");
unsafe { SetWindowLongPtrW(hwnd, GWL_USERDATA, inner_ptr_usize as isize) };
run_message_loop(hwnd);
});
let result = rx.recv().map_err(|_| IpcError::CreateReplyWindow)?;
let MessageLoopResult { hwnd_usize } = result;
let hwnd = HWND(hwnd_usize as *mut _);
if hwnd.is_invalid() {
return Err(IpcError::CreateReplyWindow);
}
Ok(Self {
hwnd,
_thread: thread,
})
}
pub fn hwnd(&self) -> HWND {
self.hwnd
}
pub fn post_message(
&self,
msg: u32,
w_param: WPARAM,
l_param: LPARAM,
) -> Result<(), windows::core::Error> {
unsafe { PostMessageW(Some(self.hwnd), msg, w_param, l_param) }
}
pub fn quit(&self) {
let _ = self.post_message(WM_QUIT, WPARAM(0), LPARAM(0));
}
}
impl Drop for ReplyWindow {
fn drop(&mut self) {
self.quit();
}
}
#[derive(Debug)]
pub struct QueryResponse {
pub id: u32,
pub data: Vec<u8>,
}
#[instrument(skip_all, fields(hwnd))]
unsafe extern "system" fn reply_window_wndproc(
hwnd: HWND,
msg: u32,
w_param: WPARAM,
l_param: LPARAM,
) -> LRESULT {
match msg {
WM_APP => {
let request_ptr = w_param.0 as *mut Vec<u8>;
let request = unsafe { Box::from_raw(request_ptr) };
let inner_ptr = unsafe { GetWindowLongPtrW(hwnd, GWL_USERDATA) };
if inner_ptr != 0 {
let inner = unsafe { &*(inner_ptr as *const ClientInner) };
let ipc_hwnd = inner.ipc_window.hwnd();
let cds = COPYDATASTRUCT {
dwData: EVERYTHING_IPC_COPYDATA_QUERY2W as usize,
cbData: request.len() as u32,
lpData: request.as_ptr() as *mut _,
};
let cds_ptr = &cds as *const COPYDATASTRUCT;
let r = unsafe {
SendMessageW(
ipc_hwnd,
WM_COPYDATA,
Some(WPARAM(hwnd.0 as usize)),
Some(LPARAM(cds_ptr as isize)),
)
};
if r.0 == 1 {
trace!(?ipc_hwnd, ?r);
} else {
warn!(?ipc_hwnd, ?r);
drop(inner.current_query_sender.lock().unwrap().take());
}
}
LRESULT(0)
}
WM_COPYDATA => {
let copydata = unsafe { &*(l_param.0 as *const COPYDATASTRUCT) };
let id = copydata.dwData as u32;
let inner_ptr = unsafe { GetWindowLongPtrW(hwnd, GWL_USERDATA) } as *const ClientInner;
if inner_ptr.is_null() {
error!("No object found");
return LRESULT(0);
}
let inner = unsafe { &*inner_ptr };
if let Some(sender) = inner.current_query_sender.lock().unwrap().take() {
if match &sender {
QuerySender::Sync(_sender) => {
false
}
#[cfg(feature = "tokio")]
QuerySender::Tokio(sender) => sender.is_closed(),
} {
return LRESULT(1);
}
let data = unsafe {
std::slice::from_raw_parts(
copydata.lpData as *const u8,
copydata.cbData as usize,
)
}
.into();
_ = unsafe { ReplyMessage(LRESULT(1)) };
trace!(id, cbData = copydata.cbData, "WM_COPYDATA received");
let results = QueryList::new(id, data);
if match sender {
QuerySender::Sync(sender) => sender.send(results).is_ok(),
#[cfg(feature = "tokio")]
QuerySender::Tokio(sender) => sender.send(results).is_ok(),
} {
debug!(id, "Sent query response");
} else {
warn!(id, "Failed to send query response");
}
} else {
warn!(id, "No pending query");
}
LRESULT(1)
}
_ => unsafe { DefWindowProcW(hwnd, msg, w_param, l_param) },
}
}
fn run_message_loop(hwnd: HWND) {
unsafe {
let mut msg: MSG = mem::zeroed();
let mut ret;
loop {
ret = GetMessageW(&mut msg, Some(hwnd), 0, 0);
if ret.0 <= 0 {
break;
}
DispatchMessageW(&mut msg);
}
}
}
enum QuerySender {
Sync(mpsc::Sender<QueryList>),
#[cfg(feature = "tokio")]
Tokio(tokio::sync::oneshot::Sender<QueryList>),
}
struct ClientInner {
ipc_window: IpcWindow,
current_query_sender: std::sync::Mutex<Option<QuerySender>>,
}
pub struct EverythingClient {
inner: Arc<ClientInner>,
reply_window: ReplyWindow,
}
impl IpcWindow {
pub fn wm_client(&self) -> Result<EverythingClient, IpcError> {
let inner = Arc::new(ClientInner {
ipc_window: self.clone(),
current_query_sender: std::sync::Mutex::new(None),
});
let inner_ptr = Arc::as_ptr(&inner) as *mut ClientInner;
let reply_window = ReplyWindow::new(inner_ptr)?;
Ok(EverythingClient {
inner,
reply_window,
})
}
}
impl std::ops::Deref for EverythingClient {
type Target = IpcWindow;
fn deref(&self) -> &Self::Target {
self.ipc_window()
}
}
impl EverythingClient {
pub fn new() -> Result<Self, IpcError> {
IpcWindow::new().ok_or(IpcError::NoIpcWindow)?.wm_client()
}
pub fn with_instance(instance_name: Option<&str>) -> Result<Self, IpcError> {
IpcWindow::with_instance(instance_name)
.ok_or(IpcError::NoIpcWindow)?
.wm_client()
}
fn ipc_window(&self) -> &IpcWindow {
&self.inner.ipc_window
}
fn next_id(&self) -> u32 {
static NEXT_ID: atomic::AtomicU32 = atomic::AtomicU32::new(0);
NEXT_ID.fetch_add(1, atomic::Ordering::SeqCst)
}
fn query_send(
&self,
search: &str,
search_flags: SearchFlags,
request_flags: RequestFlags,
sort: Sort,
id: u32,
offset: u32,
max_results: Option<u32>,
) -> bool {
let msg_hwnd = self.reply_window.hwnd();
let request = EverythingIpcQuery2::create(
msg_hwnd.0 as u32,
id,
search_flags.bits(),
offset,
max_results.unwrap_or(u32::MAX),
request_flags.bits(),
sort as u32,
search,
);
let request_box = Box::new(request);
let request_ptr = Box::into_raw(request_box);
match self
.reply_window
.post_message(WM_APP, WPARAM(request_ptr as usize), LPARAM(0))
{
Ok(_) => true,
Err(_) => {
let _ = unsafe { Box::from_raw(request_ptr) };
false
}
}
}
}
#[bon]
impl EverythingClient {
#[instrument(skip_all)]
#[builder]
pub fn query(
&self,
#[builder(start_fn)] search: &str,
#[builder(default)] search_flags: SearchFlags,
request_flags: RequestFlags,
#[builder(default)] sort: Sort,
#[builder(default)] offset: u32,
max_results: Option<u32>,
) -> Result<mpsc::Receiver<QueryList>, IpcError> {
let id = self.next_id();
debug!("generating query ID {}", id);
let (sender, receiver) = mpsc::channel::<QueryList>();
let sent = self.query_send(
search,
search_flags,
request_flags,
sort,
id,
offset,
max_results,
);
if !sent {
warn!("failed to send query ID {}", id);
return Err(IpcError::Send);
}
debug!("query ID {} sent successfully", id);
let old_sender = self
.inner
.current_query_sender
.lock()
.unwrap()
.replace(QuerySender::Sync(sender));
drop(old_sender);
Ok(receiver)
}
#[instrument(skip_all)]
#[builder]
pub fn query_wait(
&self,
#[builder(start_fn)] search: &str,
#[builder(default)] search_flags: SearchFlags,
request_flags: RequestFlags,
#[builder(default)] sort: Sort,
#[builder(default)] offset: u32,
max_results: Option<u32>,
#[builder(default = Duration::from_millis(3000))] timeout: Duration,
) -> Result<QueryList, IpcError> {
let receiver = self
.query(search)
.search_flags(search_flags)
.request_flags(request_flags)
.sort(sort)
.offset(offset)
.maybe_max_results(max_results)
.call()?;
match receiver.recv_timeout(timeout) {
Ok(results) => Ok(results),
Err(_) => {
warn!("query timed out");
Err(IpcError::Timeout)
}
}
}
}
#[cfg(feature = "tokio")]
#[bon]
impl EverythingClient {
#[instrument(skip_all)]
#[builder]
pub fn query_tokio(
&self,
#[builder(start_fn)] search: &str,
#[builder(default)] search_flags: SearchFlags,
request_flags: RequestFlags,
#[builder(default)] sort: Sort,
#[builder(default)] offset: u32,
max_results: Option<u32>,
) -> Result<tokio::sync::oneshot::Receiver<QueryList>, IpcError> {
let id = self.next_id();
debug!("generating query ID {}", id);
let (sender, receiver) = tokio::sync::oneshot::channel::<QueryList>();
let sent = self.query_send(
search,
search_flags,
request_flags,
sort,
id,
offset,
max_results,
);
if !sent {
warn!("failed to send query ID {}", id);
return Err(IpcError::Send);
}
debug!("query ID {} sent successfully", id);
let old_sender = self
.inner
.current_query_sender
.lock()
.unwrap()
.replace(QuerySender::Tokio(sender));
drop(old_sender);
Ok(receiver)
}
#[instrument(skip_all)]
#[builder]
pub async fn query_wait_tokio(
&self,
#[builder(start_fn)] search: &str,
#[builder(default)] search_flags: SearchFlags,
request_flags: RequestFlags,
#[builder(default)] sort: Sort,
#[builder(default)] offset: u32,
max_results: Option<u32>,
#[builder(default = Duration::from_millis(3000))] timeout: Duration,
) -> Result<QueryList, IpcError> {
let receiver = self
.query_tokio(search)
.search_flags(search_flags)
.request_flags(request_flags)
.sort(sort)
.offset(offset)
.maybe_max_results(max_results)
.call()?;
match tokio::time::timeout(timeout, receiver).await {
Ok(Ok(results)) => Ok(results),
Ok(Err(_)) => {
warn!("query receiver error");
Err(IpcError::Send)
}
Err(_) => {
warn!("query timed out");
Err(IpcError::Timeout)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test_log::test]
#[test_log(default_log_filter = "trace")]
fn doc() {
let everything = EverythingClient::new().expect("not available");
let list = everything
.query_wait(r"C:\Windows\ *.exe")
.request_flags(RequestFlags::FileName | RequestFlags::Size | RequestFlags::Path)
.sort(Sort::SizeDescending)
.max_results(5)
.call()
.expect("query");
println!("Found {} items:", list.len());
println!("{:<25} {:>10} {}", "Filename", "Size", "Path");
for item in list.iter() {
let filename = item.get_string(RequestFlags::FileName).unwrap();
let path = item.get_str(RequestFlags::Path).unwrap().display();
let size = item.get_size(RequestFlags::Size).unwrap();
println!("{:<25} {:>10} {}", filename, size, path);
}
println!("Total: {} items", list.total_len());
}
#[test_log::test]
#[test_log(default_log_filter = "trace")]
fn query_empty_search() {
let everything = EverythingClient::new().unwrap();
let search = "";
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName | RequestFlags::Path;
let sort = Sort::NameAscending;
let result =
everything.query_send(search, search_flags, request_flags, sort, 1000, 0, Some(5));
assert!(result, "Query should be sent successfully");
}
#[test_log::test]
#[test_log(default_log_filter = "trace")]
fn query_with_pattern() {
let everything = EverythingClient::new().unwrap();
let search = "test";
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName;
let sort = Sort::NameAscending;
let result =
everything.query_send(search, search_flags, request_flags, sort, 1001, 0, Some(10));
assert!(result, "Query should be sent successfully");
}
#[test]
fn query_with_full_path() {
let everything = EverythingClient::new().unwrap();
let search = "";
let search_flags = SearchFlags::MatchCase;
let request_flags =
RequestFlags::FullPathAndFileName | RequestFlags::Size | RequestFlags::DateModified;
let sort = Sort::NameAscending;
let result =
everything.query_send(search, search_flags, request_flags, sort, 1002, 0, Some(3));
assert!(result, "Query should be sent successfully");
}
#[test]
fn query_sort_by_size() {
let everything = EverythingClient::new().unwrap();
let search = "";
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName | RequestFlags::Size;
let sort = Sort::SizeAscending;
let result =
everything.query_send(search, search_flags, request_flags, sort, 1003, 0, Some(5));
assert!(result, "Query should be sent successfully");
}
#[test]
fn query_with_offset() {
let everything = EverythingClient::new().unwrap();
let search = "";
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName;
let sort = Sort::NameAscending;
let result1 =
everything.query_send(search, search_flags, request_flags, sort, 1005, 0, Some(2));
assert!(result1, "First query should be sent successfully");
let result2 =
everything.query_send(search, search_flags, request_flags, sort, 1006, 2, Some(2));
assert!(
result2,
"Second query with offset should be sent successfully"
);
}
#[test]
fn query_everything() {
let everything = EverythingClient::new().unwrap();
let search = "test";
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName;
let sort = Sort::NameAscending;
let result = everything.query_send(
search,
search_flags,
request_flags,
sort,
everything.next_id(),
0,
Some(5),
);
assert!(result, "Query should be sent successfully");
}
#[test]
fn query_multiple_requests() {
let everything = EverythingClient::new().unwrap();
let search = "";
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName
| RequestFlags::Path
| RequestFlags::Size
| RequestFlags::DateModified
| RequestFlags::DateCreated;
let sort = Sort::NameAscending;
let result =
everything.query_send(search, search_flags, request_flags, sort, 1004, 0, Some(5));
assert!(result, "Query should be sent successfully");
}
#[test]
fn query_wait_empty() {
let everything = EverythingClient::new().unwrap();
let search = "";
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName;
let sort = Sort::NameAscending;
assert!(everything.is_ipc_available(), "IPC should be available");
let result = everything
.query_wait(search)
.search_flags(search_flags)
.request_flags(request_flags)
.sort(sort)
.offset(0)
.max_results(10)
.call();
assert!(
result.is_ok(),
"query_wait should return Ok when Everything is available"
);
}
#[test_log::test]
#[test_log(default_log_filter = "trace")]
fn query_wait() {
let everything = EverythingClient::new().unwrap();
let search = "test";
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName;
let sort = Sort::NameAscending;
assert!(everything.is_ipc_available(), "IPC should be available");
let result = everything
.query_wait(search)
.search_flags(search_flags)
.request_flags(request_flags)
.sort(sort)
.offset(0)
.max_results(10)
.call();
dbg!(&result);
assert!(
result.is_ok(),
"query_wait should return Ok when Everything is available"
);
assert!(
result.as_ref().is_ok_and(|r| r.total_len() > 0),
"Expected found_num > 0, got: {:?}",
result
);
}
#[test_log::test]
#[test_log(default_log_filter = "trace")]
fn query_wait_cancel() {
let everything = EverythingClient::new().unwrap();
assert!(everything.is_ipc_available(), "IPC should be available");
let searches = ["", "test", "rust"];
let mut receivers = Vec::new();
for search in &searches {
let search_flags = SearchFlags::MatchCase;
let request_flags = RequestFlags::FileName;
let sort = Sort::NameAscending;
let receiver = everything
.query(search)
.search_flags(search_flags)
.request_flags(request_flags)
.sort(sort)
.offset(0)
.max_results(10)
.call()
.expect("query should succeed");
receivers.push(receiver);
}
let result = receivers[0].recv_timeout(std::time::Duration::from_millis(3000));
assert!(
result.is_err(),
"Query 0 should fail because sender was replaced (got: {:?})",
result
);
let result = receivers[1].recv_timeout(std::time::Duration::from_millis(3000));
assert!(
result.is_err(),
"Query 1 should fail because sender was replaced (got: {:?})",
result
);
let result = receivers[2].recv_timeout(std::time::Duration::from_millis(3000));
let result = result.expect("Last query should succeed");
dbg!(&result);
assert!(
result.total_len() > 0,
"Last query should return valid results"
);
}
#[test_log::test]
#[test_log(default_log_filter = "trace")]
fn query_wait_parallel() {
let everything1 = EverythingClient::new().unwrap();
let everything2 = EverythingClient::new().unwrap();
let everything3 = EverythingClient::new().unwrap();
assert!(everything1.is_ipc_available(), "IPC should be available");
let receiver1 = everything1
.query("")
.search_flags(SearchFlags::MatchCase)
.request_flags(RequestFlags::FileName)
.sort(Sort::NameAscending)
.offset(0)
.max_results(10)
.call()
.expect("query should succeed");
let receiver2 = everything2
.query("test")
.search_flags(SearchFlags::MatchCase)
.request_flags(RequestFlags::FileName)
.sort(Sort::NameAscending)
.offset(0)
.max_results(10)
.call()
.expect("query should succeed");
let receiver3 = everything3
.query("rust")
.search_flags(SearchFlags::MatchCase)
.request_flags(RequestFlags::FileName)
.sort(Sort::NameAscending)
.offset(0)
.max_results(10)
.call()
.expect("query should succeed");
for (i, receiver) in [receiver1, receiver2, receiver3].into_iter().enumerate() {
let result = receiver.recv_timeout(std::time::Duration::from_millis(5000));
let result = result.expect(&format!("Query {} timed out", i));
dbg!(&result);
assert!(result.len() > 0, "Query {} should return valid results", i);
}
}
}