1use std::collections::HashMap;
19use std::path::{Path, PathBuf};
20
21use anyhow::{Context, Result, bail};
22use rlx_core::gguf_config::{EmbedGgufKind, embed_gguf_kind};
23use rlx_core::validate_standard_device;
24use rlx_core::weights::pick_default;
25use rlx_gguf::GgufFile;
26use rlx_runtime::{CompiledGraph, Device};
27
28use rlx_core::weight_map::WeightMap;
29
30#[cfg(feature = "hf-download")]
31use super::arch::default_pooling;
32use super::arch::{Arch, detect_arch, detect_arch_from_gguf};
33use super::pooling::Pooling;
34
35pub struct RlxEmbed {
37 compiled: CompiledGraph,
38 arch: Arch,
39 hidden_size: usize,
40 device: Device,
41 #[allow(dead_code)]
42 pooling: Pooling,
43 compiled_bs: (usize, usize),
44 config_path: Option<PathBuf>,
45 weights_path: PathBuf,
46}
47
48impl RlxEmbed {
49 pub fn from_dir(dir: &Path, pooling: Pooling) -> Result<Self> {
51 Self::from_dir_on(dir, pooling, Device::Cpu)
52 }
53
54 pub fn from_dir_on(dir: &Path, pooling: Pooling, device: Device) -> Result<Self> {
56 validate_standard_device("embed", device)?;
57 let weights_path = pick_default(dir)?;
58 let config_path = resolve_embed_config_path(dir, &weights_path)?;
59 let arch = resolve_embed_arch(config_path.as_deref(), &weights_path)?;
60 let (hidden_size, compiled, _) =
61 compile_model(arch, config_path.as_deref(), &weights_path, 1, 1, device)?;
62
63 Ok(Self {
64 compiled,
65 arch,
66 hidden_size,
67 device,
68 pooling,
69 compiled_bs: (1, 1),
70 config_path,
71 weights_path,
72 })
73 }
74
75 pub fn from_weights(path: &Path, pooling: Pooling) -> Result<Self> {
77 Self::from_weights_on(path, pooling, Device::Cpu)
78 }
79
80 pub fn from_weights_on(path: &Path, pooling: Pooling, device: Device) -> Result<Self> {
82 validate_standard_device("embed", device)?;
83 let weights_path = pick_default(path)?;
84 let config_path = path
85 .parent()
86 .map(|p| p.join("config.json"))
87 .filter(|p| p.is_file());
88 let arch = resolve_embed_arch(config_path.as_deref(), &weights_path)?;
89 let (hidden_size, compiled, _) =
90 compile_model(arch, config_path.as_deref(), &weights_path, 1, 1, device)?;
91
92 Ok(Self {
93 compiled,
94 arch,
95 hidden_size,
96 device,
97 pooling,
98 compiled_bs: (1, 1),
99 config_path,
100 weights_path,
101 })
102 }
103
104 pub fn device(&self) -> Device {
106 self.device
107 }
108
109 #[cfg(feature = "hf-download")]
111 pub fn from_pretrained(repo_id: &str) -> Result<Self> {
112 Self::from_pretrained_on(repo_id, Device::Cpu)
113 }
114
115 #[cfg(feature = "hf-download")]
117 pub fn from_pretrained_on(repo_id: &str, device: Device) -> Result<Self> {
118 validate_standard_device("embed", device)?;
119 let repo = hf_hub::api::sync::ApiBuilder::new()
120 .with_progress(true)
121 .build()?
122 .model(repo_id.to_string());
123 let config_file = repo.get("config.json")?;
124 let weights_file = repo.get("model.safetensors")?;
125
126 let arch = detect_arch(&config_file)?;
127 let pooling = default_pooling(repo_id);
128 let (hidden_size, compiled, _) =
129 compile_model(arch, Some(&config_file), &weights_file, 1, 1, device)?;
130
131 Ok(Self {
132 compiled,
133 arch,
134 hidden_size,
135 device,
136 pooling,
137 compiled_bs: (1, 1),
138 config_path: Some(config_file),
139 weights_path: weights_file,
140 })
141 }
142
143 pub fn dim(&self) -> usize {
144 self.hidden_size
145 }
146
147 pub fn arch(&self) -> Arch {
148 self.arch
149 }
150
151 pub fn forward(
153 &mut self,
154 inputs: &[(&str, &[f32])],
155 batch: usize,
156 seq: usize,
157 ) -> Result<Vec<f32>> {
158 self.ensure_compiled(batch, seq)?;
159 let outputs = self.compiled.run(inputs);
160 Ok(outputs.into_iter().next().unwrap_or_default())
161 }
162
163 fn ensure_compiled(&mut self, batch: usize, seq: usize) -> Result<()> {
164 if self.compiled_bs == (batch, seq) {
165 return Ok(());
166 }
167 let (_, compiled, _) = compile_model(
168 self.arch,
169 self.config_path.as_deref(),
170 &self.weights_path,
171 batch,
172 seq,
173 self.device,
174 )?;
175 self.compiled = compiled;
176 self.compiled_bs = (batch, seq);
177 Ok(())
178 }
179}
180
181fn resolve_embed_config_path(dir: &Path, weights: &Path) -> Result<Option<PathBuf>> {
182 let sidecar = dir.join("config.json");
183 if sidecar.is_file() {
184 return Ok(Some(sidecar));
185 }
186 if weights.extension().and_then(|s| s.to_str()) == Some("gguf") {
187 return Ok(None);
188 }
189 bail!("{dir:?}: missing config.json (required for safetensors checkpoints)");
190}
191
192fn resolve_embed_arch(config_path: Option<&Path>, weights_path: &Path) -> Result<Arch> {
193 if let Some(cfg) = config_path {
194 return detect_arch(cfg);
195 }
196 let file = pick_default(weights_path)?;
197 if file.extension().and_then(|s| s.to_str()) == Some("gguf") {
198 return detect_arch_from_gguf(&file);
199 }
200 bail!("cannot detect embedding arch without config.json or a .gguf file");
201}
202
203pub fn compile_model(
205 arch: Arch,
206 config_path: Option<&Path>,
207 weights_path: &Path,
208 batch: usize,
209 seq: usize,
210 device: Device,
211) -> Result<(usize, CompiledGraph, HashMap<String, Vec<f32>>)> {
212 validate_standard_device("embed", device)?;
213 let file = pick_default(weights_path)?;
214 if file.extension().and_then(|s| s.to_str()) == Some("gguf") {
215 rlx_core::gguf_validate_arch(&file, rlx_core::EMBED_GGUF_ARCHES)?;
216 }
217 let mut wm = WeightMap::from_resolved_path(weights_path)?;
218
219 let (built, hidden_size) = match arch {
220 Arch::Bert => {
221 let cfg = load_bert_config(config_path, weights_path)?;
222 let hs = cfg.hidden_size;
223 let built = rlx_bert::flow::build_bert_built(&cfg, &mut wm, batch, seq)?;
224 (built, hs)
225 }
226 Arch::NomicBert => {
227 let cfg = load_nomic_config(config_path, weights_path)?;
228 let hs = cfg.hidden_size;
229 let built = rlx_nomic::flow::build_nomic_built(&cfg, &mut wm, batch, seq)?;
230 (built, hs)
231 }
232 Arch::NomicVision => {
233 let cfg_path = config_path.context("NomicVision requires config.json")?;
234 let cfg = rlx_core::config::NomicVisionConfig::from_file(cfg_path)?;
235 let hs = cfg.hidden_size;
236 let built = rlx_vision::flow::build_nomic_vision_built(&cfg, &mut wm, batch)?;
237 (built.model, hs)
238 }
239 };
240
241 let params = built.params().clone();
242 let compiled = rlx_core::flow_util::compile_built(built, device)?;
243 Ok((hidden_size, compiled, params))
244}
245
246fn load_bert_config(
247 config_path: Option<&Path>,
248 weights_path: &Path,
249) -> Result<rlx_core::config::BertConfig> {
250 if let Some(p) = config_path {
251 return rlx_core::config::BertConfig::from_file(p);
252 }
253 let raw = GgufFile::from_path(weights_path)?;
254 if !matches!(embed_gguf_kind(&raw)?, EmbedGgufKind::Bert) {
255 bail!("weights are not a BERT-family GGUF; use NomicBERT config or checkpoint");
256 }
257 rlx_core::config::BertConfig::from_gguf(&raw)
258}
259
260fn load_nomic_config(
261 config_path: Option<&Path>,
262 weights_path: &Path,
263) -> Result<rlx_core::config::NomicBertConfig> {
264 if let Some(p) = config_path {
265 return rlx_core::config::NomicBertConfig::from_file(p);
266 }
267 let raw = GgufFile::from_path(weights_path)?;
268 if !matches!(embed_gguf_kind(&raw)?, EmbedGgufKind::NomicBert) {
269 bail!("weights are not a nomic-bert GGUF; use BERT config or checkpoint");
270 }
271 rlx_core::config::NomicBertConfig::from_gguf(&raw)
272}
273
274pub fn compile_model_cpu(
276 arch: Arch,
277 config_path: Option<&Path>,
278 weights_path: &Path,
279 batch: usize,
280 seq: usize,
281) -> Result<(usize, CompiledGraph, HashMap<String, Vec<f32>>)> {
282 compile_model(arch, config_path, weights_path, batch, seq, Device::Cpu)
283}