use super::model_card::{ModelCard, ModelCardBuilder};
use super::{get_hf_token, HubError, Result};
use regex::Regex;
use std::fs;
use std::path::{Path, PathBuf};
fn validate_repo_id(repo_id: &str) -> Result<()> {
let slash_count = repo_id.chars().filter(|&c| c == '/').count();
if slash_count != 1 {
return Err(HubError::InvalidFormat(
"Repository ID must be in format 'username/repo-name'".to_string(),
));
}
let valid_pattern = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*$")
.expect("Invalid regex pattern");
if !valid_pattern.is_match(repo_id) {
return Err(HubError::InvalidFormat(format!(
"Repository ID '{}' contains invalid characters. Only alphanumeric, /, -, _, . are allowed",
repo_id
)));
}
if repo_id.contains("..") {
return Err(HubError::InvalidFormat(
"Repository ID cannot contain '..' (path traversal)".to_string(),
));
}
let dangerous_chars = [
'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r', '"', '\'', '\\',
];
for c in dangerous_chars {
if repo_id.contains(c) {
return Err(HubError::InvalidFormat(format!(
"Repository ID cannot contain shell metacharacter '{}'",
c
)));
}
}
Ok(())
}
fn validate_upload_path(path: &Path) -> Result<()> {
let path_str = path.to_string_lossy();
if path_str.contains("..") {
return Err(HubError::InvalidFormat(
"File path cannot contain '..' (path traversal)".to_string(),
));
}
let canonical = path.canonicalize().map_err(|e| {
HubError::NotFound(format!("Cannot resolve path '{}': {}", path.display(), e))
})?;
if !canonical.is_file() {
return Err(HubError::NotFound(format!(
"Path '{}' is not a regular file",
path.display()
)));
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct UploadConfig {
pub hf_token: String,
pub private: bool,
pub create_repo: bool,
pub include_sona_weights: bool,
pub auto_model_card: bool,
pub commit_message: String,
}
impl UploadConfig {
pub fn new(hf_token: String) -> Self {
Self {
hf_token,
private: false,
create_repo: true,
include_sona_weights: true,
auto_model_card: true,
commit_message: "Upload RuvLTRA model".to_string(),
}
}
pub fn private(mut self, private: bool) -> Self {
self.private = private;
self
}
pub fn commit_message(mut self, message: impl Into<String>) -> Self {
self.commit_message = message.into();
self
}
}
#[derive(Debug, Clone)]
pub struct UploadProgress {
pub total_bytes: u64,
pub uploaded_bytes: u64,
pub speed_bps: f64,
pub current_file: String,
pub stage: UploadStage,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UploadStage {
Preparing,
CreatingRepo,
UploadingModel,
UploadingSona,
UploadingCard,
Complete,
Failed(String),
}
#[derive(Debug, Clone)]
pub struct ModelMetadata {
pub name: String,
pub description: Option<String>,
pub architecture: String,
pub params_b: f32,
pub context_length: usize,
pub quantization: Option<String>,
pub license: Option<String>,
pub datasets: Vec<String>,
pub tags: Vec<String>,
}
pub struct ModelUploader {
config: UploadConfig,
}
impl ModelUploader {
pub fn new(hf_token: impl Into<String>) -> Self {
Self {
config: UploadConfig::new(hf_token.into()),
}
}
pub fn with_config(config: UploadConfig) -> Self {
Self { config }
}
pub fn upload(
&self,
model_path: impl AsRef<Path>,
repo_id: &str,
metadata: Option<ModelMetadata>,
) -> Result<String> {
let model_path = model_path.as_ref();
validate_repo_id(repo_id)?;
validate_upload_path(model_path)?;
self.upload_via_cli(model_path, repo_id, metadata)
}
fn upload_via_cli(
&self,
model_path: &Path,
repo_id: &str,
metadata: Option<ModelMetadata>,
) -> Result<String> {
if !self.has_hf_cli() {
return Err(HubError::Config(
"huggingface-cli not found. Install with: pip install huggingface_hub[cli]"
.to_string(),
));
}
if self.config.create_repo {
self.create_repo_cli(repo_id)?;
}
self.upload_file_cli(model_path, repo_id)?;
if self.config.auto_model_card {
if let Some(meta) = metadata {
let card = self.generate_model_card(&meta);
self.upload_model_card_cli(&card, repo_id)?;
}
}
Ok(format!("https://huggingface.co/{}", repo_id))
}
fn has_hf_cli(&self) -> bool {
std::process::Command::new("huggingface-cli")
.arg("--version")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
fn create_repo_cli(&self, repo_id: &str) -> Result<()> {
let mut args = vec![
"repo".to_string(),
"create".to_string(),
repo_id.to_string(),
];
if self.config.private {
args.push("--private".to_string());
}
let status = std::process::Command::new("huggingface-cli")
.args(&args)
.env("HF_TOKEN", &self.config.hf_token)
.status()
.map_err(|e| HubError::Network(e.to_string()))?;
if !status.success() && status.code() != Some(1) {
return Err(HubError::Network("Failed to create repository".to_string()));
}
Ok(())
}
fn upload_file_cli(&self, file_path: &Path, repo_id: &str) -> Result<()> {
let args = vec![
"upload".to_string(),
repo_id.to_string(),
file_path.to_str().unwrap().to_string(),
"--commit-message".to_string(),
self.config.commit_message.clone(),
];
let status = std::process::Command::new("huggingface-cli")
.args(&args)
.env("HF_TOKEN", &self.config.hf_token)
.status()
.map_err(|e| HubError::Network(e.to_string()))?;
if !status.success() {
return Err(HubError::Network("Failed to upload file".to_string()));
}
Ok(())
}
fn generate_model_card(&self, metadata: &ModelMetadata) -> ModelCard {
use super::model_card::{Framework, License, TaskType};
let mut builder = ModelCardBuilder::new(&metadata.name);
if let Some(desc) = &metadata.description {
builder = builder.description(desc);
}
builder = builder
.task(TaskType::TextGeneration)
.framework(Framework::Gguf)
.architecture(&metadata.architecture)
.parameters((metadata.params_b * 1e9) as u64)
.context_length(metadata.context_length);
if let Some(quant) = &metadata.quantization {
builder = builder.add_tag(quant);
}
if let Some(license) = &metadata.license {
if let Ok(lic) = license.parse() {
builder = builder.license(lic);
}
}
for dataset in &metadata.datasets {
builder = builder.add_dataset(dataset, None);
}
for tag in &metadata.tags {
builder = builder.add_tag(tag);
}
builder.build()
}
fn upload_model_card_cli(&self, card: &ModelCard, repo_id: &str) -> Result<()> {
let temp_dir = std::env::temp_dir();
let card_path = temp_dir.join("README.md");
fs::write(&card_path, card.to_markdown())?;
self.upload_file_cli(&card_path, repo_id)?;
let _ = fs::remove_file(&card_path);
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum UploadError {
#[error("Authentication failed: {0}")]
Auth(String),
#[error("Network error: {0}")]
Network(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_upload_config() {
let config = UploadConfig::new("test_token".to_string());
assert!(!config.private);
assert!(config.create_repo);
assert!(config.include_sona_weights);
}
#[test]
fn test_upload_config_builder() {
let config = UploadConfig::new("token".to_string())
.private(true)
.commit_message("Custom message");
assert!(config.private);
assert_eq!(config.commit_message, "Custom message");
}
#[test]
fn test_model_metadata() {
let metadata = ModelMetadata {
name: "RuvLTRA Test".to_string(),
description: Some("Test model".to_string()),
architecture: "llama".to_string(),
params_b: 0.5,
context_length: 4096,
quantization: Some("Q4_K_M".to_string()),
license: Some("MIT".to_string()),
datasets: vec!["dataset1".to_string()],
tags: vec!["test".to_string()],
};
assert_eq!(metadata.params_b, 0.5);
assert!(metadata.description.is_some());
}
}