use anyhow::Result;
use rlx_core::gguf_config::{EmbedGgufKind, embed_gguf_kind};
use rlx_gguf::GgufFile;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Arch {
Bert,
NomicBert,
NomicVision,
}
pub fn detect_arch_from_gguf(weights_path: &Path) -> Result<Arch> {
let raw = GgufFile::from_path(weights_path)?;
Ok(match embed_gguf_kind(&raw)? {
EmbedGgufKind::Bert => Arch::Bert,
EmbedGgufKind::NomicBert => Arch::NomicBert,
})
}
pub fn detect_arch(config_path: &Path) -> Result<Arch> {
let data = std::fs::read_to_string(config_path)?;
let json: serde_json::Value = serde_json::from_str(&data)?;
if json.get("img_size").is_some() && json.get("patch_size").is_some() {
return Ok(Arch::NomicVision);
}
if json.get("rotary_emb_base").is_some() || json.get("rotary_emb_fraction").is_some() {
return Ok(Arch::NomicBert);
}
Ok(Arch::Bert)
}
pub fn default_pooling(repo_id: &str) -> super::Pooling {
let lower = repo_id.to_lowercase();
if lower.contains("bge") || lower.contains("nomic") {
super::Pooling::Cls
} else {
super::Pooling::Mean
}
}