1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
use clap::{Parser, Subcommand};
use crate::commands::*;
#[derive(Parser)]
#[command(name = "embedding-train")]
#[command(about = "A CLI tool for training word embeddings from scratch")]
pub struct Cli {
#[command(subcommand)]
pub command: Commands,
}
#[derive(Subcommand)]
pub enum Commands {
/// Train embeddings from text data
Train {
/// Input text file path
#[arg(short, long)]
input: String,
/// Output model file path
#[arg(short, long)]
output: String,
/// Output embeddings file path
#[arg(short, long)]
embeddings: String,
/// Config JSON file path (overrides other flags)
#[arg(short, long)]
config: Option<String>,
/// Embedding dimension
#[arg(short, long, default_value = "300")]
dim: usize,
/// Learning rate
#[arg(short, long, default_value = "0.025")]
learning_rate: f64,
/// Number of training epochs
#[arg(long, default_value = "10")]
epochs: usize,
/// Batch size
#[arg(short, long, default_value = "32")]
batch_size: usize,
/// Context window size
#[arg(short, long, default_value = "5")]
window: usize,
/// Number of negative samples
#[arg(short, long, default_value = "5")]
negative_samples: usize,
/// Model type: skipgram or cbow
#[arg(short, long, default_value = "skipgram")]
model_type: String,
/// Validation data ratio (0.0 = no validation, 0.2 = 20% validation)
#[arg(long, default_value = "0.0")]
validation_ratio: f64,
/// Output validation metrics file path
#[arg(long)]
validation_output: Option<String>,
/// Treat input as source code instead of natural language text
#[arg(long)]
code: bool,
/// Programming language for code preprocessing (rust, python, javascript, etc.)
#[arg(long, default_value = "rust")]
language: String,
},
/// Calculate similarity between two words
Similarity {
/// First word
word1: String,
/// Second word
word2: String,
/// Model file path
#[arg(short, long)]
model: String,
/// Vocabulary file path
#[arg(short, long)]
vocab: String,
},
/// Load and inspect a trained model
Info {
/// Model file path
#[arg(short, long)]
model: String,
/// Vocabulary file path
#[arg(short, long)]
vocab: String,
},
/// Export embeddings to different formats
Export {
/// Model file path
#[arg(short, long)]
model: String,
/// Vocabulary file path
#[arg(short, long)]
vocab: String,
/// Output file path
#[arg(short, long)]
output: String,
/// Export format: text, json, bin, or word2vec
#[arg(short, long, default_value = "text")]
format: String,
},
/// Validate a trained model against held-out data
Validate {
/// Model file path
#[arg(short, long)]
model: String,
/// Validation text file path
#[arg(short, long)]
input: String,
/// Output metrics file path
#[arg(short, long)]
output: Option<String>,
},
/// Train a model and enter interactive mode for queries
Interactive {
/// Input text file for training
#[arg(short, long)]
input: String,
/// Output model file path
#[arg(short, long, default_value = "model.json")]
output: String,
/// Embedding dimension
#[arg(short, long, default_value = "100")]
dim: usize,
/// Number of training epochs
#[arg(short, long, default_value = "10")]
epochs: usize,
/// Learning rate
#[arg(short, long, default_value = "0.025")]
learning_rate: f64,
/// Context window size
#[arg(short, long, default_value = "5")]
window: usize,
/// Number of negative samples
#[arg(short, long, default_value = "5")]
negative_samples: usize,
/// Model type: skipgram or cbow
#[arg(short, long, default_value = "skipgram")]
model: String,
},
}
pub fn run(cli: Cli) {
match cli.command {
Commands::Train {
input,
output,
embeddings,
config: config_path,
dim,
learning_rate,
epochs,
batch_size,
window,
negative_samples,
model_type,
validation_ratio,
validation_output,
code,
language,
} => handle_train(input, output, embeddings, config_path, dim, learning_rate, epochs, batch_size, window, negative_samples, model_type, validation_ratio, validation_output, code, language),
Commands::Similarity { word1, word2, model, vocab: _vocab } => {
handle_similarity(word1, word2, model);
}
Commands::Info { model, vocab: _vocab } => {
handle_info(model);
}
Commands::Export { model, vocab: _vocab, output, format } => {
handle_export(model, output, format);
}
Commands::Validate { model, input, output } => {
handle_validate(model, input, output);
}
Commands::Interactive { input, output, dim, epochs, learning_rate, window, negative_samples, model } => {
handle_interactive(input, output, dim, epochs, learning_rate, window, negative_samples, model);
}
}
}