repl_rs/
repl.rs

1use crate::error::*;
2use crate::help::{DefaultHelpViewer, HelpContext, HelpEntry, HelpViewer};
3use crate::Value;
4use crate::{Command, Parameter};
5use rustyline::completion;
6use rustyline_derive::{Helper, Highlighter, Hinter, Validator};
7use std::boxed::Box;
8use std::collections::HashMap;
9use std::fmt::Display;
10use yansi::Paint;
11
12type ErrorHandler<Context, E> = fn(error: E, repl: &Repl<Context, E>) -> Result<()>;
13
14fn default_error_handler<Context, E: std::fmt::Display>(
15    error: E,
16    _repl: &Repl<Context, E>,
17) -> Result<()> {
18    eprintln!("{}", error);
19    Ok(())
20}
21
22/// Main REPL struct
23pub struct Repl<Context, E: std::fmt::Display> {
24    name: String,
25    version: String,
26    description: String,
27    prompt: Box<dyn Display>,
28    custom_prompt: bool,
29    commands: HashMap<String, Command<Context, E>>,
30    context: Context,
31    help_context: Option<HelpContext>,
32    help_viewer: Box<dyn HelpViewer>,
33    error_handler: ErrorHandler<Context, E>,
34    use_completion: bool,
35}
36
37impl<Context, E> Repl<Context, E>
38where
39    E: Display + From<Error>,
40{
41    /// Create a new Repl with the given context's initial value.
42    pub fn new(context: Context) -> Self {
43        let name = String::new();
44
45        Self {
46            name: name.clone(),
47            version: String::new(),
48            description: String::new(),
49            prompt: Box::new(Paint::green(format!("{}> ", name)).bold()),
50            custom_prompt: false,
51            commands: HashMap::new(),
52            context,
53            help_context: None,
54            help_viewer: Box::new(DefaultHelpViewer::new()),
55            error_handler: default_error_handler,
56            use_completion: false,
57        }
58    }
59
60    /// Give your Repl a name. This is used in the help summary for the Repl.
61    pub fn with_name(mut self, name: &str) -> Self {
62        self.name = name.to_string();
63        if !self.custom_prompt {
64            self.prompt = Box::new(Paint::green(format!("{}> ", name)).bold());
65        }
66
67        self
68    }
69
70    /// Give your Repl a version. This is used in the help summary for the Repl.
71    pub fn with_version(mut self, version: &str) -> Self {
72        self.version = version.to_string();
73
74        self
75    }
76
77    /// Give your Repl a description. This is used in the help summary for the Repl.
78    pub fn with_description(mut self, description: &str) -> Self {
79        self.description = description.to_string();
80
81        self
82    }
83
84    /// Give your Repl a custom prompt. The default prompt is the Repl name, followed by
85    /// a `>`, all in green, followed by a space.
86    pub fn with_prompt(mut self, prompt: &'static dyn Display) -> Self {
87        self.prompt = Box::new(prompt);
88        self.custom_prompt = true;
89
90        self
91    }
92
93    /// Pass in a custom help viewer
94    pub fn with_help_viewer<V: 'static + HelpViewer>(mut self, help_viewer: V) -> Self {
95        self.help_viewer = Box::new(help_viewer);
96
97        self
98    }
99
100    /// Pass in a custom error handler. This is really only for testing - the default
101    /// error handler simply prints the error to stderr and then returns
102    pub fn with_error_handler(mut self, handler: ErrorHandler<Context, E>) -> Self {
103        self.error_handler = handler;
104
105        self
106    }
107
108    /// Set whether to use command completion when tab is hit. Defaults to false.
109    pub fn use_completion(mut self, value: bool) -> Self {
110        self.use_completion = value;
111
112        self
113    }
114
115    /// Add a command to your REPL
116    pub fn add_command(mut self, command: Command<Context, E>) -> Self {
117        self.commands.insert(command.name.clone(), command);
118
119        self
120    }
121
122    fn validate_arguments(
123        &self,
124        command: &str,
125        parameters: &[Parameter],
126        args: &[&str],
127    ) -> Result<HashMap<String, Value>> {
128        if args.len() > parameters.len() {
129            return Err(Error::TooManyArguments(command.into(), parameters.len()));
130        }
131
132        let mut validated = HashMap::new();
133        for (index, parameter) in parameters.iter().enumerate() {
134            if index < args.len() {
135                validated.insert(parameter.name.clone(), Value::new(args[index]));
136            } else if parameter.required {
137                return Err(Error::MissingRequiredArgument(
138                    command.into(),
139                    parameter.name.clone(),
140                ));
141            } else if parameter.default.is_some() {
142                validated.insert(
143                    parameter.name.clone(),
144                    Value::new(&parameter.default.clone().unwrap()),
145                );
146            }
147        }
148        Ok(validated)
149    }
150
151    fn handle_command(&mut self, command: &str, args: &[&str]) -> core::result::Result<(), E> {
152        match self.commands.get(command) {
153            Some(definition) => {
154                let validated = self.validate_arguments(command, &definition.parameters, args)?;
155                match (definition.callback)(validated, &mut self.context) {
156                    Ok(Some(value)) => println!("{}", value),
157                    Ok(None) => (),
158                    Err(error) => return Err(error),
159                };
160            }
161            None => {
162                if command == "help" {
163                    self.show_help(args)?;
164                } else {
165                    return Err(Error::UnknownCommand(command.to_string()).into());
166                }
167            }
168        }
169
170        Ok(())
171    }
172
173    fn show_help(&self, args: &[&str]) -> Result<()> {
174        if args.is_empty() {
175            self.help_viewer
176                .help_general(self.help_context.as_ref().unwrap())?;
177        } else {
178            let entry_opt = self
179                .help_context
180                .as_ref()
181                .unwrap()
182                .help_entries
183                .iter()
184                .find(|entry| entry.command == args[0]);
185            match entry_opt {
186                Some(entry) => {
187                    self.help_viewer.help_command(entry)?;
188                }
189                None => eprintln!("Help not found for command '{}'", args[0]),
190            };
191        }
192        Ok(())
193    }
194
195    fn process_line(&mut self, line: String) -> core::result::Result<(), E> {
196        let trimmed = line.trim();
197        if !trimmed.is_empty() {
198            let r = regex::Regex::new(r#"("[^"\n]+"|[\S]+)"#).unwrap();
199            let args = r
200                .captures_iter(trimmed)
201                .map(|a| a[0].to_string().replace('\"', ""))
202                .collect::<Vec<String>>();
203            let mut args = args.iter().fold(vec![], |mut state, a| {
204                state.push(a.as_str());
205                state
206            });
207            let command: String = args.drain(..1).collect();
208            self.handle_command(&command, &args)?;
209        }
210        Ok(())
211    }
212
213    fn construct_help_context(&mut self) {
214        let mut help_entries = self
215            .commands
216            .values()
217            .map(|definition| {
218                HelpEntry::new(
219                    &definition.name,
220                    &definition.parameters,
221                    &definition.help_summary,
222                )
223            })
224            .collect::<Vec<HelpEntry>>();
225        help_entries.sort_by_key(|d| d.command.clone());
226        self.help_context = Some(HelpContext::new(
227            &self.name,
228            &self.version,
229            &self.description,
230            help_entries,
231        ));
232    }
233
234    fn create_helper(&mut self) -> Helper {
235        let mut helper = Helper::new();
236        if self.use_completion {
237            for name in self.commands.keys() {
238                helper.add_command(name.to_string());
239            }
240        }
241
242        helper
243    }
244
245    pub fn run(&mut self) -> Result<()> {
246        self.construct_help_context();
247        let mut editor: rustyline::Editor<Helper> = rustyline::Editor::new();
248        let helper = Some(self.create_helper());
249        editor.set_helper(helper);
250        println!("Welcome to {} {}", self.name, self.version);
251        let mut eof = false;
252        while !eof {
253            self.handle_line(&mut editor, &mut eof)?;
254        }
255
256        Ok(())
257    }
258
259    fn handle_line(
260        &mut self,
261        editor: &mut rustyline::Editor<Helper>,
262        eof: &mut bool,
263    ) -> Result<()> {
264        match editor.readline(&format!("{}", self.prompt)) {
265            Ok(line) => {
266                editor.add_history_entry(line.clone());
267                if let Err(error) = self.process_line(line) {
268                    (self.error_handler)(error, self)?;
269                }
270                *eof = false;
271                Ok(())
272            }
273            Err(rustyline::error::ReadlineError::Eof) => {
274                *eof = true;
275                Ok(())
276            }
277            Err(error) => {
278                eprintln!("Error reading line: {}", error);
279                *eof = false;
280                Ok(())
281            }
282        }
283    }
284}
285
286// rustyline Helper struct
287// Currently just does command completion with <tab>, if
288// use_completion() is set on the REPL
289#[derive(Clone, Helper, Hinter, Highlighter, Validator)]
290struct Helper {
291    commands: Vec<String>,
292}
293
294impl Helper {
295    fn new() -> Self {
296        Self { commands: vec![] }
297    }
298
299    fn add_command(&mut self, command: String) {
300        self.commands.push(command);
301    }
302}
303
304impl completion::Completer for Helper {
305    type Candidate = String;
306
307    fn complete(
308        &self,
309        line: &str,
310        _pos: usize,
311        _ctx: &rustyline::Context<'_>,
312    ) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
313        // Complete based on whether the current line is a substring
314        // of one of the set commands
315        let ret: Vec<Self::Candidate> = self
316            .commands
317            .iter()
318            .filter(|cmd| cmd.contains(line))
319            .map(|s| s.to_string())
320            .collect();
321        Ok((0, ret))
322    }
323}
324
325#[cfg(all(test, unix))]
326mod tests {
327    use crate::error::*;
328    use crate::repl::{Helper, Repl};
329    use crate::{initialize_repl, Value};
330    use crate::{Command, Parameter};
331    use clap::{crate_description, crate_name, crate_version};
332    use nix::sys::wait::{waitpid, WaitStatus};
333    use nix::unistd::{close, dup2, fork, pipe, ForkResult};
334    use std::collections::HashMap;
335    use std::fs::File;
336    use std::io::Write;
337    use std::os::unix::io::FromRawFd;
338
339    fn test_error_handler<Context>(error: Error, _repl: &Repl<Context, Error>) -> Result<()> {
340        Err(error)
341    }
342
343    fn foo<T>(args: HashMap<String, Value>, _context: &mut T) -> Result<Option<String>> {
344        Ok(Some(format!("foo {:?}", args)))
345    }
346
347    fn run_repl<Context>(mut repl: Repl<Context, Error>, input: &str, expected: Result<()>) {
348        let (rdr, wrtr) = pipe().unwrap();
349        unsafe {
350            match fork() {
351                Ok(ForkResult::Parent { child, .. }) => {
352                    // Parent
353                    let mut f = File::from_raw_fd(wrtr);
354                    write!(f, "{}", input).unwrap();
355                    if let WaitStatus::Exited(_, exit_code) = waitpid(child, None).unwrap() {
356                        assert!(exit_code == 0);
357                    };
358                }
359                Ok(ForkResult::Child) => {
360                    std::panic::set_hook(Box::new(|panic_info| {
361                        println!("Caught panic: {:?}", panic_info);
362                        if let Some(location) = panic_info.location() {
363                            println!(
364                                "panic occurred in file '{}' at line {}",
365                                location.file(),
366                                location.line(),
367                            );
368                        } else {
369                            println!("panic occurred but can't get location information...");
370                        }
371                    }));
372
373                    dup2(rdr, 0).unwrap();
374                    close(rdr).unwrap();
375                    let mut editor: rustyline::Editor<Helper> = rustyline::Editor::new();
376                    let mut eof = false;
377                    let result = repl.handle_line(&mut editor, &mut eof);
378                    let _ = std::panic::take_hook();
379                    if expected == result {
380                        std::process::exit(0);
381                    } else {
382                        eprintln!("Expected {:?}, got {:?}", expected, result);
383                        std::process::exit(1);
384                    }
385                }
386                Err(_) => println!("Fork failed"),
387            }
388        }
389    }
390
391    #[test]
392    fn test_initialize_sets_crate_values() -> Result<()> {
393        let repl: Repl<(), Error> = initialize_repl!(());
394
395        assert_eq!(crate_name!(), repl.name);
396        assert_eq!(crate_version!(), repl.version);
397        assert_eq!(crate_description!(), repl.description);
398
399        Ok(())
400    }
401
402    #[test]
403    fn test_empty_line_does_nothing() -> Result<()> {
404        let repl = Repl::new(())
405            .with_name("test")
406            .with_version("v0.1.0")
407            .with_description("Testing 1, 2, 3...")
408            .with_error_handler(test_error_handler)
409            .add_command(
410                Command::new("foo", foo)
411                    .with_parameter(Parameter::new("bar").set_required(true)?)?
412                    .with_parameter(Parameter::new("baz").set_required(true)?)?
413                    .with_help("Do foo when you can"),
414            );
415        run_repl(repl, "\n", Ok(()));
416
417        Ok(())
418    }
419
420    #[test]
421    fn test_missing_required_arg_fails() -> Result<()> {
422        let repl = Repl::new(())
423            .with_name("test")
424            .with_version("v0.1.0")
425            .with_description("Testing 1, 2, 3...")
426            .with_error_handler(test_error_handler)
427            .add_command(
428                Command::new("foo", foo)
429                    .with_parameter(Parameter::new("bar").set_required(true)?)?
430                    .with_parameter(Parameter::new("baz").set_required(true)?)?
431                    .with_help("Do foo when you can"),
432            );
433        run_repl(
434            repl,
435            "foo bar\n",
436            Err(Error::MissingRequiredArgument("foo".into(), "baz".into())),
437        );
438
439        Ok(())
440    }
441
442    #[test]
443    fn test_unknown_command_fails() -> Result<()> {
444        let repl = Repl::new(())
445            .with_name("test")
446            .with_version("v0.1.0")
447            .with_description("Testing 1, 2, 3...")
448            .with_error_handler(test_error_handler)
449            .add_command(
450                Command::new("foo", foo)
451                    .with_parameter(Parameter::new("bar").set_required(true)?)?
452                    .with_parameter(Parameter::new("baz").set_required(true)?)?
453                    .with_help("Do foo when you can"),
454            );
455        run_repl(
456            repl,
457            "bar baz\n",
458            Err(Error::UnknownCommand("bar".to_string())),
459        );
460
461        Ok(())
462    }
463
464    #[test]
465    fn test_no_required_after_optional() -> Result<()> {
466        assert_eq!(
467            Err(Error::IllegalRequiredError("bar".into())),
468            Command::<(), Error>::new("foo", foo)
469                .with_parameter(Parameter::new("baz").set_default("20")?)?
470                .with_parameter(Parameter::new("bar").set_required(true)?)
471        );
472
473        Ok(())
474    }
475
476    #[test]
477    fn test_required_cannot_be_defaulted() -> Result<()> {
478        assert_eq!(
479            Err(Error::IllegalDefaultError("bar".into())),
480            Parameter::new("bar").set_required(true)?.set_default("foo")
481        );
482
483        Ok(())
484    }
485
486    #[test]
487    fn test_string_with_spaces_for_argument() -> Result<()> {
488        let repl = Repl::new(())
489            .with_name("test")
490            .with_version("v0.1.0")
491            .with_description("Testing 1, 2, 3...")
492            .with_error_handler(test_error_handler)
493            .add_command(
494                Command::new("foo", foo)
495                    .with_parameter(Parameter::new("bar").set_required(true)?)?
496                    .with_parameter(Parameter::new("baz").set_required(true)?)?
497                    .with_help("Do foo when you can"),
498            );
499        run_repl(repl, "foo \"baz test 123\" foo\n", Ok(()));
500
501        Ok(())
502    }
503
504    #[test]
505    fn test_string_with_spaces_for_argument_last() -> Result<()> {
506        let repl = Repl::new(())
507            .with_name("test")
508            .with_version("v0.1.0")
509            .with_description("Testing 1, 2, 3...")
510            .with_error_handler(test_error_handler)
511            .add_command(
512                Command::new("foo", foo)
513                    .with_parameter(Parameter::new("bar").set_required(true)?)?
514                    .with_parameter(Parameter::new("baz").set_required(true)?)?
515                    .with_help("Do foo when you can"),
516            );
517        run_repl(repl, "foo foo \"baz test 123\"\n", Ok(()));
518
519        Ok(())
520    }
521}