Skip to main content

apr_cli/
dispatch_run.rs

1
2/// Dispatch `apr run` — extracted to reduce cognitive complexity of `execute_command`
3#[allow(clippy::too_many_arguments)]
4fn dispatch_run(
5    source: &str,
6    positional_prompt: Option<&String>,
7    input: Option<&Path>,
8    prompt: Option<&String>,
9    max_tokens: usize,
10    stream: bool,
11    language: Option<&str>,
12    task: Option<&str>,
13    format: &str,
14    no_gpu: bool,
15    offline: bool,
16    benchmark: bool,
17    verbose: bool,
18    trace: bool,
19    trace_payload: bool,
20    trace_steps: Option<&[String]>,
21    trace_verbose: bool,
22    trace_output: Option<PathBuf>,
23    trace_level: &str,
24    profile: bool,
25    chat: bool,
26    // PMAT-496: Sampling parameters
27    temperature: f32,
28    top_k: usize,
29    top_p: Option<f32>,
30    seed: u64,
31    repeat_penalty: f32,
32    repeat_last_n: usize,
33    split_prompt: bool,
34) -> Result<(), CliError> {
35    let effective_trace = trace || trace_payload;
36    let effective_trace_level = if trace_payload {
37        "payload"
38    } else {
39        trace_level
40    };
41    let merged_prompt = prompt.or(positional_prompt).cloned();
42    // GH-638: Auto-detect chat template from model name when --chat not explicit.
43    // Instruct/Chat models (Qwen-Instruct, LLaMA-Instruct, Mistral-Instruct, etc.)
44    // need ChatML wrapping for correct output. Without it, the model ignores the
45    // prompt structure and produces garbled responses.
46    let use_chat = chat || {
47        let src_lower = source.to_lowercase();
48        merged_prompt.is_some()
49            && (src_lower.contains("instruct") || src_lower.contains("chat"))
50    };
51    let effective_prompt = if use_chat {
52        merged_prompt
53            .as_ref()
54            .map(|p| format!("<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"))
55    } else {
56        merged_prompt
57    };
58
59    run::run(
60        source,
61        input,
62        effective_prompt.as_deref(),
63        max_tokens,
64        stream,
65        language,
66        task,
67        format,
68        no_gpu,
69        offline,
70        benchmark,
71        verbose,
72        effective_trace,
73        trace_steps,
74        trace_verbose,
75        trace_output,
76        effective_trace_level,
77        profile,
78        temperature,
79        top_k,
80        top_p,
81        seed,
82        repeat_penalty,
83        repeat_last_n,
84        split_prompt,
85    )
86}
87
88/// Build server config and launch serve.
89#[allow(clippy::too_many_arguments)]
90fn dispatch_serve(
91    file: &Path,
92    port: u16,
93    host: &str,
94    no_cors: bool,
95    no_metrics: bool,
96    no_gpu: bool,
97    gpu: bool,
98    batch: bool,
99    trace: bool,
100    trace_level: &str,
101    profile: bool,
102    verbose: bool,
103    backend: &Option<String>,
104    otlp_endpoint: &Option<String>,
105    context_length: usize,
106    no_fp8_cache: bool,
107) -> Result<(), CliError> {
108    if let Some(ref endpoint) = otlp_endpoint {
109        eprintln!("OTLP tracing enabled → {endpoint}");
110        eprintln!("  Spans exported as W3C Trace Context (PMAT-485)");
111    }
112    let config = serve::ServerConfig {
113        port,
114        host: host.to_owned(),
115        cors: !no_cors,
116        metrics: !no_metrics,
117        no_gpu,
118        gpu,
119        batch,
120        trace,
121        trace_level: trace_level.to_owned(),
122        profile,
123        verbose,
124        backend: backend.clone(),
125        otlp_endpoint: otlp_endpoint.clone(),
126        context_length,
127        no_fp8_cache,
128        ..Default::default()
129    };
130    serve::run(file, &config)
131}
132
133/// Route `apr serve` subcommands: plan or run.
134fn dispatch_serve_command(command: &ServeCommands, cli: &Cli) -> Result<(), CliError> {
135    match command {
136        ServeCommands::Plan {
137            model,
138            gpu,
139            batch_size,
140            seq_len,
141            format,
142            quant,
143        } => {
144            // GH-630: Thread cli.json through to serve plan
145            let effective_format = if cli.json { "json" } else { format.as_str() };
146            commands::serve_plan::run_serve_plan(
147                model, *gpu, *batch_size, *seq_len, effective_format, quant.as_deref(),
148            )
149        }
150        ServeCommands::Run {
151            file,
152            port,
153            host,
154            no_cors,
155            no_metrics,
156            no_gpu,
157            gpu,
158            batch,
159            trace,
160            trace_level,
161            profile,
162            backend,
163            otlp_endpoint,
164            context_length,
165            no_fp8_cache,
166        } => crate::error::resolve_model_path(file).and_then(|r| {
167            dispatch_serve(
168                &r,
169                *port,
170                host,
171                *no_cors,
172                *no_metrics,
173                *no_gpu,
174                *gpu,
175                *batch,
176                *trace,
177                trace_level,
178                *profile,
179                cli.verbose,
180                backend,
181                otlp_endpoint,
182                *context_length,
183                *no_fp8_cache,
184            )
185        }),
186    }
187}
188
189/// Parse hex offset and run hex inspection.
190#[allow(clippy::too_many_arguments)]
191fn dispatch_hex(
192    file: &Path,
193    tensor: Option<&str>,
194    limit: usize,
195    stats: bool,
196    list: bool,
197    json: bool,
198    header: bool,
199    blocks: bool,
200    distribution: bool,
201    contract: bool,
202    entropy: bool,
203    raw: bool,
204    offset: &str,
205    width: usize,
206    slice: Option<&str>,
207) -> Result<(), CliError> {
208    let parsed_offset = hex::parse_hex_offset(offset).map_err(CliError::InvalidFormat)?;
209    hex::run(&hex::HexOptions {
210        file: file.to_path_buf(),
211        tensor: tensor.map(String::from),
212        limit,
213        stats,
214        list,
215        json,
216        header,
217        blocks,
218        distribution,
219        contract,
220        entropy,
221        raw,
222        offset: parsed_offset,
223        width,
224        slice: slice.map(String::from),
225    })
226}
227
228/// Dispatch a rosetta subcommand.
229fn dispatch_rosetta(action: &RosettaCommands, global_json: bool) -> Result<(), CliError> {
230    match action {
231        RosettaCommands::Inspect {
232            file,
233            hexdump,
234            json,
235        } => rosetta::run_inspect(file, *hexdump, *json || global_json),
236        RosettaCommands::Convert {
237            source,
238            target,
239            quantize,
240            verify,
241            json,
242            tokenizer,
243        } => rosetta::run_convert(
244            source,
245            target,
246            quantize.as_deref(),
247            *verify,
248            *json || global_json,
249            tokenizer.as_deref(),
250        ),
251        RosettaCommands::Chain {
252            source,
253            formats,
254            work_dir,
255            json,
256        } => rosetta::run_chain(source, formats, work_dir, *json || global_json),
257        RosettaCommands::Verify {
258            source,
259            intermediate,
260            tolerance,
261            json,
262        } => rosetta::run_verify(source, intermediate, *tolerance, *json || global_json),
263        RosettaCommands::CompareInference {
264            model_a,
265            model_b,
266            prompt,
267            max_tokens,
268            temperature,
269            tolerance,
270            json,
271        } => rosetta::run_compare_inference(
272            model_a,
273            model_b,
274            prompt,
275            *max_tokens,
276            *temperature,
277            *tolerance,
278            *json || global_json,
279        ),
280        RosettaCommands::DiffTensors {
281            model_a,
282            model_b,
283            mismatches_only,
284            show_values,
285            filter,
286            json,
287        } => rosetta::run_diff_tensors(
288            model_a,
289            model_b,
290            *mismatches_only,
291            *show_values,
292            filter.as_deref(),
293            *json || global_json,
294        ),
295        RosettaCommands::Fingerprint {
296            model,
297            model_b,
298            output,
299            filter,
300            verbose,
301            json,
302        } => rosetta::run_fingerprint(
303            model,
304            model_b.as_ref().map(std::path::PathBuf::as_path),
305            output.as_ref().map(std::path::PathBuf::as_path),
306            filter.as_deref(),
307            *verbose,
308            *json || global_json,
309        ),
310        RosettaCommands::ValidateStats {
311            model,
312            reference,
313            fingerprints,
314            threshold,
315            strict,
316            json,
317        } => rosetta::run_validate_stats(
318            model,
319            reference.as_ref().map(std::path::PathBuf::as_path),
320            fingerprints.as_ref().map(std::path::PathBuf::as_path),
321            *threshold,
322            *strict,
323            *json || global_json,
324        ),
325    }
326}
327
328/// Execute the CLI command and return the result.
329pub fn execute_command(cli: &Cli) -> Result<(), CliError> {
330    contract_pre_contract_gate_enforcement!();
331    // PMAT-237: Contract gate — refuse to operate on corrupt models
332    if !cli.skip_contract {
333        let paths = extract_model_paths(&cli.command);
334        validate_model_contract(&paths)?;
335    }
336
337    dispatch_core_command(cli).unwrap_or_else(|| dispatch_extended_command(cli))
338}