use super::Architecture;
use crate::error::Result;
pub(crate) fn find_safetensors_files(path: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
if path.is_file() {
return Ok(if path.extension().is_some_and(|e| e == "safetensors") {
vec![path.to_path_buf()]
} else {
vec![]
});
}
if !path.is_dir() {
return Ok(vec![]);
}
let single = path.join("model.safetensors");
if single.exists() {
return Ok(vec![single]);
}
let mut files: Vec<std::path::PathBuf> = std::fs::read_dir(path)
.into_iter()
.flatten()
.flatten()
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|e| e == "safetensors"))
.collect();
files.sort();
if let Some(latest) = find_latest_checkpoint(&files) {
eprintln!(" Resuming from checkpoint: {}", latest.display());
return Ok(vec![latest]);
}
Ok(files)
}
fn find_latest_checkpoint(files: &[std::path::PathBuf]) -> Option<std::path::PathBuf> {
files
.iter()
.filter_map(|f| parse_checkpoint_step_from_path(f).map(|step| (step, f)))
.max_by_key(|(step, _)| *step)
.map(|(_, p)| p.clone())
}
pub(crate) fn parse_checkpoint_step_from_path(path: &std::path::Path) -> Option<usize> {
let filename = path.file_name()?.to_str()?;
filename.strip_prefix("model-step-")?.strip_suffix(".safetensors")?.parse().ok()
}
pub(crate) fn detect_architecture(tensors: &safetensors::SafeTensors<'_>) -> Architecture {
let names: Vec<String> = tensors.names().iter().map(std::string::ToString::to_string).collect();
let is_roberta = names
.iter()
.any(|n| n.starts_with("roberta.") || (n.starts_with("bert.") && n.contains(".encoder.")));
if is_roberta {
return Architecture::RoBERTa;
}
let has_attn_bias = names.iter().any(|n: &String| {
n.contains("self_attn.q_proj.bias") || n.contains("self_attn.k_proj.bias")
});
if has_attn_bias {
return Architecture::Qwen2;
}
Architecture::Llama
}