use crate::hf_pipeline::error::{FetchError, Result};
use ndarray::Array2;
use std::path::Path;
use super::{MemoryEstimate, TeacherModel};
pub struct SafeTensorsTeacher {
weights: std::collections::HashMap<String, Array2<f32>>,
tensor_names: Vec<String>,
num_layers: usize,
hidden_size: usize,
param_count: u64,
}
impl SafeTensorsTeacher {
pub fn load(path: &Path) -> Result<Self> {
use safetensors::SafeTensors;
let model_path = path.join("model.safetensors");
if !model_path.exists() {
return Err(FetchError::FileNotFound {
repo: path.display().to_string(),
file: "model.safetensors".into(),
});
}
let data = std::fs::read(&model_path)?;
let tensors = SafeTensors::deserialize(&data)
.map_err(|e| FetchError::SafeTensorsParseError { message: e.to_string() })?;
let tensor_names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
let mut param_count: u64 = 0;
for name in &tensor_names {
if let Ok(info) = tensors.tensor(name) {
let numel: u64 = info.shape().iter().map(|&x| x as u64).product();
param_count += numel;
}
}
let num_layers = detect_layer_count(&tensor_names);
let hidden_size = detect_hidden_size(&tensors, &tensor_names);
Ok(Self {
weights: std::collections::HashMap::new(), tensor_names,
num_layers,
hidden_size,
param_count,
})
}
#[must_use]
pub fn tensor_names(&self) -> &[String] {
&self.tensor_names
}
#[must_use]
pub fn weights(&self) -> &std::collections::HashMap<String, Array2<f32>> {
&self.weights
}
#[cfg(test)]
pub fn mock(num_layers: usize, hidden_size: usize) -> Self {
let param_count = (num_layers as u64) * (hidden_size as u64).pow(2) * 4;
Self {
weights: std::collections::HashMap::new(),
tensor_names: Vec::new(),
num_layers,
hidden_size,
param_count,
}
}
}
fn detect_layer_count(names: &[String]) -> usize {
use std::collections::HashSet;
let mut layer_indices: HashSet<usize> = HashSet::new();
for name in names {
if let Some(idx) = extract_layer_index(name) {
layer_indices.insert(idx);
}
}
if layer_indices.is_empty() {
12
} else {
layer_indices.len()
}
}
fn parse_leading_index(s: &str) -> Option<usize> {
let numeric_part = match s.find('.') {
Some(end) => &s[..end],
None => s,
};
numeric_part.parse::<usize>().ok()
}
fn extract_layer_index(name: &str) -> Option<usize> {
const PATTERNS: &[&str] = &[".layer.", ".layers.", ".h."];
PATTERNS.iter().find_map(|pattern| {
let pos = name.find(pattern)?;
let after_pattern = &name[pos + pattern.len()..];
parse_leading_index(after_pattern)
})
}
fn square_dim(
tensors: &safetensors::SafeTensors<'_>,
name: &str,
min_size: usize,
) -> Option<usize> {
let shape = tensors.tensor(name).ok()?.shape().to_vec();
if shape.len() == 2 && shape[0] == shape[1] && shape[0] >= min_size {
Some(shape[0])
} else {
None
}
}
fn detect_from_query_weights(
tensors: &safetensors::SafeTensors<'_>,
names: &[String],
) -> Option<usize> {
const QUERY_PATTERNS: &[&str] =
&[".query.weight", ".q_proj.weight", ".self_attn.q_proj.weight"];
names.iter().find_map(|name| {
let matches_pattern = QUERY_PATTERNS.iter().any(|p| name.ends_with(p));
matches_pattern.then(|| square_dim(tensors, name, 1)).flatten()
})
}
fn detect_from_square_weights(
tensors: &safetensors::SafeTensors<'_>,
names: &[String],
) -> Option<usize> {
names
.iter()
.filter(|name| name.contains("weight"))
.find_map(|name| square_dim(tensors, name, 256))
}
fn detect_hidden_size(tensors: &safetensors::SafeTensors<'_>, names: &[String]) -> usize {
detect_from_query_weights(tensors, names)
.or_else(|| detect_from_square_weights(tensors, names))
.unwrap_or(0)
}
impl TeacherModel for SafeTensorsTeacher {
fn forward(&self, input: &Array2<f32>) -> Result<Array2<f32>> {
Ok(input.clone())
}
fn hidden_states(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
Ok(vec![input.clone(); self.num_layers])
}
fn attention_weights(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
let (batch, _seq) = input.dim();
let attn = Array2::<f32>::ones((batch, batch));
Ok(vec![attn; self.num_layers])
}
fn estimate_memory(&self, batch_size: usize, seq_len: usize) -> MemoryEstimate {
MemoryEstimate::fp16(self.param_count, batch_size, seq_len, self.hidden_size)
}
fn param_count(&self) -> u64 {
self.param_count
}
fn num_layers(&self) -> usize {
self.num_layers
}
fn hidden_size(&self) -> usize {
self.hidden_size
}
}