use crate::cli::UI;
use crate::huggingface::{api_compat, cache::HfCache, client::HfClient, download, upload};
use crate::platform::server_registry::ServerRegistry;
use crate::redteam::RedteamBridge;
use anyhow::{bail, Context, Result};
use crate::cli::args::{HfCommands, HfPipelineCommands};
pub async fn execute(action: HfCommands, ui: &UI) -> Result<()> {
match action {
HfCommands::Pull {
model_id,
revision,
include,
exclude,
} => execute_pull(&model_id, &revision, include, exclude, ui).await,
HfCommands::Push { path, repo, hf_org } => {
execute_push(&path, &repo, hf_org.as_deref(), ui).await
}
HfCommands::Search {
query,
task,
library,
limit,
} => execute_search(&query, task.as_deref(), library.as_deref(), limit, ui).await,
HfCommands::Info { model_id } => execute_info(&model_id, ui).await,
HfCommands::Scan { model_id } => execute_scan(&model_id, ui).await,
HfCommands::Pipeline { action } => match action {
HfPipelineCommands::Trigger {
model_id,
server,
project_id,
token,
} => {
execute_pipeline_trigger(
&model_id,
server.as_deref(),
project_id,
token.as_deref(),
ui,
)
.await
}
HfPipelineCommands::Status {
pipeline_id,
server,
} => execute_pipeline_status(pipeline_id, server.as_deref(), ui).await,
},
HfCommands::FullPipeline {
model_id,
hf_org,
hf_hardware,
min_fix_rate,
probes,
epochs,
fail_on_regression,
} => {
execute_fullpipeline(
&model_id,
&hf_org,
&hf_hardware,
min_fix_rate,
probes.as_deref(),
epochs,
fail_on_regression,
ui,
)
.await
}
HfCommands::ApiCheck => execute_api_check(ui).await,
}
}
async fn execute_pull(
model_id: &str,
revision: &str,
include: Vec<String>,
exclude: Vec<String>,
ui: &UI,
) -> Result<()> {
let client = HfClient::from_env();
let cache = HfCache::new();
ui.header("HuggingFace Model Pull");
ui.field("Model", model_id);
ui.field("Revision", revision);
if let Some(state) = api_compat::load_state() {
if api_compat::is_stale(&state, 24) {
let spinner = ui.spinner("Checking API compatibility...");
match api_compat::check_api_compat(&client).await {
Ok(level) => {
ui.finish_progress(&spinner, &format!("API check: {}", level));
}
Err(e) => {
ui.finish_progress(&spinner, &format!("API check failed: {}", e));
}
}
}
}
let spinner = ui.spinner("Downloading model...");
let opts = download::DownloadOptions {
revision: revision.to_string(),
include,
exclude,
};
let result = download::download_model(&client, &cache, model_id, &opts).await?;
ui.finish_progress(&spinner, "Download complete");
if result.from_cache {
ui.success("Model already cached");
} else {
ui.field("Files downloaded", result.files_downloaded);
ui.field("Total size", format_bytes(result.total_bytes));
}
ui.field("Location", result.snapshot_path.display());
Ok(())
}
async fn execute_push(
path: &std::path::Path,
repo_id: &str,
hf_org: Option<&str>,
ui: &UI,
) -> Result<()> {
let client = HfClient::from_env();
let full_repo = if let Some(org) = hf_org {
format!("{}/{}", org, repo_id)
} else {
repo_id.to_string()
};
ui.header("HuggingFace Model Push");
ui.field("Path", path.display());
ui.field("Repository", &full_repo);
client.create_repo(&full_repo, false).await?;
let spinner = ui.spinner("Uploading model...");
let opts = upload::UploadOptions {
repo_id: full_repo,
revision: "main".to_string(),
commit_message: "Upload model via securegit".to_string(),
};
let result = upload::upload_model(&client, path, &opts).await?;
ui.finish_progress(&spinner, "Upload complete");
ui.field("Files uploaded", result.files_uploaded);
ui.field("URL", &result.commit_url);
Ok(())
}
async fn execute_search(
query: &str,
task: Option<&str>,
library: Option<&str>,
limit: usize,
ui: &UI,
) -> Result<()> {
let client = HfClient::from_env();
let spinner = ui.spinner("Searching HuggingFace Hub...");
let models = client.search_models(query, task, library, limit).await?;
ui.finish_progress(&spinner, &format!("Found {} models", models.len()));
for model in &models {
ui.blank();
ui.field("Model", model.model_id.as_deref().unwrap_or(&model.id));
if let Some(ref tag) = model.pipeline_tag {
ui.field("Task", tag);
}
if let Some(ref lib) = model.library_name {
ui.field("Library", lib);
}
ui.field("Downloads", model.downloads);
ui.field("Likes", model.likes);
}
Ok(())
}
async fn execute_info(model_id: &str, ui: &UI) -> Result<()> {
let client = HfClient::from_env();
let spinner = ui.spinner("Fetching model info...");
let info = client.model_info(model_id).await?;
ui.finish_progress(&spinner, "");
ui.header(&format!(
"Model: {}",
info.model_id.as_deref().unwrap_or(&info.id)
));
if let Some(ref sha) = info.sha {
ui.field("SHA", sha);
}
if let Some(ref tag) = info.pipeline_tag {
ui.field("Task", tag);
}
if let Some(ref lib) = info.library_name {
ui.field("Library", lib);
}
ui.field("Downloads", info.downloads);
ui.field("Likes", info.likes);
ui.field("Private", info.private);
if let Some(ref siblings) = info.siblings {
ui.blank();
ui.field("Files", siblings.len());
for file in siblings {
let size_str = file
.size
.or(file.lfs.as_ref().map(|l| l.size))
.map(format_bytes)
.unwrap_or_else(|| "?".to_string());
ui.raw(format!(" {} ({})", file.filename, size_str));
}
}
Ok(())
}
async fn execute_scan(model_id: &str, ui: &UI) -> Result<()> {
ui.header("Model Security Scan");
ui.field("Model", model_id);
let bridge = RedteamBridge::new(None);
if !bridge.is_available() {
bail!(
"armyknife-llm-redteam-mcp not found on PATH. \
Install it or set SECUREGIT_REDTEAM_BIN to the binary path."
);
}
let spinner = ui.spinner("Running LLM security scan...");
let result = bridge
.pipeline_scan(model_id, None)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
ui.finish_progress(&spinner, "Scan complete");
match serde_json::from_str::<serde_json::Value>(&result) {
Ok(json) => {
let findings = json["findings"].as_array().map(|a| a.len()).unwrap_or(0);
let severity = json["max_severity"].as_str().unwrap_or("unknown");
ui.field("Findings", findings);
ui.field("Max severity", severity);
if findings > 0 {
ui.raw(serde_json::to_string_pretty(&json).unwrap_or_else(|_| result.clone()));
}
}
Err(_) => {
ui.raw(result);
}
}
let cache = HfCache::new();
if let Some(snapshot) = cache.cached_snapshot_path(model_id, "main") {
if snapshot.exists() {
ui.blank();
ui.info("Scanning cached model files with external tools...");
let llm_config = &crate::core::Config::default().llm_security;
if llm_config.run_modelscan {
let ms = crate::toolbridges::modelscan::ModelScanBridge::new();
if ms.is_available() {
match ms.scan(&snapshot).await {
Ok(findings) if !findings.is_empty() => {
ui.field("ModelScan findings", findings.len());
for f in &findings {
ui.field(&format!("[{}]", f.severity), &f.title);
}
}
Ok(_) => {
ui.status_item(true, "ModelScan: clean");
}
Err(e) => {
ui.warning(format!("ModelScan failed: {}", e));
}
}
}
}
if llm_config.run_picklescan {
let ps = crate::toolbridges::picklescan::PickleScanBridge::new();
if ps.is_available() {
match ps.scan(&snapshot).await {
Ok(findings) if !findings.is_empty() => {
ui.field("PickleScan findings", findings.len());
for f in &findings {
ui.field(&format!("[{}]", f.severity), &f.title);
}
}
Ok(_) => {
ui.status_item(true, "PickleScan: clean");
}
Err(e) => {
ui.warning(format!("PickleScan failed: {}", e));
}
}
}
}
if llm_config.run_fickling {
let fk = crate::toolbridges::fickling::FicklingBridge::new();
if fk.is_available() {
match fk.scan(&snapshot).await {
Ok(findings) if !findings.is_empty() => {
ui.field("Fickling findings", findings.len());
for f in &findings {
ui.field(&format!("[{}]", f.severity), &f.title);
}
}
Ok(_) => {
ui.status_item(true, "Fickling: clean");
}
Err(e) => {
ui.warning(format!("Fickling failed: {}", e));
}
}
}
}
}
}
Ok(())
}
async fn execute_fullpipeline(
model_id: &str,
hf_org: &str,
hf_hardware: &str,
min_fix_rate: f64,
_probes: Option<&str>,
epochs: u32,
fail_on_regression: bool,
ui: &UI,
) -> Result<()> {
ui.header("Full Cloud Pipeline");
ui.field("Model", model_id);
ui.field("HF Org", hf_org);
ui.field("Hardware", hf_hardware);
ui.field("Min fix rate", format!("{:.0}%", min_fix_rate * 100.0));
ui.field("Epochs", epochs);
ui.blank();
let bridge = RedteamBridge::new(None);
if !bridge.is_available() {
bail!(
"armyknife-llm-redteam-mcp not found on PATH. \
Install it or set SECUREGIT_REDTEAM_BIN to the binary path."
);
}
ui.info("Step 1/6: Scanning model on HuggingFace Inference...");
let scan_model_spec = format!("huggingface://{}", model_id);
let spinner = ui.spinner("Running cloud scan...");
let scan_result = bridge
.pipeline_scan(&scan_model_spec, Some("scan-results"))
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
ui.finish_progress(&spinner, "Scan complete");
let scan_json: serde_json::Value =
serde_json::from_str(&scan_result).unwrap_or(serde_json::json!({}));
let total_findings = scan_json["total_findings"]
.as_u64()
.or_else(|| {
scan_json["content"]
.as_array()
.and_then(|arr| arr.first())
.and_then(|c| c["text"].as_str())
.and_then(|text| serde_json::from_str::<serde_json::Value>(text).ok())
.and_then(|v| v["total_findings"].as_u64())
})
.unwrap_or(0);
ui.field("Findings", total_findings);
if total_findings == 0 {
ui.success("Model is already clean — no vulnerabilities detected.");
return Ok(());
}
ui.blank();
ui.info("Step 2/6: Generating DPO training pairs...");
let spinner = ui.spinner("Generating training data...");
let harden_result = bridge
.pipeline_harden(
model_id,
"scan-results",
Some("hardened-model"),
Some("dpo"),
Some("firm"),
Some(epochs),
Some("hf"),
Some(hf_org),
Some(hf_hardware),
)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
ui.finish_progress(&spinner, "Training data generated + cloud training started");
let harden_json: serde_json::Value =
serde_json::from_str(&harden_result).unwrap_or(serde_json::json!({}));
let hf_model_url = harden_json["content"]
.as_array()
.and_then(|arr| arr.first())
.and_then(|c| c["text"].as_str())
.and_then(|text| serde_json::from_str::<serde_json::Value>(text).ok())
.and_then(|v| v["hf_model_url"].as_str().map(String::from))
.unwrap_or_else(|| format!("https://huggingface.co/{hf_org}/{}", model_id.replace('/', "-")));
ui.field("Training model URL", &hf_model_url);
let hardened_model_name = model_id
.split('/')
.last()
.unwrap_or(model_id)
.replace('.', "-");
let hardened_repo = format!("{}/{}-Hardened", hf_org, hardened_model_name);
ui.blank();
ui.info("Step 3/6: Verifying hardened model (targeted re-scan)...");
let spinner = ui.spinner("Running verification scan...");
let hardened_model_spec = format!("huggingface://{}", hardened_repo);
let verify_result = bridge
.pipeline_verify(
"scan-results/findings.json",
&hardened_model_spec,
Some(model_id),
Some("verification-report"),
Some(min_fix_rate),
Some(fail_on_regression),
)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
ui.finish_progress(&spinner, "Verification complete");
let verify_json: serde_json::Value =
serde_json::from_str(&verify_result).unwrap_or(serde_json::json!({}));
let verify_data = verify_json["content"]
.as_array()
.and_then(|arr| arr.first())
.and_then(|c| c["text"].as_str())
.and_then(|text| serde_json::from_str::<serde_json::Value>(text).ok())
.unwrap_or(verify_json.clone());
let fix_rate = verify_data["fix_rate"].as_f64().unwrap_or(0.0);
let verdict = verify_data["verdict"].as_str().unwrap_or("UNKNOWN");
let passed = verify_data["passed"].as_bool().unwrap_or(false);
let regression_count = verify_data["regression_count"].as_u64().unwrap_or(0);
ui.field("Fix rate", format!("{:.1}%", fix_rate * 100.0));
ui.field("Verdict", verdict);
ui.field("Regressions", regression_count);
if !passed {
ui.blank();
ui.error(format!("Verification FAILED (fix rate {:.1}% below {:.1}% threshold)", fix_rate * 100.0, min_fix_rate * 100.0));
if regression_count > 0 {
ui.error(format!("{} regressions detected", regression_count));
}
bail!("Pipeline halted: verification failed");
}
ui.blank();
ui.info("Step 4/6: Publishing hardened model to HuggingFace...");
let spinner = ui.spinner("Publishing model and scorecard...");
let publish_result = bridge
.pipeline_publish(
"hardened-model",
"verification-report",
&hardened_model_name,
Some(hf_org),
)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
ui.finish_progress(&spinner, "Published");
let publish_json: serde_json::Value =
serde_json::from_str(&publish_result).unwrap_or(serde_json::json!({}));
let repo_url = publish_json["content"]
.as_array()
.and_then(|arr| arr.first())
.and_then(|c| c["text"].as_str())
.and_then(|text| serde_json::from_str::<serde_json::Value>(text).ok())
.and_then(|v| v["repo_url"].as_str().map(String::from))
.unwrap_or_else(|| format!("https://huggingface.co/{}", hardened_repo));
ui.blank();
ui.header("Pipeline Complete");
ui.field("Original model", model_id);
ui.field("Hardened model", &hardened_repo);
ui.field("Model URL", &repo_url);
ui.field("Fix rate", format!("{:.1}%", fix_rate * 100.0));
ui.field("Verdict", verdict);
ui.field("Findings (before)", total_findings);
ui.field("Regressions", regression_count);
ui.success("Full pipeline completed successfully.");
Ok(())
}
async fn execute_pipeline_trigger(
model_id: &str,
server_name: Option<&str>,
project_id: Option<u64>,
trigger_token: Option<&str>,
ui: &UI,
) -> Result<()> {
let server_name = server_name.unwrap_or("gpubox");
let registry = ServerRegistry::load().context("Failed to load server registry")?;
let server = registry
.get(server_name)
.ok_or_else(|| anyhow::anyhow!("Server '{}' not found in registry", server_name))?;
let project_id = project_id
.or_else(|| {
std::env::var("SECUREGIT_PIPELINE_PROJECT_ID")
.ok()?
.parse()
.ok()
})
.ok_or_else(|| {
anyhow::anyhow!(
"Pipeline project ID required (--project-id or SECUREGIT_PIPELINE_PROJECT_ID)"
)
})?;
let token = trigger_token
.map(|t| t.to_string())
.or_else(|| std::env::var("SECUREGIT_PIPELINE_TOKEN").ok())
.ok_or_else(|| {
anyhow::anyhow!("Pipeline trigger token required (--token or SECUREGIT_PIPELINE_TOKEN)")
})?;
ui.header("Trigger Model Hardening Pipeline");
ui.field("Model", model_id);
ui.field("Server", server_name);
ui.field("Project ID", project_id);
let api_url = format!(
"{}/projects/{}/trigger/pipeline",
server.api_url, project_id
);
let client = reqwest::Client::new();
let resp = client
.post(&api_url)
.form(&[
("token", token.as_str()),
("ref", "main"),
("variables[MODEL_ID]", model_id),
])
.send()
.await
.context("Failed to trigger pipeline")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("Pipeline trigger failed ({}): {}", status, text);
}
let result: serde_json::Value = resp.json().await?;
let pipeline_id = result["id"].as_u64().unwrap_or(0);
let web_url = result["web_url"].as_str().unwrap_or("");
ui.field("Pipeline ID", pipeline_id);
ui.field("URL", web_url);
ui.success("Pipeline triggered successfully");
Ok(())
}
async fn execute_pipeline_status(
pipeline_id: Option<u64>,
server_name: Option<&str>,
ui: &UI,
) -> Result<()> {
let server_name = server_name.unwrap_or("gpubox");
let registry = ServerRegistry::load().context("Failed to load server registry")?;
let server = registry
.get(server_name)
.ok_or_else(|| anyhow::anyhow!("Server '{}' not found in registry", server_name))?;
let project_id: u64 = std::env::var("SECUREGIT_PIPELINE_PROJECT_ID")
.ok()
.and_then(|v| v.parse().ok())
.ok_or_else(|| anyhow::anyhow!("SECUREGIT_PIPELINE_PROJECT_ID required"))?;
let token_str = crate::auth::token_for_server(server)
.ok_or_else(|| anyhow::anyhow!("No credentials for server '{}'", server_name))?;
let url = if let Some(pid) = pipeline_id {
format!(
"{}/projects/{}/pipelines/{}",
server.api_url, project_id, pid
)
} else {
format!(
"{}/projects/{}/pipelines?per_page=5&order_by=id&sort=desc",
server.api_url, project_id
)
};
let client = reqwest::Client::new();
let resp = client
.get(&url)
.header("PRIVATE-TOKEN", token_str.as_str())
.send()
.await
.context("Failed to fetch pipeline status")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("Pipeline status failed ({}): {}", status, text);
}
let result: serde_json::Value = resp.json().await?;
if pipeline_id.is_some() {
ui.field("Pipeline ID", result["id"].as_u64().unwrap_or(0));
ui.field("Status", result["status"].as_str().unwrap_or("unknown"));
ui.field("Ref", result["ref"].as_str().unwrap_or(""));
ui.field("URL", result["web_url"].as_str().unwrap_or(""));
} else {
let pipelines = result.as_array().map(|a| a.as_slice()).unwrap_or(&[]);
for p in pipelines {
ui.field(
&format!("#{}", p["id"].as_u64().unwrap_or(0)),
format!(
"{} ({})",
p["status"].as_str().unwrap_or("?"),
p["ref"].as_str().unwrap_or("?")
),
);
}
}
Ok(())
}
async fn execute_api_check(ui: &UI) -> Result<()> {
let client = HfClient::from_env();
ui.header("HuggingFace API Compatibility Check");
let spinner = ui.spinner("Checking API...");
let level = api_compat::check_api_compat(&client).await?;
ui.finish_progress(&spinner, "");
match &level {
api_compat::ApiCompatLevel::Ok => ui.success("API compatible"),
api_compat::ApiCompatLevel::FirstRun => ui.success("Baseline recorded"),
api_compat::ApiCompatLevel::Warn(msgs) => {
for msg in msgs {
ui.warning(msg);
}
}
api_compat::ApiCompatLevel::Error(msgs) => {
for msg in msgs {
ui.error(msg);
}
}
}
Ok(())
}
fn format_bytes(bytes: u64) -> String {
const KB: u64 = 1024;
const MB: u64 = 1024 * KB;
const GB: u64 = 1024 * MB;
if bytes >= GB {
format!("{:.1} GB", bytes as f64 / GB as f64)
} else if bytes >= MB {
format!("{:.1} MB", bytes as f64 / MB as f64)
} else if bytes >= KB {
format!("{:.1} KB", bytes as f64 / KB as f64)
} else {
format!("{} B", bytes)
}
}