Skip to main content

rlx_embed/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX-backed text and image embedding models.
17//!
18//! Migrated from `burnembed` — compiles BERT / NomicBERT / NomicVision graphs
19//! via `rlx-runtime` and exposes tier-0 inference helpers.
20//!
21//! ```rust,ignore
22//! use rlx_models::embed::{Pooling, RlxBertModel, BertTokenizer, embed_with_rlx};
23//!
24//! let tok = BertTokenizer::from_dir(model_dir, 512)?;
25//! let mut model = RlxBertModel::load(&config, &weights)?;
26//! let vecs = embed_with_rlx(&mut model, &tok, &["hello", "world"], Pooling::Mean)?;
27//! ```
28
29mod arch;
30mod bert;
31mod nomic;
32mod pooling;
33mod registry;
34mod runtime;
35mod text;
36mod tokenizer;
37mod vision;
38
39pub use arch::{Arch, default_pooling, detect_arch};
40pub use bert::RlxBertModel;
41pub use nomic::RlxNomicModel;
42pub use pooling::{Pooling, l2_normalize_in_place, pool_embeddings};
43pub use registry::{
44    EmbeddingModel, ImageEmbeddingModel, ImageModelInfo, ModelArch, ModelInfo, models_map,
45};
46pub use runtime::{RlxEmbed, compile_model, compile_model_cpu};
47pub use text::embed_with_rlx;
48pub use tokenizer::{BertTokenizer, TokenizedBatch};
49pub use vision::{RlxVisionModel, assemble_vision_hidden};
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54    use rlx_core::weight_map::WeightMap;
55    use std::collections::HashMap;
56
57    fn tiny_bert_cfg() -> rlx_core::config::BertConfig {
58        rlx_core::config::BertConfig {
59            vocab_size: 32,
60            hidden_size: 16,
61            num_hidden_layers: 1,
62            num_attention_heads: 4,
63            intermediate_size: 32,
64            max_position_embeddings: 32,
65            type_vocab_size: 2,
66            layer_norm_eps: 1e-12,
67            hidden_act: "gelu".into(),
68        }
69    }
70
71    fn tiny_bert_weights(cfg: &rlx_core::config::BertConfig) -> WeightMap {
72        let h = cfg.hidden_size;
73        let int_dim = cfg.intermediate_size;
74        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
75        let z = |n: usize| vec![0.0f32; n];
76        t.insert(
77            "embeddings.word_embeddings.weight".into(),
78            (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
79        );
80        t.insert(
81            "embeddings.position_embeddings.weight".into(),
82            (
83                z(cfg.max_position_embeddings * h),
84                vec![cfg.max_position_embeddings, h],
85            ),
86        );
87        t.insert(
88            "embeddings.token_type_embeddings.weight".into(),
89            (z(cfg.type_vocab_size * h), vec![cfg.type_vocab_size, h]),
90        );
91        t.insert("embeddings.LayerNorm.weight".into(), (z(h), vec![h]));
92        t.insert("embeddings.LayerNorm.bias".into(), (z(h), vec![h]));
93        let lp = "encoder.layer.0";
94        t.insert(
95            format!("{lp}.attention.self.query.weight"),
96            (z(h * h), vec![h, h]),
97        );
98        t.insert(format!("{lp}.attention.self.query.bias"), (z(h), vec![h]));
99        t.insert(
100            format!("{lp}.attention.self.key.weight"),
101            (z(h * h), vec![h, h]),
102        );
103        t.insert(format!("{lp}.attention.self.key.bias"), (z(h), vec![h]));
104        t.insert(
105            format!("{lp}.attention.self.value.weight"),
106            (z(h * h), vec![h, h]),
107        );
108        t.insert(format!("{lp}.attention.self.value.bias"), (z(h), vec![h]));
109        t.insert(
110            format!("{lp}.attention.output.dense.weight"),
111            (z(h * h), vec![h, h]),
112        );
113        t.insert(format!("{lp}.attention.output.dense.bias"), (z(h), vec![h]));
114        t.insert(
115            format!("{lp}.attention.output.LayerNorm.weight"),
116            (z(h), vec![h]),
117        );
118        t.insert(
119            format!("{lp}.attention.output.LayerNorm.bias"),
120            (z(h), vec![h]),
121        );
122        t.insert(
123            format!("{lp}.intermediate.dense.weight"),
124            (z(int_dim * h), vec![int_dim, h]),
125        );
126        t.insert(
127            format!("{lp}.intermediate.dense.bias"),
128            (z(int_dim), vec![int_dim]),
129        );
130        t.insert(
131            format!("{lp}.output.dense.weight"),
132            (z(h * int_dim), vec![h, int_dim]),
133        );
134        t.insert(format!("{lp}.output.dense.bias"), (z(h), vec![h]));
135        t.insert(format!("{lp}.output.LayerNorm.weight"), (z(h), vec![h]));
136        t.insert(format!("{lp}.output.LayerNorm.bias"), (z(h), vec![h]));
137        t.insert("pooler.dense.weight".into(), (z(h * h), vec![h, h]));
138        t.insert("pooler.dense.bias".into(), (z(h), vec![h]));
139        WeightMap::from_tensors(t)
140    }
141
142    #[test]
143    fn rlx_bert_graph_builds() {
144        let cfg = tiny_bert_cfg();
145        let mut wm = tiny_bert_weights(&cfg);
146        let (graph, params) = rlx_bert::bert::build_bert_graph_sized(&cfg, &mut wm, 1, 4).unwrap();
147        assert_eq!(graph.outputs.len(), 1);
148        assert!(!params.is_empty());
149    }
150
151    #[test]
152    fn registry_lists_models() {
153        assert!(!EmbeddingModel::list_supported().is_empty());
154        assert!(EmbeddingModel::AllMiniLML6V2.get_info().is_some());
155    }
156}