1
2#[allow(clippy::too_many_arguments)]
4fn dispatch_run(
5 source: &str,
6 positional_prompt: Option<&String>,
7 input: Option<&Path>,
8 prompt: Option<&String>,
9 max_tokens: usize,
10 stream: bool,
11 language: Option<&str>,
12 task: Option<&str>,
13 format: &str,
14 no_gpu: bool,
15 offline: bool,
16 benchmark: bool,
17 verbose: bool,
18 trace: bool,
19 trace_payload: bool,
20 trace_steps: Option<&[String]>,
21 trace_verbose: bool,
22 trace_output: Option<PathBuf>,
23 trace_level: &str,
24 profile: bool,
25 chat: bool,
26) -> Result<(), CliError> {
27 let effective_trace = trace || trace_payload;
28 let effective_trace_level = if trace_payload {
29 "payload"
30 } else {
31 trace_level
32 };
33 let merged_prompt = prompt.or(positional_prompt).cloned();
34 let effective_prompt = if chat {
35 merged_prompt
36 .as_ref()
37 .map(|p| format!("<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"))
38 } else {
39 merged_prompt
40 };
41
42 run::run(
43 source,
44 input,
45 effective_prompt.as_deref(),
46 max_tokens,
47 stream,
48 language,
49 task,
50 format,
51 no_gpu,
52 offline,
53 benchmark,
54 verbose,
55 effective_trace,
56 trace_steps,
57 trace_verbose,
58 trace_output,
59 effective_trace_level,
60 profile,
61 )
62}
63
64#[allow(clippy::too_many_arguments)]
66fn dispatch_serve(
67 file: &Path,
68 port: u16,
69 host: &str,
70 no_cors: bool,
71 no_metrics: bool,
72 no_gpu: bool,
73 gpu: bool,
74 batch: bool,
75 trace: bool,
76 trace_level: &str,
77 profile: bool,
78 verbose: bool,
79) -> Result<(), CliError> {
80 let config = serve::ServerConfig {
81 port,
82 host: host.to_owned(),
83 cors: !no_cors,
84 metrics: !no_metrics,
85 no_gpu,
86 gpu,
87 batch,
88 trace,
89 trace_level: trace_level.to_owned(),
90 profile,
91 verbose,
92 ..Default::default()
93 };
94 serve::run(file, &config)
95}
96
97fn dispatch_serve_command(command: &ServeCommands, cli: &Cli) -> Result<(), CliError> {
99 match command {
100 ServeCommands::Plan {
101 model,
102 gpu,
103 batch_size,
104 seq_len,
105 format,
106 quant,
107 } => commands::serve_plan::run_serve_plan(
108 model, *gpu, *batch_size, *seq_len, format, quant.as_deref(),
109 ),
110 ServeCommands::Run {
111 file,
112 port,
113 host,
114 no_cors,
115 no_metrics,
116 no_gpu,
117 gpu,
118 batch,
119 trace,
120 trace_level,
121 profile,
122 } => dispatch_serve(
123 file,
124 *port,
125 host,
126 *no_cors,
127 *no_metrics,
128 *no_gpu,
129 *gpu,
130 *batch,
131 *trace,
132 trace_level,
133 *profile,
134 cli.verbose,
135 ),
136 }
137}
138
139#[allow(clippy::too_many_arguments)]
141fn dispatch_hex(
142 file: &Path,
143 tensor: Option<&str>,
144 limit: usize,
145 stats: bool,
146 list: bool,
147 json: bool,
148 header: bool,
149 blocks: bool,
150 distribution: bool,
151 contract: bool,
152 entropy: bool,
153 raw: bool,
154 offset: &str,
155 width: usize,
156 slice: Option<&str>,
157) -> Result<(), CliError> {
158 let parsed_offset = hex::parse_hex_offset(offset).map_err(CliError::InvalidFormat)?;
159 hex::run(&hex::HexOptions {
160 file: file.to_path_buf(),
161 tensor: tensor.map(String::from),
162 limit,
163 stats,
164 list,
165 json,
166 header,
167 blocks,
168 distribution,
169 contract,
170 entropy,
171 raw,
172 offset: parsed_offset,
173 width,
174 slice: slice.map(String::from),
175 })
176}
177
178fn dispatch_rosetta(action: &RosettaCommands, global_json: bool) -> Result<(), CliError> {
180 match action {
181 RosettaCommands::Inspect {
182 file,
183 hexdump,
184 json,
185 } => rosetta::run_inspect(file, *hexdump, *json || global_json),
186 RosettaCommands::Convert {
187 source,
188 target,
189 quantize,
190 verify,
191 json,
192 tokenizer,
193 } => rosetta::run_convert(
194 source,
195 target,
196 quantize.as_deref(),
197 *verify,
198 *json || global_json,
199 tokenizer.as_deref(),
200 ),
201 RosettaCommands::Chain {
202 source,
203 formats,
204 work_dir,
205 json,
206 } => rosetta::run_chain(source, formats, work_dir, *json || global_json),
207 RosettaCommands::Verify {
208 source,
209 intermediate,
210 tolerance,
211 json,
212 } => rosetta::run_verify(source, intermediate, *tolerance, *json || global_json),
213 RosettaCommands::CompareInference {
214 model_a,
215 model_b,
216 prompt,
217 max_tokens,
218 temperature,
219 tolerance,
220 json,
221 } => rosetta::run_compare_inference(
222 model_a,
223 model_b,
224 prompt,
225 *max_tokens,
226 *temperature,
227 *tolerance,
228 *json || global_json,
229 ),
230 RosettaCommands::DiffTensors {
231 model_a,
232 model_b,
233 mismatches_only,
234 show_values,
235 filter,
236 json,
237 } => rosetta::run_diff_tensors(
238 model_a,
239 model_b,
240 *mismatches_only,
241 *show_values,
242 filter.as_deref(),
243 *json || global_json,
244 ),
245 RosettaCommands::Fingerprint {
246 model,
247 model_b,
248 output,
249 filter,
250 verbose,
251 json,
252 } => rosetta::run_fingerprint(
253 model,
254 model_b.as_ref().map(std::path::PathBuf::as_path),
255 output.as_ref().map(std::path::PathBuf::as_path),
256 filter.as_deref(),
257 *verbose,
258 *json || global_json,
259 ),
260 RosettaCommands::ValidateStats {
261 model,
262 reference,
263 fingerprints,
264 threshold,
265 strict,
266 json,
267 } => rosetta::run_validate_stats(
268 model,
269 reference.as_ref().map(std::path::PathBuf::as_path),
270 fingerprints.as_ref().map(std::path::PathBuf::as_path),
271 *threshold,
272 *strict,
273 *json || global_json,
274 ),
275 }
276}
277
278pub fn execute_command(cli: &Cli) -> Result<(), CliError> {
280 if !cli.skip_contract {
282 let paths = extract_model_paths(&cli.command);
283 validate_model_contract(&paths)?;
284 }
285
286 dispatch_core_command(cli).unwrap_or_else(|| dispatch_extended_command(cli))
287}