Skip to main content

apr_cli/
dispatch_run.rs

1
2/// Dispatch `apr run` — extracted to reduce cognitive complexity of `execute_command`
3#[allow(clippy::too_many_arguments)]
4fn dispatch_run(
5    source: &str,
6    positional_prompt: Option<&String>,
7    input: Option<&Path>,
8    prompt: Option<&String>,
9    max_tokens: usize,
10    stream: bool,
11    language: Option<&str>,
12    task: Option<&str>,
13    format: &str,
14    no_gpu: bool,
15    offline: bool,
16    benchmark: bool,
17    verbose: bool,
18    trace: bool,
19    trace_payload: bool,
20    trace_steps: Option<&[String]>,
21    trace_verbose: bool,
22    trace_output: Option<PathBuf>,
23    trace_level: &str,
24    profile: bool,
25    chat: bool,
26    // PMAT-496: Sampling parameters
27    temperature: f32,
28    top_k: usize,
29    top_p: Option<f32>,
30    seed: u64,
31    repeat_penalty: f32,
32    repeat_last_n: usize,
33    split_prompt: bool,
34) -> Result<(), CliError> {
35    let effective_trace = trace || trace_payload;
36    let effective_trace_level = if trace_payload {
37        "payload"
38    } else {
39        trace_level
40    };
41    let merged_prompt = prompt.or(positional_prompt).cloned();
42    let effective_prompt = if chat {
43        merged_prompt
44            .as_ref()
45            .map(|p| format!("<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"))
46    } else {
47        merged_prompt
48    };
49
50    run::run(
51        source,
52        input,
53        effective_prompt.as_deref(),
54        max_tokens,
55        stream,
56        language,
57        task,
58        format,
59        no_gpu,
60        offline,
61        benchmark,
62        verbose,
63        effective_trace,
64        trace_steps,
65        trace_verbose,
66        trace_output,
67        effective_trace_level,
68        profile,
69        temperature,
70        top_k,
71        top_p,
72        seed,
73        repeat_penalty,
74        repeat_last_n,
75        split_prompt,
76    )
77}
78
79/// Build server config and launch serve.
80#[allow(clippy::too_many_arguments)]
81fn dispatch_serve(
82    file: &Path,
83    port: u16,
84    host: &str,
85    no_cors: bool,
86    no_metrics: bool,
87    no_gpu: bool,
88    gpu: bool,
89    batch: bool,
90    trace: bool,
91    trace_level: &str,
92    profile: bool,
93    verbose: bool,
94    backend: &Option<String>,
95    otlp_endpoint: &Option<String>,
96    context_length: usize,
97    no_fp8_cache: bool,
98) -> Result<(), CliError> {
99    if let Some(ref endpoint) = otlp_endpoint {
100        eprintln!("OTLP tracing enabled → {endpoint}");
101        eprintln!("  Spans exported as W3C Trace Context (PMAT-485)");
102    }
103    let config = serve::ServerConfig {
104        port,
105        host: host.to_owned(),
106        cors: !no_cors,
107        metrics: !no_metrics,
108        no_gpu,
109        gpu,
110        batch,
111        trace,
112        trace_level: trace_level.to_owned(),
113        profile,
114        verbose,
115        backend: backend.clone(),
116        otlp_endpoint: otlp_endpoint.clone(),
117        context_length,
118        no_fp8_cache,
119        ..Default::default()
120    };
121    serve::run(file, &config)
122}
123
124/// Route `apr serve` subcommands: plan or run.
125fn dispatch_serve_command(command: &ServeCommands, cli: &Cli) -> Result<(), CliError> {
126    match command {
127        ServeCommands::Plan {
128            model,
129            gpu,
130            batch_size,
131            seq_len,
132            format,
133            quant,
134        } => commands::serve_plan::run_serve_plan(
135            model, *gpu, *batch_size, *seq_len, format, quant.as_deref(),
136        ),
137        ServeCommands::Run {
138            file,
139            port,
140            host,
141            no_cors,
142            no_metrics,
143            no_gpu,
144            gpu,
145            batch,
146            trace,
147            trace_level,
148            profile,
149            backend,
150            otlp_endpoint,
151            context_length,
152            no_fp8_cache,
153        } => crate::error::resolve_model_path(file).and_then(|r| {
154            dispatch_serve(
155                &r,
156                *port,
157                host,
158                *no_cors,
159                *no_metrics,
160                *no_gpu,
161                *gpu,
162                *batch,
163                *trace,
164                trace_level,
165                *profile,
166                cli.verbose,
167                backend,
168                otlp_endpoint,
169                *context_length,
170                *no_fp8_cache,
171            )
172        }),
173    }
174}
175
176/// Parse hex offset and run hex inspection.
177#[allow(clippy::too_many_arguments)]
178fn dispatch_hex(
179    file: &Path,
180    tensor: Option<&str>,
181    limit: usize,
182    stats: bool,
183    list: bool,
184    json: bool,
185    header: bool,
186    blocks: bool,
187    distribution: bool,
188    contract: bool,
189    entropy: bool,
190    raw: bool,
191    offset: &str,
192    width: usize,
193    slice: Option<&str>,
194) -> Result<(), CliError> {
195    let parsed_offset = hex::parse_hex_offset(offset).map_err(CliError::InvalidFormat)?;
196    hex::run(&hex::HexOptions {
197        file: file.to_path_buf(),
198        tensor: tensor.map(String::from),
199        limit,
200        stats,
201        list,
202        json,
203        header,
204        blocks,
205        distribution,
206        contract,
207        entropy,
208        raw,
209        offset: parsed_offset,
210        width,
211        slice: slice.map(String::from),
212    })
213}
214
215/// Dispatch a rosetta subcommand.
216fn dispatch_rosetta(action: &RosettaCommands, global_json: bool) -> Result<(), CliError> {
217    match action {
218        RosettaCommands::Inspect {
219            file,
220            hexdump,
221            json,
222        } => rosetta::run_inspect(file, *hexdump, *json || global_json),
223        RosettaCommands::Convert {
224            source,
225            target,
226            quantize,
227            verify,
228            json,
229            tokenizer,
230        } => rosetta::run_convert(
231            source,
232            target,
233            quantize.as_deref(),
234            *verify,
235            *json || global_json,
236            tokenizer.as_deref(),
237        ),
238        RosettaCommands::Chain {
239            source,
240            formats,
241            work_dir,
242            json,
243        } => rosetta::run_chain(source, formats, work_dir, *json || global_json),
244        RosettaCommands::Verify {
245            source,
246            intermediate,
247            tolerance,
248            json,
249        } => rosetta::run_verify(source, intermediate, *tolerance, *json || global_json),
250        RosettaCommands::CompareInference {
251            model_a,
252            model_b,
253            prompt,
254            max_tokens,
255            temperature,
256            tolerance,
257            json,
258        } => rosetta::run_compare_inference(
259            model_a,
260            model_b,
261            prompt,
262            *max_tokens,
263            *temperature,
264            *tolerance,
265            *json || global_json,
266        ),
267        RosettaCommands::DiffTensors {
268            model_a,
269            model_b,
270            mismatches_only,
271            show_values,
272            filter,
273            json,
274        } => rosetta::run_diff_tensors(
275            model_a,
276            model_b,
277            *mismatches_only,
278            *show_values,
279            filter.as_deref(),
280            *json || global_json,
281        ),
282        RosettaCommands::Fingerprint {
283            model,
284            model_b,
285            output,
286            filter,
287            verbose,
288            json,
289        } => rosetta::run_fingerprint(
290            model,
291            model_b.as_ref().map(std::path::PathBuf::as_path),
292            output.as_ref().map(std::path::PathBuf::as_path),
293            filter.as_deref(),
294            *verbose,
295            *json || global_json,
296        ),
297        RosettaCommands::ValidateStats {
298            model,
299            reference,
300            fingerprints,
301            threshold,
302            strict,
303            json,
304        } => rosetta::run_validate_stats(
305            model,
306            reference.as_ref().map(std::path::PathBuf::as_path),
307            fingerprints.as_ref().map(std::path::PathBuf::as_path),
308            *threshold,
309            *strict,
310            *json || global_json,
311        ),
312    }
313}
314
315/// Execute the CLI command and return the result.
316pub fn execute_command(cli: &Cli) -> Result<(), CliError> {
317    // PMAT-237: Contract gate — refuse to operate on corrupt models
318    if !cli.skip_contract {
319        let paths = extract_model_paths(&cli.command);
320        validate_model_contract(&paths)?;
321    }
322
323    dispatch_core_command(cli).unwrap_or_else(|| dispatch_extended_command(cli))
324}