use std::path::Path;
use std::sync::mpsc as std_mpsc;
use std::time::Duration;
use futures::Stream;
use notify::{Config, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
use tokio::sync::mpsc as tokio_mpsc;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UsbHotplugEvent {
Arrived,
Left,
}
#[derive(Debug)]
pub struct HotplugError(String);
impl std::fmt::Display for HotplugError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "hotplug watch error: {}", self.0)
}
}
impl std::error::Error for HotplugError {}
impl From<notify::Error> for HotplugError {
fn from(e: notify::Error) -> Self {
HotplugError(e.to_string())
}
}
fn watch_path() -> &'static Path {
#[cfg(target_os = "linux")]
return Path::new("/sys/block");
#[cfg(target_os = "macos")]
return Path::new("/dev");
#[cfg(target_os = "windows")]
return Path::new("C:\\");
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
return Path::new("/dev");
}
pub fn watch_usb_events() -> Result<impl Stream<Item = UsbHotplugEvent>, HotplugError> {
let path = watch_path();
if !path.exists() {
let (tx, rx) = tokio_mpsc::unbounded_channel::<UsbHotplugEvent>();
drop(tx); return Ok(UnboundedReceiverStream::new(rx));
}
let (notify_tx, notify_rx) = std_mpsc::channel::<notify::Result<notify::Event>>();
let (hotplug_tx, hotplug_rx) = tokio_mpsc::unbounded_channel::<UsbHotplugEvent>();
let mut watcher = RecommendedWatcher::new(
notify_tx,
Config::default()
.with_poll_interval(Duration::from_secs(2)),
)?;
watcher.watch(path, RecursiveMode::NonRecursive)?;
std::thread::Builder::new()
.name("flashkraft-hotplug-watcher".into())
.spawn(move || {
let _watcher = watcher;
loop {
std::thread::park();
}
})
.ok();
let hotplug_tx_b = hotplug_tx;
std::thread::Builder::new()
.name("flashkraft-hotplug-bridge".into())
.spawn(move || {
for result in ¬ify_rx {
let event = match result {
Ok(e) => e,
Err(_) => continue,
};
let translated = translate_event(&event);
if let Some(hp_event) = translated {
if hotplug_tx_b.send(hp_event).is_err() {
break;
}
}
}
})
.ok();
Ok(UnboundedReceiverStream::new(hotplug_rx))
}
fn translate_event(event: ¬ify::Event) -> Option<UsbHotplugEvent> {
match &event.kind {
EventKind::Create(_) => Some(UsbHotplugEvent::Arrived),
EventKind::Remove(_) => Some(UsbHotplugEvent::Left),
EventKind::Modify(notify::event::ModifyKind::Name(_)) => Some(UsbHotplugEvent::Arrived),
EventKind::Modify(_) => None,
EventKind::Access(_) => None,
EventKind::Other => Some(UsbHotplugEvent::Arrived),
EventKind::Any => None,
}
}
struct UnboundedReceiverStream<T> {
inner: tokio_mpsc::UnboundedReceiver<T>,
}
impl<T> UnboundedReceiverStream<T> {
fn new(inner: tokio_mpsc::UnboundedReceiver<T>) -> Self {
Self { inner }
}
}
impl<T> Stream for UnboundedReceiverStream<T> {
type Item = T;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.poll_recv(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use notify::event::{CreateKind, ModifyKind, RemoveKind, RenameMode};
#[test]
fn test_event_traits() {
let a = UsbHotplugEvent::Arrived;
let b = UsbHotplugEvent::Left;
let a2 = a.clone();
let b2 = b.clone();
assert_eq!(a, a2);
assert_eq!(b, b2);
assert_ne!(a, b);
assert!(format!("{a:?}").contains("Arrived"));
assert!(format!("{b:?}").contains("Left"));
}
#[test]
fn test_variant_exhaustiveness() {
for event in [UsbHotplugEvent::Arrived, UsbHotplugEvent::Left] {
let label = match event {
UsbHotplugEvent::Arrived => "arrived",
UsbHotplugEvent::Left => "left",
};
assert!(!label.is_empty());
}
}
fn make_event(kind: EventKind) -> notify::Event {
notify::Event {
kind,
paths: vec![],
attrs: Default::default(),
}
}
macro_rules! translate_event_test {
($name:ident, $kind:expr, $expected:expr) => {
#[test]
fn $name() {
let e = make_event($kind);
assert_eq!(translate_event(&e), $expected);
}
};
}
translate_event_test!(
test_translate_create_any_arrives,
EventKind::Create(CreateKind::Any),
Some(UsbHotplugEvent::Arrived)
);
translate_event_test!(
test_translate_create_file_arrives,
EventKind::Create(CreateKind::File),
Some(UsbHotplugEvent::Arrived)
);
translate_event_test!(
test_translate_create_folder_arrives,
EventKind::Create(CreateKind::Folder),
Some(UsbHotplugEvent::Arrived)
);
translate_event_test!(
test_translate_remove_any_left,
EventKind::Remove(RemoveKind::Any),
Some(UsbHotplugEvent::Left)
);
translate_event_test!(
test_translate_remove_file_left,
EventKind::Remove(RemoveKind::File),
Some(UsbHotplugEvent::Left)
);
translate_event_test!(
test_translate_rename_arrives,
EventKind::Modify(ModifyKind::Name(RenameMode::Any)),
Some(UsbHotplugEvent::Arrived)
);
translate_event_test!(
test_translate_modify_data_ignored,
EventKind::Modify(ModifyKind::Data(notify::event::DataChange::Any)),
None
);
translate_event_test!(
test_translate_modify_metadata_ignored,
EventKind::Modify(ModifyKind::Metadata(notify::event::MetadataKind::Any)),
None
);
translate_event_test!(
test_translate_access_ignored,
EventKind::Access(notify::event::AccessKind::Any),
None
);
translate_event_test!(
test_translate_other_arrives,
EventKind::Other,
Some(UsbHotplugEvent::Arrived)
);
translate_event_test!(test_translate_any_ignored, EventKind::Any, None);
#[test]
fn test_watch_usb_events_does_not_panic() {
let result = watch_usb_events();
match result {
Ok(_) => println!("watch_usb_events: stream created successfully"),
Err(ref e) => println!("watch_usb_events: OS returned error (acceptable): {e}"),
}
}
#[test]
fn test_hotplug_error_display() {
let e = HotplugError("something went wrong".into());
let s = format!("{e}");
assert!(s.contains("hotplug watch error"));
assert!(s.contains("something went wrong"));
}
#[test]
fn test_hotplug_error_from_notify() {
let notify_err = notify::Error::generic("test error");
let hp_err = HotplugError::from(notify_err);
let s = format!("{hp_err}");
assert!(s.contains("hotplug watch error"));
}
}