1
2#[allow(clippy::too_many_arguments)]
4fn dispatch_run(
5 source: &str,
6 positional_prompt: Option<&String>,
7 input: Option<&Path>,
8 prompt: Option<&String>,
9 max_tokens: usize,
10 stream: bool,
11 language: Option<&str>,
12 task: Option<&str>,
13 format: &str,
14 no_gpu: bool,
15 offline: bool,
16 benchmark: bool,
17 verbose: bool,
18 trace: bool,
19 trace_payload: bool,
20 trace_steps: Option<&[String]>,
21 trace_verbose: bool,
22 trace_output: Option<PathBuf>,
23 trace_level: &str,
24 profile: bool,
25 chat: bool,
26) -> Result<(), CliError> {
27 let effective_trace = trace || trace_payload;
28 let effective_trace_level = if trace_payload {
29 "payload"
30 } else {
31 trace_level
32 };
33 let merged_prompt = prompt.or(positional_prompt).cloned();
34 let effective_prompt = if chat {
35 merged_prompt
36 .as_ref()
37 .map(|p| format!("<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"))
38 } else {
39 merged_prompt
40 };
41
42 run::run(
43 source,
44 input,
45 effective_prompt.as_deref(),
46 max_tokens,
47 stream,
48 language,
49 task,
50 format,
51 no_gpu,
52 offline,
53 benchmark,
54 verbose,
55 effective_trace,
56 trace_steps,
57 trace_verbose,
58 trace_output,
59 effective_trace_level,
60 profile,
61 )
62}
63
64#[allow(clippy::too_many_arguments)]
66fn dispatch_serve(
67 file: &Path,
68 port: u16,
69 host: &str,
70 no_cors: bool,
71 no_metrics: bool,
72 no_gpu: bool,
73 gpu: bool,
74 batch: bool,
75 trace: bool,
76 trace_level: &str,
77 profile: bool,
78 verbose: bool,
79) -> Result<(), CliError> {
80 let config = serve::ServerConfig {
81 port,
82 host: host.to_owned(),
83 cors: !no_cors,
84 metrics: !no_metrics,
85 no_gpu,
86 gpu,
87 batch,
88 trace,
89 trace_level: trace_level.to_owned(),
90 profile,
91 verbose,
92 ..Default::default()
93 };
94 serve::run(file, &config)
95}
96
97#[allow(clippy::too_many_arguments)]
99fn dispatch_hex(
100 file: &Path,
101 tensor: Option<&str>,
102 limit: usize,
103 stats: bool,
104 list: bool,
105 json: bool,
106 header: bool,
107 blocks: bool,
108 distribution: bool,
109 contract: bool,
110 entropy: bool,
111 raw: bool,
112 offset: &str,
113 width: usize,
114 slice: Option<&str>,
115) -> Result<(), CliError> {
116 let parsed_offset = hex::parse_hex_offset(offset).map_err(CliError::InvalidFormat)?;
117 hex::run(&hex::HexOptions {
118 file: file.to_path_buf(),
119 tensor: tensor.map(String::from),
120 limit,
121 stats,
122 list,
123 json,
124 header,
125 blocks,
126 distribution,
127 contract,
128 entropy,
129 raw,
130 offset: parsed_offset,
131 width,
132 slice: slice.map(String::from),
133 })
134}
135
136fn dispatch_rosetta(action: &RosettaCommands, global_json: bool) -> Result<(), CliError> {
138 match action {
139 RosettaCommands::Inspect {
140 file,
141 hexdump,
142 json,
143 } => rosetta::run_inspect(file, *hexdump, *json || global_json),
144 RosettaCommands::Convert {
145 source,
146 target,
147 quantize,
148 verify,
149 json,
150 tokenizer,
151 } => rosetta::run_convert(
152 source,
153 target,
154 quantize.as_deref(),
155 *verify,
156 *json || global_json,
157 tokenizer.as_deref(),
158 ),
159 RosettaCommands::Chain {
160 source,
161 formats,
162 work_dir,
163 json,
164 } => rosetta::run_chain(source, formats, work_dir, *json || global_json),
165 RosettaCommands::Verify {
166 source,
167 intermediate,
168 tolerance,
169 json,
170 } => rosetta::run_verify(source, intermediate, *tolerance, *json || global_json),
171 RosettaCommands::CompareInference {
172 model_a,
173 model_b,
174 prompt,
175 max_tokens,
176 temperature,
177 tolerance,
178 json,
179 } => rosetta::run_compare_inference(
180 model_a,
181 model_b,
182 prompt,
183 *max_tokens,
184 *temperature,
185 *tolerance,
186 *json || global_json,
187 ),
188 RosettaCommands::DiffTensors {
189 model_a,
190 model_b,
191 mismatches_only,
192 show_values,
193 filter,
194 json,
195 } => rosetta::run_diff_tensors(
196 model_a,
197 model_b,
198 *mismatches_only,
199 *show_values,
200 filter.as_deref(),
201 *json || global_json,
202 ),
203 RosettaCommands::Fingerprint {
204 model,
205 model_b,
206 output,
207 filter,
208 verbose,
209 json,
210 } => rosetta::run_fingerprint(
211 model,
212 model_b.as_ref().map(std::path::PathBuf::as_path),
213 output.as_ref().map(std::path::PathBuf::as_path),
214 filter.as_deref(),
215 *verbose,
216 *json || global_json,
217 ),
218 RosettaCommands::ValidateStats {
219 model,
220 reference,
221 fingerprints,
222 threshold,
223 strict,
224 json,
225 } => rosetta::run_validate_stats(
226 model,
227 reference.as_ref().map(std::path::PathBuf::as_path),
228 fingerprints.as_ref().map(std::path::PathBuf::as_path),
229 *threshold,
230 *strict,
231 *json || global_json,
232 ),
233 }
234}
235
236pub fn execute_command(cli: &Cli) -> Result<(), CliError> {
238 if !cli.skip_contract {
240 let paths = extract_model_paths(&cli.command);
241 validate_model_contract(&paths)?;
242 }
243
244 dispatch_core_command(cli).unwrap_or_else(|| dispatch_extended_command(cli))
245}