1use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Duration;
12
13#[derive(Debug, Clone)]
19pub struct RateLimitConfig {
20 pub initial_backoff: Duration,
22 pub max_backoff: Duration,
24 pub max_retries: u32,
26 pub multiplier: f64,
28}
29
30impl Default for RateLimitConfig {
31 fn default() -> Self {
32 Self {
33 initial_backoff: Duration::from_secs(1),
34 max_backoff: Duration::from_secs(60),
35 max_retries: 5,
36 multiplier: 2.0,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct RateLimitState {
44 pub retry_count: u32,
45 pub current_backoff: Duration,
46 pub retry_after: Option<Duration>,
47}
48
49impl RateLimitState {
50 pub fn new() -> Self {
51 Self { retry_count: 0, current_backoff: Duration::from_secs(1), retry_after: None }
52 }
53
54 pub fn next_backoff(&mut self, config: &RateLimitConfig) -> Option<Duration> {
56 if self.retry_count >= config.max_retries {
57 return None; }
59
60 self.retry_count += 1;
61
62 let backoff = self.retry_after.unwrap_or_else(|| {
64 let backoff_secs = config.initial_backoff.as_secs_f64()
65 * config.multiplier.powi(self.retry_count as i32 - 1);
66 Duration::from_secs_f64(backoff_secs.min(config.max_backoff.as_secs_f64()))
67 });
68
69 self.current_backoff = backoff;
70 Some(backoff)
71 }
72
73 pub fn reset(&mut self) {
75 self.retry_count = 0;
76 self.current_backoff = Duration::from_secs(1);
77 self.retry_after = None;
78 }
79
80 pub fn should_retry(&self, config: &RateLimitConfig) -> bool {
82 self.retry_count < config.max_retries
83 }
84}
85
86impl Default for RateLimitState {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
98pub enum SafetyPolicy {
99 #[default]
101 SafeOnly,
102 AllowUnsafe,
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum FileSafety {
109 Safe,
111 Unsafe,
113 Unknown,
115}
116
117pub fn classify_file_safety(filename: &str) -> FileSafety {
119 const SAFE_EXTENSIONS: &[&str] =
120 &[".safetensors", ".json", ".txt", ".md", ".gguf", ".ggml", ".yaml", ".yml", ".toml"];
121 const UNSAFE_EXTENSIONS: &[&str] = &[".bin", ".pt", ".pth", ".pkl", ".pickle"];
122
123 let lower = filename.to_lowercase();
124 if SAFE_EXTENSIONS.iter().any(|ext| lower.ends_with(ext)) {
125 FileSafety::Safe
126 } else if UNSAFE_EXTENSIONS.iter().any(|ext| lower.ends_with(ext)) {
127 FileSafety::Unsafe
128 } else {
129 FileSafety::Unknown
130 }
131}
132
133pub fn check_download_allowed(files: &[&str], policy: SafetyPolicy) -> Result<(), Vec<String>> {
135 if policy == SafetyPolicy::AllowUnsafe {
136 return Ok(());
137 }
138
139 let unsafe_files: Vec<String> = files
140 .iter()
141 .filter(|f| classify_file_safety(f) == FileSafety::Unsafe)
142 .map(|f| (*f).to_string())
143 .collect();
144
145 if unsafe_files.is_empty() {
147 Ok(())
148 } else {
149 Err(unsafe_files)
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ModelCardMetadata {
160 pub model_name: String,
161 pub language: Option<String>,
162 pub license: Option<String>,
163 pub tags: Vec<String>,
164 pub library_name: Option<String>,
165 pub pipeline_tag: Option<String>,
166 pub datasets: Vec<String>,
167 pub metrics: HashMap<String, f64>,
168}
169
170impl ModelCardMetadata {
171 pub fn new(model_name: impl Into<String>) -> Self {
172 Self {
173 model_name: model_name.into(),
174 language: None,
175 license: None,
176 tags: Vec::new(),
177 library_name: Some("paiml".to_string()),
178 pipeline_tag: None,
179 datasets: Vec::new(),
180 metrics: HashMap::new(),
181 }
182 }
183
184 pub fn with_license(mut self, license: impl Into<String>) -> Self {
185 self.license = Some(license.into());
186 self
187 }
188
189 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
190 self.tags.push(tag.into());
191 self
192 }
193
194 pub fn with_metric(mut self, name: impl Into<String>, value: f64) -> Self {
195 self.metrics.insert(name.into(), value);
196 self
197 }
198}
199
200pub fn generate_model_card(metadata: &ModelCardMetadata) -> String {
202 let mut card = String::new();
203
204 card.push_str("---\n");
206 let optional_fields: &[(&str, Option<&str>)] = &[
207 ("license", metadata.license.as_deref()),
208 ("language", metadata.language.as_deref()),
209 ("library_name", metadata.library_name.as_deref()),
210 ("pipeline_tag", metadata.pipeline_tag.as_deref()),
211 ];
212 for (key, value) in optional_fields {
213 if let Some(v) = value {
214 card.push_str(&format!("{}: {}\n", key, v));
215 }
216 }
217 if !metadata.tags.is_empty() {
218 card.push_str("tags:\n");
219 for tag in &metadata.tags {
220 card.push_str(&format!(" - {}\n", tag));
221 }
222 }
223 card.push_str("---\n\n");
224
225 card.push_str(&format!("# {}\n\n", metadata.model_name));
227
228 card.push_str("## Model Description\n\n");
230 card.push_str("This model was trained using the PAIML stack.\n\n");
231
232 if !metadata.metrics.is_empty() {
234 card.push_str("## Evaluation Results\n\n");
235 card.push_str("| Metric | Value |\n");
236 card.push_str("|--------|-------|\n");
237 for (name, value) in &metadata.metrics {
238 card.push_str(&format!("| {} | {:.4} |\n", name, value));
239 }
240 card.push('\n');
241 }
242
243 card.push_str("## Training Details\n\n");
245 card.push_str("Trained with [PAIML Stack](https://github.com/paiml).\n");
246
247 card
248}
249
250#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
256pub struct FileHash {
257 pub sha256: String,
258 pub size: u64,
259}
260
261impl FileHash {
262 pub fn new(sha256: impl Into<String>, size: u64) -> Self {
263 Self { sha256: sha256.into(), size }
264 }
265
266 pub fn from_content(content: &[u8]) -> Self {
268 use std::collections::hash_map::DefaultHasher;
269 use std::hash::{Hash, Hasher};
270
271 let mut hasher = DefaultHasher::new();
273 content.hash(&mut hasher);
274 let hash = hasher.finish();
275
276 Self { sha256: format!("{:016x}", hash), size: content.len() as u64 }
277 }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct UploadManifest {
283 pub files: HashMap<String, FileHash>,
284}
285
286impl UploadManifest {
287 pub fn new() -> Self {
288 Self { files: HashMap::new() }
289 }
290
291 pub fn add_file(&mut self, path: impl Into<String>, hash: FileHash) {
292 self.files.insert(path.into(), hash);
293 }
294
295 pub fn diff(&self, remote: &UploadManifest) -> Vec<String> {
297 self.files
298 .iter()
299 .filter(|(path, hash)| remote.files.get(*path) != Some(hash))
300 .map(|(path, _)| path.clone())
301 .collect()
302 }
303
304 pub fn total_size(&self, files: &[String]) -> u64 {
306 files.iter().filter_map(|f| self.files.get(f)).map(|h| h.size).sum()
307 }
308}
309
310impl Default for UploadManifest {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316#[derive(Debug, Clone, Copy, PartialEq, Eq)]
322pub enum SecretType {
323 ApiKey,
324 EnvFile,
325 PrivateKey,
326 Password,
327}
328
329#[derive(Debug, Clone)]
331pub struct SecretDetection {
332 pub file: String,
333 pub secret_type: SecretType,
334 pub line: Option<usize>,
335}
336
337fn detect_secret_type(lower: &str) -> Option<SecretType> {
340 const RULES: &[(&[&str], SecretType)] = &[
341 (&[".env", ".env.", "env"], SecretType::EnvFile),
342 (&[".pem", ".key", "id_rsa", "id_ed25519"], SecretType::PrivateKey),
343 (&["credentials", "secrets", "password"], SecretType::Password),
344 ];
345 RULES.iter().find_map(|(patterns, secret_type)| {
346 patterns.iter().any(|p| lower.contains(p)).then_some(*secret_type)
347 })
348}
349
350pub fn scan_for_secrets(files: &[&str]) -> Vec<SecretDetection> {
351 files
352 .iter()
353 .filter_map(|file| {
354 detect_secret_type(&file.to_lowercase()).map(|secret_type| SecretDetection {
355 file: (*file).to_string(),
356 secret_type,
357 line: None,
358 })
359 })
360 .collect()
361}
362
363pub fn check_push_allowed(files: &[&str]) -> Result<(), Vec<SecretDetection>> {
365 let secrets = scan_for_secrets(files);
366 if secrets.is_empty() {
367 Ok(())
368 } else {
369 Err(secrets)
370 }
371}
372
373#[cfg(test)]
378#[allow(non_snake_case)]
379mod tests {
380 use super::*;
381
382 fn assert_file_safety(filename: &str, expected: FileSafety) {
388 assert_eq!(
389 classify_file_safety(filename),
390 expected,
391 "Expected {filename} to be {expected:?}"
392 );
393 }
394
395 fn test_manifest(files: &[(&str, &str, u64)]) -> UploadManifest {
397 let mut manifest = UploadManifest::new();
398 for &(path, sha, size) in files {
399 manifest.add_file(path, FileHash::new(sha, size));
400 }
401 manifest
402 }
403
404 fn make_metadata(
406 license: Option<&str>,
407 tags: &[&str],
408 metrics: &[(&str, f64)],
409 ) -> ModelCardMetadata {
410 let mut meta = ModelCardMetadata::new("test-model");
411 if let Some(lic) = license {
412 meta = meta.with_license(lic);
413 }
414 for tag in tags {
415 meta = meta.with_tag(*tag);
416 }
417 for &(name, value) in metrics {
418 meta = meta.with_metric(name, value);
419 }
420 meta
421 }
422
423 fn assert_card_contains(
426 license: Option<&str>,
427 tags: &[&str],
428 metrics: &[(&str, f64)],
429 expected: &[&str],
430 ) {
431 let meta = make_metadata(license, tags, metrics);
432 let card = generate_model_card(&meta);
433 for s in expected {
434 assert!(card.contains(s), "Card missing expected string: {s:?}");
435 }
436 }
437
438 #[test]
443 fn test_HF_CLIENT_001_rate_limit_config_default() {
444 let config = RateLimitConfig::default();
445 assert_eq!(config.initial_backoff, Duration::from_secs(1));
446 assert_eq!(config.max_retries, 5);
447 assert_eq!(config.multiplier, 2.0);
448 }
449
450 #[test]
451 fn test_HF_CLIENT_001_rate_limit_state_new() {
452 let state = RateLimitState::new();
453 assert_eq!(state.retry_count, 0);
454 assert!(state.retry_after.is_none());
455 }
456
457 #[test]
458 fn test_HF_CLIENT_001_rate_limit_exponential_backoff() {
459 let config = RateLimitConfig::default();
460 let mut state = RateLimitState::new();
461
462 let backoff1 = state.next_backoff(&config).expect("unexpected failure");
464 assert_eq!(backoff1, Duration::from_secs(1));
465
466 let backoff2 = state.next_backoff(&config).expect("unexpected failure");
468 assert_eq!(backoff2, Duration::from_secs(2));
469
470 let backoff3 = state.next_backoff(&config).expect("unexpected failure");
472 assert_eq!(backoff3, Duration::from_secs(4));
473 }
474
475 #[test]
476 fn test_HF_CLIENT_001_rate_limit_max_backoff() {
477 let config = RateLimitConfig { max_backoff: Duration::from_secs(10), ..Default::default() };
478 let mut state = RateLimitState::new();
479
480 for _ in 0..4 {
482 state.next_backoff(&config);
483 }
484
485 let backoff = state.next_backoff(&config).expect("unexpected failure");
486 assert!(backoff <= config.max_backoff);
487 }
488
489 #[test]
490 fn test_HF_CLIENT_001_rate_limit_max_retries() {
491 let config = RateLimitConfig { max_retries: 2, ..Default::default() };
492 let mut state = RateLimitState::new();
493
494 assert!(state.next_backoff(&config).is_some());
495 assert!(state.next_backoff(&config).is_some());
496 assert!(state.next_backoff(&config).is_none()); }
498
499 #[test]
500 fn test_HF_CLIENT_001_rate_limit_reset() {
501 let config = RateLimitConfig::default();
502 let mut state = RateLimitState::new();
503
504 state.next_backoff(&config);
505 state.next_backoff(&config);
506 assert_eq!(state.retry_count, 2);
507
508 state.reset();
509 assert_eq!(state.retry_count, 0);
510 }
511
512 #[test]
513 fn test_HF_CLIENT_001_rate_limit_retry_after_header() {
514 let config = RateLimitConfig::default();
515 let mut state = RateLimitState::new();
516 state.retry_after = Some(Duration::from_secs(30));
517
518 let backoff = state.next_backoff(&config).expect("unexpected failure");
519 assert_eq!(backoff, Duration::from_secs(30));
520 }
521
522 #[test]
527 fn test_HF_CLIENT_002_classify_safetensors_safe() {
528 assert_file_safety("model.safetensors", FileSafety::Safe);
529 }
530
531 #[test]
532 fn test_HF_CLIENT_002_classify_json_safe() {
533 assert_file_safety("config.json", FileSafety::Safe);
534 }
535
536 #[test]
537 fn test_HF_CLIENT_002_classify_gguf_safe() {
538 assert_file_safety("model.gguf", FileSafety::Safe);
539 }
540
541 #[test]
543 fn test_HF_CLIENT_002_classify_bin_unsafe() {
544 assert_file_safety("pytorch_model.bin", FileSafety::Unsafe);
545 }
546
547 #[test]
549 fn test_HF_CLIENT_002_classify_pickle_unsafe() {
550 assert_file_safety("model.pkl", FileSafety::Unsafe);
551 assert_file_safety("model.pickle", FileSafety::Unsafe);
552 }
553
554 #[test]
556 fn test_HF_CLIENT_002_classify_pt_unsafe() {
557 assert_file_safety("model.pt", FileSafety::Unsafe);
558 assert_file_safety("model.pth", FileSafety::Unsafe);
559 }
560
561 #[test]
562 fn test_HF_CLIENT_002_check_download_safe_only_pass() {
563 let files = vec!["model.safetensors", "config.json"];
564 assert!(check_download_allowed(&files, SafetyPolicy::SafeOnly).is_ok());
565 }
566
567 #[test]
568 fn test_HF_CLIENT_002_check_download_safe_only_fail() {
569 let files = vec!["model.safetensors", "pytorch_model.bin"];
570 let result = check_download_allowed(&files, SafetyPolicy::SafeOnly);
571 assert!(result.is_err());
572 assert_eq!(result.unwrap_err(), vec!["pytorch_model.bin".to_string()]);
573 }
574
575 #[test]
577 fn test_HF_CLIENT_002_check_download_allow_unsafe() {
578 let files = vec!["model.safetensors", "pytorch_model.bin"];
579 assert!(check_download_allowed(&files, SafetyPolicy::AllowUnsafe).is_ok());
580 }
581
582 #[test]
587 fn test_HF_CLIENT_003_model_card_metadata_new() {
588 let meta = ModelCardMetadata::new("my-model");
589 assert_eq!(meta.model_name, "my-model");
590 assert_eq!(meta.library_name, Some("paiml".to_string()));
591 }
592
593 #[test]
594 fn test_HF_CLIENT_003_model_card_with_license() {
595 let meta = ModelCardMetadata::new("my-model").with_license("apache-2.0");
596 assert_eq!(meta.license, Some("apache-2.0".to_string()));
597 }
598
599 #[test]
600 fn test_HF_CLIENT_003_model_card_with_tags() {
601 let meta = make_metadata(None, &["text-classification", "rust"], &[]);
602 assert_eq!(meta.tags.len(), 2);
603 }
604
605 #[test]
606 fn test_HF_CLIENT_003_model_card_with_metrics() {
607 let meta = make_metadata(None, &[], &[("accuracy", 0.95), ("f1", 0.92)]);
608 assert_eq!(meta.metrics.len(), 2);
609 assert_eq!(meta.metrics.get("accuracy"), Some(&0.95));
610 }
611
612 #[test]
613 fn test_HF_CLIENT_003_generate_model_card_header() {
614 let meta = make_metadata(None, &[], &[]);
615 let card = generate_model_card(&meta);
616 assert!(card.starts_with("---\n"));
617 assert!(card.contains("# test-model"));
618 }
619
620 #[test]
621 fn test_HF_CLIENT_003_generate_model_card_license() {
622 assert_card_contains(Some("mit"), &[], &[], &["license: mit"]);
623 }
624
625 #[test]
626 fn test_HF_CLIENT_003_generate_model_card_metrics() {
627 assert_card_contains(None, &[], &[("acc", 0.9)], &["| acc |", "0.9"]);
628 }
629
630 #[test]
631 fn test_HF_CLIENT_003_generate_model_card_paiml_footer() {
632 assert_card_contains(None, &[], &[], &["PAIML Stack"]);
633 }
634
635 #[test]
640 fn test_HF_CLIENT_004_file_hash_new() {
641 let hash = FileHash::new("abc123", 1024);
642 assert_eq!(hash.sha256, "abc123");
643 assert_eq!(hash.size, 1024);
644 }
645
646 #[test]
647 fn test_HF_CLIENT_004_file_hash_from_content() {
648 let hash = FileHash::from_content(b"hello world");
649 assert!(!hash.sha256.is_empty());
650 assert_eq!(hash.size, 11);
651 }
652
653 #[test]
654 fn test_HF_CLIENT_004_file_hash_deterministic() {
655 let hash1 = FileHash::from_content(b"test");
656 let hash2 = FileHash::from_content(b"test");
657 assert_eq!(hash1.sha256, hash2.sha256);
658 }
659
660 #[test]
661 fn test_HF_CLIENT_004_upload_manifest_new() {
662 let manifest = UploadManifest::new();
663 assert!(manifest.files.is_empty());
664 }
665
666 #[test]
667 fn test_HF_CLIENT_004_upload_manifest_add_file() {
668 let manifest = test_manifest(&[("model.safetensors", "abc", 1000)]);
669 assert_eq!(manifest.files.len(), 1);
670 }
671
672 #[test]
673 fn test_HF_CLIENT_004_upload_manifest_diff_new_file() {
674 let local = test_manifest(&[("new.txt", "abc", 100)]);
675 let remote = test_manifest(&[]);
676
677 let diff = local.diff(&remote);
678 assert_eq!(diff, vec!["new.txt".to_string()]);
679 }
680
681 #[test]
682 fn test_HF_CLIENT_004_upload_manifest_diff_changed_file() {
683 let local = test_manifest(&[("file.txt", "new_hash", 100)]);
684 let remote = test_manifest(&[("file.txt", "old_hash", 100)]);
685
686 let diff = local.diff(&remote);
687 assert_eq!(diff, vec!["file.txt".to_string()]);
688 }
689
690 #[test]
691 fn test_HF_CLIENT_004_upload_manifest_diff_unchanged() {
692 let local = test_manifest(&[("file.txt", "same", 100)]);
693 let remote = test_manifest(&[("file.txt", "same", 100)]);
694
695 let diff = local.diff(&remote);
696 assert!(diff.is_empty());
697 }
698
699 #[test]
700 fn test_HF_CLIENT_004_upload_manifest_total_size() {
701 let manifest = test_manifest(&[("a.txt", "a", 100), ("b.txt", "b", 200)]);
702
703 let files = vec!["a.txt".to_string(), "b.txt".to_string()];
704 assert_eq!(manifest.total_size(&files), 300);
705 }
706
707 #[test]
712 fn test_HF_CLIENT_005_scan_env_file() {
713 let files = vec![".env", "model.safetensors"];
714 let secrets = scan_for_secrets(&files);
715 assert_eq!(secrets.len(), 1);
716 assert_eq!(secrets[0].secret_type, SecretType::EnvFile);
717 }
718
719 #[test]
720 fn test_HF_CLIENT_005_scan_env_local() {
721 let files = vec![".env.local"];
722 let secrets = scan_for_secrets(&files);
723 assert_eq!(secrets.len(), 1);
724 }
725
726 #[test]
727 fn test_HF_CLIENT_005_scan_private_key() {
728 let files = vec!["id_rsa", "key.pem"];
729 let secrets = scan_for_secrets(&files);
730 assert_eq!(secrets.len(), 2);
731 assert!(secrets.iter().all(|s| s.secret_type == SecretType::PrivateKey));
732 }
733
734 #[test]
735 fn test_HF_CLIENT_005_scan_credentials() {
736 let files = vec!["credentials.json"];
737 let secrets = scan_for_secrets(&files);
738 assert_eq!(secrets.len(), 1);
739 assert_eq!(secrets[0].secret_type, SecretType::Password);
740 }
741
742 #[test]
743 fn test_HF_CLIENT_005_scan_no_secrets() {
744 let files = vec!["model.safetensors", "config.json", "README.md"];
745 let secrets = scan_for_secrets(&files);
746 assert!(secrets.is_empty());
747 }
748
749 #[test]
750 fn test_HF_CLIENT_005_check_push_allowed_clean() {
751 let files = vec!["model.safetensors", "config.json"];
752 assert!(check_push_allowed(&files).is_ok());
753 }
754
755 #[test]
756 fn test_HF_CLIENT_005_check_push_blocked() {
757 let files = vec!["model.safetensors", ".env"];
758 let result = check_push_allowed(&files);
759 assert!(result.is_err());
760 }
761}