use std::{
ffi::OsString,
fmt::Debug,
future::Future,
pin::Pin,
sync::{Arc, RwLock},
};
#[cfg(test)]
mod test;
use quokka_state::{FromState, ProvideState, ProvideStateRef};
#[derive(Clone, Debug, thiserror::Error, PartialEq)]
pub enum Error {
#[error("Unable to call client command: {0}")]
CommandCallError(String),
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone)]
pub struct Commands<S> {
commands: Arc<RwLock<Vec<CommandDef<S>>>>,
}
pub trait CommandHandler: Send + Sync {
type Error: std::error::Error;
fn args() -> clap::Command
where
Self: Sized;
fn call(
self,
args: clap::ArgMatches,
) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send;
}
pub trait CommandStateExt<S> {
fn commands(&mut self) -> &mut Commands<S>;
fn register_command<H: CommandHandler + FromState<S> + 'static>(&mut self);
}
type CommandFactory<S> =
Box<dyn Fn(&S) -> Box<dyn AbstractCommandHandler + Send + 'static> + Send + Sync>;
struct CommandDef<S> {
args: clap::Command,
factory: CommandFactory<S>,
}
#[doc(hidden)]
trait AbstractCommandHandler: Send + Sync {
fn run(
self: Box<Self>,
matches: clap::ArgMatches,
) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>;
}
impl<T: CommandHandler + 'static> AbstractCommandHandler for T {
fn run(
self: Box<Self>,
matches: clap::ArgMatches,
) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>> {
Box::pin(async move {
self.call(matches)
.await
.inspect_err(|error| tracing::error!(?error, "Unable to run command"))
.map_err(|error| crate::Error::CommandCallError(error.to_string()))?;
Ok(())
})
}
}
impl<S: Send + Sync + Clone + 'static> Commands<S> {
pub fn register_command<C: CommandHandler + 'static>(&mut self)
where
S: ProvideState<C>,
{
self.commands.write().unwrap().push(CommandDef {
args: C::args(),
factory: Box::new(|state| Box::new(ProvideState::<C>::provide(state))),
});
}
pub async fn dispatch<I, T>(self, state: S, args: I) -> crate::Result<()>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let command = self.build_clap_command();
let matches = command.clone().get_matches_from(args);
for command in self.commands.write().unwrap().drain(..) {
if let Some(matches) = matches.subcommand_matches(command.args.get_name()) {
let handler = (command.factory)(&state);
handler
.run(matches.clone())
.await
.inspect_err(|error| tracing::debug!(?error, "Unable to dispatch command"))?;
return Ok(());
}
}
Ok(())
}
pub fn build_clap_command(&self) -> clap::Command {
let mut command = clap::Command::new(clap::crate_name!())
.version(clap::crate_version!())
.about(clap::crate_description!())
.author(clap::crate_authors!())
.subcommand_required(true);
for command_def in self.commands.read().unwrap().iter() {
command = command.subcommand(command_def.args.clone());
}
command
}
}
impl<S> Debug for Commands<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Commands")
}
}
impl<S> Default for Commands<S> {
fn default() -> Self {
Self {
commands: Default::default(),
}
}
}
impl<S> quokka_config::TryFromConfig for Commands<S> {
type Error = crate::Error;
async fn try_from_config(_: &quokka_config::Config) -> crate::Result<Self>
where
Self: Sized,
{
Ok(Self::default())
}
}
impl<S: Send + Sync + Clone + ProvideStateRef<Commands<S>> + 'static> CommandStateExt<S> for S {
fn commands(&mut self) -> &mut Commands<S> {
self.provide_mut()
}
fn register_command<H: CommandHandler + FromState<S> + 'static>(&mut self) {
self.commands().register_command::<H>();
}
}