Skip to main content

entrenar/transformer/weights/
detect.rs

1//! Architecture detection and SafeTensors file discovery
2
3use super::Architecture;
4use crate::error::Result;
5
6/// Find SafeTensors files in a directory or return single file.
7///
8/// For checkpoint directories containing `model-step-*.safetensors` files,
9/// returns only the latest (highest step number) checkpoint to avoid loading
10/// all intermediate checkpoints. For sharded models (e.g. `model-00001-of-00014.safetensors`),
11/// returns all shards.
12pub(crate) fn find_safetensors_files(path: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
13    if path.is_file() {
14        return Ok(if path.extension().is_some_and(|e| e == "safetensors") {
15            vec![path.to_path_buf()]
16        } else {
17            vec![]
18        });
19    }
20
21    if !path.is_dir() {
22        return Ok(vec![]);
23    }
24
25    // Check for single model.safetensors first
26    let single = path.join("model.safetensors");
27    if single.exists() {
28        return Ok(vec![single]);
29    }
30
31    // Collect all .safetensors files
32    let mut files: Vec<std::path::PathBuf> = std::fs::read_dir(path)
33        .into_iter()
34        .flatten()
35        .flatten()
36        .map(|e| e.path())
37        .filter(|p| p.extension().is_some_and(|e| e == "safetensors"))
38        .collect();
39    files.sort();
40
41    // Check if these are checkpoint files (model-step-*.safetensors).
42    // If so, return only the latest one — loading all checkpoints is wasteful
43    // (each has all 218 tensors, last-write-wins) and was the root cause of
44    // the resume bug where 5x1.5GB of redundant I/O slowed startup.
45    if let Some(latest) = find_latest_checkpoint(&files) {
46        eprintln!("  Resuming from checkpoint: {}", latest.display());
47        return Ok(vec![latest]);
48    }
49
50    Ok(files)
51}
52
53/// Find the checkpoint file with the highest step number.
54///
55/// Returns `None` if no files match the `model-step-{N}.safetensors` pattern.
56/// Ignores non-checkpoint files (e.g. `model-best.safetensors`) — only looks
57/// at files matching the `model-step-{N}.safetensors` pattern.
58fn find_latest_checkpoint(files: &[std::path::PathBuf]) -> Option<std::path::PathBuf> {
59    files
60        .iter()
61        .filter_map(|f| parse_checkpoint_step_from_path(f).map(|step| (step, f)))
62        .max_by_key(|(step, _)| *step)
63        .map(|(_, p)| p.clone())
64}
65
66/// Parse step number from a checkpoint path like `.../model-step-3000.safetensors`.
67pub(crate) fn parse_checkpoint_step_from_path(path: &std::path::Path) -> Option<usize> {
68    let filename = path.file_name()?.to_str()?;
69    filename.strip_prefix("model-step-")?.strip_suffix(".safetensors")?.parse().ok()
70}
71
72/// Auto-detect model architecture from tensor names
73pub(crate) fn detect_architecture(tensors: &safetensors::SafeTensors<'_>) -> Architecture {
74    let names: Vec<String> = tensors.names().iter().map(std::string::ToString::to_string).collect();
75
76    // RoBERTa / CodeBERT: look for "roberta." or "bert." prefix with encoder layers
77    let is_roberta = names
78        .iter()
79        .any(|n| n.starts_with("roberta.") || (n.starts_with("bert.") && n.contains(".encoder.")));
80    if is_roberta {
81        return Architecture::RoBERTa;
82    }
83
84    // Qwen2 uses "model.layers.X.self_attn.q_proj" with biases
85    // LLaMA uses same pattern but no biases
86    // Check for bias tensors to distinguish
87    let has_attn_bias = names.iter().any(|n: &String| {
88        n.contains("self_attn.q_proj.bias") || n.contains("self_attn.k_proj.bias")
89    });
90
91    if has_attn_bias {
92        // Qwen2 has attention biases
93        return Architecture::Qwen2;
94    }
95
96    // Check for Mistral-specific patterns
97    // Mistral uses sliding window attention config but same weight names as LLaMA
98    // We default to LLaMA if no Qwen2 bias markers found
99
100    Architecture::Llama
101}