embeddings/
embeddings.rs

1//! Comprehensive embeddings example demonstrating vector embeddings with testing patterns.
2//!
3//! This example showcases the `OpenAI` embeddings API, including:
4//! - Basic embedding generation with different models
5//! - Batch processing of multiple texts
6//! - Dimension reduction capabilities
7//! - Similarity comparisons between embeddings
8//! - Testing patterns for embeddings
9//! - Comprehensive error handling and documentation
10//!
11//! ## Features Demonstrated
12//!
13//! - **Multiple Models**: text-embedding-3-small, text-embedding-3-large, ada-002
14//! - **Batch Processing**: Process multiple texts efficiently in a single request
15//! - **Dimension Control**: Reduce dimensions for optimized storage and performance
16//! - **Similarity Metrics**: Cosine similarity calculations between vectors
17//! - **Error Handling**: Robust error handling for various failure scenarios
18//! - **Testing Patterns**: Mock-friendly design for unit testing
19//!
20//! ## Prerequisites
21//!
22//! Set your `OpenAI` API key:
23//! ```bash
24//! export OPENAI_API_KEY="your-key-here"
25//! ```
26//!
27//! ## Usage
28//!
29//! ```bash
30//! cargo run --example embeddings
31//! ```
32//!
33//! Note: This example uses simulated responses to keep the example runnable without
34//! real `OpenAI` credentials. Replace the simulated sections with
35//! `client.embeddings().create(...)` calls to interact with the live API.
36
37#![allow(clippy::uninlined_format_args)]
38#![allow(clippy::no_effect_underscore_binding)]
39#![allow(clippy::cast_sign_loss)]
40#![allow(clippy::unused_async)]
41#![allow(dead_code)]
42
43use openai_ergonomic::Client;
44use std::collections::HashMap;
45
46/// Embedding models supported by `OpenAI`
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum EmbeddingModel {
49    /// text-embedding-3-small - Latest small embedding model
50    TextEmbedding3Small,
51    /// text-embedding-3-large - Latest large embedding model
52    TextEmbedding3Large,
53    /// text-embedding-ada-002 - Legacy ada model
54    Ada002,
55}
56
57impl EmbeddingModel {
58    /// Get the model name string
59    pub const fn as_str(&self) -> &'static str {
60        match self {
61            Self::TextEmbedding3Small => "text-embedding-3-small",
62            Self::TextEmbedding3Large => "text-embedding-3-large",
63            Self::Ada002 => "text-embedding-ada-002",
64        }
65    }
66
67    /// Get default dimensions for the model
68    pub const fn default_dimensions(&self) -> usize {
69        match self {
70            Self::TextEmbedding3Large => 3072,
71            Self::TextEmbedding3Small | Self::Ada002 => 1536,
72        }
73    }
74
75    /// Check if the model supports dimension reduction
76    pub const fn supports_dimensions(&self) -> bool {
77        matches!(self, Self::TextEmbedding3Small | Self::TextEmbedding3Large)
78    }
79}
80
81/// Represents an embedding vector with metadata
82#[derive(Debug, Clone)]
83pub struct Embedding {
84    /// The vector values
85    pub vector: Vec<f32>,
86    /// The input text that generated this embedding
87    pub text: String,
88    /// The model used to generate this embedding
89    pub model: EmbeddingModel,
90    /// Number of tokens in the input text
91    pub token_count: Option<usize>,
92}
93
94impl Embedding {
95    /// Create a new embedding
96    pub const fn new(vector: Vec<f32>, text: String, model: EmbeddingModel) -> Self {
97        Self {
98            vector,
99            text,
100            model,
101            token_count: None,
102        }
103    }
104
105    /// Get the dimensionality of this embedding
106    pub fn dimensions(&self) -> usize {
107        self.vector.len()
108    }
109
110    /// Calculate cosine similarity with another embedding
111    pub fn cosine_similarity(&self, other: &Self) -> Result<f32, EmbeddingError> {
112        if self.vector.len() != other.vector.len() {
113            return Err(EmbeddingError::DimensionMismatch {
114                expected: self.vector.len(),
115                actual: other.vector.len(),
116            });
117        }
118
119        let dot_product: f32 = self
120            .vector
121            .iter()
122            .zip(&other.vector)
123            .map(|(a, b)| a * b)
124            .sum();
125
126        let norm_a: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
127        let norm_b: f32 = other.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
128
129        if norm_a == 0.0 || norm_b == 0.0 {
130            return Err(EmbeddingError::ZeroVector);
131        }
132
133        Ok(dot_product / (norm_a * norm_b))
134    }
135
136    /// Calculate Euclidean distance with another embedding
137    pub fn euclidean_distance(&self, other: &Self) -> Result<f32, EmbeddingError> {
138        if self.vector.len() != other.vector.len() {
139            return Err(EmbeddingError::DimensionMismatch {
140                expected: self.vector.len(),
141                actual: other.vector.len(),
142            });
143        }
144
145        let distance: f32 = self
146            .vector
147            .iter()
148            .zip(&other.vector)
149            .map(|(a, b)| (a - b).powi(2))
150            .sum::<f32>()
151            .sqrt();
152
153        Ok(distance)
154    }
155}
156
157/// Custom error types for embedding operations
158#[derive(Debug, thiserror::Error)]
159pub enum EmbeddingError {
160    /// Dimension mismatch between vectors
161    #[error("Dimension mismatch: expected {expected}, got {actual}")]
162    DimensionMismatch {
163        /// Expected dimension count
164        expected: usize,
165        /// Actual dimension count
166        actual: usize,
167    },
168
169    /// Cannot calculate similarity with zero vector
170    #[error("Cannot calculate similarity with zero vector")]
171    ZeroVector,
172
173    /// Model does not support dimension reduction
174    #[error("Model {model} does not support dimension reduction")]
175    DimensionReductionNotSupported {
176        /// Model name that doesn't support dimension reduction
177        model: String,
178    },
179
180    /// Invalid dimensions specified
181    #[error("Invalid dimensions: {dimensions} (must be between 1 and {max})")]
182    InvalidDimensions {
183        /// Requested dimensions
184        dimensions: usize,
185        /// Maximum allowed dimensions
186        max: usize,
187    },
188
189    /// Batch processing failed
190    #[error("Batch processing failed: {message}")]
191    BatchProcessingFailed {
192        /// Error message
193        message: String,
194    },
195}
196
197/// Embedding request configuration
198#[derive(Debug, Clone)]
199pub struct EmbeddingRequest {
200    /// Texts to embed
201    pub inputs: Vec<String>,
202    /// Model to use
203    pub model: EmbeddingModel,
204    /// Optional dimension reduction
205    pub dimensions: Option<usize>,
206    /// User identifier for abuse monitoring
207    pub user: Option<String>,
208}
209
210impl EmbeddingRequest {
211    /// Create a new embedding request for a single text
212    pub fn new(text: impl Into<String>, model: EmbeddingModel) -> Self {
213        Self {
214            inputs: vec![text.into()],
215            model,
216            dimensions: None,
217            user: None,
218        }
219    }
220
221    /// Create a new embedding request for multiple texts
222    pub const fn batch(texts: Vec<String>, model: EmbeddingModel) -> Self {
223        Self {
224            inputs: texts,
225            model,
226            dimensions: None,
227            user: None,
228        }
229    }
230
231    /// Set dimension reduction (only for supported models)
232    pub fn with_dimensions(mut self, dimensions: usize) -> Result<Self, EmbeddingError> {
233        if !self.model.supports_dimensions() {
234            return Err(EmbeddingError::DimensionReductionNotSupported {
235                model: self.model.as_str().to_string(),
236            });
237        }
238
239        let max_dims = self.model.default_dimensions();
240        if dimensions == 0 || dimensions > max_dims {
241            return Err(EmbeddingError::InvalidDimensions {
242                dimensions,
243                max: max_dims,
244            });
245        }
246
247        self.dimensions = Some(dimensions);
248        Ok(self)
249    }
250
251    /// Set user identifier for abuse monitoring
252    #[must_use]
253    pub fn with_user(mut self, user: impl Into<String>) -> Self {
254        self.user = Some(user.into());
255        self
256    }
257}
258
259/// Embedding response containing results and metadata
260#[derive(Debug, Clone)]
261pub struct EmbeddingResponse {
262    /// Generated embeddings
263    pub embeddings: Vec<Embedding>,
264    /// Model used
265    pub model: EmbeddingModel,
266    /// Token usage information
267    pub usage: EmbeddingUsage,
268}
269
270/// Token usage information for embedding requests
271#[derive(Debug, Clone)]
272pub struct EmbeddingUsage {
273    /// Number of tokens in the input
274    pub prompt_tokens: usize,
275    /// Total tokens processed
276    pub total_tokens: usize,
277}
278
279/// Similarity search result
280#[derive(Debug, Clone)]
281pub struct SimilarityResult {
282    /// The embedding that was matched
283    pub embedding: Embedding,
284    /// Similarity score (higher is more similar)
285    pub score: f32,
286    /// Index in the original collection
287    pub index: usize,
288}
289
290/// Collection of embeddings for similarity search
291#[derive(Debug, Clone)]
292pub struct EmbeddingCollection {
293    embeddings: Vec<Embedding>,
294    metadata: HashMap<usize, serde_json::Value>,
295}
296
297impl EmbeddingCollection {
298    /// Create a new embedding collection
299    pub fn new() -> Self {
300        Self {
301            embeddings: Vec::new(),
302            metadata: HashMap::new(),
303        }
304    }
305
306    /// Add an embedding to the collection
307    pub fn add(&mut self, embedding: Embedding) -> usize {
308        let index = self.embeddings.len();
309        self.embeddings.push(embedding);
310        index
311    }
312
313    /// Add an embedding with metadata
314    pub fn add_with_metadata(
315        &mut self,
316        embedding: Embedding,
317        metadata: serde_json::Value,
318    ) -> usize {
319        let index = self.add(embedding);
320        self.metadata.insert(index, metadata);
321        index
322    }
323
324    /// Find the most similar embeddings to a query
325    pub fn find_similar(
326        &self,
327        query: &Embedding,
328        top_k: usize,
329    ) -> Result<Vec<SimilarityResult>, EmbeddingError> {
330        let mut results = Vec::new();
331
332        for (index, embedding) in self.embeddings.iter().enumerate() {
333            let score = query.cosine_similarity(embedding)?;
334            results.push(SimilarityResult {
335                embedding: embedding.clone(),
336                score,
337                index,
338            });
339        }
340
341        // Sort by similarity score (descending)
342        results.sort_by(|a, b| {
343            b.score
344                .partial_cmp(&a.score)
345                .unwrap_or(std::cmp::Ordering::Equal)
346        });
347
348        // Return top k results
349        results.truncate(top_k);
350        Ok(results)
351    }
352
353    /// Get metadata for an embedding by index
354    pub fn get_metadata(&self, index: usize) -> Option<&serde_json::Value> {
355        self.metadata.get(&index)
356    }
357
358    /// Get the number of embeddings in the collection
359    pub fn len(&self) -> usize {
360        self.embeddings.len()
361    }
362
363    /// Check if the collection is empty
364    pub fn is_empty(&self) -> bool {
365        self.embeddings.is_empty()
366    }
367}
368
369impl Default for EmbeddingCollection {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375#[tokio::main]
376async fn main() -> Result<(), Box<dyn std::error::Error>> {
377    println!(" OpenAI Ergonomic - Comprehensive Embeddings Example\n");
378
379    // Initialize client from environment variables
380    let client = match Client::from_env() {
381        Ok(client_builder) => {
382            println!(" Client initialized successfully");
383            client_builder.build()
384        }
385        Err(e) => {
386            eprintln!(" Failed to initialize client: {e}");
387            eprintln!(" Make sure OPENAI_API_KEY is set in your environment");
388            return Err(e.into());
389        }
390    };
391
392    // Example 1: Basic Embedding Generation
393    println!("\n Example 1: Basic Embedding Generation");
394    println!("=========================================");
395
396    match basic_embedding_example(&client).await {
397        Ok(()) => println!(" Basic embedding example completed"),
398        Err(e) => {
399            eprintln!(" Basic embedding example failed: {e}");
400            handle_embedding_error(e.as_ref());
401        }
402    }
403
404    // Example 2: Model Comparison
405    println!("\n Example 2: Model Comparison");
406    println!("===============================");
407
408    match model_comparison_example(&client).await {
409        Ok(()) => println!(" Model comparison example completed"),
410        Err(e) => {
411            eprintln!(" Model comparison example failed: {e}");
412            handle_embedding_error(e.as_ref());
413        }
414    }
415
416    // Example 3: Batch Processing
417    println!("\n Example 3: Batch Processing");
418    println!("===============================");
419
420    match batch_processing_example(&client).await {
421        Ok(()) => println!(" Batch processing example completed"),
422        Err(e) => {
423            eprintln!(" Batch processing example failed: {e}");
424            handle_embedding_error(e.as_ref());
425        }
426    }
427
428    // Example 4: Dimension Reduction
429    println!("\n Example 4: Dimension Reduction");
430    println!("==================================");
431
432    match dimension_reduction_example(&client).await {
433        Ok(()) => println!(" Dimension reduction example completed"),
434        Err(e) => {
435            eprintln!(" Dimension reduction example failed: {e}");
436            handle_embedding_error(e.as_ref());
437        }
438    }
439
440    // Example 5: Similarity Search
441    println!("\n Example 5: Similarity Search");
442    println!("================================");
443
444    match similarity_search_example(&client).await {
445        Ok(()) => println!(" Similarity search example completed"),
446        Err(e) => {
447            eprintln!(" Similarity search example failed: {e}");
448            handle_embedding_error(e.as_ref());
449        }
450    }
451
452    // Example 6: Testing Patterns
453    println!("\n Example 6: Testing Patterns");
454    println!("===============================");
455
456    match testing_patterns_example().await {
457        Ok(()) => println!(" Testing patterns example completed"),
458        Err(e) => {
459            eprintln!(" Testing patterns example failed: {e}");
460            handle_embedding_error(e.as_ref());
461        }
462    }
463
464    println!("\n All examples completed! Check the console output above for results.");
465    println!("\nNote: This example simulates API responses. Swap the simulated sections with");
466    println!("real `client.embeddings().create(...)` calls when you're ready to hit the API.");
467
468    Ok(())
469}
470
471/// Example 1: Basic embedding generation with a single text
472async fn basic_embedding_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
473    println!("Creating embeddings for a simple text...");
474
475    // This would be the intended API usage:
476    // let builder = client
477    //     .embeddings()
478    //     .text("text-embedding-3-small", "The quick brown fox jumps over the lazy dog");
479    // let response = client.embeddings().create(builder).await?;
480
481    // For now, we'll simulate the response
482    let text = "The quick brown fox jumps over the lazy dog";
483    let model = EmbeddingModel::TextEmbedding3Small;
484
485    println!(" Input text: \"{}\"", text);
486    println!(" Model: {}", model.as_str());
487    println!(" Expected dimensions: {}", model.default_dimensions());
488
489    // Simulate embedding generation
490    let simulated_embedding = simulate_embedding(text, model);
491
492    println!(
493        " Generated embedding with {} dimensions",
494        simulated_embedding.dimensions()
495    );
496    println!(
497        " First 5 values: {:?}",
498        &simulated_embedding.vector[..5.min(simulated_embedding.vector.len())]
499    );
500
501    if let Some(token_count) = simulated_embedding.token_count {
502        println!(" Token count: {}", token_count);
503    }
504
505    Ok(())
506}
507
508/// Example 2: Compare different embedding models
509async fn model_comparison_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
510    println!("Comparing embeddings across different models...");
511
512    let text = "Artificial intelligence is transforming the world";
513    let models = [
514        EmbeddingModel::TextEmbedding3Small,
515        EmbeddingModel::TextEmbedding3Large,
516        EmbeddingModel::Ada002,
517    ];
518
519    println!(" Input text: \"{}\"", text);
520    println!();
521
522    for model in models {
523        println!("Testing model: {}", model.as_str());
524
525        // Simulate embedding generation for each model
526        let embedding = simulate_embedding(text, model);
527
528        println!("   Dimensions: {}", embedding.dimensions());
529        println!(
530            "   Supports dimension reduction: {}",
531            model.supports_dimensions()
532        );
533        println!(
534            "   Vector norm: {:.6}",
535            calculate_vector_norm(&embedding.vector)
536        );
537        println!();
538    }
539
540    println!(" Different models produce embeddings with different characteristics:");
541    println!("   - text-embedding-3-small: Balanced performance and cost");
542    println!("   - text-embedding-3-large: Higher quality, more expensive");
543    println!("   - ada-002: Legacy model, still widely used");
544
545    Ok(())
546}
547
548/// Example 3: Batch processing multiple texts
549async fn batch_processing_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
550    println!("Processing multiple texts in batch...");
551
552    let texts = vec![
553        "The weather is sunny today".to_string(),
554        "I love reading science fiction books".to_string(),
555        "Machine learning algorithms are fascinating".to_string(),
556        "Pizza is my favorite food".to_string(),
557        "The ocean is vast and mysterious".to_string(),
558    ];
559
560    println!(" Processing {} texts in batch:", texts.len());
561    for (i, text) in texts.iter().enumerate() {
562        println!("  {}. \"{}\"", i + 1, text);
563    }
564
565    // This would be the intended API usage:
566    // let builder = client
567    //     .embeddings()
568    //     .builder("text-embedding-3-small")
569    //     .input_texts(texts.clone());
570    // let response = client.embeddings().create(builder).await?;
571
572    // Simulate batch processing
573    let mut embeddings = Vec::new();
574    let mut total_tokens = 0;
575
576    for text in &texts {
577        let embedding = simulate_embedding(text, EmbeddingModel::TextEmbedding3Small);
578        if let Some(tokens) = embedding.token_count {
579            total_tokens += tokens;
580        }
581        embeddings.push(embedding);
582    }
583
584    println!("\n Generated {} embeddings", embeddings.len());
585    println!(" Total tokens used: {}", total_tokens);
586    #[allow(clippy::cast_precision_loss)]
587    {
588        println!(
589            " Average tokens per text: {:.1}",
590            total_tokens as f32 / texts.len() as f32
591        );
592    }
593
594    // Show some statistics
595    #[allow(clippy::cast_precision_loss)]
596    let avg_norm: f32 = embeddings
597        .iter()
598        .map(|e| calculate_vector_norm(&e.vector))
599        .sum::<f32>()
600        / embeddings.len() as f32;
601
602    println!(" Average vector norm: {:.6}", avg_norm);
603
604    println!("\n Batch processing is more efficient for multiple texts:");
605    println!("   - Reduced API calls and latency");
606    println!("   - Better throughput for large datasets");
607    println!("   - Cost-effective for bulk operations");
608
609    Ok(())
610}
611
612/// Example 4: Dimension reduction for optimized storage
613async fn dimension_reduction_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
614    println!("Demonstrating dimension reduction capabilities...");
615
616    let text = "Vector databases enable semantic search at scale";
617    let model = EmbeddingModel::TextEmbedding3Small;
618    let original_dims = model.default_dimensions();
619    let reduced_dims = [512, 256, 128];
620
621    println!(" Input text: \"{}\"", text);
622    println!(
623        " Model: {} (default: {} dimensions)",
624        model.as_str(),
625        original_dims
626    );
627
628    // Generate original embedding
629    let original_embedding = simulate_embedding(text, model);
630    println!(
631        "\n Original embedding: {} dimensions",
632        original_embedding.dimensions()
633    );
634
635    // Test different dimension reductions
636    for &dims in &reduced_dims {
637        // This would be the intended API usage:
638        // let builder = client
639        //     .embeddings()
640        //     .text(model.as_str(), text)
641        //     .dimensions(dims as i32);
642        // let response = client.embeddings().create(builder).await?;
643
644        // Simulate dimension reduction
645        let reduced_embedding = simulate_reduced_embedding(text, model, dims).unwrap();
646
647        println!(" Reduced to {} dimensions:", dims);
648
649        // Calculate similarity between original and reduced
650        if let Ok(similarity) = original_embedding.cosine_similarity(&reduced_embedding) {
651            println!("    Similarity to original: {:.4}", similarity);
652        }
653
654        #[allow(clippy::cast_precision_loss)]
655        let compression_ratio = dims as f32 / original_dims as f32;
656        println!("    Compression ratio: {:.1}%", compression_ratio * 100.0);
657
658        let storage_savings = (1.0 - compression_ratio) * 100.0;
659        println!("    Storage savings: {:.1}%", storage_savings);
660    }
661
662    println!("\n Dimension reduction trade-offs:");
663    println!("    Pros: Reduced storage, faster search, lower memory usage");
664    println!("     Cons: Some semantic information loss");
665    println!("    Best practice: Test different dimensions for your use case");
666
667    Ok(())
668}
669
670/// Example 5: Similarity search and comparison
671async fn similarity_search_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
672    println!("Demonstrating similarity search and comparison...");
673
674    // Create a collection of text documents
675    let documents = vec![
676        "The cat sat on the mat",
677        "A feline rested on the rug",
678        "Dogs are loyal companions",
679        "Canines make great pets",
680        "The weather is sunny today",
681        "It's a beautiful clear day",
682        "Machine learning is fascinating",
683        "AI algorithms are powerful tools",
684    ];
685
686    let model = EmbeddingModel::TextEmbedding3Small;
687    println!(" Document collection ({} items):", documents.len());
688    for (i, doc) in documents.iter().enumerate() {
689        println!("  {}. \"{}\"", i + 1, doc);
690    }
691
692    // Create embeddings for all documents
693    let mut collection = EmbeddingCollection::new();
694    for doc in &documents {
695        let embedding = simulate_embedding(doc, model);
696        let metadata = serde_json::json!({
697            "text": doc,
698            "length": doc.len(),
699            "word_count": doc.split_whitespace().count()
700        });
701        collection.add_with_metadata(embedding, metadata);
702    }
703
704    println!(
705        "\n Created embedding collection with {} items",
706        collection.len()
707    );
708
709    // Query examples
710    let queries = vec![
711        "A cat sitting down",
712        "Dog pets",
713        "Nice weather",
714        "Artificial intelligence",
715    ];
716
717    for query in queries {
718        println!("\n Query: \"{}\"", query);
719
720        let query_embedding = simulate_embedding(query, model);
721        let results = collection.find_similar(&query_embedding, 3)?;
722
723        println!("   Top 3 similar documents:");
724        for (rank, result) in results.iter().enumerate() {
725            println!(
726                "     {}. \"{}\" (similarity: {:.4})",
727                rank + 1,
728                result.embedding.text,
729                result.score
730            );
731
732            if let Some(metadata) = collection.get_metadata(result.index) {
733                if let Some(word_count) = metadata["word_count"].as_u64() {
734                    println!("        Words: {}", word_count);
735                }
736            }
737        }
738    }
739
740    println!("\n Similarity search applications:");
741    println!("    Semantic search engines");
742    println!("    Document retrieval systems");
743    println!("    Recommendation engines");
744    println!("    Content deduplication");
745
746    Ok(())
747}
748
749/// Example 6: Testing patterns for embeddings
750async fn testing_patterns_example() -> Result<(), Box<dyn std::error::Error>> {
751    println!("Demonstrating testing patterns for embeddings...");
752
753    // Test 1: Embedding properties
754    println!("\n Test 1: Embedding Properties");
755    let text = "Test embedding generation";
756    let model = EmbeddingModel::TextEmbedding3Small;
757    let embedding = simulate_embedding(text, model);
758
759    assert_eq!(embedding.dimensions(), model.default_dimensions());
760    assert_eq!(embedding.text, text);
761    assert_eq!(embedding.model, model);
762    println!("    Embedding properties test passed");
763
764    // Test 2: Similarity calculations
765    println!("\n Test 2: Similarity Calculations");
766    let text1 = "Hello world";
767    let text2 = "Hello world"; // Identical
768    let text3 = "Goodbye world"; // Different
769
770    let embed1 = simulate_embedding(text1, model);
771    let embed2 = simulate_embedding(text2, model);
772    let embed3 = simulate_embedding(text3, model);
773
774    let identical_similarity = embed1.cosine_similarity(&embed2)?;
775    let different_similarity = embed1.cosine_similarity(&embed3)?;
776
777    assert!(
778        identical_similarity > 0.99,
779        "Identical texts should have high similarity"
780    );
781    assert!(
782        different_similarity < identical_similarity,
783        "Different texts should have lower similarity"
784    );
785    println!("    Similarity calculation test passed");
786    println!(
787        "      Identical texts similarity: {:.4}",
788        identical_similarity
789    );
790    println!(
791        "      Different texts similarity: {:.4}",
792        different_similarity
793    );
794
795    // Test 3: Dimension mismatch error
796    println!("\n Test 3: Error Handling");
797    let small_embed =
798        simulate_reduced_embedding("test", EmbeddingModel::TextEmbedding3Small, 256).unwrap();
799    let large_embed = simulate_embedding("test", EmbeddingModel::TextEmbedding3Large);
800
801    match small_embed.cosine_similarity(&large_embed) {
802        Err(EmbeddingError::DimensionMismatch { expected, actual }) => {
803            println!("    Dimension mismatch error handled correctly");
804            println!("      Expected: {}, Actual: {}", expected, actual);
805        }
806        Ok(_) => panic!("Should have failed with dimension mismatch"),
807        Err(e) => panic!("Unexpected error: {}", e),
808    }
809
810    // Test 4: Collection operations
811    println!("\n Test 4: Collection Operations");
812    let mut collection = EmbeddingCollection::new();
813    assert!(collection.is_empty());
814
815    let test_embedding = simulate_embedding("test", model);
816    let index = collection.add(test_embedding);
817    assert_eq!(index, 0);
818    assert_eq!(collection.len(), 1);
819    assert!(!collection.is_empty());
820
821    println!("    Collection operations test passed");
822
823    // Test 5: Model capabilities
824    println!("\n Test 5: Model Capabilities");
825    assert!(EmbeddingModel::TextEmbedding3Small.supports_dimensions());
826    assert!(EmbeddingModel::TextEmbedding3Large.supports_dimensions());
827    assert!(!EmbeddingModel::Ada002.supports_dimensions());
828
829    println!("    Model capabilities test passed");
830
831    println!("\n Testing best practices:");
832    println!("    Test embedding properties and dimensions");
833    println!("    Validate similarity calculations");
834    println!("    Test error conditions and edge cases");
835    println!("    Test with known similar/dissimilar text pairs");
836    println!("    Use deterministic test data for reproducible results");
837
838    Ok(())
839}
840
841/// Simulate embedding generation (for demonstration purposes)
842fn simulate_embedding(text: &str, model: EmbeddingModel) -> Embedding {
843    use std::collections::hash_map::DefaultHasher;
844    use std::hash::{Hash, Hasher};
845
846    let dimensions = model.default_dimensions();
847
848    // Create a deterministic "embedding" based on text hash
849    let mut hasher = DefaultHasher::new();
850    text.hash(&mut hasher);
851    model.as_str().hash(&mut hasher);
852    let seed = hasher.finish();
853
854    let mut rng = XorShift64Star::new(seed);
855    let mut vector = Vec::with_capacity(dimensions);
856
857    // Generate random-ish values that sum to create a unit vector
858    for _ in 0..dimensions {
859        vector.push(rng.next_f32() - 0.5);
860    }
861
862    // Normalize to unit vector
863    let norm = calculate_vector_norm(&vector);
864    if norm > 0.0 {
865        for value in &mut vector {
866            *value /= norm;
867        }
868    }
869
870    let token_count = text.split_whitespace().count().max(1);
871
872    let mut embedding = Embedding::new(vector, text.to_string(), model);
873    embedding.token_count = Some(token_count);
874
875    embedding
876}
877
878/// Simulate embedding with reduced dimensions
879fn simulate_reduced_embedding(
880    text: &str,
881    model: EmbeddingModel,
882    dimensions: usize,
883) -> Result<Embedding, Box<dyn std::error::Error>> {
884    if !model.supports_dimensions() {
885        return Err(EmbeddingError::DimensionReductionNotSupported {
886            model: model.as_str().to_string(),
887        }
888        .into());
889    }
890
891    let mut original = simulate_embedding(text, model);
892
893    // Truncate to desired dimensions and renormalize
894    original.vector.truncate(dimensions);
895    let norm = calculate_vector_norm(&original.vector);
896    if norm > 0.0 {
897        for value in &mut original.vector {
898            *value /= norm;
899        }
900    }
901
902    Ok(original)
903}
904
905/// Calculate the Euclidean norm of a vector
906fn calculate_vector_norm(vector: &[f32]) -> f32 {
907    vector.iter().map(|x| x * x).sum::<f32>().sqrt()
908}
909
910/// Simple PRNG for deterministic "random" embeddings
911struct XorShift64Star {
912    state: u64,
913}
914
915impl XorShift64Star {
916    const fn new(seed: u64) -> Self {
917        Self {
918            state: if seed == 0 { 1 } else { seed },
919        }
920    }
921
922    fn next_u64(&mut self) -> u64 {
923        self.state ^= self.state >> 12;
924        self.state ^= self.state << 25;
925        self.state ^= self.state >> 27;
926        self.state.wrapping_mul(0x2545_F491_4F6C_DD1D)
927    }
928
929    #[allow(clippy::cast_precision_loss)]
930    fn next_f32(&mut self) -> f32 {
931        (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
932    }
933}
934
935/// Handle embedding-specific errors with helpful context
936fn handle_embedding_error(error: &dyn std::error::Error) {
937    // This is simplified - in a real implementation you'd match on specific error types
938    eprintln!(" Embedding Error: {}", error);
939
940    if let Some(source) = error.source() {
941        eprintln!("   Caused by: {}", source);
942    }
943
944    // Provide context-specific guidance
945    eprintln!(" Troubleshooting tips:");
946    eprintln!("   - Check your API key and network connection");
947    eprintln!("   - Verify text input is not empty");
948    eprintln!("   - Ensure model supports requested features");
949    eprintln!("   - Check dimension parameters are valid");
950}