1use crate::error::*;
2use crate::help::{DefaultHelpViewer, HelpContext, HelpEntry, HelpViewer};
3use crate::{Command, Parameter};
4use rustyline::completion;
5use rustyline_derive::{Helper, Highlighter, Hinter, Validator};
6use std::boxed::Box;
7use std::collections::BTreeMap;
8use std::fmt::Display;
9use std::path::{Path, PathBuf};
10
11type ErrorHandler<Context, E> = fn(error: E, repl: &Repl<Context, E>) -> Result<()>;
12
13fn default_error_handler<Context, E: std::fmt::Display>(
14 error: E,
15 _repl: &Repl<Context, E>,
16) -> Result<()> {
17 eprintln!("{}", error);
18 Ok(())
19}
20
21pub trait Prompt {
22 fn prompt(&self) -> String;
23 fn complete(&self, command: &str, args: &[&str], incomplete: &str) -> Vec<String>;
24}
25
26impl Prompt for () {
27 fn prompt(&self) -> String { "> ".into() }
28 fn complete(&self, _command: &str, _args: &[&str], _incomplete: &str) -> Vec<String> { vec![] }
29}
30
31pub struct Repl<Context, E: std::fmt::Display> {
33 name: String,
34 version: String,
35 description: String,
36 commands: BTreeMap<String, Command<Context, E>>,
37 aliases: BTreeMap<String, String>,
38 context: Option<Context>,
39 help_context: Option<HelpContext>,
40 help_viewer: Box<dyn HelpViewer>,
41 error_handler: ErrorHandler<Context, E>,
42 use_completion: bool,
43 history: Option<PathBuf>,
44}
45
46impl<Context, E> Repl<Context, E>
47where
48 E: Display + From<Error>,
49 Context: Prompt,
50{
51 pub fn new(context: Context) -> Self {
53 let name = String::new();
54
55 Self {
56 name: name.clone(),
57 version: String::new(),
58 description: String::new(),
59 commands: BTreeMap::new(),
60 aliases: BTreeMap::new(),
61 context: Some(context),
62 help_context: None,
63 help_viewer: Box::new(DefaultHelpViewer::new()),
64 error_handler: default_error_handler,
65 use_completion: false,
66 history: None,
67 }
68 }
69
70 pub fn with_name(mut self, name: &str) -> Self {
72 self.name = name.to_string();
73 self
74 }
75
76 pub fn with_version(mut self, version: &str) -> Self {
78 self.version = version.to_string();
79
80 self
81 }
82
83 pub fn with_description(mut self, description: &str) -> Self {
85 self.description = description.to_string();
86
87 self
88 }
89
90 pub fn with_help_viewer<V: 'static + HelpViewer>(mut self, help_viewer: V) -> Self {
92 self.help_viewer = Box::new(help_viewer);
93
94 self
95 }
96
97 pub fn with_error_handler(mut self, handler: ErrorHandler<Context, E>) -> Self {
100 self.error_handler = handler;
101
102 self
103 }
104
105 pub fn use_completion(mut self, value: bool) -> Self {
107 self.use_completion = value;
108
109 self
110 }
111
112 pub fn with_history_in(mut self, p: &std::path::Path) -> Self {
113 self.history = Some(p.to_path_buf());
114
115 self
116 }
117
118 pub fn with_history(self) -> Self {
119 self.with_history_in(Path::new(".polkajam-repl-history"))
120 }
121
122 pub fn add_command(mut self, command: Command<Context, E>) -> Self {
124 for i in command.aliases.iter() {
125 self.aliases.insert(i.clone(), command.name.clone());
126 }
127 self.commands.insert(command.name.clone(), command);
128 self
129 }
130
131 fn validate_arguments(
132 &self,
133 command: &str,
134 parameters: &[Parameter],
135 args: &[&str],
136 ) -> Result<BTreeMap<String, String>> {
137 if args.len() > parameters.len() {
138 return Err(Error::TooManyArguments(command.into(), parameters.len()));
139 }
140
141 let mut validated = BTreeMap::new();
142 for (index, parameter) in parameters.iter().enumerate() {
143 if index < args.len() {
144 validated.insert(parameter.name.clone(), args[index].to_string());
145 } else if parameter.required {
146 return Err(Error::MissingRequiredArgument(
147 command.into(),
148 parameter.name.clone(),
149 ));
150 } else if parameter.default.is_some() {
151 validated.insert(
152 parameter.name.clone(),
153 parameter.default.clone().unwrap().to_string(),
154 );
155 }
156 }
157 Ok(validated)
158 }
159
160 fn handle_command(&mut self, command: &str, args: &[&str]) -> core::result::Result<(), E> {
161 let canon = self.aliases.get(command).cloned().unwrap_or(command.into());
162 match self.commands.get(&canon) {
163 Some(definition) => {
164 let validated = self.validate_arguments(command, &definition.parameters, args)?;
165 match (definition.callback)(validated, self.context.as_mut().unwrap()) {
166 Ok(Some(value)) => println!("{}", value),
167 Ok(None) => (),
168 Err(error) => return Err(error),
169 };
170 }
171 None => {
172 if command == "help" {
173 self.show_help(args)?;
174 } else {
175 return Err(Error::UnknownCommand(command.to_string()).into());
176 }
177 }
178 }
179
180 Ok(())
181 }
182
183 fn show_help(&self, args: &[&str]) -> Result<()> {
184 if args.is_empty() {
185 self.help_viewer
186 .help_general(self.help_context.as_ref().unwrap())?;
187 } else {
188 let entry_opt = self
189 .help_context
190 .as_ref()
191 .unwrap()
192 .help_entries
193 .iter()
194 .find(|entry| entry.command == args[0]);
195 match entry_opt {
196 Some(entry) => {
197 self.help_viewer.help_command(entry)?;
198 }
199 None => eprintln!("Help not found for command '{}'", args[0]),
200 };
201 }
202 Ok(())
203 }
204
205 fn process_line(&mut self, line: &str) -> core::result::Result<(), E> {
206 let (command, args) = split_line(line.trim());
207 if !command.is_empty() {
208 self.handle_command(&command, &args.iter().map(|s| s.as_ref()).collect::<Vec<_>>())?;
209 }
210 Ok(())
211 }
212
213 fn construct_help_context(&mut self) {
214 let mut help_entries = self
215 .commands
216 .values()
217 .map(|definition| {
218 HelpEntry::new(
219 &definition.name,
220 &definition.aliases,
221 &definition.parameters,
222 &definition.help_summary,
223 )
224 })
225 .collect::<Vec<HelpEntry>>();
226 help_entries.sort_by_key(|d| d.command.clone());
227 self.help_context = Some(HelpContext::new(
228 &self.name,
229 &self.version,
230 &self.description,
231 help_entries,
232 ));
233 }
234
235 pub fn run(&mut self) -> Result<()> {
236 self.construct_help_context();
237 let mut editor: rustyline::Editor<Helper<Context>> = rustyline::Editor::new();
238 editor.set_helper(Some(Helper {
239 canon: self.commands.keys()
240 .map(|x| (x, x))
241 .chain(self.aliases.iter())
242 .map(|(x, y)| (x.clone(), y.clone()))
243 .collect(),
244 context: None,
245 }));
246 if let Some(history) = self.history.as_ref() {
247 let _ = editor.load_history(history);
248 }
249 println!("Welcome to {} {}", self.name, self.version);
250 let mut eof = false;
251 while !eof {
252 self.handle_line(&mut editor, &mut eof)?;
253 }
254
255 Ok(())
256 }
257
258fn handle_line(
266 &mut self,
267 editor: &mut rustyline::Editor<Helper<Context>>,
268 eof: &mut bool,
269 ) -> Result<()> {
270 let prompt = format!("{}", self.context.as_ref().unwrap().prompt());
271 editor.helper_mut().unwrap().context = self.context.take();
272 let r = editor.readline(&prompt);
273 self.context = editor.helper_mut().unwrap().context.take();
274 match r {
275 Ok(line) => {
276 editor.add_history_entry(line.clone());
277 if let Some(history) = self.history.as_ref() {
278 let _ = editor.append_history(history);
279 }
280 if let Err(error) = self.process_line(&line) {
281 (self.error_handler)(error, self)?;
282 }
283 *eof = false;
284 Ok(())
285 }
286 Err(rustyline::error::ReadlineError::Eof) => {
287 *eof = true;
288 Ok(())
289 }
290 Err(error) => {
291 eprintln!("Error reading line: {}", error);
292 *eof = false;
293 Ok(())
294 }
295 }
296 }
297}
298
299fn split_line(line: &str) -> (String, Vec<String>) {
300 let trimmed = line.trim();
301 if trimmed.is_empty() {
302 Default::default()
303 } else {
304 let r = regex::Regex::new(r#"("[^"\n]+"|[\S]+)"#).unwrap();
305 let mut args = r
306 .captures_iter(trimmed)
307 .map(|a| a[0].to_string().replace('\"', ""))
308 .collect::<Vec<String>>();
309 let command = args.remove(0);
310 if line.ends_with(' ') {
311 args.push(Default::default());
312 }
313 (command, args)
314 }
315}
316
317fn quote_if_needed(s: String) -> String {
318 if s.contains(' ') {
319 format!("\"{s}\"")
320 } else {
321 s
322 }
323}
324
325#[derive(Clone, Helper, Hinter, Highlighter, Validator)]
329struct Helper<Context: Prompt> {
330 canon: BTreeMap<String, String>,
331 context: Option<Context>,
332}
333
334impl<Context: Prompt> completion::Completer for Helper<Context> {
335 type Candidate = String;
336
337 fn complete(
338 &self,
339 line: &str,
340 _pos: usize,
341 _ctx: &rustyline::Context<'_>,
342 ) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
343 let (command, mut args) = split_line(line);
346if let Some(last_arg) = args.pop() {
348 let Some(context) = self.context.as_ref() else { return Ok((0, Vec::new())) };
349 let first_args = args.iter().map(|s| s.as_ref()).collect::<Vec<_>>();
350 let command = self.canon.get(&command).cloned().unwrap_or(command);
351 let last_cands = context.complete(&command, &first_args, &last_arg);
352let args_blob = args.iter().cloned().map(quote_if_needed).collect::<Vec<_>>().join(" ");
354 let prefix = format!("{command}{}{args_blob}", if args_blob.is_empty() { "" } else { " " });
355 let completions = last_cands.into_iter().map(|c| format!("{prefix} {c}")).collect();
356 Ok((0, completions))
357 } else {
358 let ret: Vec<Self::Candidate> = self
359 .canon
360 .keys()
361 .filter(|cmd| cmd.contains(&command))
362 .map(|s| s.to_string())
363 .collect();
364 Ok((0, ret))
365 }
366 }
367}
368
369#[cfg(all(test, unix))]
370mod tests {
371 use crate::error::*;
372 use crate::repl::{Helper, Repl};
373 use crate::initialize_repl;
374 use crate::{Command, Parameter};
375 use clap::{crate_description, crate_name, crate_version};
376 use nix::sys::wait::{waitpid, WaitStatus};
377 use nix::unistd::{close, dup2, fork, pipe, ForkResult};
378 use std::collections::BTreeMap;
379 use std::fs::File;
380 use std::io::Write;
381 use std::os::unix::io::FromRawFd;
382 use super::Prompt;
383
384 fn test_error_handler<Context>(error: Error, _repl: &Repl<Context, Error>) -> Result<()> {
385 Err(error)
386 }
387
388 fn foo<T>(args: BTreeMap<String, String>, _context: &mut T) -> Result<Option<String>> {
389 Ok(Some(format!("foo {:?}", args)))
390 }
391
392 fn run_repl<Context: Prompt>(mut repl: Repl<Context, Error>, input: &str, expected: Result<()>) {
393 let (rdr, wrtr) = pipe().unwrap();
394 unsafe {
395 match fork() {
396 Ok(ForkResult::Parent { child, .. }) => {
397 let mut f = File::from_raw_fd(wrtr);
399 write!(f, "{}", input).unwrap();
400 if let WaitStatus::Exited(_, exit_code) = waitpid(child, None).unwrap() {
401 assert!(exit_code == 0);
402 };
403 }
404 Ok(ForkResult::Child) => {
405 std::panic::set_hook(Box::new(|panic_info| {
406 println!("Caught panic: {:?}", panic_info);
407 if let Some(location) = panic_info.location() {
408 println!(
409 "panic occurred in file '{}' at line {}",
410 location.file(),
411 location.line(),
412 );
413 } else {
414 println!("panic occurred but can't get location information...");
415 }
416 }));
417
418 dup2(rdr, 0).unwrap();
419 close(rdr).unwrap();
420 let mut editor: rustyline::Editor<Helper<Context>> = rustyline::Editor::new();
421 let mut eof = false;
422 let result = repl.handle_line(&mut editor, &mut eof);
423 let _ = std::panic::take_hook();
424 if expected == result {
425 std::process::exit(0);
426 } else {
427 eprintln!("Expected {:?}, got {:?}", expected, result);
428 std::process::exit(1);
429 }
430 }
431 Err(_) => println!("Fork failed"),
432 }
433 }
434 }
435
436 #[test]
437 fn test_initialize_sets_crate_values() -> Result<()> {
438 let repl: Repl<(), Error> = initialize_repl!(());
439
440 assert_eq!(crate_name!(), repl.name);
441 assert_eq!(crate_version!(), repl.version);
442 assert_eq!(crate_description!(), repl.description);
443
444 Ok(())
445 }
446
447 #[test]
448 fn test_empty_line_does_nothing() -> Result<()> {
449 let repl = Repl::new(())
450 .with_name("test")
451 .with_version("v0.1.0")
452 .with_description("Testing 1, 2, 3...")
453 .with_error_handler(test_error_handler)
454 .add_command(
455 Command::new("foo", foo)
456 .with_parameter(Parameter::new("bar").set_required(true)?)?
457 .with_parameter(Parameter::new("baz").set_required(true)?)?
458 .with_help("Do foo when you can"),
459 );
460 run_repl(repl, "\n", Ok(()));
461
462 Ok(())
463 }
464
465 #[test]
466 fn test_missing_required_arg_fails() -> Result<()> {
467 let repl = Repl::new(())
468 .with_name("test")
469 .with_version("v0.1.0")
470 .with_description("Testing 1, 2, 3...")
471 .with_error_handler(test_error_handler)
472 .add_command(
473 Command::new("foo", foo)
474 .with_parameter(Parameter::new("bar").set_required(true)?)?
475 .with_parameter(Parameter::new("baz").set_required(true)?)?
476 .with_help("Do foo when you can"),
477 );
478 run_repl(
479 repl,
480 "foo bar\n",
481 Err(Error::MissingRequiredArgument("foo".into(), "baz".into())),
482 );
483
484 Ok(())
485 }
486
487 #[test]
488 fn test_unknown_command_fails() -> Result<()> {
489 let repl = Repl::new(())
490 .with_name("test")
491 .with_version("v0.1.0")
492 .with_description("Testing 1, 2, 3...")
493 .with_error_handler(test_error_handler)
494 .add_command(
495 Command::new("foo", foo)
496 .with_parameter(Parameter::new("bar").set_required(true)?)?
497 .with_parameter(Parameter::new("baz").set_required(true)?)?
498 .with_help("Do foo when you can"),
499 );
500 run_repl(
501 repl,
502 "bar baz\n",
503 Err(Error::UnknownCommand("bar".to_string())),
504 );
505
506 Ok(())
507 }
508
509 #[test]
510 fn test_no_required_after_optional() -> Result<()> {
511 assert_eq!(
512 Err(Error::IllegalRequiredError("bar".into())),
513 Command::<(), Error>::new("foo", foo)
514 .with_parameter(Parameter::new("baz").set_default("20")?)?
515 .with_parameter(Parameter::new("bar").set_required(true)?)
516 );
517
518 Ok(())
519 }
520
521 #[test]
522 fn test_required_cannot_be_defaulted() -> Result<()> {
523 assert_eq!(
524 Err(Error::IllegalDefaultError("bar".into())),
525 Parameter::new("bar").set_required(true)?.set_default("foo")
526 );
527
528 Ok(())
529 }
530
531 #[test]
532 fn test_string_with_spaces_for_argument() -> Result<()> {
533 let repl = Repl::new(())
534 .with_name("test")
535 .with_version("v0.1.0")
536 .with_description("Testing 1, 2, 3...")
537 .with_error_handler(test_error_handler)
538 .add_command(
539 Command::new("foo", foo)
540 .with_parameter(Parameter::new("bar").set_required(true)?)?
541 .with_parameter(Parameter::new("baz").set_required(true)?)?
542 .with_help("Do foo when you can"),
543 );
544 run_repl(repl, "foo \"baz test 123\" foo\n", Ok(()));
545
546 Ok(())
547 }
548
549 #[test]
550 fn test_string_with_spaces_for_argument_last() -> Result<()> {
551 let repl = Repl::new(())
552 .with_name("test")
553 .with_version("v0.1.0")
554 .with_description("Testing 1, 2, 3...")
555 .with_error_handler(test_error_handler)
556 .add_command(
557 Command::new("foo", foo)
558 .with_parameter(Parameter::new("bar").set_required(true)?)?
559 .with_parameter(Parameter::new("baz").set_required(true)?)?
560 .with_help("Do foo when you can"),
561 );
562 run_repl(repl, "foo foo \"baz test 123\"\n", Ok(()));
563
564 Ok(())
565 }
566}