Skip to main content

test_mixedbread/
test_mixedbread.rs

1#[cfg(feature = "mixedbread")]
2use ck_embed::create_embedder;
3#[cfg(feature = "mixedbread")]
4use ck_embed::reranker::create_reranker;
5#[cfg(feature = "mixedbread")]
6use ck_models::{ModelRegistry, RerankModelRegistry};
7
8fn main() {
9    #[cfg(not(feature = "mixedbread"))]
10    {
11        println!("This example requires the 'mixedbread' feature to be enabled.");
12        println!("Run with: cargo run --example test_mixedbread --features mixedbread");
13        return;
14    }
15
16    #[cfg(feature = "mixedbread")]
17    run_example();
18}
19
20#[cfg(feature = "mixedbread")]
21fn run_example() {
22    println!("=== Testing Mixedbread Models ===\n");
23
24    // Test 1: Model Registry Resolution
25    println!("1. Testing Model Registry Resolution");
26    println!("   Checking if 'mxbai-xsmall' alias resolves...");
27
28    let registry = ModelRegistry::default();
29    match registry.resolve(Some("mxbai-xsmall")) {
30        Ok((alias, config)) => {
31            println!("   ✅ Resolved alias: '{alias}'");
32            println!("      Model name: {}", config.name);
33            println!("      Provider: {}", config.provider);
34            println!("      Dimensions: {}", config.dimensions);
35            println!("      Max tokens: {}", config.max_tokens);
36        }
37        Err(e) => {
38            println!("   ❌ Failed to resolve alias: {e}");
39            return;
40        }
41    }
42
43    // Test 2: Embedder Creation
44    println!("\n2. Testing Mixedbread Embedder Creation");
45    println!("   Attempting to create Mixedbread embedder...");
46
47    let result = create_embedder(Some("mixedbread-ai/mxbai-embed-xsmall-v1"));
48
49    match result {
50        Ok(mut embedder) => {
51            println!("   ✅ Successfully created embedder: {}", embedder.id());
52            println!("      Model name: {}", embedder.model_name());
53            println!("      Dimensions: {}", embedder.dim());
54
55            // Test 3: Embedding Generation
56            println!("\n3. Testing Embedding Generation");
57            let test_texts = vec![
58                "Hello world".to_string(),
59                "Rust programming language".to_string(),
60                "Machine learning and artificial intelligence".to_string(),
61            ];
62            println!("   Generating embeddings for {} texts...", test_texts.len());
63
64            match embedder.embed(&test_texts) {
65                Ok(embeddings) => {
66                    println!("   ✅ Successfully generated embeddings");
67                    println!(
68                        "      Shape: {} embeddings of {} dimensions",
69                        embeddings.len(),
70                        embeddings[0].len()
71                    );
72
73                    // Verify dimensions
74                    assert_eq!(
75                        embeddings.len(),
76                        test_texts.len(),
77                        "Should have one embedding per text"
78                    );
79                    assert_eq!(
80                        embeddings[0].len(),
81                        384,
82                        "Mixedbread xsmall should produce 384-dim vectors"
83                    );
84
85                    // Check normalization (L2 norm should be ~1.0)
86                    for (i, emb) in embeddings.iter().enumerate() {
87                        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
88                        println!("      Embedding {i} L2 norm: {norm:.6} (should be ~1.0)");
89                        assert!(
90                            (norm - 1.0).abs() < 0.01,
91                            "Embeddings should be L2-normalized"
92                        );
93                    }
94                }
95                Err(e) => {
96                    println!("   ❌ Failed to generate embeddings: {e}");
97                    return;
98                }
99            }
100        }
101        Err(e) => {
102            println!("   ❌ Failed to create Mixedbread embedder: {e}");
103            println!("      Error details: {e:?}");
104            return;
105        }
106    }
107
108    // Test 4: Reranker Registry Resolution
109    println!("\n4. Testing Reranker Registry Resolution");
110    println!("   Checking if 'mxbai' reranker alias resolves...");
111
112    let rerank_registry = RerankModelRegistry::default();
113    match rerank_registry.resolve(Some("mxbai")) {
114        Ok((alias, config)) => {
115            println!("   ✅ Resolved reranker alias: '{alias}'");
116            println!("      Model name: {}", config.name);
117            println!("      Provider: {}", config.provider);
118        }
119        Err(e) => {
120            println!("   ❌ Failed to resolve reranker alias: {e}");
121            return;
122        }
123    }
124
125    // Test 5: Reranker Creation
126    println!("\n5. Testing Mixedbread Reranker Creation");
127    println!("   Attempting to create Mixedbread reranker...");
128
129    match create_reranker(Some("mixedbread-ai/mxbai-rerank-xsmall-v1")) {
130        Ok(mut reranker) => {
131            println!("   ✅ Successfully created reranker: {}", reranker.id());
132
133            // Test 6: Reranking
134            println!("\n6. Testing Reranking");
135            let query = "error handling in Rust";
136            let documents = vec![
137                "Rust error handling with Result and Option types".to_string(),
138                "Python web development frameworks".to_string(),
139                "Rust provides excellent error handling mechanisms".to_string(),
140                "JavaScript async programming patterns".to_string(),
141            ];
142            println!("   Query: '{query}'");
143            println!("   Reranking {} documents...", documents.len());
144
145            match reranker.rerank(query, &documents) {
146                Ok(results) => {
147                    println!("   ✅ Successfully reranked documents");
148                    println!("      Results (sorted by score):");
149                    for (i, result) in results.iter().enumerate() {
150                        println!(
151                            "      {}. Score: {:.4} | Doc: {}",
152                            i + 1,
153                            result.score,
154                            if result.document.len() > 60 {
155                                &result.document[..60]
156                            } else {
157                                &result.document
158                            }
159                        );
160                    }
161
162                    // Verify results are sorted by score (descending)
163                    let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
164                    let sorted_scores: Vec<f32> = {
165                        let mut s = scores.clone();
166                        s.sort_by(|a, b| b.partial_cmp(a).unwrap());
167                        s
168                    };
169                    assert_eq!(
170                        scores, sorted_scores,
171                        "Results should be sorted by score descending"
172                    );
173
174                    // Verify scores are in valid range [0, 1]
175                    for result in &results {
176                        assert!(
177                            result.score >= 0.0 && result.score <= 1.0,
178                            "Rerank scores should be in [0, 1] range"
179                        );
180                    }
181                }
182                Err(e) => {
183                    println!("   ❌ Failed to rerank: {e}");
184                    return;
185                }
186            }
187        }
188        Err(e) => {
189            println!("   ❌ Failed to create Mixedbread reranker: {e}");
190            println!("      Error details: {e:?}");
191            return;
192        }
193    }
194
195    println!("\n=== All Tests Passed! ===");
196}