entrenar/transformer/weights/detect.rs
1//! Architecture detection and SafeTensors file discovery
2
3use super::Architecture;
4use crate::error::Result;
5
6/// Find SafeTensors files in a directory or return single file.
7///
8/// For checkpoint directories containing `model-step-*.safetensors` files,
9/// returns only the latest (highest step number) checkpoint to avoid loading
10/// all intermediate checkpoints. For sharded models (e.g. `model-00001-of-00014.safetensors`),
11/// returns all shards.
12pub(crate) fn find_safetensors_files(path: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
13 if path.is_file() {
14 return Ok(if path.extension().is_some_and(|e| e == "safetensors") {
15 vec![path.to_path_buf()]
16 } else {
17 vec![]
18 });
19 }
20
21 if !path.is_dir() {
22 return Ok(vec![]);
23 }
24
25 // Check for single model.safetensors first
26 let single = path.join("model.safetensors");
27 if single.exists() {
28 return Ok(vec![single]);
29 }
30
31 // Collect all .safetensors files
32 let mut files: Vec<std::path::PathBuf> = std::fs::read_dir(path)
33 .into_iter()
34 .flatten()
35 .flatten()
36 .map(|e| e.path())
37 .filter(|p| p.extension().is_some_and(|e| e == "safetensors"))
38 .collect();
39 files.sort();
40
41 // Check if these are checkpoint files (model-step-*.safetensors).
42 // If so, return only the latest one — loading all checkpoints is wasteful
43 // (each has all 218 tensors, last-write-wins) and was the root cause of
44 // the resume bug where 5x1.5GB of redundant I/O slowed startup.
45 if let Some(latest) = find_latest_checkpoint(&files) {
46 eprintln!(" Resuming from checkpoint: {}", latest.display());
47 return Ok(vec![latest]);
48 }
49
50 Ok(files)
51}
52
53/// Find the checkpoint file with the highest step number.
54///
55/// Returns `None` if no files match the `model-step-{N}.safetensors` pattern.
56/// Ignores non-checkpoint files (e.g. `model-best.safetensors`) — only looks
57/// at files matching the `model-step-{N}.safetensors` pattern.
58fn find_latest_checkpoint(files: &[std::path::PathBuf]) -> Option<std::path::PathBuf> {
59 files
60 .iter()
61 .filter_map(|f| parse_checkpoint_step_from_path(f).map(|step| (step, f)))
62 .max_by_key(|(step, _)| *step)
63 .map(|(_, p)| p.clone())
64}
65
66/// Parse step number from a checkpoint path like `.../model-step-3000.safetensors`.
67pub(crate) fn parse_checkpoint_step_from_path(path: &std::path::Path) -> Option<usize> {
68 let filename = path.file_name()?.to_str()?;
69 filename.strip_prefix("model-step-")?.strip_suffix(".safetensors")?.parse().ok()
70}
71
72/// Auto-detect model architecture from tensor names
73pub(crate) fn detect_architecture(tensors: &safetensors::SafeTensors<'_>) -> Architecture {
74 let names: Vec<String> = tensors.names().iter().map(std::string::ToString::to_string).collect();
75
76 // RoBERTa / CodeBERT: look for "roberta." or "bert." prefix with encoder layers
77 let is_roberta = names
78 .iter()
79 .any(|n| n.starts_with("roberta.") || (n.starts_with("bert.") && n.contains(".encoder.")));
80 if is_roberta {
81 return Architecture::RoBERTa;
82 }
83
84 // Qwen2 uses "model.layers.X.self_attn.q_proj" with biases
85 // LLaMA uses same pattern but no biases
86 // Check for bias tensors to distinguish
87 let has_attn_bias = names.iter().any(|n: &String| {
88 n.contains("self_attn.q_proj.bias") || n.contains("self_attn.k_proj.bias")
89 });
90
91 if has_attn_bias {
92 // Qwen2 has attention biases
93 return Architecture::Qwen2;
94 }
95
96 // Check for Mistral-specific patterns
97 // Mistral uses sliding window attention config but same weight names as LLaMA
98 // We default to LLaMA if no Qwen2 bias markers found
99
100 Architecture::Llama
101}