1use crate::error::*;
2use crate::help::{DefaultHelpViewer, HelpContext, HelpEntry, HelpViewer};
3use crate::Value;
4use crate::{Command, Parameter};
5use rustyline::completion;
6use rustyline_derive::{Helper, Highlighter, Hinter, Validator};
7use std::boxed::Box;
8use std::collections::HashMap;
9use std::fmt::Display;
10use yansi::Paint;
11
12type ErrorHandler<Context, E> = fn(error: E, repl: &Repl<Context, E>) -> Result<()>;
13
14fn default_error_handler<Context, E: std::fmt::Display>(
15 error: E,
16 _repl: &Repl<Context, E>,
17) -> Result<()> {
18 eprintln!("{}", error);
19 Ok(())
20}
21
22pub struct Repl<Context, E: std::fmt::Display> {
24 name: String,
25 version: String,
26 description: String,
27 prompt: Box<dyn Display>,
28 custom_prompt: bool,
29 commands: HashMap<String, Command<Context, E>>,
30 context: Context,
31 help_context: Option<HelpContext>,
32 help_viewer: Box<dyn HelpViewer>,
33 error_handler: ErrorHandler<Context, E>,
34 use_completion: bool,
35}
36
37impl<Context, E> Repl<Context, E>
38where
39 E: Display + From<Error>,
40{
41 pub fn new(context: Context) -> Self {
43 let name = String::new();
44
45 Self {
46 name: name.clone(),
47 version: String::new(),
48 description: String::new(),
49 prompt: Box::new(Paint::green(format!("{}> ", name)).bold()),
50 custom_prompt: false,
51 commands: HashMap::new(),
52 context,
53 help_context: None,
54 help_viewer: Box::new(DefaultHelpViewer::new()),
55 error_handler: default_error_handler,
56 use_completion: false,
57 }
58 }
59
60 pub fn with_name(mut self, name: &str) -> Self {
62 self.name = name.to_string();
63 if !self.custom_prompt {
64 self.prompt = Box::new(Paint::green(format!("{}> ", name)).bold());
65 }
66
67 self
68 }
69
70 pub fn with_version(mut self, version: &str) -> Self {
72 self.version = version.to_string();
73
74 self
75 }
76
77 pub fn with_description(mut self, description: &str) -> Self {
79 self.description = description.to_string();
80
81 self
82 }
83
84 pub fn with_prompt(mut self, prompt: &'static dyn Display) -> Self {
87 self.prompt = Box::new(prompt);
88 self.custom_prompt = true;
89
90 self
91 }
92
93 pub fn with_help_viewer<V: 'static + HelpViewer>(mut self, help_viewer: V) -> Self {
95 self.help_viewer = Box::new(help_viewer);
96
97 self
98 }
99
100 pub fn with_error_handler(mut self, handler: ErrorHandler<Context, E>) -> Self {
103 self.error_handler = handler;
104
105 self
106 }
107
108 pub fn use_completion(mut self, value: bool) -> Self {
110 self.use_completion = value;
111
112 self
113 }
114
115 pub fn add_command(mut self, command: Command<Context, E>) -> Self {
117 self.commands.insert(command.name.clone(), command);
118
119 self
120 }
121
122 fn validate_arguments(
123 &self,
124 command: &str,
125 parameters: &[Parameter],
126 args: &[&str],
127 ) -> Result<HashMap<String, Value>> {
128 if args.len() > parameters.len() {
129 return Err(Error::TooManyArguments(command.into(), parameters.len()));
130 }
131
132 let mut validated = HashMap::new();
133 for (index, parameter) in parameters.iter().enumerate() {
134 if index < args.len() {
135 validated.insert(parameter.name.clone(), Value::new(args[index]));
136 } else if parameter.required {
137 return Err(Error::MissingRequiredArgument(
138 command.into(),
139 parameter.name.clone(),
140 ));
141 } else if parameter.default.is_some() {
142 validated.insert(
143 parameter.name.clone(),
144 Value::new(¶meter.default.clone().unwrap()),
145 );
146 }
147 }
148 Ok(validated)
149 }
150
151 fn handle_command(&mut self, command: &str, args: &[&str]) -> core::result::Result<(), E> {
152 match self.commands.get(command) {
153 Some(definition) => {
154 let validated = self.validate_arguments(command, &definition.parameters, args)?;
155 match (definition.callback)(validated, &mut self.context) {
156 Ok(Some(value)) => println!("{}", value),
157 Ok(None) => (),
158 Err(error) => return Err(error),
159 };
160 }
161 None => {
162 if command == "help" {
163 self.show_help(args)?;
164 } else {
165 return Err(Error::UnknownCommand(command.to_string()).into());
166 }
167 }
168 }
169
170 Ok(())
171 }
172
173 fn show_help(&self, args: &[&str]) -> Result<()> {
174 if args.is_empty() {
175 self.help_viewer
176 .help_general(self.help_context.as_ref().unwrap())?;
177 } else {
178 let entry_opt = self
179 .help_context
180 .as_ref()
181 .unwrap()
182 .help_entries
183 .iter()
184 .find(|entry| entry.command == args[0]);
185 match entry_opt {
186 Some(entry) => {
187 self.help_viewer.help_command(entry)?;
188 }
189 None => eprintln!("Help not found for command '{}'", args[0]),
190 };
191 }
192 Ok(())
193 }
194
195 fn process_line(&mut self, line: String) -> core::result::Result<(), E> {
196 let trimmed = line.trim();
197 if !trimmed.is_empty() {
198 let r = regex::Regex::new(r#"("[^"\n]+"|[\S]+)"#).unwrap();
199 let args = r
200 .captures_iter(trimmed)
201 .map(|a| a[0].to_string().replace('\"', ""))
202 .collect::<Vec<String>>();
203 let mut args = args.iter().fold(vec![], |mut state, a| {
204 state.push(a.as_str());
205 state
206 });
207 let command: String = args.drain(..1).collect();
208 self.handle_command(&command, &args)?;
209 }
210 Ok(())
211 }
212
213 fn construct_help_context(&mut self) {
214 let mut help_entries = self
215 .commands
216 .values()
217 .map(|definition| {
218 HelpEntry::new(
219 &definition.name,
220 &definition.parameters,
221 &definition.help_summary,
222 )
223 })
224 .collect::<Vec<HelpEntry>>();
225 help_entries.sort_by_key(|d| d.command.clone());
226 self.help_context = Some(HelpContext::new(
227 &self.name,
228 &self.version,
229 &self.description,
230 help_entries,
231 ));
232 }
233
234 fn create_helper(&mut self) -> Helper {
235 let mut helper = Helper::new();
236 if self.use_completion {
237 for name in self.commands.keys() {
238 helper.add_command(name.to_string());
239 }
240 }
241
242 helper
243 }
244
245 pub fn run(&mut self) -> Result<()> {
246 self.construct_help_context();
247 let mut editor: rustyline::Editor<Helper> = rustyline::Editor::new();
248 let helper = Some(self.create_helper());
249 editor.set_helper(helper);
250 println!("Welcome to {} {}", self.name, self.version);
251 let mut eof = false;
252 while !eof {
253 self.handle_line(&mut editor, &mut eof)?;
254 }
255
256 Ok(())
257 }
258
259 fn handle_line(
260 &mut self,
261 editor: &mut rustyline::Editor<Helper>,
262 eof: &mut bool,
263 ) -> Result<()> {
264 match editor.readline(&format!("{}", self.prompt)) {
265 Ok(line) => {
266 editor.add_history_entry(line.clone());
267 if let Err(error) = self.process_line(line) {
268 (self.error_handler)(error, self)?;
269 }
270 *eof = false;
271 Ok(())
272 }
273 Err(rustyline::error::ReadlineError::Eof) => {
274 *eof = true;
275 Ok(())
276 }
277 Err(error) => {
278 eprintln!("Error reading line: {}", error);
279 *eof = false;
280 Ok(())
281 }
282 }
283 }
284}
285
286#[derive(Clone, Helper, Hinter, Highlighter, Validator)]
290struct Helper {
291 commands: Vec<String>,
292}
293
294impl Helper {
295 fn new() -> Self {
296 Self { commands: vec![] }
297 }
298
299 fn add_command(&mut self, command: String) {
300 self.commands.push(command);
301 }
302}
303
304impl completion::Completer for Helper {
305 type Candidate = String;
306
307 fn complete(
308 &self,
309 line: &str,
310 _pos: usize,
311 _ctx: &rustyline::Context<'_>,
312 ) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
313 let ret: Vec<Self::Candidate> = self
316 .commands
317 .iter()
318 .filter(|cmd| cmd.contains(line))
319 .map(|s| s.to_string())
320 .collect();
321 Ok((0, ret))
322 }
323}
324
325#[cfg(all(test, unix))]
326mod tests {
327 use crate::error::*;
328 use crate::repl::{Helper, Repl};
329 use crate::{initialize_repl, Value};
330 use crate::{Command, Parameter};
331 use clap::{crate_description, crate_name, crate_version};
332 use nix::sys::wait::{waitpid, WaitStatus};
333 use nix::unistd::{close, dup2, fork, pipe, ForkResult};
334 use std::collections::HashMap;
335 use std::fs::File;
336 use std::io::Write;
337 use std::os::unix::io::FromRawFd;
338
339 fn test_error_handler<Context>(error: Error, _repl: &Repl<Context, Error>) -> Result<()> {
340 Err(error)
341 }
342
343 fn foo<T>(args: HashMap<String, Value>, _context: &mut T) -> Result<Option<String>> {
344 Ok(Some(format!("foo {:?}", args)))
345 }
346
347 fn run_repl<Context>(mut repl: Repl<Context, Error>, input: &str, expected: Result<()>) {
348 let (rdr, wrtr) = pipe().unwrap();
349 unsafe {
350 match fork() {
351 Ok(ForkResult::Parent { child, .. }) => {
352 let mut f = File::from_raw_fd(wrtr);
354 write!(f, "{}", input).unwrap();
355 if let WaitStatus::Exited(_, exit_code) = waitpid(child, None).unwrap() {
356 assert!(exit_code == 0);
357 };
358 }
359 Ok(ForkResult::Child) => {
360 std::panic::set_hook(Box::new(|panic_info| {
361 println!("Caught panic: {:?}", panic_info);
362 if let Some(location) = panic_info.location() {
363 println!(
364 "panic occurred in file '{}' at line {}",
365 location.file(),
366 location.line(),
367 );
368 } else {
369 println!("panic occurred but can't get location information...");
370 }
371 }));
372
373 dup2(rdr, 0).unwrap();
374 close(rdr).unwrap();
375 let mut editor: rustyline::Editor<Helper> = rustyline::Editor::new();
376 let mut eof = false;
377 let result = repl.handle_line(&mut editor, &mut eof);
378 let _ = std::panic::take_hook();
379 if expected == result {
380 std::process::exit(0);
381 } else {
382 eprintln!("Expected {:?}, got {:?}", expected, result);
383 std::process::exit(1);
384 }
385 }
386 Err(_) => println!("Fork failed"),
387 }
388 }
389 }
390
391 #[test]
392 fn test_initialize_sets_crate_values() -> Result<()> {
393 let repl: Repl<(), Error> = initialize_repl!(());
394
395 assert_eq!(crate_name!(), repl.name);
396 assert_eq!(crate_version!(), repl.version);
397 assert_eq!(crate_description!(), repl.description);
398
399 Ok(())
400 }
401
402 #[test]
403 fn test_empty_line_does_nothing() -> Result<()> {
404 let repl = Repl::new(())
405 .with_name("test")
406 .with_version("v0.1.0")
407 .with_description("Testing 1, 2, 3...")
408 .with_error_handler(test_error_handler)
409 .add_command(
410 Command::new("foo", foo)
411 .with_parameter(Parameter::new("bar").set_required(true)?)?
412 .with_parameter(Parameter::new("baz").set_required(true)?)?
413 .with_help("Do foo when you can"),
414 );
415 run_repl(repl, "\n", Ok(()));
416
417 Ok(())
418 }
419
420 #[test]
421 fn test_missing_required_arg_fails() -> Result<()> {
422 let repl = Repl::new(())
423 .with_name("test")
424 .with_version("v0.1.0")
425 .with_description("Testing 1, 2, 3...")
426 .with_error_handler(test_error_handler)
427 .add_command(
428 Command::new("foo", foo)
429 .with_parameter(Parameter::new("bar").set_required(true)?)?
430 .with_parameter(Parameter::new("baz").set_required(true)?)?
431 .with_help("Do foo when you can"),
432 );
433 run_repl(
434 repl,
435 "foo bar\n",
436 Err(Error::MissingRequiredArgument("foo".into(), "baz".into())),
437 );
438
439 Ok(())
440 }
441
442 #[test]
443 fn test_unknown_command_fails() -> Result<()> {
444 let repl = Repl::new(())
445 .with_name("test")
446 .with_version("v0.1.0")
447 .with_description("Testing 1, 2, 3...")
448 .with_error_handler(test_error_handler)
449 .add_command(
450 Command::new("foo", foo)
451 .with_parameter(Parameter::new("bar").set_required(true)?)?
452 .with_parameter(Parameter::new("baz").set_required(true)?)?
453 .with_help("Do foo when you can"),
454 );
455 run_repl(
456 repl,
457 "bar baz\n",
458 Err(Error::UnknownCommand("bar".to_string())),
459 );
460
461 Ok(())
462 }
463
464 #[test]
465 fn test_no_required_after_optional() -> Result<()> {
466 assert_eq!(
467 Err(Error::IllegalRequiredError("bar".into())),
468 Command::<(), Error>::new("foo", foo)
469 .with_parameter(Parameter::new("baz").set_default("20")?)?
470 .with_parameter(Parameter::new("bar").set_required(true)?)
471 );
472
473 Ok(())
474 }
475
476 #[test]
477 fn test_required_cannot_be_defaulted() -> Result<()> {
478 assert_eq!(
479 Err(Error::IllegalDefaultError("bar".into())),
480 Parameter::new("bar").set_required(true)?.set_default("foo")
481 );
482
483 Ok(())
484 }
485
486 #[test]
487 fn test_string_with_spaces_for_argument() -> Result<()> {
488 let repl = Repl::new(())
489 .with_name("test")
490 .with_version("v0.1.0")
491 .with_description("Testing 1, 2, 3...")
492 .with_error_handler(test_error_handler)
493 .add_command(
494 Command::new("foo", foo)
495 .with_parameter(Parameter::new("bar").set_required(true)?)?
496 .with_parameter(Parameter::new("baz").set_required(true)?)?
497 .with_help("Do foo when you can"),
498 );
499 run_repl(repl, "foo \"baz test 123\" foo\n", Ok(()));
500
501 Ok(())
502 }
503
504 #[test]
505 fn test_string_with_spaces_for_argument_last() -> Result<()> {
506 let repl = Repl::new(())
507 .with_name("test")
508 .with_version("v0.1.0")
509 .with_description("Testing 1, 2, 3...")
510 .with_error_handler(test_error_handler)
511 .add_command(
512 Command::new("foo", foo)
513 .with_parameter(Parameter::new("bar").set_required(true)?)?
514 .with_parameter(Parameter::new("baz").set_required(true)?)?
515 .with_help("Do foo when you can"),
516 );
517 run_repl(repl, "foo foo \"baz test 123\"\n", Ok(()));
518
519 Ok(())
520 }
521}