use anyhow::Result;
use clap::{CommandFactory, Parser, Subcommand};
use colored::*;
use std::path::PathBuf;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod commands;
mod config;
mod utils;
use commands::*;
#[derive(Parser)]
#[command(
name = "torsh",
author,
version,
about = "Command-line tools for the ToRSh deep learning framework",
long_about = r#"
ToRSh CLI provides a comprehensive suite of command-line tools for machine learning workflows:
• Model operations: convert, optimize, quantize, and inspect models
• Training utilities: start training, resume from checkpoints, distributed training
• Dataset tools: download, preprocess, validate, and analyze datasets
• Benchmarking: performance testing and profiling
• Hub integration: download and upload models to the ToRSh Hub
• Development tools: code generation, testing, and debugging
Examples:
torsh model convert --input model.pth --output model.torsh --format torsh
torsh train --config config.yaml --resume checkpoint.pth
torsh benchmark --model resnet50 --batch-size 32
torsh hub download microsoft/resnet50 --cache-dir ./models
"#
)]
struct Cli {
#[arg(short, long, global = true)]
verbose: bool,
#[arg(short, long, global = true)]
quiet: bool,
#[arg(short, long, global = true)]
config: Option<PathBuf>,
#[arg(long, global = true, default_value = "table")]
output_format: String,
#[arg(long, global = true)]
no_color: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
#[command(subcommand)]
Model(ModelCommands),
#[command(subcommand)]
Train(TrainCommands),
#[command(subcommand)]
Dataset(DatasetCommands),
#[command(subcommand)]
Benchmark(BenchmarkCommands),
#[command(subcommand)]
Hub(HubCommands),
#[command(subcommand)]
Dev(DevCommands),
Info(InfoCommand),
Completions {
#[arg(value_enum)]
shell: clap_complete::Shell,
},
Init(InitCommand),
Update(UpdateCommand),
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
init_logging(cli.verbose, cli.quiet)?;
if cli.no_color || !console::Term::stdout().features().colors_supported() {
colored::control::set_override(false);
}
let config = config::load_config(cli.config.as_deref()).await?;
info!("ToRSh CLI starting");
match cli.command {
Commands::Model(cmd) => model::execute(cmd, &config, &cli.output_format).await,
Commands::Train(cmd) => train::execute(cmd, &config, &cli.output_format).await,
Commands::Dataset(cmd) => dataset::execute(cmd, &config, &cli.output_format).await,
Commands::Benchmark(cmd) => benchmark::execute(cmd, &config, &cli.output_format).await,
Commands::Hub(cmd) => hub::execute(cmd, &config, &cli.output_format).await,
Commands::Dev(cmd) => dev::execute(cmd, &config, &cli.output_format).await,
Commands::Info(cmd) => info::execute(cmd, &config, &cli.output_format).await,
Commands::Completions { shell } => {
generate_completions(shell);
Ok(())
}
Commands::Init(cmd) => init::execute(cmd, &config, &cli.output_format).await,
Commands::Update(cmd) => update::execute(cmd, &config, &cli.output_format).await,
}
}
fn init_logging(verbose: bool, quiet: bool) -> Result<()> {
let level = if quiet {
tracing::Level::ERROR
} else if verbose {
tracing::Level::DEBUG
} else {
tracing::Level::INFO
};
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| match level {
tracing::Level::ERROR => tracing_subscriber::EnvFilter::new("error"),
tracing::Level::WARN => tracing_subscriber::EnvFilter::new("warn"),
tracing::Level::INFO => tracing_subscriber::EnvFilter::new("info"),
tracing::Level::DEBUG => tracing_subscriber::EnvFilter::new("debug"),
tracing::Level::TRACE => tracing_subscriber::EnvFilter::new("trace"),
}),
)
.with(tracing_subscriber::fmt::layer().with_target(false))
.init();
Ok(())
}
fn generate_completions(shell: clap_complete::Shell) {
let mut app = Cli::command();
let name = app.get_name().to_string();
clap_complete::generate(shell, &mut app, name, &mut std::io::stdout());
}
pub fn display_banner() {
let banner = r#"
______ _____ _____ _
|__ _| | __ \ / ____| |
| | ___ _ _| |__) | (___ | |__
| |/ _ \| '__| _ / \___ \| '_ \
_| | (_) | | | | \ \ ____) | | | |
|_| \___/|_| |_| \_\_____/|_| |_|
"#;
println!("{}", banner.bright_cyan().bold());
println!(
"{}",
"ToRSh CLI - Advanced Deep Learning Framework Tools"
.bright_white()
.bold()
);
println!(
"{}",
format!("Version: {} | Build: {}", env!("CARGO_PKG_VERSION"), "dev").bright_black()
);
println!();
}
#[cfg(test)]
mod tests {
#[test]
fn test_version_check() {
assert!(!env!("CARGO_PKG_VERSION").is_empty());
}
}