use std::path::PathBuf;
use clap::{Args, Parser, Subcommand, ValueEnum};
use serde::{Deserialize, Serialize};
#[derive(Parser, Debug)]
#[command(name = "grammstein")]
#[command(author = "Dylon Edwards")]
#[command(version)]
#[command(about = "Language model training and experimentation CLI")]
#[command(
long_about = "A unified CLI for training, evaluating, and querying N-gram and hybrid \
language models that integrate with lling-llang WFST text correction."
)]
#[command(propagate_version = true)]
pub struct Cli {
#[command(subcommand)]
pub command: Commands,
#[arg(short, long, global = true)]
pub verbose: bool,
#[arg(short, long, global = true)]
pub quiet: bool,
}
#[derive(Subcommand, Debug)]
pub enum Commands {
#[command(subcommand)]
Train(TrainCommands),
#[command(subcommand)]
Eval(EvalCommands),
#[command(subcommand)]
Query(QueryCommands),
#[command(subcommand)]
Models(ModelsCommands),
#[command(subcommand)]
Corpus(CorpusCommands),
#[command(subcommand)]
Convert(ConvertCommands),
Repl(ReplArgs),
}
#[derive(Subcommand, Debug)]
pub enum TrainCommands {
Ngram(TrainNgramArgs),
Embedding(TrainEmbeddingArgs),
Hybrid(TrainHybridArgs),
#[cfg(feature = "google-books")]
ImportGoogleBooks(ImportGoogleBooksArgs),
}
#[derive(Args, Debug)]
pub struct TrainNgramArgs {
#[arg(value_name = "CORPUS")]
pub corpus: String,
#[arg(value_name = "OUTPUT")]
pub output: PathBuf,
#[arg(short, long, default_value = "5")]
pub order: usize,
#[arg(short, long, default_value = "2")]
pub min_count: u64,
#[arg(short, long, default_value = "10000")]
pub batch_size: usize,
#[arg(short, long, value_enum, default_value = "plaintext")]
pub format: CorpusFormat,
#[arg(long)]
pub lowercase: bool,
#[arg(short = 'L', long)]
pub language: Option<String>,
#[arg(long)]
pub detect_language: bool,
#[command(flatten)]
pub checkpoint: CheckpointArgs,
#[command(flatten)]
pub resources: ResourceArgs,
#[arg(long)]
pub auto_clean: bool,
}
#[derive(Args, Debug)]
pub struct TrainEmbeddingArgs {
#[arg(value_name = "CORPUS")]
pub corpus: String,
#[arg(value_name = "OUTPUT")]
pub output: PathBuf,
#[arg(short, long, default_value = "100")]
pub dim: usize,
#[arg(short, long, default_value = "5")]
pub window: usize,
#[arg(short, long, default_value = "5")]
pub min_count: u64,
#[arg(short, long, default_value = "5")]
pub neg_samples: usize,
#[arg(short, long, default_value = "5")]
pub epochs: u32,
#[arg(short, long, default_value = "0.025")]
pub learning_rate: f64,
#[arg(short, long, value_enum, default_value = "plaintext")]
pub format: CorpusFormat,
#[arg(short = 'L', long)]
pub language: Option<String>,
#[arg(long)]
pub detect_language: bool,
#[arg(long)]
pub vocab_first: bool,
#[command(flatten)]
pub checkpoint: CheckpointArgs,
#[command(flatten)]
pub resources: ResourceArgs,
#[arg(long)]
pub auto_clean: bool,
}
#[derive(Args, Debug)]
pub struct TrainHybridArgs {
#[arg(value_name = "NGRAM_MODEL")]
pub ngram_model: PathBuf,
#[arg(value_name = "EMBEDDING_MODEL")]
pub embedding_model: PathBuf,
#[arg(value_name = "OUTPUT")]
pub output: PathBuf,
#[arg(short, long, value_enum, default_value = "linear")]
pub strategy: InterpolationStrategy,
#[arg(short, long, default_value = "0.8")]
pub alpha: f64,
#[arg(long, default_value = "50000")]
pub cache_size: usize,
}
#[cfg(feature = "google-books")]
#[derive(Args, Debug)]
pub struct ImportGoogleBooksArgs {
#[arg(value_name = "OUTPUT")]
pub output: PathBuf,
#[arg(short = 'L', long, default_value = "en")]
pub language: String,
#[arg(long, default_value = "1")]
pub min_order: u8,
#[arg(long, default_value = "5")]
pub max_order: u8,
#[arg(short, long, default_value = "40")]
pub min_count: u64,
#[arg(long)]
pub min_year: Option<u16>,
#[arg(long)]
pub max_year: Option<u16>,
#[arg(long, value_name = "DIR")]
pub local_files: Option<PathBuf>,
#[arg(long, default_value = "4")]
pub parallel: usize,
#[arg(long)]
pub skip_pos_tags: bool,
#[arg(long)]
pub no_resume: bool,
#[arg(long)]
pub keep_shards: bool,
#[arg(long)]
pub cache_files: bool,
#[arg(long, value_enum, default_value = "enabled")]
pub sharding: ShardingModeArg,
#[arg(long)]
pub prefix: Option<String>,
#[arg(long, value_name = "ENTRIES")]
pub lockfree_flush_threshold: Option<u64>,
#[arg(long, default_value = "500000", value_name = "ENTRIES")]
pub tx_chunk_size: u64,
#[arg(long, default_value = "10", value_name = "GIB")]
pub overlay_budget_gib: u64,
#[command(flatten)]
pub resources: ResourceArgs,
}
#[derive(Subcommand, Debug)]
pub enum EvalCommands {
Perplexity(EvalPerplexityArgs),
Compare(EvalCompareArgs),
}
#[derive(Args, Debug)]
pub struct EvalPerplexityArgs {
#[arg(value_name = "MODEL")]
pub model: PathBuf,
#[arg(value_name = "TEST_CORPUS")]
pub test_corpus: String,
#[arg(short, long, value_enum, default_value = "plaintext")]
pub format: CorpusFormat,
#[arg(long)]
pub per_sentence: bool,
#[arg(short, long)]
pub output: Option<PathBuf>,
}
#[derive(Args, Debug)]
pub struct EvalCompareArgs {
#[arg(value_name = "TEST_CORPUS")]
pub test_corpus: String,
#[arg(value_name = "MODEL", required = true, num_args = 2..)]
pub models: Vec<PathBuf>,
#[arg(short, long, value_enum, default_value = "plaintext")]
pub format: CorpusFormat,
#[arg(short, long)]
pub output: Option<PathBuf>,
}
#[derive(Subcommand, Debug)]
pub enum QueryCommands {
Score(QueryScoreArgs),
Similar(QuerySimilarArgs),
Completions(QueryCompletionsArgs),
}
#[derive(Args, Debug)]
pub struct QueryScoreArgs {
#[arg(value_name = "MODEL")]
pub model: PathBuf,
#[arg(value_name = "TOKENS")]
pub tokens: Vec<String>,
#[arg(long)]
pub sentence: bool,
#[arg(long)]
pub continuation: bool,
#[arg(short, long)]
pub json: bool,
}
#[derive(Args, Debug)]
pub struct QuerySimilarArgs {
#[arg(value_name = "MODEL")]
pub model: PathBuf,
#[arg(value_name = "WORD")]
pub word: String,
#[arg(short = 'n', long, default_value = "10")]
pub top: usize,
#[arg(short, long)]
pub json: bool,
}
#[derive(Args, Debug)]
pub struct QueryCompletionsArgs {
#[arg(value_name = "MODEL")]
pub model: PathBuf,
#[arg(value_name = "CONTEXT", required = true)]
pub context: Vec<String>,
#[arg(short = 'n', long, default_value = "10")]
pub top: usize,
#[arg(short, long)]
pub json: bool,
}
#[derive(Subcommand, Debug)]
pub enum ModelsCommands {
List(ModelsListArgs),
Info(ModelsInfoArgs),
}
#[derive(Args, Debug)]
pub struct ModelsListArgs {
#[arg(short = 'L', long)]
pub language: Option<String>,
#[arg(long, value_enum, default_value = "table")]
pub format: OutputFormat,
#[arg(long, default_value = "./models")]
pub models_dir: PathBuf,
}
#[derive(Args, Debug)]
pub struct ModelsInfoArgs {
#[arg(value_name = "MODEL")]
pub model: PathBuf,
#[arg(short, long)]
pub json: bool,
}
#[derive(Subcommand, Debug)]
pub enum CorpusCommands {
Stats(CorpusStatsArgs),
Sample(CorpusSampleArgs),
Download(CorpusDownloadArgs),
Detect(CorpusDetectArgs),
List(CorpusListArgs),
Clean(CorpusCleanArgs),
}
#[derive(Args, Debug)]
pub struct CorpusStatsArgs {
#[arg(value_name = "CORPUS")]
pub corpus: String,
#[arg(short, long, value_enum, default_value = "plaintext")]
pub format: CorpusFormat,
}
#[derive(Args, Debug)]
pub struct CorpusSampleArgs {
#[arg(value_name = "CORPUS")]
pub corpus: String,
#[arg(short = 'n', long, default_value = "10")]
pub count: usize,
#[arg(short, long, value_enum, default_value = "plaintext")]
pub format: CorpusFormat,
#[arg(long)]
pub seed: Option<u64>,
}
#[derive(Args, Debug)]
pub struct CorpusDownloadArgs {
#[arg(value_name = "LANGUAGE")]
pub language: String,
#[arg(short, long, value_enum, default_value = "wikipedia")]
pub source: CorpusSource,
#[arg(short, long)]
pub output: Option<PathBuf>,
#[arg(long)]
pub sample: bool,
#[arg(long)]
pub resume: bool,
}
#[derive(Args, Debug)]
pub struct CorpusDetectArgs {
#[arg(value_name = "CORPUS")]
pub corpus: String,
#[arg(short, long, value_enum, default_value = "plaintext")]
pub format: CorpusFormat,
}
#[derive(Args, Debug)]
pub struct CorpusListArgs {
#[arg(short, long)]
pub verbose: bool,
#[arg(long, value_enum, default_value = "table")]
pub format: OutputFormat,
}
#[derive(Args, Debug)]
pub struct CorpusCleanArgs {
#[arg(long)]
pub dry_run: bool,
#[arg(short, long, value_enum)]
pub source: Option<CorpusSource>,
#[arg(long)]
pub older_than: Option<u32>,
#[arg(long, short)]
pub force: bool,
#[arg(long)]
pub all: bool,
}
#[derive(Subcommand, Debug)]
pub enum ConvertCommands {
ToStatic(ConvertToStaticArgs),
#[cfg(feature = "google-books")]
ToPathmap(ConvertToPathmapArgs),
#[cfg(feature = "google-books")]
ExtractDict(ExtractDictArgs),
Info(ConvertInfoArgs),
}
#[derive(Args, Debug)]
pub struct ConvertToStaticArgs {
#[arg(value_name = "INPUT")]
pub input: PathBuf,
#[arg(value_name = "OUTPUT")]
pub output: PathBuf,
}
#[derive(Args, Debug)]
pub struct ConvertInfoArgs {
#[arg(value_name = "MODEL")]
pub model: PathBuf,
}
#[cfg(feature = "google-books")]
#[derive(Args, Debug)]
pub struct ConvertToPathmapArgs {
#[arg(value_name = "INPUT")]
pub input: PathBuf,
#[arg(value_name = "OUTPUT")]
pub output: PathBuf,
#[arg(long)]
pub verify: bool,
}
#[cfg(feature = "google-books")]
#[derive(Args, Debug)]
pub struct ExtractDictArgs {
#[arg(value_name = "MODEL")]
pub model: PathBuf,
#[arg(value_name = "OUTPUT")]
pub output: PathBuf,
#[arg(short, long, default_value = "100")]
pub min_count: u64,
#[arg(long)]
pub unigrams_only: bool,
}
#[derive(Args, Debug)]
pub struct ReplArgs {
#[arg(value_name = "MODEL")]
pub model: Option<PathBuf>,
#[arg(long, default_value = "~/.grammstein_history")]
pub history: PathBuf,
}
#[derive(Args, Debug)]
pub struct CheckpointArgs {
#[arg(long)]
pub checkpoint: Option<PathBuf>,
#[arg(long)]
pub resume: Option<String>,
#[arg(long, default_value = "1000000")]
pub checkpoint_interval: u64,
#[arg(long, default_value = "5")]
pub keep_checkpoints: usize,
}
#[derive(Args, Debug)]
pub struct ResourceArgs {
#[arg(long)]
pub threads: Option<usize>,
#[arg(long)]
pub max_memory: Option<String>,
#[arg(long)]
pub no_progress: bool,
}
#[derive(ValueEnum, Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum CorpusFormat {
#[default]
Plaintext,
Wikipedia,
Gutenberg,
}
#[derive(ValueEnum, Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum InterpolationStrategy {
#[default]
Linear,
LogLinear,
NgramFallback,
Dynamic,
}
#[derive(ValueEnum, Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum OutputFormat {
#[default]
Table,
Json,
}
#[derive(ValueEnum, Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum CorpusSource {
#[default]
Wikipedia,
Gutenberg,
Oscar,
}
#[cfg(feature = "google-books")]
#[derive(ValueEnum, Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum ShardingModeArg {
#[default]
Enabled,
Disabled,
}