use crate::error::{CliError, Result};
use crate::output;
use aprender::format::{apr_import, Architecture, ImportOptions, Source, ValidationConfig};
use colored::Colorize;
use std::path::{Path, PathBuf};
#[provable_contracts_macros::contract("apr-cli-safety-v1", equation = "offline_guard")]
pub(crate) fn run(
source: &str,
output: Option<&Path>,
arch: Option<&str>,
quantize: Option<&str>,
strict: bool,
preserve_q4k: bool,
tokenizer: Option<&PathBuf>,
enforce_provenance: bool,
allow_no_config: bool,
) -> Result<()> {
contract_pre_format_conversion_roundtrip!();
reject_pytorch_format(source)?;
check_provenance(source, enforce_provenance)?;
let output_path = match output {
Some(p) => p.to_path_buf(),
None => derive_output_path(source)?,
};
let output = output_path.as_path();
if preserve_q4k {
eprintln!(
" {} --preserve-q4k is now the default for GGUF imports (PMAT-103). Flag has no additional effect.",
output::badge_warn("NOTE")
);
}
#[cfg(feature = "inference")]
if preserve_q4k {
let source_path = std::path::Path::new(source);
if source_path.exists()
&& source_path
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("gguf"))
{
return run_q4k_import(source_path, output);
}
}
#[cfg(not(feature = "inference"))]
if preserve_q4k {
eprintln!(
" {} --preserve-q4k requires the 'inference' feature. \
Falling back to standard import (Q4K will be dequantized to F32).",
output::badge_warn("WARN")
);
}
let parsed_source = Source::parse(source)
.map_err(|e| CliError::ValidationFailed(format!("Invalid source: {e}")))?;
output::header("APR Import Pipeline");
let source_desc = describe_source(&parsed_source);
println!(
"{}",
output::kv_table(&[
("Source", source_desc),
("Output", output.display().to_string()),
])
);
println!();
let architecture = parse_architecture(arch)?;
let options = ImportOptions {
architecture,
validation: if strict {
ValidationConfig::Strict
} else {
ValidationConfig::Basic
},
quantize: parse_quantize(quantize)?,
compress: None,
strict,
cache: true,
tokenizer_path: tokenizer.cloned(),
allow_no_config,
};
print_import_config(&options);
output::pipeline_stage("Importing", output::StageStatus::Running);
print_import_result(apr_import(source, output, options))
}
fn check_provenance(source: &str, enforce: bool) -> Result<()> {
if !enforce {
return Ok(());
}
let is_gguf = source.to_ascii_lowercase().ends_with(".gguf")
|| source.contains("-GGUF")
|| source.contains("-gguf");
if is_gguf {
return Err(CliError::ValidationFailed(
"F-GT-001: --enforce-provenance rejects GGUF imports. \
Use SafeTensors as the canonical source format for single-provenance testing. \
See Section 0 of qwen2.5-coder-showcase-demo.md for rationale."
.to_string(),
));
}
Ok(())
}
fn describe_source(source: &Source) -> String {
match source {
Source::HuggingFace { org, repo, file } => {
let base = format!("hf://{org}/{repo}");
file.as_ref()
.map_or(base.clone(), |f| format!("{base}/{f}"))
}
Source::Local(path) => path.display().to_string(),
Source::Url(url) => url.clone(),
}
}
fn parse_architecture(arch: Option<&str>) -> Result<Architecture> {
match arch {
Some("whisper") => Ok(Architecture::Whisper),
Some("llama") => Ok(Architecture::Llama),
Some("bert") => Ok(Architecture::Bert),
Some("qwen2") => Ok(Architecture::Qwen2),
Some("qwen3") => Ok(Architecture::Qwen3),
Some("qwen3_5" | "qwen3.5") => Ok(Architecture::Qwen3_5),
Some("gpt2" | "starcoder" | "bigcode") => Ok(Architecture::Gpt2),
Some("gpt-neox" | "gpt_neox" | "pythia") => Ok(Architecture::GptNeoX),
Some("opt" | "galactica") => Ok(Architecture::Opt),
Some("phi" | "phi3" | "phi4") => Ok(Architecture::Phi),
Some("gemma" | "gemma2" | "gemma3") => Ok(Architecture::Llama),
Some(unsupported @ ("falcon" | "mamba" | "t5")) => Err(CliError::ValidationFailed(format!(
"Architecture '{unsupported}' is not yet supported. Tracking: https://github.com/anthropics/aprender/issues"
))),
Some("auto") | None => Ok(Architecture::Auto),
Some(other) => Err(CliError::ValidationFailed(format!(
"Unknown architecture: {other}. Supported: whisper, llama, bert, qwen2, qwen3, qwen3_5, gpt2, starcoder, gpt-neox, opt, phi, gemma, falcon, mamba, t5, auto"
))),
}
}
fn print_import_config(options: &ImportOptions) {
let mut config_pairs: Vec<(&str, String)> = vec![
("Architecture", format!("{:?}", options.architecture)),
("Validation", format!("{:?}", options.validation)),
];
if let Some(q) = &options.quantize {
config_pairs.push(("Quantization", format!("{q:?}")));
}
println!("{}", output::kv_table(&config_pairs));
println!();
}
fn print_import_result(
result: std::result::Result<aprender::format::ValidationReport, aprender::error::AprenderError>,
) -> Result<()> {
match result {
Ok(report) => {
println!();
output::subheader("Validation Report");
let grade = report.grade();
println!(
"{}",
output::kv_table(&[
("Score", format!("{}/100", report.total_score)),
("Grade", output::grade_color(grade).to_string()),
])
);
println!();
if report.passed(95) {
println!(" {}", output::badge_pass("Import successful"));
} else {
println!(" {}", output::badge_warn("Import completed with warnings"));
}
Ok(())
}
Err(e) => {
println!();
println!(" {}", output::badge_fail("Import failed"));
Err(CliError::ValidationFailed(e.to_string()))
}
}
}
fn parse_quantize(
quantize: Option<&str>,
) -> Result<Option<aprender::format::converter::QuantizationType>> {
use aprender::format::converter::QuantizationType;
match quantize {
None => Ok(None),
Some("int8") => Ok(Some(QuantizationType::Int8)),
Some("int4") => Ok(Some(QuantizationType::Int4)),
Some("fp16") => Ok(Some(QuantizationType::Fp16)),
Some("q4k" | "q4_k") => Ok(Some(QuantizationType::Q4K)),
Some(other) => Err(CliError::ValidationFailed(format!(
"Unknown quantization: {other}. Supported: int8, int4, fp16, q4k"
))),
}
}
#[cfg(feature = "inference")]
fn run_q4k_import(source: &Path, output: &Path) -> Result<()> {
use humansize::{format_size, BINARY};
use realizar::convert::GgufToAprQ4KConverter;
output::header("APR Q4K Import (Fused Kernel)");
println!(
"{}",
output::kv_table(&[
("Source", format!("{} (GGUF)", source.display())),
("Output", format!("{} (APR with Q4K)", output.display())),
])
);
println!();
output::pipeline_stage("Preserving Q4K quantization", output::StageStatus::Running);
match GgufToAprQ4KConverter::convert(source, output) {
Ok(stats) => {
println!();
output::subheader("Q4K Import Report");
println!(
"{}",
output::kv_table(&[
("Total tensors", stats.tensor_count.to_string()),
("Q4K tensors", stats.q4k_tensor_count.to_string()),
("Total bytes", format_size(stats.total_bytes as u64, BINARY)),
("Architecture", stats.architecture.clone()),
("Layers", stats.num_layers.to_string()),
("Hidden size", stats.hidden_size.to_string()),
])
);
println!();
println!(" {}", output::badge_pass("Q4K import successful"));
println!(
"{}",
" Model ready for fused kernel inference (30+ tok/s CPU target)".dimmed()
);
Ok(())
}
Err(e) => {
println!();
println!(" {}", output::badge_fail("Q4K import failed"));
Err(CliError::ValidationFailed(e.to_string()))
}
}
}
fn derive_output_path(source: &str) -> Result<PathBuf> {
if let Ok(parsed) = Source::parse(source) {
match parsed {
Source::HuggingFace { org: _, repo, file } => {
let base_name = if let Some(f) = file {
Path::new(&f)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(&repo)
.to_string()
} else {
repo
};
Ok(PathBuf::from(format!("{base_name}.apr")))
}
Source::Local(path) => {
let stem = path.file_stem().and_then(|s| s.to_str()).ok_or_else(|| {
CliError::ValidationFailed("Cannot derive output name from source".into())
})?;
Ok(PathBuf::from(format!("{stem}.apr")))
}
Source::Url(url) => {
let filename = url.rsplit('/').next().unwrap_or("model");
let stem = Path::new(filename)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("model");
Ok(PathBuf::from(format!("{stem}.apr")))
}
}
} else {
let path = Path::new(source);
let stem = path.file_stem().and_then(|s| s.to_str()).ok_or_else(|| {
CliError::ValidationFailed(
"Cannot derive output name from source. Please specify --output.".into(),
)
})?;
Ok(PathBuf::from(format!("{stem}.apr")))
}
}
fn reject_pytorch_format(source: &str) -> Result<()> {
let path = Path::new(source);
let is_bin_extension = path
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("bin") || ext.eq_ignore_ascii_case("pt"));
if !is_bin_extension {
return Ok(());
}
if path.exists() {
if let Ok(magic) = read_magic_bytes(path) {
if is_pytorch_magic(&magic) {
return Err(pytorch_conversion_error(source));
}
}
return Ok(());
}
Err(pytorch_conversion_error(source))
}
fn read_magic_bytes(path: &Path) -> std::io::Result<[u8; 4]> {
use std::io::Read;
let mut f = std::fs::File::open(path)?;
let mut buf = [0u8; 4];
f.read_exact(&mut buf)?;
Ok(buf)
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_pytorch_magic(magic: &[u8; 4]) -> bool {
if magic[0..4] == *b"PK\x03\x04" {
return true;
}
if magic[0] == 0x80 && (2..=5).contains(&magic[1]) {
return true;
}
false
}
fn pytorch_conversion_error(source: &str) -> CliError {
CliError::ValidationFailed(format!(
"GH-267: '{source}' appears to be a PyTorch checkpoint (model.bin / .pt).\n\
\n\
PyTorch checkpoints use Python pickle format which cannot be parsed in pure Rust.\n\
Convert to SafeTensors first using one of these methods:\n\
\n\
Method 1 (recommended): HuggingFace CLI\n\
pip install huggingface-hub\n\
huggingface-cli convert {source} --to safetensors\n\
\n\
Method 2: Python one-liner\n\
pip install torch safetensors\n\
python -c \"import torch; from safetensors.torch import save_file; \\\n\
sd = torch.load('{source}', weights_only=True); \\\n\
save_file(sd, '{source}'.replace('.bin', '.safetensors'))\"\n\
\n\
Then import the resulting .safetensors file:\n\
apr import model.safetensors -o model.apr"
))
}
#[cfg(test)]
#[path = "import_tests.rs"]
mod tests;