use clap::{arg, ArgAction, Args, Parser, Subcommand};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
#[command(subcommand)]
pub command: Commands,
}
#[derive(Subcommand)]
pub enum Commands {
Chat {
#[arg(short, long)]
model: String,
#[arg(long)]
msg: String,
#[arg(short, long, action = ArgAction::Append)]
image: Vec<String>,
#[arg(short, long)]
framework: Option<String>,
#[arg(short, long)]
adapter: Option<String>,
},
Create {
#[command(subcommand)]
command: CreateCommands,
},
Get {
#[command(subcommand)]
command: GetCommands,
},
Stop {
#[command(subcommand)]
command: StopCommands,
},
Delete {
#[command(subcommand)]
command: DeleteCommands,
},
Train {
#[command(subcommand)]
command: TrainCommands,
},
Prepare {
#[arg(short, long)]
dataset_type: String,
#[arg(short, long)]
url: String,
#[arg(short, long)]
split_ratio: f64,
#[arg(short, long)]
base_path: Option<String>,
#[arg(short, long)]
image_path: Option<String>,
#[arg(long, default_value = "8")]
num_workers: usize,
},
Serve {
#[arg(long, default_value = "127.0.0.1")]
host: String,
#[arg(short, long, default_value_t = 8080)]
port: u16,
},
Report {
#[clap(long)]
watch_dir: String,
#[clap(long)]
save_dir: String,
#[clap(long, default_value_t = 2)]
debounce_secs: u64,
},
Send {
buffer_name: String,
#[arg(short, long)]
file: String,
#[arg(long)]
train: Option<bool>,
},
Trigger {
buffer_name: String,
},
Work {
#[command(subcommand)]
command: WorkCommands,
},
Login,
}
#[derive(Subcommand)]
pub enum CreateCommands {
Buffer {
#[command(flatten)]
command: ReplayBufferCommands,
},
}
#[derive(Args)]
pub struct VllmOptions {
#[arg(long = "model")]
pub model: String,
#[arg(long = "model-type")]
pub model_type: Option<String>,
#[arg(long, default_value_t = true)]
pub trust_remote_code: bool,
#[arg(long, default_value_t = 1)]
pub tensor_parallel_size: i32,
#[arg(long, default_value_t = 1)]
pub max_images_per_prompt: i32,
#[arg(long, default_value = "cuda")]
pub device: String,
#[arg(long, default_value_t = 8192)]
pub max_model_len: i32,
#[arg(long, default_value_t = 5)]
pub max_num_seqs: i32,
#[arg(long, default_value_t = 0.8)]
pub gpu_memory_utilization: f32,
#[arg(long, default_value_t = true)]
pub enforce_eager: bool,
#[arg(long, default_value_t = false)]
pub enable_adapter: bool,
}
#[derive(Args)]
pub struct EasyOcrOptions {
#[arg(long, default_value = "cuda")]
pub device: String,
#[arg(long, default_value_t = true)]
pub gpu: bool,
#[arg(long = "lang-list", value_delimiter = ',', default_value = "en")]
pub lang_list: Vec<String>,
#[arg(long, default_value_t = false)]
pub quantize: bool,
}
#[derive(Args)]
pub struct DoctrOptions {
#[arg(long = "det-arch", default_value = "fast_base")]
pub det_arch: String,
#[arg(long = "reco-arch", default_value = "crnn_vgg16_bn")]
pub reco_arch: String,
#[arg(long, default_value_t = true)]
pub pretrained: bool,
}
#[derive(Args)]
pub struct SentenceTfOptions {
#[arg(long = "model", default_value = "clip-ViT-B-32")]
pub model: String,
#[arg(long, default_value = "cuda")]
pub device: String,
}
#[derive(Args, Debug)]
pub struct LiteLLMOptions {
#[arg(long = "api-key", value_parser = parse_key_val::<String, String>)]
pub api_keys: Vec<(String, String)>,
}
fn parse_key_val<K, V>(s: &str) -> Result<(K, V), String>
where
K: std::str::FromStr,
V: std::str::FromStr,
<K as std::str::FromStr>::Err: std::fmt::Display,
<V as std::str::FromStr>::Err: std::fmt::Display,
{
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{}`", s))?;
Ok((
s[..pos]
.parse()
.map_err(|e| format!("invalid key: {}", e))?,
s[pos + 1..]
.parse()
.map_err(|e| format!("invalid value: {}", e))?,
))
}
#[derive(Args)]
pub struct CommonModelArgs {
#[arg(long)]
pub vram: Option<String>,
#[arg(long)]
pub dtype: Option<String>,
#[arg(long)]
pub max_pixels: Option<i32>,
}
#[derive(Subcommand)]
pub enum DeploymentCommands {
Vllm {
#[command(flatten)]
args: CommonModelArgs,
#[command(flatten)]
options: VllmOptions,
},
Easyocr {
#[command(flatten)]
args: CommonModelArgs,
#[command(flatten)]
options: EasyOcrOptions,
},
Doctr {
#[command(flatten)]
args: CommonModelArgs,
#[command(flatten)]
options: DoctrOptions,
},
STF {
#[command(flatten)]
args: CommonModelArgs,
#[command(flatten)]
options: SentenceTfOptions,
},
Litellm {
#[command(flatten)]
options: LiteLLMOptions,
},
}
#[derive(Subcommand)]
pub enum GetCommands {
Deployments {
id: Option<String>,
},
Trainings {
id: Option<String>,
},
Buffers {
name: Option<String>,
},
Models {},
Datasets {},
Adapters {},
}
#[derive(Subcommand)]
pub enum DeleteCommands {
Deployment {
id: String,
},
Buffer {
name: String,
},
Adapter {
name: String,
},
}
#[derive(Subcommand)]
pub enum StopCommands {
Training {
#[arg(short, long)]
id: String,
},
}
#[derive(Args)]
pub struct MSSwiftArgs {
#[arg(long, default_value = "Qwen/Qwen2-VL-7B-Instruct")]
pub model: String,
#[arg(long, default_value = "qwen2-vl-7b-instruct")]
pub model_type: String,
#[arg(long, default_value = "lora")]
pub train_type: String,
#[arg(long, default_value = "zero3")]
pub deepspeed: String,
#[arg(long, default_value = "bfloat16")]
pub torch_dtype: String,
#[arg(long, default_value_t = 8192)]
pub max_length: i32,
#[arg(long)]
pub dataset: String,
#[arg(long, default_value_t = 0.90)]
pub val_split_ratio: f32,
#[arg(long, default_value_t = 3)]
pub num_train_epochs: i32,
#[arg(long, default_value = "epoch")]
pub eval_strategy: String,
#[arg(long, default_value = "epoch")]
pub save_strategy: String,
#[arg(long, default_value_t = 3)]
pub save_total_limit: i32,
#[arg(long)]
pub lora_rank: Option<i32>,
#[arg(long)]
pub lora_alpha: Option<i32>,
#[arg(long, default_value_t = 28)]
pub size_factor: i32,
#[arg(long, default_value_t = 802816)]
pub max_pixels: i32,
#[arg(long, default_value_t = false)]
pub freeze_vit: bool,
#[arg(long)]
pub rlhf_type: Option<String>,
#[arg(long, default_value_t = 16)]
pub gradient_accumulation_steps_total: i32,
#[arg(long)]
pub learning_rate: Option<f32>,
#[arg(long)]
pub save_steps: Option<i32>,
}
#[derive(Args)]
pub struct MSSwiftCommands {
#[command(flatten)]
pub train_args: TrainArgs,
#[command(flatten)]
pub ms_swift_args: MSSwiftArgs,
}
#[derive(Args)]
pub struct TrainArgs {
#[arg(long)]
pub name: Option<String>,
#[arg(long)]
pub namespace: Option<String>,
#[arg(long)]
pub vram: Option<String>,
#[arg(long)]
pub accelerators: Option<Vec<String>>,
#[arg(long)]
pub cpu_request: Option<String>,
#[arg(long)]
pub trust_remote_code: Option<bool>,
#[arg(long)]
pub adapter: Option<String>,
#[arg(long)]
pub buffer: Option<String>,
#[arg(long)]
pub resume: Option<bool>,
#[arg(long)]
pub queue: Option<String>,
#[arg(long)]
pub platform: Option<String>,
}
#[derive(Subcommand)]
pub enum TrainCommands {
Swift(MSSwiftCommands),
}
#[derive(Args)]
pub struct MSSwiftBufferArgs {
#[arg(long, default_value = "Qwen/Qwen2-VL-7B-Instruct")]
pub model: String,
#[arg(long, default_value = "qwen2-vl-7b-instruct")]
pub model_type: String,
#[arg(long, default_value = "lora")]
pub train_type: String,
#[arg(long, default_value = "zero3")]
pub deepspeed: String,
#[arg(long, default_value = "bfloat16")]
pub torch_dtype: String,
#[arg(long, default_value_t = 8192)]
pub max_length: i32,
#[arg(long, default_value_t = 0.90)]
pub val_split_ratio: f32,
#[arg(long, default_value_t = 3)]
pub num_train_epochs: i32,
#[arg(long, default_value = "epoch")]
pub eval_strategy: String,
#[arg(long, default_value = "epoch")]
pub save_strategy: String,
#[arg(long, default_value_t = 3)]
pub save_total_limit: i32,
#[arg(long)]
pub lora_rank: Option<i32>,
#[arg(long)]
pub lora_alpha: Option<i32>,
#[arg(long, default_value_t = 28)]
pub size_factor: i32,
#[arg(long, default_value_t = 802816)]
pub max_pixels: i32,
#[arg(long, default_value_t = false)]
pub freeze_vit: bool,
#[arg(long)]
pub rlhf_type: Option<String>,
#[arg(long, default_value_t = 16)]
pub gradient_accumulation_steps_total: i32,
#[arg(long)]
pub learning_rate: Option<f32>,
#[arg(long)]
pub save_steps: Option<i32>,
}
#[derive(Args)]
pub struct ReplayBufferCommands {
#[arg(long)]
pub name: String,
#[arg(long)]
pub namespace: Option<String>,
#[arg(long)]
pub train_every: Option<i32>,
#[arg(long, default_value_t = 100)]
pub sample_n: i32,
#[arg(long, default_value = "Random")]
pub sample_strategy: String,
#[arg(long)]
pub image: String,
#[arg(long)]
pub command: Option<String>,
#[arg(long)]
pub accelerators: Option<Vec<String>>,
#[arg(long, default_value_t = 1)]
pub num_epochs: i32,
}
#[derive(Subcommand)]
pub enum WorkCommands {
ReportTrainings {
#[arg(long, default_value = "default")]
k8s_namespace: String,
},
}