1mod arch;
30mod bert;
31mod nomic;
32mod pooling;
33mod registry;
34mod runtime;
35mod text;
36mod tokenizer;
37mod vision;
38
39pub use arch::{Arch, default_pooling, detect_arch};
40pub use bert::RlxBertModel;
41pub use nomic::RlxNomicModel;
42pub use pooling::{Pooling, l2_normalize_in_place, pool_embeddings};
43pub use registry::{
44 EmbeddingModel, ImageEmbeddingModel, ImageModelInfo, ModelArch, ModelInfo, models_map,
45};
46pub use runtime::{RlxEmbed, compile_model, compile_model_cpu};
47pub use text::embed_with_rlx;
48pub use tokenizer::{BertTokenizer, TokenizedBatch};
49pub use vision::{RlxVisionModel, assemble_vision_hidden};
50
51#[cfg(test)]
52mod tests {
53 use super::*;
54 use rlx_core::weight_map::WeightMap;
55 use std::collections::HashMap;
56
57 fn tiny_bert_cfg() -> rlx_core::config::BertConfig {
58 rlx_core::config::BertConfig {
59 vocab_size: 32,
60 hidden_size: 16,
61 num_hidden_layers: 1,
62 num_attention_heads: 4,
63 intermediate_size: 32,
64 max_position_embeddings: 32,
65 type_vocab_size: 2,
66 layer_norm_eps: 1e-12,
67 hidden_act: "gelu".into(),
68 }
69 }
70
71 fn tiny_bert_weights(cfg: &rlx_core::config::BertConfig) -> WeightMap {
72 let h = cfg.hidden_size;
73 let int_dim = cfg.intermediate_size;
74 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
75 let z = |n: usize| vec![0.0f32; n];
76 t.insert(
77 "embeddings.word_embeddings.weight".into(),
78 (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
79 );
80 t.insert(
81 "embeddings.position_embeddings.weight".into(),
82 (
83 z(cfg.max_position_embeddings * h),
84 vec![cfg.max_position_embeddings, h],
85 ),
86 );
87 t.insert(
88 "embeddings.token_type_embeddings.weight".into(),
89 (z(cfg.type_vocab_size * h), vec![cfg.type_vocab_size, h]),
90 );
91 t.insert("embeddings.LayerNorm.weight".into(), (z(h), vec![h]));
92 t.insert("embeddings.LayerNorm.bias".into(), (z(h), vec![h]));
93 let lp = "encoder.layer.0";
94 t.insert(
95 format!("{lp}.attention.self.query.weight"),
96 (z(h * h), vec![h, h]),
97 );
98 t.insert(format!("{lp}.attention.self.query.bias"), (z(h), vec![h]));
99 t.insert(
100 format!("{lp}.attention.self.key.weight"),
101 (z(h * h), vec![h, h]),
102 );
103 t.insert(format!("{lp}.attention.self.key.bias"), (z(h), vec![h]));
104 t.insert(
105 format!("{lp}.attention.self.value.weight"),
106 (z(h * h), vec![h, h]),
107 );
108 t.insert(format!("{lp}.attention.self.value.bias"), (z(h), vec![h]));
109 t.insert(
110 format!("{lp}.attention.output.dense.weight"),
111 (z(h * h), vec![h, h]),
112 );
113 t.insert(format!("{lp}.attention.output.dense.bias"), (z(h), vec![h]));
114 t.insert(
115 format!("{lp}.attention.output.LayerNorm.weight"),
116 (z(h), vec![h]),
117 );
118 t.insert(
119 format!("{lp}.attention.output.LayerNorm.bias"),
120 (z(h), vec![h]),
121 );
122 t.insert(
123 format!("{lp}.intermediate.dense.weight"),
124 (z(int_dim * h), vec![int_dim, h]),
125 );
126 t.insert(
127 format!("{lp}.intermediate.dense.bias"),
128 (z(int_dim), vec![int_dim]),
129 );
130 t.insert(
131 format!("{lp}.output.dense.weight"),
132 (z(h * int_dim), vec![h, int_dim]),
133 );
134 t.insert(format!("{lp}.output.dense.bias"), (z(h), vec![h]));
135 t.insert(format!("{lp}.output.LayerNorm.weight"), (z(h), vec![h]));
136 t.insert(format!("{lp}.output.LayerNorm.bias"), (z(h), vec![h]));
137 t.insert("pooler.dense.weight".into(), (z(h * h), vec![h, h]));
138 t.insert("pooler.dense.bias".into(), (z(h), vec![h]));
139 WeightMap::from_tensors(t)
140 }
141
142 #[test]
143 fn rlx_bert_graph_builds() {
144 let cfg = tiny_bert_cfg();
145 let mut wm = tiny_bert_weights(&cfg);
146 let (graph, params) = rlx_bert::bert::build_bert_graph_sized(&cfg, &mut wm, 1, 4).unwrap();
147 assert_eq!(graph.outputs.len(), 1);
148 assert!(!params.is_empty());
149 }
150
151 #[test]
152 fn registry_lists_models() {
153 assert!(!EmbeddingModel::list_supported().is_empty());
154 assert!(EmbeddingModel::AllMiniLML6V2.get_info().is_some());
155 }
156}