test_mixedbread/
test_mixedbread.rs1#[cfg(feature = "mixedbread")]
2use ck_embed::create_embedder;
3#[cfg(feature = "mixedbread")]
4use ck_embed::reranker::create_reranker;
5#[cfg(feature = "mixedbread")]
6use ck_models::{ModelRegistry, RerankModelRegistry};
7
8fn main() {
9 #[cfg(not(feature = "mixedbread"))]
10 {
11 println!("This example requires the 'mixedbread' feature to be enabled.");
12 println!("Run with: cargo run --example test_mixedbread --features mixedbread");
13 return;
14 }
15
16 #[cfg(feature = "mixedbread")]
17 run_example();
18}
19
20#[cfg(feature = "mixedbread")]
21fn run_example() {
22 println!("=== Testing Mixedbread Models ===\n");
23
24 println!("1. Testing Model Registry Resolution");
26 println!(" Checking if 'mxbai-xsmall' alias resolves...");
27
28 let registry = ModelRegistry::default();
29 match registry.resolve(Some("mxbai-xsmall")) {
30 Ok((alias, config)) => {
31 println!(" ✅ Resolved alias: '{alias}'");
32 println!(" Model name: {}", config.name);
33 println!(" Provider: {}", config.provider);
34 println!(" Dimensions: {}", config.dimensions);
35 println!(" Max tokens: {}", config.max_tokens);
36 }
37 Err(e) => {
38 println!(" ❌ Failed to resolve alias: {e}");
39 return;
40 }
41 }
42
43 println!("\n2. Testing Mixedbread Embedder Creation");
45 println!(" Attempting to create Mixedbread embedder...");
46
47 let result = create_embedder(Some("mixedbread-ai/mxbai-embed-xsmall-v1"));
48
49 match result {
50 Ok(mut embedder) => {
51 println!(" ✅ Successfully created embedder: {}", embedder.id());
52 println!(" Model name: {}", embedder.model_name());
53 println!(" Dimensions: {}", embedder.dim());
54
55 println!("\n3. Testing Embedding Generation");
57 let test_texts = vec![
58 "Hello world".to_string(),
59 "Rust programming language".to_string(),
60 "Machine learning and artificial intelligence".to_string(),
61 ];
62 println!(" Generating embeddings for {} texts...", test_texts.len());
63
64 match embedder.embed(&test_texts) {
65 Ok(embeddings) => {
66 println!(" ✅ Successfully generated embeddings");
67 println!(
68 " Shape: {} embeddings of {} dimensions",
69 embeddings.len(),
70 embeddings[0].len()
71 );
72
73 assert_eq!(
75 embeddings.len(),
76 test_texts.len(),
77 "Should have one embedding per text"
78 );
79 assert_eq!(
80 embeddings[0].len(),
81 384,
82 "Mixedbread xsmall should produce 384-dim vectors"
83 );
84
85 for (i, emb) in embeddings.iter().enumerate() {
87 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
88 println!(" Embedding {i} L2 norm: {norm:.6} (should be ~1.0)");
89 assert!(
90 (norm - 1.0).abs() < 0.01,
91 "Embeddings should be L2-normalized"
92 );
93 }
94 }
95 Err(e) => {
96 println!(" ❌ Failed to generate embeddings: {e}");
97 return;
98 }
99 }
100 }
101 Err(e) => {
102 println!(" ❌ Failed to create Mixedbread embedder: {e}");
103 println!(" Error details: {e:?}");
104 return;
105 }
106 }
107
108 println!("\n4. Testing Reranker Registry Resolution");
110 println!(" Checking if 'mxbai' reranker alias resolves...");
111
112 let rerank_registry = RerankModelRegistry::default();
113 match rerank_registry.resolve(Some("mxbai")) {
114 Ok((alias, config)) => {
115 println!(" ✅ Resolved reranker alias: '{alias}'");
116 println!(" Model name: {}", config.name);
117 println!(" Provider: {}", config.provider);
118 }
119 Err(e) => {
120 println!(" ❌ Failed to resolve reranker alias: {e}");
121 return;
122 }
123 }
124
125 println!("\n5. Testing Mixedbread Reranker Creation");
127 println!(" Attempting to create Mixedbread reranker...");
128
129 match create_reranker(Some("mixedbread-ai/mxbai-rerank-xsmall-v1")) {
130 Ok(mut reranker) => {
131 println!(" ✅ Successfully created reranker: {}", reranker.id());
132
133 println!("\n6. Testing Reranking");
135 let query = "error handling in Rust";
136 let documents = vec![
137 "Rust error handling with Result and Option types".to_string(),
138 "Python web development frameworks".to_string(),
139 "Rust provides excellent error handling mechanisms".to_string(),
140 "JavaScript async programming patterns".to_string(),
141 ];
142 println!(" Query: '{query}'");
143 println!(" Reranking {} documents...", documents.len());
144
145 match reranker.rerank(query, &documents) {
146 Ok(results) => {
147 println!(" ✅ Successfully reranked documents");
148 println!(" Results (sorted by score):");
149 for (i, result) in results.iter().enumerate() {
150 println!(
151 " {}. Score: {:.4} | Doc: {}",
152 i + 1,
153 result.score,
154 if result.document.len() > 60 {
155 &result.document[..60]
156 } else {
157 &result.document
158 }
159 );
160 }
161
162 let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
164 let sorted_scores: Vec<f32> = {
165 let mut s = scores.clone();
166 s.sort_by(|a, b| b.partial_cmp(a).unwrap());
167 s
168 };
169 assert_eq!(
170 scores, sorted_scores,
171 "Results should be sorted by score descending"
172 );
173
174 for result in &results {
176 assert!(
177 result.score >= 0.0 && result.score <= 1.0,
178 "Rerank scores should be in [0, 1] range"
179 );
180 }
181 }
182 Err(e) => {
183 println!(" ❌ Failed to rerank: {e}");
184 return;
185 }
186 }
187 }
188 Err(e) => {
189 println!(" ❌ Failed to create Mixedbread reranker: {e}");
190 println!(" Error details: {e:?}");
191 return;
192 }
193 }
194
195 println!("\n=== All Tests Passed! ===");
196}