Skip to main content

rlx_embed/
runtime.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Unified high-level embedding loader (auto-detect arch, lazy recompile).
17
18use std::collections::HashMap;
19use std::path::{Path, PathBuf};
20
21use anyhow::{Context, Result, bail};
22use rlx_core::gguf_config::{EmbedGgufKind, embed_gguf_kind};
23use rlx_core::validate_standard_device;
24use rlx_core::weights::pick_default;
25use rlx_gguf::GgufFile;
26use rlx_runtime::{CompiledGraph, Device};
27
28use rlx_core::weight_map::WeightMap;
29
30#[cfg(feature = "hf-download")]
31use super::arch::default_pooling;
32use super::arch::{Arch, detect_arch, detect_arch_from_gguf};
33use super::pooling::Pooling;
34
35/// High-level embedding model — auto-detects BERT / NomicBERT / NomicVision.
36pub struct RlxEmbed {
37    compiled: CompiledGraph,
38    arch: Arch,
39    hidden_size: usize,
40    device: Device,
41    #[allow(dead_code)]
42    pooling: Pooling,
43    compiled_bs: (usize, usize),
44    config_path: Option<PathBuf>,
45    weights_path: PathBuf,
46}
47
48impl RlxEmbed {
49    /// Load from a local directory (`config.json` + `model.safetensors`) on CPU.
50    pub fn from_dir(dir: &Path, pooling: Pooling) -> Result<Self> {
51        Self::from_dir_on(dir, pooling, Device::Cpu)
52    }
53
54    /// Load from a local directory on the given device.
55    pub fn from_dir_on(dir: &Path, pooling: Pooling, device: Device) -> Result<Self> {
56        validate_standard_device("embed", device)?;
57        let weights_path = pick_default(dir)?;
58        let config_path = resolve_embed_config_path(dir, &weights_path)?;
59        let arch = resolve_embed_arch(config_path.as_deref(), &weights_path)?;
60        let (hidden_size, compiled, _) =
61            compile_model(arch, config_path.as_deref(), &weights_path, 1, 1, device)?;
62
63        Ok(Self {
64            compiled,
65            arch,
66            hidden_size,
67            device,
68            pooling,
69            compiled_bs: (1, 1),
70            config_path,
71            weights_path,
72        })
73    }
74
75    /// Load from a `.gguf` file or a directory containing one (optional sidecar `config.json`).
76    pub fn from_weights(path: &Path, pooling: Pooling) -> Result<Self> {
77        Self::from_weights_on(path, pooling, Device::Cpu)
78    }
79
80    /// Load weights path on the given device.
81    pub fn from_weights_on(path: &Path, pooling: Pooling, device: Device) -> Result<Self> {
82        validate_standard_device("embed", device)?;
83        let weights_path = pick_default(path)?;
84        let config_path = path
85            .parent()
86            .map(|p| p.join("config.json"))
87            .filter(|p| p.is_file());
88        let arch = resolve_embed_arch(config_path.as_deref(), &weights_path)?;
89        let (hidden_size, compiled, _) =
90            compile_model(arch, config_path.as_deref(), &weights_path, 1, 1, device)?;
91
92        Ok(Self {
93            compiled,
94            arch,
95            hidden_size,
96            device,
97            pooling,
98            compiled_bs: (1, 1),
99            config_path,
100            weights_path,
101        })
102    }
103
104    /// Execution device for this instance.
105    pub fn device(&self) -> Device {
106        self.device
107    }
108
109    /// Load by HuggingFace repo id (downloads when `hf-download` feature enabled).
110    #[cfg(feature = "hf-download")]
111    pub fn from_pretrained(repo_id: &str) -> Result<Self> {
112        Self::from_pretrained_on(repo_id, Device::Cpu)
113    }
114
115    /// Load by HuggingFace repo id on the given device.
116    #[cfg(feature = "hf-download")]
117    pub fn from_pretrained_on(repo_id: &str, device: Device) -> Result<Self> {
118        validate_standard_device("embed", device)?;
119        let repo = hf_hub::api::sync::ApiBuilder::new()
120            .with_progress(true)
121            .build()?
122            .model(repo_id.to_string());
123        let config_file = repo.get("config.json")?;
124        let weights_file = repo.get("model.safetensors")?;
125
126        let arch = detect_arch(&config_file)?;
127        let pooling = default_pooling(repo_id);
128        let (hidden_size, compiled, _) =
129            compile_model(arch, Some(&config_file), &weights_file, 1, 1, device)?;
130
131        Ok(Self {
132            compiled,
133            arch,
134            hidden_size,
135            device,
136            pooling,
137            compiled_bs: (1, 1),
138            config_path: Some(config_file),
139            weights_path: weights_file,
140        })
141    }
142
143    pub fn dim(&self) -> usize {
144        self.hidden_size
145    }
146
147    pub fn arch(&self) -> Arch {
148        self.arch
149    }
150
151    /// Forward on pre-tokenized inputs; returns flattened hidden states.
152    pub fn forward(
153        &mut self,
154        inputs: &[(&str, &[f32])],
155        batch: usize,
156        seq: usize,
157    ) -> Result<Vec<f32>> {
158        self.ensure_compiled(batch, seq)?;
159        let outputs = self.compiled.run(inputs);
160        Ok(outputs.into_iter().next().unwrap_or_default())
161    }
162
163    fn ensure_compiled(&mut self, batch: usize, seq: usize) -> Result<()> {
164        if self.compiled_bs == (batch, seq) {
165            return Ok(());
166        }
167        let (_, compiled, _) = compile_model(
168            self.arch,
169            self.config_path.as_deref(),
170            &self.weights_path,
171            batch,
172            seq,
173            self.device,
174        )?;
175        self.compiled = compiled;
176        self.compiled_bs = (batch, seq);
177        Ok(())
178    }
179}
180
181fn resolve_embed_config_path(dir: &Path, weights: &Path) -> Result<Option<PathBuf>> {
182    let sidecar = dir.join("config.json");
183    if sidecar.is_file() {
184        return Ok(Some(sidecar));
185    }
186    if weights.extension().and_then(|s| s.to_str()) == Some("gguf") {
187        return Ok(None);
188    }
189    bail!("{dir:?}: missing config.json (required for safetensors checkpoints)");
190}
191
192fn resolve_embed_arch(config_path: Option<&Path>, weights_path: &Path) -> Result<Arch> {
193    if let Some(cfg) = config_path {
194        return detect_arch(cfg);
195    }
196    let file = pick_default(weights_path)?;
197    if file.extension().and_then(|s| s.to_str()) == Some("gguf") {
198        return detect_arch_from_gguf(&file);
199    }
200    bail!("cannot detect embedding arch without config.json or a .gguf file");
201}
202
203/// Compile an embedding graph for the given batch/seq on `device`.
204pub fn compile_model(
205    arch: Arch,
206    config_path: Option<&Path>,
207    weights_path: &Path,
208    batch: usize,
209    seq: usize,
210    device: Device,
211) -> Result<(usize, CompiledGraph, HashMap<String, Vec<f32>>)> {
212    validate_standard_device("embed", device)?;
213    let file = pick_default(weights_path)?;
214    if file.extension().and_then(|s| s.to_str()) == Some("gguf") {
215        rlx_core::gguf_validate_arch(&file, rlx_core::EMBED_GGUF_ARCHES)?;
216    }
217    let mut wm = WeightMap::from_resolved_path(weights_path)?;
218
219    let (built, hidden_size) = match arch {
220        Arch::Bert => {
221            let cfg = load_bert_config(config_path, weights_path)?;
222            let hs = cfg.hidden_size;
223            let built = rlx_bert::flow::build_bert_built(&cfg, &mut wm, batch, seq)?;
224            (built, hs)
225        }
226        Arch::NomicBert => {
227            let cfg = load_nomic_config(config_path, weights_path)?;
228            let hs = cfg.hidden_size;
229            let built = rlx_nomic::flow::build_nomic_built(&cfg, &mut wm, batch, seq)?;
230            (built, hs)
231        }
232        Arch::NomicVision => {
233            let cfg_path = config_path.context("NomicVision requires config.json")?;
234            let cfg = rlx_core::config::NomicVisionConfig::from_file(cfg_path)?;
235            let hs = cfg.hidden_size;
236            let built = rlx_vision::flow::build_nomic_vision_built(&cfg, &mut wm, batch)?;
237            (built.model, hs)
238        }
239    };
240
241    let params = built.params().clone();
242    let compiled = rlx_core::flow_util::compile_built(built, device)?;
243    Ok((hidden_size, compiled, params))
244}
245
246fn load_bert_config(
247    config_path: Option<&Path>,
248    weights_path: &Path,
249) -> Result<rlx_core::config::BertConfig> {
250    if let Some(p) = config_path {
251        return rlx_core::config::BertConfig::from_file(p);
252    }
253    let raw = GgufFile::from_path(weights_path)?;
254    if !matches!(embed_gguf_kind(&raw)?, EmbedGgufKind::Bert) {
255        bail!("weights are not a BERT-family GGUF; use NomicBERT config or checkpoint");
256    }
257    rlx_core::config::BertConfig::from_gguf(&raw)
258}
259
260fn load_nomic_config(
261    config_path: Option<&Path>,
262    weights_path: &Path,
263) -> Result<rlx_core::config::NomicBertConfig> {
264    if let Some(p) = config_path {
265        return rlx_core::config::NomicBertConfig::from_file(p);
266    }
267    let raw = GgufFile::from_path(weights_path)?;
268    if !matches!(embed_gguf_kind(&raw)?, EmbedGgufKind::NomicBert) {
269        bail!("weights are not a nomic-bert GGUF; use BERT config or checkpoint");
270    }
271    rlx_core::config::NomicBertConfig::from_gguf(&raw)
272}
273
274/// Compile on CPU (convenience for tests and default [`RlxEmbed::from_dir`]).
275pub fn compile_model_cpu(
276    arch: Arch,
277    config_path: Option<&Path>,
278    weights_path: &Path,
279    batch: usize,
280    seq: usize,
281) -> Result<(usize, CompiledGraph, HashMap<String, Vec<f32>>)> {
282    compile_model(arch, config_path, weights_path, batch, seq, Device::Cpu)
283}