fn strip_quant_suffix(stem: &str) -> &str {
stem.trim_end_matches("-q4k")
.trim_end_matches("-q4_k_m")
.trim_end_matches("-q6k")
.trim_end_matches("-q6_k")
.trim_end_matches("-q5k")
.trim_end_matches("-q5_k_m")
.trim_end_matches("-q8_0")
.trim_end_matches("-f16")
.trim_end_matches("-f32")
}
fn find_sharded_safetensors(dir: &Path) -> Option<std::path::PathBuf> {
let index = dir.join("model.safetensors.index.json");
if !index.exists() {
return None;
}
let mut shards: Vec<_> = std::fs::read_dir(dir)
.ok()?
.flatten()
.filter_map(|entry| {
let name = entry.file_name();
let name_str = name.to_string_lossy().to_string();
if name_str.ends_with(".safetensors") && name_str != "model.safetensors" {
Some(entry.path())
} else {
None
}
})
.collect();
shards.sort();
shards.into_iter().next()
}
fn discover_sibling_subdir(parent: &Path, base_name: &str) -> Option<std::path::PathBuf> {
let subdir = parent.join(base_name);
if !subdir.is_dir() {
return None;
}
let single = subdir.join("model.safetensors");
if single.exists() {
return Some(single);
}
find_sharded_safetensors(&subdir)
}
fn find_safetensors_in_snapshot(snap_path: &Path) -> Option<std::path::PathBuf> {
let single = snap_path.join("model.safetensors");
if single.exists() {
return Some(single);
}
let mut shards: Vec<_> = std::fs::read_dir(snap_path)
.ok()?
.flatten()
.filter_map(|f| {
let fname = f.file_name();
let name = fname.to_string_lossy();
if name.ends_with(".safetensors") && name != "model.safetensors" {
Some(f.path())
} else {
None
}
})
.collect();
shards.sort();
shards.into_iter().next()
}
fn hf_cache_dir_matches(dir_name: &str, base_lower: &str) -> bool {
if !dir_name.starts_with("models--") {
return false;
}
let model_part = dir_name
.trim_start_matches("models--")
.replace("--", "/")
.to_lowercase();
model_part.contains(base_lower)
}
fn search_hf_model_snapshots(model_dir: &Path) -> Option<std::path::PathBuf> {
let snapshots = model_dir.join("snapshots");
for snap in std::fs::read_dir(&snapshots).ok()?.flatten() {
if let Some(found) = find_safetensors_in_snapshot(&snap.path()) {
return Some(found);
}
}
None
}
fn discover_hf_cache(base_name: &str) -> Option<std::path::PathBuf> {
let hf_cache = dirs::home_dir()?.join(".cache/huggingface/hub");
if !hf_cache.is_dir() {
return None;
}
let base_lower = base_name.to_lowercase();
for entry in std::fs::read_dir(&hf_cache).ok()?.flatten() {
let dir_name = entry.file_name();
if !hf_cache_dir_matches(&dir_name.to_string_lossy(), &base_lower) {
continue;
}
if let Some(found) = search_hf_model_snapshots(&entry.path()) {
return Some(found);
}
}
None
}
fn find_safetensors_in_repo(repo_path: &Path) -> Option<std::path::PathBuf> {
find_sharded_safetensors(repo_path).or_else(|| {
let single = repo_path.join("model.safetensors");
single.exists().then_some(single)
})
}
fn discover_apr_cache(base_name: &str) -> Option<std::path::PathBuf> {
let apr_cache = dirs::home_dir()?.join(".apr").join("cache").join("hf");
if !apr_cache.is_dir() {
return None;
}
let base_lower = base_name.to_lowercase();
for org_entry in std::fs::read_dir(&apr_cache).ok()?.flatten() {
if !org_entry.path().is_dir() {
continue;
}
if let Some(found) = search_org_for_model(&org_entry.path(), &base_lower) {
return Some(found);
}
}
None
}
fn search_org_for_model(org_path: &Path, base_lower: &str) -> Option<std::path::PathBuf> {
for repo_entry in std::fs::read_dir(org_path).ok()?.flatten() {
let repo_name = repo_entry.file_name().to_string_lossy().to_lowercase();
if repo_name.contains(base_lower) {
if let Some(found) = find_safetensors_in_repo(&repo_entry.path()) {
return Some(found);
}
}
}
None
}
fn auto_discover_safetensors(gguf_path: &Path) -> Option<std::path::PathBuf> {
let parent = gguf_path.parent()?;
let stem = gguf_path.file_stem()?.to_str()?;
let sibling = parent.join(format!("{stem}.safetensors"));
if sibling.exists() {
return Some(sibling);
}
let base_name = strip_quant_suffix(stem);
if let Some(found) = discover_sibling_subdir(parent, base_name) {
return Some(found);
}
if let Some(found) = discover_hf_cache(base_name) {
return Some(found);
}
discover_apr_cache(base_name)
}
fn compute_argmax(logits: &[f32]) -> Option<u32> {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as u32)
}
fn compare_argmax_results(
gguf_argmax: Option<u32>,
st_argmax: Option<u32>,
duration: Duration,
) -> GateResult {
match (gguf_argmax, st_argmax) {
(Some(gguf_token), Some(st_token)) if gguf_token == st_token => GateResult::passed(
"format_parity",
&format!(
"GGUF argmax={} == SafeTensors argmax={} (Cross-format parity VERIFIED)",
gguf_token, st_token
),
Some(gguf_token as f64),
Some(st_token as f64),
duration,
),
(Some(gguf_token), Some(st_token)) => GateResult::failed(
"format_parity",
&format!(
"GGUF argmax={} != SafeTensors argmax={} (Cross-format parity BROKEN)",
gguf_token, st_token
),
Some(gguf_token as f64),
Some(st_token as f64),
duration,
),
_ => GateResult::failed(
"format_parity",
"Failed to get argmax from one or both formats",
None,
None,
duration,
),
}
}
fn resolve_safetensors_path(
gguf_path: &Path,
config: &QaConfig,
elapsed: Duration,
) -> std::result::Result<std::path::PathBuf, GateResult> {
if let Some(p) = &config.safetensors_path {
return Ok(p.clone());
}
match auto_discover_safetensors(gguf_path) {
Some(p) => {
if !config.json {
println!(
" {} Auto-discovered SafeTensors: {}",
"INFO".cyan(),
p.display()
);
}
Ok(p)
}
None => Err(GateResult::failed(
"format_parity",
"No SafeTensors found. Provide --safetensors-path or download: \
huggingface-cli download <model> --include '*.safetensors'",
None,
None,
elapsed,
)),
}
}
fn run_format_parity_gate(path: &Path, config: &QaConfig) -> Result<GateResult> {
let start = Instant::now();
if !config.json && config.verbose {
println!("{}", "Running cross-format parity test...".yellow());
}
#[cfg(feature = "inference")]
{
use realizar::format::{detect_format, ModelFormat};
use realizar::gguf::{GGUFModel, MappedGGUFModel, OwnedQuantizedModel};
let safetensors_path = match resolve_safetensors_path(path, config, start.elapsed()) {
Ok(p) => p,
Err(gate_result) => return Ok(gate_result),
};
let gguf_bytes = std::fs::read(path)
.map_err(|e| CliError::ValidationFailed(format!("Failed to read GGUF: {e}")))?;
let gguf_format = detect_format(&gguf_bytes[..8.min(gguf_bytes.len())]).map_err(|e| {
CliError::ValidationFailed(format!("Failed to detect GGUF format: {e}"))
})?;
if gguf_format != ModelFormat::Gguf {
return Ok(GateResult::failed(
"format_parity",
"Primary model must be GGUF format for cross-format parity test",
None,
None,
start.elapsed(),
));
}
if !safetensors_path.exists() {
return Ok(GateResult::failed(
"format_parity",
&format!(
"SafeTensors not found: {}. Download with: huggingface-cli download <model> --include '*.safetensors'",
safetensors_path.display()
),
None,
None,
start.elapsed(),
));
}
let gguf = GGUFModel::from_bytes(&gguf_bytes)
.map_err(|e| CliError::ValidationFailed(format!("Failed to parse GGUF: {e}")))?;
let prompt = "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\n";
let bos = aprender::demo::SpecialTokens::qwen2().bos_id;
let prompt_tokens: Vec<u32> = gguf.encode(prompt).unwrap_or_else(|| vec![bos, 9707]);
let gguf_logits = {
let mapped = MappedGGUFModel::from_path(path)
.map_err(|e| CliError::ValidationFailed(format!("GGUF map failed: {e}")))?;
let model = OwnedQuantizedModel::from_mapped(&mapped)
.map_err(|e| CliError::ValidationFailed(format!("GGUF model failed: {e}")))?;
model
.forward(&prompt_tokens)
.map_err(|e| CliError::ValidationFailed(format!("GGUF forward failed: {e}")))?
};
let st_logits = match run_safetensors_forward(&safetensors_path, &prompt_tokens) {
Ok(logits) => logits,
Err(ForwardError::ConversionFailed(path)) => {
return Ok(GateResult::failed(
"format_parity",
&format!("SafeTensors conversion failed: {}", path),
None,
None,
start.elapsed(),
));
}
Err(ForwardError::Cli(e)) => return Err(e),
};
let duration = start.elapsed();
Ok(compare_argmax_results(
compute_argmax(&gguf_logits),
compute_argmax(&st_logits),
duration,
))
}
#[cfg(not(feature = "inference"))]
{
let _ = (path, config);
Ok(GateResult::skipped(
"format_parity",
"Requires 'inference' feature",
))
}
}
#[cfg(feature = "inference")]
enum ForwardError {
ConversionFailed(String),
Cli(CliError),
}
#[cfg(feature = "inference")]
fn run_safetensors_forward(
safetensors_path: &Path,
prompt_tokens: &[u32],
) -> std::result::Result<Vec<f32>, ForwardError> {
use realizar::safetensors_infer::SafetensorsToAprConverter;
use realizar::{SafetensorsConfig, ShardedSafeTensorsModel};
let parent_dir = safetensors_path.parent().unwrap_or(Path::new("."));
let index_path = parent_dir.join("model.safetensors.index.json");
let transformer = if index_path.exists() {
let sharded = ShardedSafeTensorsModel::load_from_index(&index_path)
.map_err(|e| ForwardError::Cli(CliError::ValidationFailed(format!("Sharded load failed: {e}"))))?;
let config = SafetensorsConfig::load_from_sibling(safetensors_path)
.ok_or_else(|| ForwardError::Cli(CliError::ValidationFailed("config.json not found for sharded model".to_string())))?;
SafetensorsToAprConverter::convert_sharded(&sharded, &config)
} else {
SafetensorsToAprConverter::convert(safetensors_path)
};
let model = match transformer {
Ok(t) => t,
Err(e) => {
let msg = format!("{e}");
if msg.contains("Tensor not found") || msg.contains("not supported") {
return Err(ForwardError::ConversionFailed(safetensors_path.display().to_string()));
}
return Err(ForwardError::Cli(CliError::ValidationFailed(format!("SafeTensors convert failed: {e}"))));
}
};
model.forward(prompt_tokens)
.map_err(|e| ForwardError::Cli(CliError::ValidationFailed(format!("SafeTensors forward failed: {e}"))))
}
fn check_ollama_available() -> bool {
std::process::Command::new("curl")
.args([
"-s",
"-o",
"/dev/null",
"-w",
"%{http_code}",
"http://localhost:11434/api/tags",
])
.output()
.map(|o| String::from_utf8_lossy(&o.stdout).trim() == "200")
.unwrap_or(false)
}
fn detect_size_from_filename(filename_lower: &str) -> Option<&'static str> {
const SIZE_PATTERNS: &[(&str, &str)] = &[
("0.5b", "0.5b"),
("0_5b", "0.5b"),
("1.5b", "1.5b"),
("1_5b", "1.5b"),
("3b", "3b"),
("7b", "7b"),
("14b", "14b"),
("32b", "32b"),
];
SIZE_PATTERNS.iter().find_map(|(pattern, label)| {
if let Some(pos) = filename_lower.find(pattern) {
let end = pos + pattern.len();
let has_boundary = end >= filename_lower.len()
|| !filename_lower.as_bytes()[end].is_ascii_alphanumeric();
if has_boundary {
return Some(*label);
}
}
None
})
}
fn estimate_size_from_file(path: &Path) -> &'static str {
match std::fs::metadata(path).map(|m| m.len()).unwrap_or(0) {
0..=800_000_000 => "0.5b",
800_000_001..=2_000_000_000 => "1.5b",
2_000_000_001..=4_000_000_000 => "3b",
_ => "7b",
}
}
include!("ollama.rs");
include!("forward_error_tests.rs");