Skip to main content

apr_cli/
dispatch_run.rs

1
2/// Dispatch `apr run` — extracted to reduce cognitive complexity of `execute_command`
3#[allow(clippy::too_many_arguments)]
4fn dispatch_run(
5    source: &str,
6    positional_prompt: Option<&String>,
7    input: Option<&Path>,
8    prompt: Option<&String>,
9    max_tokens: usize,
10    stream: bool,
11    language: Option<&str>,
12    task: Option<&str>,
13    format: &str,
14    no_gpu: bool,
15    offline: bool,
16    benchmark: bool,
17    verbose: bool,
18    trace: bool,
19    trace_payload: bool,
20    trace_steps: Option<&[String]>,
21    trace_verbose: bool,
22    trace_output: Option<PathBuf>,
23    trace_level: &str,
24    profile: bool,
25    chat: bool,
26) -> Result<(), CliError> {
27    let effective_trace = trace || trace_payload;
28    let effective_trace_level = if trace_payload {
29        "payload"
30    } else {
31        trace_level
32    };
33    let merged_prompt = prompt.or(positional_prompt).cloned();
34    let effective_prompt = if chat {
35        merged_prompt
36            .as_ref()
37            .map(|p| format!("<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"))
38    } else {
39        merged_prompt
40    };
41
42    run::run(
43        source,
44        input,
45        effective_prompt.as_deref(),
46        max_tokens,
47        stream,
48        language,
49        task,
50        format,
51        no_gpu,
52        offline,
53        benchmark,
54        verbose,
55        effective_trace,
56        trace_steps,
57        trace_verbose,
58        trace_output,
59        effective_trace_level,
60        profile,
61    )
62}
63
64/// Build server config and launch serve.
65#[allow(clippy::too_many_arguments)]
66fn dispatch_serve(
67    file: &Path,
68    port: u16,
69    host: &str,
70    no_cors: bool,
71    no_metrics: bool,
72    no_gpu: bool,
73    gpu: bool,
74    batch: bool,
75    trace: bool,
76    trace_level: &str,
77    profile: bool,
78    verbose: bool,
79) -> Result<(), CliError> {
80    let config = serve::ServerConfig {
81        port,
82        host: host.to_owned(),
83        cors: !no_cors,
84        metrics: !no_metrics,
85        no_gpu,
86        gpu,
87        batch,
88        trace,
89        trace_level: trace_level.to_owned(),
90        profile,
91        verbose,
92        ..Default::default()
93    };
94    serve::run(file, &config)
95}
96
97/// Parse hex offset and run hex inspection.
98#[allow(clippy::too_many_arguments)]
99fn dispatch_hex(
100    file: &Path,
101    tensor: Option<&str>,
102    limit: usize,
103    stats: bool,
104    list: bool,
105    json: bool,
106    header: bool,
107    blocks: bool,
108    distribution: bool,
109    contract: bool,
110    entropy: bool,
111    raw: bool,
112    offset: &str,
113    width: usize,
114    slice: Option<&str>,
115) -> Result<(), CliError> {
116    let parsed_offset = hex::parse_hex_offset(offset).map_err(CliError::InvalidFormat)?;
117    hex::run(&hex::HexOptions {
118        file: file.to_path_buf(),
119        tensor: tensor.map(String::from),
120        limit,
121        stats,
122        list,
123        json,
124        header,
125        blocks,
126        distribution,
127        contract,
128        entropy,
129        raw,
130        offset: parsed_offset,
131        width,
132        slice: slice.map(String::from),
133    })
134}
135
136/// Dispatch a rosetta subcommand.
137fn dispatch_rosetta(action: &RosettaCommands, global_json: bool) -> Result<(), CliError> {
138    match action {
139        RosettaCommands::Inspect {
140            file,
141            hexdump,
142            json,
143        } => rosetta::run_inspect(file, *hexdump, *json || global_json),
144        RosettaCommands::Convert {
145            source,
146            target,
147            quantize,
148            verify,
149            json,
150            tokenizer,
151        } => rosetta::run_convert(
152            source,
153            target,
154            quantize.as_deref(),
155            *verify,
156            *json || global_json,
157            tokenizer.as_deref(),
158        ),
159        RosettaCommands::Chain {
160            source,
161            formats,
162            work_dir,
163            json,
164        } => rosetta::run_chain(source, formats, work_dir, *json || global_json),
165        RosettaCommands::Verify {
166            source,
167            intermediate,
168            tolerance,
169            json,
170        } => rosetta::run_verify(source, intermediate, *tolerance, *json || global_json),
171        RosettaCommands::CompareInference {
172            model_a,
173            model_b,
174            prompt,
175            max_tokens,
176            temperature,
177            tolerance,
178            json,
179        } => rosetta::run_compare_inference(
180            model_a,
181            model_b,
182            prompt,
183            *max_tokens,
184            *temperature,
185            *tolerance,
186            *json || global_json,
187        ),
188        RosettaCommands::DiffTensors {
189            model_a,
190            model_b,
191            mismatches_only,
192            show_values,
193            filter,
194            json,
195        } => rosetta::run_diff_tensors(
196            model_a,
197            model_b,
198            *mismatches_only,
199            *show_values,
200            filter.as_deref(),
201            *json || global_json,
202        ),
203        RosettaCommands::Fingerprint {
204            model,
205            model_b,
206            output,
207            filter,
208            verbose,
209            json,
210        } => rosetta::run_fingerprint(
211            model,
212            model_b.as_ref().map(std::path::PathBuf::as_path),
213            output.as_ref().map(std::path::PathBuf::as_path),
214            filter.as_deref(),
215            *verbose,
216            *json || global_json,
217        ),
218        RosettaCommands::ValidateStats {
219            model,
220            reference,
221            fingerprints,
222            threshold,
223            strict,
224            json,
225        } => rosetta::run_validate_stats(
226            model,
227            reference.as_ref().map(std::path::PathBuf::as_path),
228            fingerprints.as_ref().map(std::path::PathBuf::as_path),
229            *threshold,
230            *strict,
231            *json || global_json,
232        ),
233    }
234}
235
236/// Execute the CLI command and return the result.
237pub fn execute_command(cli: &Cli) -> Result<(), CliError> {
238    // PMAT-237: Contract gate — refuse to operate on corrupt models
239    if !cli.skip_contract {
240        let paths = extract_model_paths(&cli.command);
241        validate_model_contract(&paths)?;
242    }
243
244    dispatch_core_command(cli).unwrap_or_else(|| dispatch_extended_command(cli))
245}