use crate::error::CliError;
use aprender::format::model_card::ModelCard;
#[cfg(feature = "hf-hub")]
use aprender::hf_hub::{HfHubClient, PushOptions, UploadProgress};
use std::fs;
use std::path::Path;
#[cfg(feature = "hf-hub")]
use std::sync::Arc;
fn validate_publish_inputs(
directory: &Path,
repo_id: &str,
) -> Result<Vec<std::path::PathBuf>, CliError> {
if !repo_id.contains('/') || repo_id.split('/').count() != 2 {
return Err(CliError::ValidationFailed(format!(
"Invalid repo ID '{}'. Expected format: org/repo-name",
repo_id
)));
}
if !directory.exists() {
return Err(CliError::FileNotFound(directory.to_path_buf()));
}
let files = find_model_files(directory)?;
if files.is_empty() {
return Err(CliError::ValidationFailed(format!(
"No model files found in {}. Expected .apr, .safetensors, or .gguf files.",
directory.display()
)));
}
Ok(files)
}
#[cfg(feature = "hf-hub")]
fn format_upload_route(size_bytes: u64) -> &'static str {
if aprender::hf_hub::xet::should_use_xet(size_bytes) {
if cfg!(feature = "xet") {
"[→ Xet CAS (>5 GiB)]"
} else {
"[✗ would FAIL: rebuild with --features xet]"
}
} else {
"[→ HTTP LFS (≤5 GiB)]"
}
}
#[cfg(not(feature = "hf-hub"))]
fn format_upload_route(_size_bytes: u64) -> &'static str {
"[? hf-hub feature off]"
}
#[cfg(feature = "hf-hub")]
#[allow(clippy::too_many_arguments)]
fn upload_to_hub_extended(
client: &HfHubClient,
repo_id: &str,
files: &[std::path::PathBuf],
readme_content: &str,
manifest: Option<&Path>,
extra_files: &[std::path::PathBuf],
commit_msg: &str,
verbose: bool,
) -> Result<(), CliError> {
let progress_callback: Arc<dyn Fn(UploadProgress) + Send + Sync> = Arc::new(move |progress| {
if verbose {
println!(
" [{}/{}] {} ({:.1}%)",
progress.files_completed + 1,
progress.total_files,
progress.current_file,
progress.percentage()
);
}
});
let upload_one = |src: &Path, path_in_repo: &str| -> Result<(), CliError> {
if verbose {
let size = fs::metadata(src).map(|m| m.len()).unwrap_or(0);
println!(
"Uploading {} ({:.1} MB)...",
path_in_repo,
size as f64 / 1_000_000.0
);
}
let file_data = fs::read(src)?;
let options = PushOptions::new()
.with_filename(path_in_repo.to_string())
.with_commit_message(commit_msg)
.with_progress_callback(progress_callback.clone())
.with_create_repo(true);
client
.push_to_hub(repo_id, &file_data, options)
.map_err(|e| CliError::NetworkError(format!("Upload failed: {e}")))?;
Ok(())
};
for file in files {
let filename = file
.file_name()
.ok_or_else(|| CliError::ValidationFailed("Invalid file path".into()))?
.to_string_lossy()
.to_string();
upload_one(file, &filename)?;
}
for ef in extra_files {
let filename = ef
.file_name()
.ok_or_else(|| CliError::ValidationFailed("Invalid extra-file path".into()))?
.to_string_lossy()
.to_string();
upload_one(ef, &filename)?;
}
if let Some(manifest_path) = manifest {
upload_one(manifest_path, "manifest.yaml")?;
} else {
if verbose {
println!("Uploading README.md...");
}
let readme_options = PushOptions::new()
.with_filename("README.md")
.with_commit_message(commit_msg)
.with_create_repo(false);
client
.push_to_hub(repo_id, readme_content.as_bytes(), readme_options)
.map_err(|e| CliError::NetworkError(format!("README upload failed: {e}")))?;
}
Ok(())
}
#[cfg(feature = "hf-hub")]
#[allow(dead_code)]
fn upload_to_hub(
client: &HfHubClient,
repo_id: &str,
files: &[std::path::PathBuf],
readme_content: &str,
commit_msg: &str,
verbose: bool,
) -> Result<(), CliError> {
let progress_callback: Arc<dyn Fn(UploadProgress) + Send + Sync> = Arc::new(move |progress| {
if verbose {
println!(
" [{}/{}] {} ({:.1}%)",
progress.files_completed + 1,
progress.total_files,
progress.current_file,
progress.percentage()
);
}
});
for file in files {
let filename = file
.file_name()
.ok_or_else(|| CliError::ValidationFailed("Invalid file path".into()))?
.to_string_lossy()
.to_string();
if verbose {
let size = fs::metadata(file).map(|m| m.len()).unwrap_or(0);
println!(
"Uploading {} ({:.1} MB)...",
filename,
size as f64 / 1_000_000.0
);
}
let file_data = fs::read(file)?;
let options = PushOptions::new()
.with_filename(filename)
.with_commit_message(commit_msg)
.with_progress_callback(progress_callback.clone())
.with_create_repo(true);
client
.push_to_hub(repo_id, &file_data, options)
.map_err(|e| CliError::NetworkError(format!("Upload failed: {e}")))?;
}
if verbose {
println!("Uploading README.md...");
}
let readme_options = PushOptions::new()
.with_filename("README.md")
.with_commit_message(commit_msg)
.with_create_repo(false);
client
.push_to_hub(repo_id, readme_content.as_bytes(), readme_options)
.map_err(|e| CliError::NetworkError(format!("README upload failed: {e}")))?;
Ok(())
}
#[provable_contracts_macros::contract(
"apr-cli-command-safety-v1",
equation = "mutating_output_contract"
)]
#[allow(clippy::too_many_arguments)]
pub fn execute(
directory: &Path,
repo_id: &str,
model_name: Option<&str>,
license: &str,
pipeline_tag: &str,
library_name: Option<&str>,
tags: &[String],
commit_message: Option<&str>,
dry_run: bool,
verbose: bool,
manifest: Option<&Path>,
extra_files: &[std::path::PathBuf],
) -> Result<(), CliError> {
let files = if let Some(manifest_path) = manifest {
let artifact = preflight_manifest_guard(manifest_path, directory)?;
vec![artifact]
} else {
validate_publish_inputs(directory, repo_id)?
};
if manifest.is_some() {
if !repo_id.contains('/') || repo_id.split('/').count() != 2 {
return Err(CliError::ValidationFailed(format!(
"Invalid repo ID '{}'. Expected format: org/repo-name",
repo_id
)));
}
if !directory.exists() {
return Err(CliError::FileNotFound(directory.to_path_buf()));
}
}
if verbose {
println!("Uploading {} primary artifact(s):", files.len());
for f in &files {
println!(" - {}", f.display());
}
}
for ef in extra_files {
if !ef.exists() {
return Err(CliError::FileNotFound(ef.clone()));
}
}
let (model_card, file_names) = generate_model_card(
repo_id,
model_name,
license,
pipeline_tag,
library_name,
tags,
&files,
);
let readme_content =
model_card.to_huggingface_extended(pipeline_tag, library_name, tags, &file_names);
if dry_run {
println!("=== DRY RUN: Would publish to {} ===\n", repo_id);
println!("Files to upload:");
for f in &files {
let size = fs::metadata(f).map(|m| m.len()).unwrap_or(0);
println!(
" - {} ({:.1} MB) {}",
f.display(),
size as f64 / 1_000_000.0,
format_upload_route(size)
);
}
for ef in extra_files {
let size = fs::metadata(ef).map(|m| m.len()).unwrap_or(0);
println!(
" - {} ({:.1} MB) [extra-file] {}",
ef.display(),
size as f64 / 1_000_000.0,
format_upload_route(size)
);
}
if let Some(m) = manifest {
let size = fs::metadata(m).map(|meta| meta.len()).unwrap_or(0);
println!(
" - {} ({:.1} KB) [manifest]",
m.display(),
size as f64 / 1_000.0
);
println!("\n(README.md auto-generation suppressed: manifest provides provenance)");
} else {
println!("\nGenerated README.md:\n");
println!("{}", readme_content);
}
println!("\n=== DRY RUN COMPLETE ===");
return Ok(());
}
#[cfg(not(feature = "hf-hub"))]
{
let _ = (commit_message, verbose, manifest, extra_files);
return Err(CliError::ValidationFailed(
"Publishing requires the 'hf-hub' feature. Rebuild with: \
cargo install --path crates/apr-cli --features hf-hub"
.to_string(),
));
}
#[cfg(feature = "hf-hub")]
{
let client = HfHubClient::new().map_err(|e| {
CliError::ValidationFailed(format!("Failed to create HF Hub client: {e}"))
})?;
if !client.is_authenticated() {
return Err(CliError::ValidationFailed(
"HF_TOKEN environment variable not set. Set it with: export HF_TOKEN=hf_...".into(),
));
}
let commit_msg = commit_message.unwrap_or("Upload via apr-cli publish");
println!("Publishing to https://huggingface.co/{}", repo_id);
let extras_size: u64 = extra_files
.iter()
.map(|f| fs::metadata(f).map(|m| m.len()).unwrap_or(0))
.sum();
let manifest_size: u64 = manifest
.map(|m| fs::metadata(m).map(|meta| meta.len()).unwrap_or(0))
.unwrap_or(0);
let total_size: u64 = files
.iter()
.map(|f| fs::metadata(f).map(|m| m.len()).unwrap_or(0))
.sum::<u64>()
+ extras_size
+ manifest_size
+ if manifest.is_some() {
0
} else {
readme_content.len() as u64
};
println!(
"Total upload size: {:.1} MB",
total_size as f64 / 1_000_000.0
);
upload_to_hub_extended(
&client,
repo_id,
&files,
&readme_content,
manifest,
extra_files,
commit_msg,
verbose,
)?;
println!("\n✓ Published to https://huggingface.co/{}", repo_id);
Ok(())
}
}
fn preflight_manifest_guard(
manifest_path: &Path,
directory: &Path,
) -> Result<std::path::PathBuf, CliError> {
let manifest_src = fs::read_to_string(manifest_path).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot read manifest {}: {e}",
manifest_path.display()
))
})?;
let parsed: serde_yaml::Value = serde_yaml::from_str(&manifest_src)
.map_err(|e| CliError::ValidationFailed(format!("Manifest YAML parse error: {e}")))?;
let declared_sha = parsed
.get("sha256")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CliError::ValidationFailed("Manifest missing required field: sha256".into())
})?;
let artifact_url = parsed
.get("artifact_url")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CliError::ValidationFailed("Manifest missing required field: artifact_url".into())
})?;
let declared_size = parsed.get("size_bytes").and_then(serde_yaml::Value::as_u64);
let artifact_basename = artifact_url
.rsplit('/')
.next()
.ok_or_else(|| CliError::ValidationFailed("artifact_url has no basename".into()))?;
let local_artifact = directory.join(artifact_basename);
if !local_artifact.exists() {
return Err(CliError::ValidationFailed(format!(
"Manifest-declared artifact not found locally: {}",
local_artifact.display()
)));
}
let computed_sha = stream_sha256(&local_artifact)?;
if computed_sha != declared_sha {
return Err(CliError::ValidationFailed(format!(
"sha256 mismatch — manifest-declared vs local artifact.\n \
manifest: {declared_sha}\n \
local: {computed_sha}\n \
file: {}",
local_artifact.display()
)));
}
if let Some(expected) = declared_size {
let actual = fs::metadata(&local_artifact).map(|m| m.len()).unwrap_or(0);
if expected != actual {
return Err(CliError::ValidationFailed(format!(
"size_bytes mismatch — manifest {expected}, local {actual}"
)));
}
}
Ok(local_artifact)
}
fn stream_sha256(path: &Path) -> Result<String, CliError> {
use sha2::{Digest, Sha256};
use std::io::Read;
let mut f = fs::File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 65_536];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(format!("{:x}", hasher.finalize()))
}
fn find_model_files(directory: &Path) -> Result<Vec<std::path::PathBuf>, CliError> {
let mut files = Vec::new();
let entries = fs::read_dir(directory)?;
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension() {
let ext_str = ext.to_string_lossy().to_lowercase();
if ext_str == "apr" || ext_str == "safetensors" || ext_str == "gguf" {
files.push(path);
}
}
}
}
files.sort();
Ok(files)
}
fn generate_model_card(
repo_id: &str,
model_name: Option<&str>,
license: &str,
_pipeline_tag: &str,
_library_name: Option<&str>,
_tags: &[String],
files: &[std::path::PathBuf],
) -> (ModelCard, Vec<String>) {
let name = model_name.unwrap_or_else(|| repo_id.split('/').next_back().unwrap_or(repo_id));
let file_names: Vec<String> = files
.iter()
.filter_map(|f| f.file_name())
.map(|f| f.to_string_lossy().to_string())
.collect();
let card = ModelCard::new(repo_id, "1.0.0")
.with_name(name)
.with_license(license)
.with_description(format!("{} model published via aprender", name));
(card, file_names)
}
trait ModelCardExt {
fn to_huggingface_extended(
&self,
pipeline_tag: &str,
library_name: Option<&str>,
extra_tags: &[String],
file_names: &[String],
) -> String;
}
impl ModelCardExt for ModelCard {
fn to_huggingface_extended(
&self,
pipeline_tag: &str,
library_name: Option<&str>,
extra_tags: &[String],
file_names: &[String],
) -> String {
use std::fmt::Write;
let mut output = String::from("---\n");
if let Some(license) = &self.license {
let _ = writeln!(output, "license: {}", license.to_lowercase());
}
if pipeline_tag == "automatic-speech-recognition" {
output.push_str("language:\n");
output.push_str(" - en\n");
output.push_str(" - multilingual\n");
}
let _ = writeln!(output, "pipeline_tag: {}", pipeline_tag);
if let Some(lib) = library_name {
let _ = writeln!(output, "library_name: {}", lib);
}
output.push_str("tags:\n");
if let Some(arch) = &self.architecture {
let _ = writeln!(output, " - {}", arch.to_lowercase());
}
output.push_str(" - aprender\n");
output.push_str(" - rust\n");
let mut seen_tags = std::collections::HashSet::new();
seen_tags.insert("aprender");
seen_tags.insert("rust");
if pipeline_tag == "automatic-speech-recognition" {
if seen_tags.insert("speech-recognition") {
output.push_str(" - speech-recognition\n");
}
if seen_tags.insert("audio") {
output.push_str(" - audio\n");
}
}
for tag in extra_tags {
if seen_tags.insert(tag.as_str()) {
let _ = writeln!(output, " - {}", tag);
}
}
output.push_str("model-index:\n");
let _ = writeln!(output, " - name: {}", self.model_id);
output.push_str(" results:\n");
output.push_str(" - task:\n");
let _ = writeln!(output, " type: {}", pipeline_tag);
output.push_str(" dataset:\n");
output.push_str(" name: custom\n");
output.push_str(" type: custom\n");
output.push_str(" metrics:\n");
if self.metrics.is_empty() {
output.push_str(" - name: accuracy\n");
output.push_str(" type: custom\n");
output.push_str(" value: N/A\n");
} else {
for (key, value) in &self.metrics {
let _ = writeln!(output, " - name: {}", key);
output.push_str(" type: custom\n");
let _ = writeln!(output, " value: {}", value);
}
}
output.push_str("---\n\n");
let _ = writeln!(output, "# {}\n", self.name);
if let Some(desc) = &self.description {
let _ = writeln!(output, "{}\n", desc);
}
output.push_str("## Available Formats\n\n");
output.push_str("| Format | Description |\n");
output.push_str("|--------|-------------|\n");
if file_names.is_empty() {
output.push_str("| `model.apr` | Native APR format (streaming, WASM-optimized) |\n");
} else {
for name in file_names {
let desc = match std::path::Path::new(name)
.extension()
.and_then(|e| e.to_str())
{
Some("apr") => "Native APR format (streaming, WASM-optimized)",
Some("safetensors") => "HuggingFace SafeTensors format",
Some("gguf") => "GGUF format (llama.cpp compatible)",
Some("bin") | Some("pt") | Some("pth") => "PyTorch binary format",
_ => "Model file",
};
let _ = writeln!(output, "| `{}` | {} |", name, desc);
}
}
output.push('\n');
output.push_str("## Usage\n\n");
output.push_str("```rust\n");
output.push_str("use aprender::Model;\n");
output.push('\n');
output.push_str("let model = Model::load(\"model.apr\")?;\n");
output.push_str("let result = model.run(&input)?;\n");
output.push_str("```\n\n");
output.push_str("## Framework\n\n");
let _ = writeln!(output, "- **Version:** {}", self.framework_version);
if let Some(rust) = &self.rust_version {
let _ = writeln!(output, "- **Rust:** {}", rust);
}
output.push('\n');
output.push_str("## Citation\n\n");
output.push_str("```bibtex\n");
output.push_str("@software{aprender,\n");
output.push_str(" title = {aprender: Rust ML Library},\n");
output.push_str(" author = {PAIML},\n");
output.push_str(" year = {2025},\n");
output.push_str(" url = {https://github.com/paiml/aprender}\n");
output.push_str("}\n");
output.push_str("```\n");
output
}
}
#[cfg(test)]
#[path = "publish_tests.rs"]
mod tests;