burn_lm_cli/
cli.rs

1use clap::Arg;
2
3use crate::commands;
4
5pub fn run(backend: &str, dtype: &str) -> anyhow::Result<()> {
6    // Define CLI
7    let cli = clap::command!()
8        // Those values are going to be parsed by the burn-lm crate.
9        //
10        // This ensures that this CLI is correctly compiled with the proper backend and dtype.
11        // The goal is also that all values are parsed using a unified CLI with seamless
12        // recompilation if the backend or dtype changes.
13        .args([
14            Arg::new("backend")
15                .short('b')
16                .long("backend")
17                .help("The backend selected."),
18            Arg::new("dtype")
19                .short('d')
20                .long("dtype")
21                .help("The element type used."),
22        ])
23        .subcommand(commands::backends::create())
24        .subcommand(commands::chat::create())
25        .subcommand(commands::delete::create())
26        .subcommand(commands::download::create())
27        .subcommand(commands::models::create())
28        .subcommand(commands::new::create())
29        .subcommand(commands::run::create())
30        .subcommand(commands::server::create())
31        .subcommand(commands::shell::create())
32        .subcommand(commands::web::create());
33
34    // Execute commands
35    let matches = cli.clone().get_matches();
36
37    if let Some(b) = matches.get_one::<String>("backend") {
38        assert_eq!(b, backend);
39    }
40
41    if let Some(d) = matches.get_one::<String>("dtype") {
42        assert_eq!(d, dtype);
43    }
44
45    if matches.subcommand_matches("backends").is_some() {
46        commands::backends::handle().map(|_| ())
47    } else if let Some(args) = matches.subcommand_matches("chat") {
48        commands::chat::handle(args, backend, dtype).map(|_| ())
49    } else if let Some(args) = matches.subcommand_matches("delete") {
50        commands::delete::handle(args).map(|_| ())
51    } else if let Some(args) = matches.subcommand_matches("download") {
52        commands::download::handle(args).map(|_| ())
53    } else if matches.subcommand_matches("models").is_some() {
54        commands::models::handle(false).map(|_| ())
55    } else if let Some(args) = matches.subcommand_matches("new") {
56        commands::new::handle(args).map(|_| ())
57    } else if let Some(args) = matches.subcommand_matches("run") {
58        commands::run::handle(args).map(|_| ())
59    } else if let Some(args) = matches.subcommand_matches("server") {
60        commands::server::handle(args, backend, dtype).map(|_| ())
61    } else if matches.subcommand_matches("shell").is_some() {
62        commands::shell::handle(backend, dtype).map(|_| ())
63    } else if let Some(args) = matches.subcommand_matches("web") {
64        commands::web::handle(args, backend, dtype).map(|_| ())
65    } else {
66        // default action is to start a shell
67        commands::shell::handle(backend, dtype).map(|_| ())
68    }
69}