Skip to main content

rlx_embed/
arch.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Architecture detection for embedding checkpoints.
17
18use anyhow::Result;
19use rlx_core::gguf_config::{EmbedGgufKind, embed_gguf_kind};
20use rlx_gguf::GgufFile;
21use std::path::Path;
22
23/// Detected embedding architecture from `config.json`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum Arch {
26    Bert,
27    NomicBert,
28    NomicVision,
29}
30
31/// Detect architecture from GGUF `general.architecture`.
32pub fn detect_arch_from_gguf(weights_path: &Path) -> Result<Arch> {
33    let raw = GgufFile::from_path(weights_path)?;
34    Ok(match embed_gguf_kind(&raw)? {
35        EmbedGgufKind::Bert => Arch::Bert,
36        EmbedGgufKind::NomicBert => Arch::NomicBert,
37    })
38}
39
40/// Detect architecture from config.json fields.
41pub fn detect_arch(config_path: &Path) -> Result<Arch> {
42    let data = std::fs::read_to_string(config_path)?;
43    let json: serde_json::Value = serde_json::from_str(&data)?;
44
45    if json.get("img_size").is_some() && json.get("patch_size").is_some() {
46        return Ok(Arch::NomicVision);
47    }
48    if json.get("rotary_emb_base").is_some() || json.get("rotary_emb_fraction").is_some() {
49        return Ok(Arch::NomicBert);
50    }
51    Ok(Arch::Bert)
52}
53
54/// Default pooling heuristic from HuggingFace repo id.
55pub fn default_pooling(repo_id: &str) -> super::Pooling {
56    let lower = repo_id.to_lowercase();
57    if lower.contains("bge") || lower.contains("nomic") {
58        super::Pooling::Cls
59    } else {
60        super::Pooling::Mean
61    }
62}