use crate::{Error, Result};
use log::warn;
use std::{
collections::HashMap, env, fs, future::Future, io, path::Path, pin::Pin, result, str::FromStr,
sync::Arc,
};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{UnixListener, UnixStream},
sync::Mutex,
};
const SDK_TRIGGER_FILENAME: &str =
"/usr/share/updatehub/state-change-callbacks.d/10-updatehub-sdk-statechange-trigger";
const SOCKET_PATH: &str = "/run/updatehub-statechange.sock";
type CallbackFn = dyn Fn(Handler) -> Pin<Box<dyn Future<Output = Result<()>>>>;
#[derive(Default)]
pub struct StateChange {
callbacks: HashMap<State, Vec<Box<CallbackFn>>>,
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub enum State {
Probe,
Download,
Install,
Reboot,
Error,
}
impl FromStr for State {
type Err = io::Error;
fn from_str(s: &str) -> result::Result<Self, Self::Err> {
match s {
"probe" => Ok(State::Probe),
"download" => Ok(State::Download),
"install" => Ok(State::Install),
"reboot" => Ok(State::Reboot),
"error" => Ok(State::Error),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("the '{}' is not a valid state", s),
)),
}
}
}
pub struct Handler {
stream: Arc<Mutex<UnixStream>>,
}
impl Handler {
pub async fn cancel(&mut self) -> Result<()> {
self.stream.lock().await.write_all(b"cancel").await.map_err(Error::Io)
}
pub async fn proceed(&self) -> Result<()> {
Ok(())
}
}
impl StateChange {
#[inline]
pub fn new() -> Self {
StateChange::default()
}
pub fn on_state<F, Fut>(&mut self, state: State, f: F)
where
F: Fn(Handler) -> Fut + 'static,
Fut: Future<Output = Result<()>> + 'static,
{
self.callbacks.entry(state).or_insert_with(Vec::new).push(Box::new(move |d| Box::pin(f(d))))
}
pub async fn listen(&self) -> Result<()> {
let sdk_trigger = Path::new(SDK_TRIGGER_FILENAME);
if !sdk_trigger.exists() {
warn!("WARNING: updatehub-sdk-statechange-trigger not found on {:?}", sdk_trigger);
}
let socket_path = env::var("UH_LISTENER_TEST").unwrap_or_else(|_| SOCKET_PATH.to_string());
let socket_path = Path::new(&socket_path);
if socket_path.exists() {
fs::remove_file(socket_path)?;
}
let listener = UnixListener::bind(socket_path)?;
loop {
let (socket, ..) = listener.accept().await?;
self.handle_connection(socket).await?;
}
}
async fn handle_connection(&self, mut stream: UnixStream) -> Result<()> {
let mut line = String::new();
{
let mut reader = BufReader::new(&mut stream);
reader.read_line(&mut line).await?;
}
self.emit(stream, line.trim()).await
}
async fn emit(&self, stream: UnixStream, input: &str) -> Result<()> {
let state = State::from_str(input)?;
if let Some(callbacks) = self.callbacks.get(&state) {
let stream = Arc::new(Mutex::new(stream));
for f in callbacks {
let stream = stream.clone();
f(Handler { stream }).await?;
}
}
Ok(())
}
}