Skip to main content

apr_cli/
dispatch_run.rs

1
2/// Dispatch `apr run` — extracted to reduce cognitive complexity of `execute_command`
3#[allow(clippy::too_many_arguments)]
4fn dispatch_run(
5    source: &str,
6    positional_prompt: Option<&String>,
7    input: Option<&Path>,
8    prompt: Option<&String>,
9    max_tokens: usize,
10    stream: bool,
11    language: Option<&str>,
12    task: Option<&str>,
13    format: &str,
14    no_gpu: bool,
15    offline: bool,
16    benchmark: bool,
17    verbose: bool,
18    trace: bool,
19    trace_payload: bool,
20    trace_steps: Option<&[String]>,
21    trace_verbose: bool,
22    trace_output: Option<PathBuf>,
23    trace_level: &str,
24    profile: bool,
25    chat: bool,
26) -> Result<(), CliError> {
27    let effective_trace = trace || trace_payload;
28    let effective_trace_level = if trace_payload {
29        "payload"
30    } else {
31        trace_level
32    };
33    let merged_prompt = prompt.or(positional_prompt).cloned();
34    let effective_prompt = if chat {
35        merged_prompt
36            .as_ref()
37            .map(|p| format!("<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"))
38    } else {
39        merged_prompt
40    };
41
42    run::run(
43        source,
44        input,
45        effective_prompt.as_deref(),
46        max_tokens,
47        stream,
48        language,
49        task,
50        format,
51        no_gpu,
52        offline,
53        benchmark,
54        verbose,
55        effective_trace,
56        trace_steps,
57        trace_verbose,
58        trace_output,
59        effective_trace_level,
60        profile,
61    )
62}
63
64/// Build server config and launch serve.
65#[allow(clippy::too_many_arguments)]
66fn dispatch_serve(
67    file: &Path,
68    port: u16,
69    host: &str,
70    no_cors: bool,
71    no_metrics: bool,
72    no_gpu: bool,
73    gpu: bool,
74    batch: bool,
75    trace: bool,
76    trace_level: &str,
77    profile: bool,
78    verbose: bool,
79) -> Result<(), CliError> {
80    let config = serve::ServerConfig {
81        port,
82        host: host.to_owned(),
83        cors: !no_cors,
84        metrics: !no_metrics,
85        no_gpu,
86        gpu,
87        batch,
88        trace,
89        trace_level: trace_level.to_owned(),
90        profile,
91        verbose,
92        ..Default::default()
93    };
94    serve::run(file, &config)
95}
96
97/// Route `apr serve` subcommands: plan or run.
98fn dispatch_serve_command(command: &ServeCommands, cli: &Cli) -> Result<(), CliError> {
99    match command {
100        ServeCommands::Plan {
101            model,
102            gpu,
103            batch_size,
104            seq_len,
105            format,
106            quant,
107        } => commands::serve_plan::run_serve_plan(
108            model, *gpu, *batch_size, *seq_len, format, quant.as_deref(),
109        ),
110        ServeCommands::Run {
111            file,
112            port,
113            host,
114            no_cors,
115            no_metrics,
116            no_gpu,
117            gpu,
118            batch,
119            trace,
120            trace_level,
121            profile,
122        } => dispatch_serve(
123            file,
124            *port,
125            host,
126            *no_cors,
127            *no_metrics,
128            *no_gpu,
129            *gpu,
130            *batch,
131            *trace,
132            trace_level,
133            *profile,
134            cli.verbose,
135        ),
136    }
137}
138
139/// Parse hex offset and run hex inspection.
140#[allow(clippy::too_many_arguments)]
141fn dispatch_hex(
142    file: &Path,
143    tensor: Option<&str>,
144    limit: usize,
145    stats: bool,
146    list: bool,
147    json: bool,
148    header: bool,
149    blocks: bool,
150    distribution: bool,
151    contract: bool,
152    entropy: bool,
153    raw: bool,
154    offset: &str,
155    width: usize,
156    slice: Option<&str>,
157) -> Result<(), CliError> {
158    let parsed_offset = hex::parse_hex_offset(offset).map_err(CliError::InvalidFormat)?;
159    hex::run(&hex::HexOptions {
160        file: file.to_path_buf(),
161        tensor: tensor.map(String::from),
162        limit,
163        stats,
164        list,
165        json,
166        header,
167        blocks,
168        distribution,
169        contract,
170        entropy,
171        raw,
172        offset: parsed_offset,
173        width,
174        slice: slice.map(String::from),
175    })
176}
177
178/// Dispatch a rosetta subcommand.
179fn dispatch_rosetta(action: &RosettaCommands, global_json: bool) -> Result<(), CliError> {
180    match action {
181        RosettaCommands::Inspect {
182            file,
183            hexdump,
184            json,
185        } => rosetta::run_inspect(file, *hexdump, *json || global_json),
186        RosettaCommands::Convert {
187            source,
188            target,
189            quantize,
190            verify,
191            json,
192            tokenizer,
193        } => rosetta::run_convert(
194            source,
195            target,
196            quantize.as_deref(),
197            *verify,
198            *json || global_json,
199            tokenizer.as_deref(),
200        ),
201        RosettaCommands::Chain {
202            source,
203            formats,
204            work_dir,
205            json,
206        } => rosetta::run_chain(source, formats, work_dir, *json || global_json),
207        RosettaCommands::Verify {
208            source,
209            intermediate,
210            tolerance,
211            json,
212        } => rosetta::run_verify(source, intermediate, *tolerance, *json || global_json),
213        RosettaCommands::CompareInference {
214            model_a,
215            model_b,
216            prompt,
217            max_tokens,
218            temperature,
219            tolerance,
220            json,
221        } => rosetta::run_compare_inference(
222            model_a,
223            model_b,
224            prompt,
225            *max_tokens,
226            *temperature,
227            *tolerance,
228            *json || global_json,
229        ),
230        RosettaCommands::DiffTensors {
231            model_a,
232            model_b,
233            mismatches_only,
234            show_values,
235            filter,
236            json,
237        } => rosetta::run_diff_tensors(
238            model_a,
239            model_b,
240            *mismatches_only,
241            *show_values,
242            filter.as_deref(),
243            *json || global_json,
244        ),
245        RosettaCommands::Fingerprint {
246            model,
247            model_b,
248            output,
249            filter,
250            verbose,
251            json,
252        } => rosetta::run_fingerprint(
253            model,
254            model_b.as_ref().map(std::path::PathBuf::as_path),
255            output.as_ref().map(std::path::PathBuf::as_path),
256            filter.as_deref(),
257            *verbose,
258            *json || global_json,
259        ),
260        RosettaCommands::ValidateStats {
261            model,
262            reference,
263            fingerprints,
264            threshold,
265            strict,
266            json,
267        } => rosetta::run_validate_stats(
268            model,
269            reference.as_ref().map(std::path::PathBuf::as_path),
270            fingerprints.as_ref().map(std::path::PathBuf::as_path),
271            *threshold,
272            *strict,
273            *json || global_json,
274        ),
275    }
276}
277
278/// Execute the CLI command and return the result.
279pub fn execute_command(cli: &Cli) -> Result<(), CliError> {
280    // PMAT-237: Contract gate — refuse to operate on corrupt models
281    if !cli.skip_contract {
282        let paths = extract_model_paths(&cli.command);
283        validate_model_contract(&paths)?;
284    }
285
286    dispatch_core_command(cli).unwrap_or_else(|| dispatch_extended_command(cli))
287}