Skip to main content

repl_rs/
repl.rs

1use crate::Value;
2use crate::error::*;
3use crate::help::{DefaultHelpViewer, HelpContext, HelpEntry, HelpViewer};
4use crate::{Command, Parameter};
5use rustyline::completion;
6use rustyline::history::DefaultHistory;
7use rustyline_derive::{Helper, Highlighter, Hinter, Validator};
8use std::boxed::Box;
9use std::collections::HashMap;
10use std::fmt::Display;
11use yansi::Paint;
12
13type ErrorHandler<Context, E> = fn(error: E, repl: &Repl<Context, E>) -> Result<()>;
14
15fn default_error_handler<Context, E: std::fmt::Display>(
16    error: E,
17    _repl: &Repl<Context, E>,
18) -> Result<()> {
19    eprintln!("{}", error);
20    Ok(())
21}
22
23/// Main REPL struct
24pub struct Repl<Context, E: std::fmt::Display> {
25    name: String,
26    version: String,
27    description: String,
28    prompt: Box<dyn Display>,
29    custom_prompt: bool,
30    commands: HashMap<String, Command<Context, E>>,
31    context: Context,
32    help_context: Option<HelpContext>,
33    help_viewer: Box<dyn HelpViewer>,
34    error_handler: ErrorHandler<Context, E>,
35    use_completion: bool,
36}
37
38impl<Context, E> Repl<Context, E>
39where
40    E: Display + From<Error>,
41{
42    /// Create a new Repl with the given context's initial value.
43    pub fn new(context: Context) -> Self {
44        let name = String::new();
45
46        Self {
47            name: name.clone(),
48            version: String::new(),
49            description: String::new(),
50            prompt: Box::new(format!("{}> ", name.green().bold())),
51            custom_prompt: false,
52            commands: HashMap::new(),
53            context,
54            help_context: None,
55            help_viewer: Box::new(DefaultHelpViewer::new()),
56            error_handler: default_error_handler,
57            use_completion: false,
58        }
59    }
60
61    /// Give your Repl a name. This is used in the help summary for the Repl.
62    pub fn with_name(mut self, name: &str) -> Self {
63        self.name = name.to_string();
64        if !self.custom_prompt {
65            self.prompt = Box::new(format!("{}> ", name.green().bold()));
66        }
67
68        self
69    }
70
71    /// Give your Repl a version. This is used in the help summary for the Repl.
72    pub fn with_version(mut self, version: &str) -> Self {
73        self.version = version.to_string();
74
75        self
76    }
77
78    /// Give your Repl a description. This is used in the help summary for the Repl.
79    pub fn with_description(mut self, description: &str) -> Self {
80        self.description = description.to_string();
81
82        self
83    }
84
85    /// Give your Repl a custom prompt. The default prompt is the Repl name, followed by
86    /// a `>`, all in green, followed by a space.
87    pub fn with_prompt(mut self, prompt: &'static dyn Display) -> Self {
88        self.prompt = Box::new(prompt);
89        self.custom_prompt = true;
90
91        self
92    }
93
94    /// Pass in a custom help viewer
95    pub fn with_help_viewer<V: 'static + HelpViewer>(mut self, help_viewer: V) -> Self {
96        self.help_viewer = Box::new(help_viewer);
97
98        self
99    }
100
101    /// Pass in a custom error handler. This is really only for testing - the default
102    /// error handler simply prints the error to stderr and then returns
103    pub fn with_error_handler(mut self, handler: ErrorHandler<Context, E>) -> Self {
104        self.error_handler = handler;
105
106        self
107    }
108
109    /// Set whether to use command completion when tab is hit. Defaults to false.
110    pub fn use_completion(mut self, value: bool) -> Self {
111        self.use_completion = value;
112
113        self
114    }
115
116    /// Add a command to your REPL
117    pub fn add_command(mut self, command: Command<Context, E>) -> Self {
118        self.commands.insert(command.name.clone(), command);
119
120        self
121    }
122
123    fn validate_arguments(
124        &self,
125        command: &str,
126        parameters: &[Parameter],
127        args: &[&str],
128    ) -> Result<HashMap<String, Value>> {
129        if args.len() > parameters.len() {
130            return Err(Error::TooManyArguments(command.into(), parameters.len()));
131        }
132
133        let mut validated = HashMap::new();
134        for (index, parameter) in parameters.iter().enumerate() {
135            if index < args.len() {
136                validated.insert(parameter.name.clone(), Value::new(args[index]));
137            } else if parameter.required {
138                return Err(Error::MissingRequiredArgument(
139                    command.into(),
140                    parameter.name.clone(),
141                ));
142            } else if parameter.default.is_some() {
143                validated.insert(
144                    parameter.name.clone(),
145                    Value::new(&parameter.default.clone().unwrap()),
146                );
147            }
148        }
149        Ok(validated)
150    }
151
152    fn handle_command(&mut self, command: &str, args: &[&str]) -> core::result::Result<(), E> {
153        match self.commands.get(command) {
154            Some(definition) => {
155                let validated = self.validate_arguments(command, &definition.parameters, args)?;
156                match (definition.callback)(validated, &mut self.context) {
157                    Ok(Some(value)) => println!("{}", value),
158                    Ok(None) => (),
159                    Err(error) => return Err(error),
160                };
161            }
162            None => {
163                if command == "help" {
164                    self.show_help(args)?;
165                } else {
166                    return Err(Error::UnknownCommand(command.to_string()).into());
167                }
168            }
169        }
170
171        Ok(())
172    }
173
174    fn show_help(&self, args: &[&str]) -> Result<()> {
175        if args.is_empty() {
176            self.help_viewer
177                .help_general(self.help_context.as_ref().unwrap())?;
178        } else {
179            let entry_opt = self
180                .help_context
181                .as_ref()
182                .unwrap()
183                .help_entries
184                .iter()
185                .find(|entry| entry.command == args[0]);
186            match entry_opt {
187                Some(entry) => {
188                    self.help_viewer.help_command(entry)?;
189                }
190                None => eprintln!("Help not found for command '{}'", args[0]),
191            };
192        }
193        Ok(())
194    }
195
196    fn process_line(&mut self, line: String) -> core::result::Result<(), E> {
197        let trimmed = line.trim();
198        if !trimmed.is_empty() {
199            let r = regex::Regex::new(r#"("[^"\n]+"|[\S]+)"#).unwrap();
200            let args = r
201                .captures_iter(trimmed)
202                .map(|a| a[0].to_string().replace('\"', ""))
203                .collect::<Vec<String>>();
204            let mut args = args.iter().fold(vec![], |mut state, a| {
205                state.push(a.as_str());
206                state
207            });
208            let command: String = args.drain(..1).collect();
209            self.handle_command(&command, &args)?;
210        }
211        Ok(())
212    }
213
214    fn construct_help_context(&mut self) {
215        let mut help_entries = self
216            .commands
217            .values()
218            .map(|definition| {
219                HelpEntry::new(
220                    &definition.name,
221                    &definition.parameters,
222                    &definition.help_summary,
223                )
224            })
225            .collect::<Vec<HelpEntry>>();
226        help_entries.sort_by_key(|d| d.command.clone());
227        self.help_context = Some(HelpContext::new(
228            &self.name,
229            &self.version,
230            &self.description,
231            help_entries,
232        ));
233    }
234
235    fn create_helper(&mut self) -> Helper {
236        let mut helper = Helper::new();
237        if self.use_completion {
238            for name in self.commands.keys() {
239                helper.add_command(name.to_string());
240            }
241        }
242
243        helper
244    }
245
246    pub fn run(&mut self) -> Result<()> {
247        self.construct_help_context();
248        let mut editor: rustyline::Editor<Helper, DefaultHistory> =
249            rustyline::Editor::new().map_err(|e| Error::UnknownError(e.to_string()))?;
250        let helper = Some(self.create_helper());
251        editor.set_helper(helper);
252        println!("Welcome to {} {}", self.name, self.version);
253        let mut eof = false;
254        while !eof {
255            self.handle_line(&mut editor, &mut eof)?;
256        }
257
258        Ok(())
259    }
260
261    fn handle_line(
262        &mut self,
263        editor: &mut rustyline::Editor<Helper, DefaultHistory>,
264        eof: &mut bool,
265    ) -> Result<()> {
266        match editor.readline(&format!("{}", self.prompt)) {
267            Ok(line) => {
268                editor
269                    .add_history_entry(line.clone())
270                    .map_err(|e| Error::UnknownError(e.to_string()))?;
271                if let Err(error) = self.process_line(line) {
272                    (self.error_handler)(error, self)?;
273                }
274                *eof = false;
275                Ok(())
276            }
277            Err(rustyline::error::ReadlineError::Eof) => {
278                *eof = true;
279                Ok(())
280            }
281            Err(error) => {
282                eprintln!("Error reading line: {}", error);
283                *eof = false;
284                Ok(())
285            }
286        }
287    }
288}
289
290// rustyline Helper struct
291// Currently just does command completion with <tab>, if
292// use_completion() is set on the REPL
293#[derive(Clone, Helper, Hinter, Highlighter, Validator)]
294struct Helper {
295    commands: Vec<String>,
296}
297
298impl Helper {
299    fn new() -> Self {
300        Self { commands: vec![] }
301    }
302
303    fn add_command(&mut self, command: String) {
304        self.commands.push(command);
305    }
306}
307
308impl completion::Completer for Helper {
309    type Candidate = String;
310
311    fn complete(
312        &self,
313        line: &str,
314        _pos: usize,
315        _ctx: &rustyline::Context<'_>,
316    ) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
317        // Complete based on whether the current line is a substring
318        // of one of the set commands
319        let ret: Vec<Self::Candidate> = self
320            .commands
321            .iter()
322            .filter(|cmd| cmd.contains(line))
323            .map(|s| s.to_string())
324            .collect();
325        Ok((0, ret))
326    }
327}
328
329#[cfg(all(test, unix))]
330mod tests {
331    use crate::error::*;
332    use crate::repl::{Helper, Repl};
333    use crate::{Command, Parameter};
334    use crate::{Value, initialize_repl};
335    use clap::{crate_description, crate_name, crate_version};
336    use nix::sys::wait::{WaitStatus, waitpid};
337    use nix::unistd::{ForkResult, close, dup2_stdin, fork};
338    use rustyline::history::DefaultHistory;
339    use std::collections::HashMap;
340    use std::io::{Write, pipe};
341    use tokio::runtime::Runtime;
342
343    fn test_error_handler<Context>(error: Error, _repl: &Repl<Context, Error>) -> Result<()> {
344        Err(error)
345    }
346
347    fn foo<T>(args: HashMap<String, Value>, _context: &mut T) -> Result<Option<String>> {
348        Ok(Some(format!("foo {:?}", args)))
349    }
350
351    fn run_repl<Context>(mut repl: Repl<Context, Error>, input: &str, expected: Result<()>) {
352        let (rdr, mut wrtr) = pipe().unwrap();
353        unsafe {
354            match fork() {
355                Ok(ForkResult::Parent { child, .. }) => {
356                    // Parent
357                    wrtr.write(input.as_bytes()).unwrap();
358                    if let WaitStatus::Exited(_, exit_code) = waitpid(child, None).unwrap() {
359                        assert!(exit_code == 0);
360                    };
361                }
362                Ok(ForkResult::Child) => {
363                    std::panic::set_hook(Box::new(|panic_info| {
364                        if let Some(location) = panic_info.location() {
365                            println!(
366                                "panic occurred in file '{}' at line {}",
367                                location.file(),
368                                location.line(),
369                            );
370                        } else {
371                            println!("panic occurred but can't get location information...");
372                        }
373                    }));
374
375                    dup2_stdin(&rdr).unwrap();
376                    close(rdr).unwrap();
377                    let mut editor: rustyline::Editor<Helper, DefaultHistory> =
378                        rustyline::Editor::new().expect("Error creating editor");
379                    let mut eof = false;
380                    let result = repl.handle_line(&mut editor, &mut eof);
381                    let _ = std::panic::take_hook();
382                    if expected == result {
383                        std::process::exit(0);
384                    } else {
385                        eprintln!("Expected {:?}, got {:?}", expected, result);
386                        std::process::exit(1);
387                    }
388                }
389                Err(_) => println!("Fork failed"),
390            }
391        }
392    }
393
394    #[test]
395    fn test_initialize_sets_crate_values() -> Result<()> {
396        let repl: Repl<(), Error> = initialize_repl!(());
397
398        assert_eq!(crate_name!(), repl.name);
399        assert_eq!(crate_version!(), repl.version);
400        assert_eq!(crate_description!(), repl.description);
401
402        Ok(())
403    }
404
405    #[test]
406    fn test_empty_line_does_nothing() -> Result<()> {
407        let repl = Repl::new(())
408            .with_name("test")
409            .with_version("v0.1.0")
410            .with_description("Testing 1, 2, 3...")
411            .with_error_handler(test_error_handler)
412            .add_command(
413                Command::new("foo", foo)
414                    .with_parameter(Parameter::new("bar").set_required(true)?)?
415                    .with_parameter(Parameter::new("baz").set_required(true)?)?
416                    .with_help("Do foo when you can"),
417            );
418        run_repl(repl, "\n", Ok(()));
419
420        Ok(())
421    }
422
423    #[test]
424    fn test_missing_required_arg_fails() -> Result<()> {
425        let repl = Repl::new(())
426            .with_name("test")
427            .with_version("v0.1.0")
428            .with_description("Testing 1, 2, 3...")
429            .with_error_handler(test_error_handler)
430            .add_command(
431                Command::new("foo", foo)
432                    .with_parameter(Parameter::new("bar").set_required(true)?)?
433                    .with_parameter(Parameter::new("baz").set_required(true)?)?
434                    .with_help("Do foo when you can"),
435            );
436        run_repl(
437            repl,
438            "foo bar\n",
439            Err(Error::MissingRequiredArgument("foo".into(), "baz".into())),
440        );
441
442        Ok(())
443    }
444
445    #[test]
446    fn test_unknown_command_fails() -> Result<()> {
447        let repl = Repl::new(())
448            .with_name("test")
449            .with_version("v0.1.0")
450            .with_description("Testing 1, 2, 3...")
451            .with_error_handler(test_error_handler)
452            .add_command(
453                Command::new("foo", foo)
454                    .with_parameter(Parameter::new("bar").set_required(true)?)?
455                    .with_parameter(Parameter::new("baz").set_required(true)?)?
456                    .with_help("Do foo when you can"),
457            );
458        run_repl(
459            repl,
460            "bar baz\n",
461            Err(Error::UnknownCommand("bar".to_string())),
462        );
463
464        Ok(())
465    }
466
467    #[test]
468    fn test_no_required_after_optional() -> Result<()> {
469        assert_eq!(
470            Err(Error::IllegalRequiredError("bar".into())),
471            Command::<(), Error>::new("foo", foo)
472                .with_parameter(Parameter::new("baz").set_default("20")?)?
473                .with_parameter(Parameter::new("bar").set_required(true)?)
474        );
475
476        Ok(())
477    }
478
479    #[test]
480    fn test_required_cannot_be_defaulted() -> Result<()> {
481        assert_eq!(
482            Err(Error::IllegalDefaultError("bar".into())),
483            Parameter::new("bar").set_required(true)?.set_default("foo")
484        );
485
486        Ok(())
487    }
488
489    #[test]
490    fn test_string_with_spaces_for_argument() -> Result<()> {
491        let repl = Repl::new(())
492            .with_name("test")
493            .with_version("v0.1.0")
494            .with_description("Testing 1, 2, 3...")
495            .with_error_handler(test_error_handler)
496            .add_command(
497                Command::new("foo", foo)
498                    .with_parameter(Parameter::new("bar").set_required(true)?)?
499                    .with_parameter(Parameter::new("baz").set_required(true)?)?
500                    .with_help("Do foo when you can"),
501            );
502        run_repl(repl, "foo \"baz test 123\" foo\n", Ok(()));
503
504        Ok(())
505    }
506
507    #[test]
508    fn test_string_with_spaces_for_argument_last() -> Result<()> {
509        let repl = Repl::new(())
510            .with_name("test")
511            .with_version("v0.1.0")
512            .with_description("Testing 1, 2, 3...")
513            .with_error_handler(test_error_handler)
514            .add_command(
515                Command::new("foo", foo)
516                    .with_parameter(Parameter::new("bar").set_required(true)?)?
517                    .with_parameter(Parameter::new("baz").set_required(true)?)?
518                    .with_help("Do foo when you can"),
519            );
520        run_repl(repl, "foo foo \"baz test 123\"\n", Ok(()));
521
522        Ok(())
523    }
524
525    async fn async_bar<T>(args: HashMap<String, Value>) -> Result<Option<String>> {
526        Ok(Some(format!("foo {:?}", args)))
527    }
528
529    fn bar<T>(args: HashMap<String, Value>, _context: &mut T) -> Result<Option<String>> {
530        Runtime::new()
531            .unwrap()
532            .block_on(async move { async_bar::<T>(args).await })
533    }
534
535    #[test]
536    fn test_async_calling() -> Result<()> {
537        let repl = Repl::new(())
538            .with_name("test")
539            .with_version("v0.1.0")
540            .with_description("Testing 1, 2, 3...")
541            .with_error_handler(test_error_handler)
542            .add_command(Command::new("bar", bar).with_help("Do bar when you can"));
543        run_repl(repl, "\n", Ok(()));
544
545        Ok(())
546    }
547}