use crate::error::{Context, RuntimeError, SystemError};
use crate::io::epoll::{Epoll, Message};
use crate::io::fd::HasFixedFd;
use crate::io::input::InputDevice;
use crate::io::internal_pipe;
use crate::io::internal_pipe::{Receiver, Sender};
use crate::persist::blueprint::Blueprint;
use crate::persist::inotify::Inotify;
use crate::persist::interface::HostInterface;
use std::collections::HashSet;
use std::os::unix::io::{AsRawFd, RawFd};
use std::path::PathBuf;
pub enum Command {
AddBlueprint(Blueprint),
Shutdown,
}
#[allow(clippy::large_enum_variant)]
pub enum Report {
DeviceOpened(InputDevice),
BlueprintDropped,
Shutdown,
}
enum Pollable {
Command(Receiver<Command>),
Daemon(Daemon),
}
impl AsRawFd for Pollable {
fn as_raw_fd(&self) -> RawFd {
match self {
Pollable::Command(receiver) => receiver.as_raw_fd(),
Pollable::Daemon(daemon) => daemon.as_raw_fd(),
}
}
}
unsafe impl HasFixedFd for Pollable {}
pub struct Daemon {
blueprints: Vec<Blueprint>,
inotify: Inotify,
}
pub fn launch() -> Result<HostInterface, SystemError> {
let (commander, comm_in) = internal_pipe::channel()?;
let (mut comm_out, reporter) = internal_pipe::channel()?;
let join_handle = std::thread::spawn(move || {
let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
start_worker(comm_in, &mut comm_out)
.with_context("In the persistence subsystem:")
.print_err();
}));
comm_out.send(Report::Shutdown).print_err();
if let Err(payload) = panic_result {
std::panic::resume_unwind(payload);
}
});
Ok(HostInterface {
commander,
reporter,
join_handle,
})
}
fn start_worker(
comm_in: Receiver<Command>,
comm_out: &mut Sender<Report>,
) -> Result<(), RuntimeError> {
let daemon = Daemon::new()?;
let mut epoll = Epoll::new()?;
let daemon_index = epoll.add_file(Pollable::Daemon(daemon))?;
epoll.add_file(Pollable::Command(comm_in))?;
loop {
let (commands, mut reports) = poll(&mut epoll)?;
for command in commands {
match command {
Command::Shutdown => return Ok(()),
Command::AddBlueprint(blueprint) => match &mut epoll[daemon_index] {
Pollable::Daemon(daemon) => {
daemon.add_blueprint(blueprint)?;
try_open_and_report(daemon, &mut reports)?;
}
_ => unreachable!(),
},
}
}
for report in reports {
comm_out.send(report)?;
}
}
}
fn poll(epoll: &mut Epoll<Pollable>) -> Result<(Vec<Command>, Vec<Report>), RuntimeError> {
let mut commands: Vec<Command> = Vec::new();
let mut reports: Vec<Report> = Vec::new();
match epoll.poll(crate::io::epoll::INDEFINITE_TIMEOUT) {
Err(error) => {
error
.with_context("While the persistence subsystem was polling for events:")
.print_err();
commands.push(Command::Shutdown);
}
Ok(messages) => {
for message in messages {
match message {
Message::Broken(_index) => {
return Err(SystemError::new("Persistence daemon broken.").into())
}
Message::Ready(index) | Message::Hup(index) => match &mut epoll[index] {
Pollable::Daemon(daemon) => {
daemon.poll()?;
try_open_and_report(daemon, &mut reports)?
}
Pollable::Command(receiver) => match receiver.recv() {
Ok(command) => commands.push(command),
Err(error) => return Err(error.into()),
},
},
}
}
}
}
Ok((commands, reports))
}
fn try_open_and_report(daemon: &mut Daemon, reports: &mut Vec<Report>) -> Result<(), RuntimeError> {
let TryOpenResult {
opened_devices,
broken_blueprints,
error_encountered,
} = daemon.try_open();
reports.extend(opened_devices.into_iter().map(Report::DeviceOpened));
reports.extend(
broken_blueprints
.into_iter()
.map(|_| Report::BlueprintDropped),
);
match error_encountered {
Some(err) => Err(err),
None => Ok(()),
}
}
struct TryOpenResult {
opened_devices: Vec<InputDevice>,
broken_blueprints: Vec<Blueprint>,
error_encountered: Option<RuntimeError>,
}
impl Daemon {
pub fn new() -> Result<Daemon, SystemError> {
Ok(Daemon {
blueprints: Vec::new(),
inotify: Inotify::new()?,
})
}
pub fn add_blueprint(&mut self, blueprint: Blueprint) -> Result<(), RuntimeError> {
self.blueprints.push(blueprint);
self.update_watches()?;
Ok(())
}
pub fn poll(&mut self) -> Result<(), SystemError> {
self.inotify.poll()
}
fn try_open(&mut self) -> TryOpenResult {
const MAX_TRIES: usize = 5;
let mut result = TryOpenResult {
opened_devices: Vec::new(),
broken_blueprints: Vec::new(),
error_encountered: None,
};
for _ in 0..MAX_TRIES {
let mut remaining_blueprints = Vec::new();
for blueprint in self.blueprints.drain(..) {
match blueprint.try_open() {
Ok(Some(device)) => result.opened_devices.push(device),
Ok(None) => remaining_blueprints.push(blueprint),
Err(error) => {
error.print_err();
result.broken_blueprints.push(blueprint);
}
}
}
self.blueprints = remaining_blueprints;
match self.update_watches() {
Ok(false) => return result, Ok(true) => (), Err(error) => {
result.error_encountered = Some(error);
return result;
}
}
}
crate::utils::warn_once(
"Warning: maximum try count exceeded while listening for new devices.",
);
result
}
fn update_watches(&mut self) -> Result<bool, RuntimeError> {
let paths_to_watch: Vec<String> = self.get_paths_to_watch();
let paths_to_watch_hashset: HashSet<&String> = paths_to_watch.iter().collect();
let paths_already_watched: HashSet<&String> = self.inotify.watched_paths().collect();
if paths_to_watch_hashset == paths_already_watched {
Ok(false)
} else {
self.inotify.set_watched_paths(paths_to_watch)?;
Ok(true)
}
}
pub fn get_paths_to_watch(&mut self) -> Vec<String> {
let mut traversed_directories: Vec<String> = Vec::new();
for blueprint in &mut self.blueprints {
let paths = walk_symlink(blueprint.pre_device.path.clone());
let mut directories = paths.into_iter().filter_map(|mut path| {
path.pop();
match path.into_os_string().into_string() {
Ok(string) => Some(string),
Err(os_string) => {
let warning_message = format!(
"Error: unable to deal with non-UTF8 path \"{}\".",
os_string.to_string_lossy()
);
crate::utils::warn_once(warning_message);
None
}
}
});
traversed_directories.extend(&mut directories);
}
traversed_directories.sort_unstable();
traversed_directories.dedup();
traversed_directories
}
}
impl AsRawFd for Daemon {
fn as_raw_fd(&self) -> RawFd {
self.inotify.as_raw_fd()
}
}
fn walk_symlink(path: PathBuf) -> Vec<PathBuf> {
const MAX_SYMLINKS: usize = 20;
let mut current_path: PathBuf = path.clone();
let mut traversed_paths: Vec<PathBuf> = vec![current_path.clone()];
while let Ok(next_path_rel) = current_path.read_link() {
current_path.pop();
current_path = current_path.join(next_path_rel);
if traversed_paths.contains(¤t_path) {
break;
}
traversed_paths.push(current_path.clone());
if traversed_paths.len() > MAX_SYMLINKS + 1 {
crate::utils::warn_once(format!(
"Warning: too many symlinks encountered while resolving \"{}\".",
path.display()
));
break;
}
}
traversed_paths
}