1use std::collections::HashSet;
19use std::path::Path;
20
21use anyhow::Result;
22use rlx_runtime::{CompileCache, Device, PrecisionPolicy};
23
24use rlx_core::config::NomicBertConfig;
25use rlx_core::flow_util::graph_from_built;
26use rlx_core::weight_map::WeightMap;
27use rlx_nomic::flow::build_nomic_built;
28
29pub struct RlxNomicModel {
31 cache: CompileCache,
32 params_loaded: HashSet<u64>,
33 config: NomicBertConfig,
34 weights_path: String,
35 current_key: u64,
36 #[allow(dead_code)]
37 device: Device,
38 #[allow(dead_code)]
39 policy: Option<PrecisionPolicy>,
40}
41
42impl RlxNomicModel {
43 fn key(batch: usize, seq: usize) -> u64 {
44 ((batch as u64) << 32) | (seq as u64)
45 }
46
47 pub fn load_sized_on(
48 config_path: &Path,
49 weights_path: &str,
50 batch: usize,
51 seq: usize,
52 device: Device,
53 ) -> Result<Self> {
54 Self::load_sized_with_policy(config_path, weights_path, batch, seq, device, None)
55 }
56
57 pub fn load_sized_with_policy(
58 config_path: &Path,
59 weights_path: &str,
60 batch: usize,
61 seq: usize,
62 device: Device,
63 policy: Option<PrecisionPolicy>,
64 ) -> Result<Self> {
65 let config = NomicBertConfig::from_file(config_path)?;
66 let mut model = Self {
67 cache: CompileCache::with_policy(device, 16, policy.clone()),
68 params_loaded: HashSet::new(),
69 config,
70 weights_path: weights_path.to_string(),
71 current_key: Self::key(batch, seq),
72 device,
73 policy,
74 };
75 model.recompile(batch, seq)?;
76 Ok(model)
77 }
78
79 pub fn load_sized(
80 config_path: &Path,
81 weights_path: &str,
82 batch: usize,
83 seq: usize,
84 ) -> Result<Self> {
85 Self::load_sized_on(config_path, weights_path, batch, seq, Device::Cpu)
86 }
87
88 pub fn load(config_path: &Path, weights_path: &str) -> Result<Self> {
89 Self::load_sized(config_path, weights_path, 1, 1)
90 }
91
92 pub fn recompile(&mut self, batch: usize, seq: usize) -> Result<()> {
93 let key = Self::key(batch, seq);
94 self.current_key = key;
95 if self.cache.contains(key) && self.params_loaded.contains(&key) {
96 return Ok(());
97 }
98 let mut wm = WeightMap::from_file(&self.weights_path)?;
99 let (graph, params) =
100 graph_from_built(build_nomic_built(&self.config, &mut wm, batch, seq)?)?;
101 let compiled = self.cache.get_or_compile(key, || graph);
102 for (name, data) in ¶ms {
103 compiled.set_param(name, data);
104 }
105 self.params_loaded.insert(key);
106 Ok(())
107 }
108
109 pub fn forward(
110 &mut self,
111 input_ids: &[f32],
112 attention_mask: &[f32],
113 token_type_ids: &[f32],
114 ) -> Vec<f32> {
115 let key = self.current_key;
116 let compiled = self.cache.get_or_compile(key, || {
117 unreachable!("forward called without prior recompile/load_sized")
118 });
119 let outputs = compiled.run(&[
120 ("input_ids", input_ids),
121 ("attention_mask", attention_mask),
122 ("token_type_ids", token_type_ids),
123 ]);
124 outputs.into_iter().next().unwrap_or_default()
125 }
126
127 pub fn forward_slots(
128 &mut self,
129 input_ids: &[f32],
130 attention_mask: &[f32],
131 token_type_ids: &[f32],
132 ) -> (*const f32, usize) {
133 let key = self.current_key;
134 let compiled = self.cache.get_or_compile(key, || unreachable!());
135 let slots = compiled.run_slots(&[input_ids, attention_mask, token_type_ids]);
136 if slots.is_empty() {
137 return (std::ptr::null(), 0);
138 }
139 let (off, len) = slots[0];
140 unsafe {
141 let ptr = compiled.arena_ptr().add(off) as *const f32;
142 (ptr, len)
143 }
144 }
145
146 pub fn forward_pipelined(
147 &mut self,
148 input_sets: &[(Vec<f32>, Vec<f32>, Vec<f32>)],
149 ) -> Vec<Vec<Vec<f32>>> {
150 let key = self.current_key;
151 let compiled = self.cache.get_or_compile(key, || unreachable!());
152 let prepared: Vec<Vec<(&str, &[f32])>> = input_sets
153 .iter()
154 .map(|(ids, mask, tt)| {
155 vec![
156 ("input_ids", ids.as_slice()),
157 ("attention_mask", mask.as_slice()),
158 ("token_type_ids", tt.as_slice()),
159 ]
160 })
161 .collect();
162 compiled.run_pipelined(&prepared)
163 }
164
165 pub fn hidden_size(&self) -> usize {
166 self.config.hidden_size
167 }
168}