Skip to main content

rlx_embed/
nomic.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX-compiled NomicBERT encoder for text embeddings.
17
18use std::collections::HashSet;
19use std::path::Path;
20
21use anyhow::Result;
22use rlx_runtime::{CompileCache, Device, PrecisionPolicy};
23
24use rlx_core::config::NomicBertConfig;
25use rlx_core::flow_util::graph_from_built;
26use rlx_core::weight_map::WeightMap;
27use rlx_nomic::flow::build_nomic_built;
28
29/// RLX-compiled NomicBERT with shape-bucketed compile cache.
30pub struct RlxNomicModel {
31    cache: CompileCache,
32    params_loaded: HashSet<u64>,
33    config: NomicBertConfig,
34    weights_path: String,
35    current_key: u64,
36    #[allow(dead_code)]
37    device: Device,
38    #[allow(dead_code)]
39    policy: Option<PrecisionPolicy>,
40}
41
42impl RlxNomicModel {
43    fn key(batch: usize, seq: usize) -> u64 {
44        ((batch as u64) << 32) | (seq as u64)
45    }
46
47    pub fn load_sized_on(
48        config_path: &Path,
49        weights_path: &str,
50        batch: usize,
51        seq: usize,
52        device: Device,
53    ) -> Result<Self> {
54        Self::load_sized_with_policy(config_path, weights_path, batch, seq, device, None)
55    }
56
57    pub fn load_sized_with_policy(
58        config_path: &Path,
59        weights_path: &str,
60        batch: usize,
61        seq: usize,
62        device: Device,
63        policy: Option<PrecisionPolicy>,
64    ) -> Result<Self> {
65        let config = NomicBertConfig::from_file(config_path)?;
66        let mut model = Self {
67            cache: CompileCache::with_policy(device, 16, policy.clone()),
68            params_loaded: HashSet::new(),
69            config,
70            weights_path: weights_path.to_string(),
71            current_key: Self::key(batch, seq),
72            device,
73            policy,
74        };
75        model.recompile(batch, seq)?;
76        Ok(model)
77    }
78
79    pub fn load_sized(
80        config_path: &Path,
81        weights_path: &str,
82        batch: usize,
83        seq: usize,
84    ) -> Result<Self> {
85        Self::load_sized_on(config_path, weights_path, batch, seq, Device::Cpu)
86    }
87
88    pub fn load(config_path: &Path, weights_path: &str) -> Result<Self> {
89        Self::load_sized(config_path, weights_path, 1, 1)
90    }
91
92    pub fn recompile(&mut self, batch: usize, seq: usize) -> Result<()> {
93        let key = Self::key(batch, seq);
94        self.current_key = key;
95        if self.cache.contains(key) && self.params_loaded.contains(&key) {
96            return Ok(());
97        }
98        let mut wm = WeightMap::from_file(&self.weights_path)?;
99        let (graph, params) =
100            graph_from_built(build_nomic_built(&self.config, &mut wm, batch, seq)?)?;
101        let compiled = self.cache.get_or_compile(key, || graph);
102        for (name, data) in &params {
103            compiled.set_param(name, data);
104        }
105        self.params_loaded.insert(key);
106        Ok(())
107    }
108
109    pub fn forward(
110        &mut self,
111        input_ids: &[f32],
112        attention_mask: &[f32],
113        token_type_ids: &[f32],
114    ) -> Vec<f32> {
115        let key = self.current_key;
116        let compiled = self.cache.get_or_compile(key, || {
117            unreachable!("forward called without prior recompile/load_sized")
118        });
119        let outputs = compiled.run(&[
120            ("input_ids", input_ids),
121            ("attention_mask", attention_mask),
122            ("token_type_ids", token_type_ids),
123        ]);
124        outputs.into_iter().next().unwrap_or_default()
125    }
126
127    pub fn forward_slots(
128        &mut self,
129        input_ids: &[f32],
130        attention_mask: &[f32],
131        token_type_ids: &[f32],
132    ) -> (*const f32, usize) {
133        let key = self.current_key;
134        let compiled = self.cache.get_or_compile(key, || unreachable!());
135        let slots = compiled.run_slots(&[input_ids, attention_mask, token_type_ids]);
136        if slots.is_empty() {
137            return (std::ptr::null(), 0);
138        }
139        let (off, len) = slots[0];
140        unsafe {
141            let ptr = compiled.arena_ptr().add(off) as *const f32;
142            (ptr, len)
143        }
144    }
145
146    pub fn forward_pipelined(
147        &mut self,
148        input_sets: &[(Vec<f32>, Vec<f32>, Vec<f32>)],
149    ) -> Vec<Vec<Vec<f32>>> {
150        let key = self.current_key;
151        let compiled = self.cache.get_or_compile(key, || unreachable!());
152        let prepared: Vec<Vec<(&str, &[f32])>> = input_sets
153            .iter()
154            .map(|(ids, mask, tt)| {
155                vec![
156                    ("input_ids", ids.as_slice()),
157                    ("attention_mask", mask.as_slice()),
158                    ("token_type_ids", tt.as_slice()),
159                ]
160            })
161            .collect();
162        compiled.run_pipelined(&prepared)
163    }
164
165    pub fn hidden_size(&self) -> usize {
166        self.config.hidden_size
167    }
168}