1use std::{
2 ffi::OsString,
3 fmt::Debug,
4 future::Future,
5 pin::Pin,
6 sync::{Arc, RwLock},
7};
8
9#[cfg(test)]
10mod test;
11
12use quokka_state::{FromState, ProvideState, ProvideStateRef};
13
14#[derive(Clone, Debug, thiserror::Error, PartialEq)]
15pub enum Error {
16 #[error("Unable to call client command: {0}")]
17 CommandCallError(String),
18}
19
20pub type Result<T> = std::result::Result<T, Error>;
21
22#[derive(Clone)]
73pub struct Commands<S> {
74 commands: Arc<RwLock<Vec<CommandDef<S>>>>,
75}
76
77pub trait CommandHandler: Send + Sync {
83 type Error: std::error::Error;
84
85 fn args() -> clap::Command
89 where
90 Self: Sized;
91
92 fn call(
96 self,
97 args: clap::ArgMatches,
98 ) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send;
99}
100
101pub trait CommandStateExt<S> {
103 fn commands(&mut self) -> &mut Commands<S>;
104
105 fn register_command<H: CommandHandler + FromState<S> + 'static>(&mut self);
106}
107
108type CommandFactory<S> =
109 Box<dyn Fn(&S) -> Box<dyn AbstractCommandHandler + Send + 'static> + Send + Sync>;
110
111struct CommandDef<S> {
112 args: clap::Command,
113 factory: CommandFactory<S>,
114}
115
116#[doc(hidden)]
117trait AbstractCommandHandler: Send + Sync {
118 fn run(
119 self: Box<Self>,
120 matches: clap::ArgMatches,
121 ) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>;
122}
123
124impl<T: CommandHandler + 'static> AbstractCommandHandler for T {
125 fn run(
126 self: Box<Self>,
127 matches: clap::ArgMatches,
128 ) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>> {
129 Box::pin(async move {
130 self.call(matches)
131 .await
132 .inspect_err(|error| tracing::error!(?error, "Unable to run command"))
133 .map_err(|error| crate::Error::CommandCallError(error.to_string()))?;
134
135 Ok(())
136 })
137 }
138}
139
140impl<S: Send + Sync + Clone + 'static> Commands<S> {
141 pub fn register_command<C: CommandHandler + 'static>(&mut self)
142 where
143 S: ProvideState<C>,
144 {
145 self.commands.write().unwrap().push(CommandDef {
146 args: C::args(),
147 factory: Box::new(|state| Box::new(ProvideState::<C>::provide(state))),
148 });
149 }
150
151 pub async fn dispatch<I, T>(self, state: S, args: I) -> crate::Result<()>
152 where
153 I: IntoIterator<Item = T>,
154 T: Into<OsString> + Clone,
155 {
156 let command = self.build_clap_command();
157 let matches = command.clone().get_matches_from(args);
158
159 for command in self.commands.write().unwrap().drain(..) {
160 if let Some(matches) = matches.subcommand_matches(command.args.get_name()) {
161 let handler = (command.factory)(&state);
162
163 handler
164 .run(matches.clone())
165 .await
166 .inspect_err(|error| tracing::debug!(?error, "Unable to dispatch command"))?;
167
168 return Ok(());
169 }
170 }
171
172 Ok(())
173 }
174
175 pub fn build_clap_command(&self) -> clap::Command {
176 let mut command = clap::Command::new(clap::crate_name!())
177 .version(clap::crate_version!())
178 .about(clap::crate_description!())
179 .author(clap::crate_authors!())
180 .subcommand_required(true);
181
182 for command_def in self.commands.read().unwrap().iter() {
183 command = command.subcommand(command_def.args.clone());
184 }
185
186 command
187 }
188}
189
190impl<S> Debug for Commands<S> {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 f.write_str("Commands")
193 }
194}
195
196impl<S> Default for Commands<S> {
197 fn default() -> Self {
198 Self {
199 commands: Default::default(),
200 }
201 }
202}
203
204impl<S> quokka_config::TryFromConfig for Commands<S> {
205 type Error = crate::Error;
206
207 async fn try_from_config(_: &quokka_config::Config) -> crate::Result<Self>
208 where
209 Self: Sized,
210 {
211 Ok(Self::default())
212 }
213}
214
215impl<S: Send + Sync + Clone + ProvideStateRef<Commands<S>> + 'static> CommandStateExt<S> for S {
216 fn commands(&mut self) -> &mut Commands<S> {
217 self.provide_mut()
218 }
219
220 fn register_command<H: CommandHandler + FromState<S> + 'static>(&mut self) {
221 self.commands().register_command::<H>();
222 }
223}