1#![allow(clippy::uninlined_format_args)]
38#![allow(clippy::no_effect_underscore_binding)]
39#![allow(clippy::cast_sign_loss)]
40#![allow(clippy::unused_async)]
41#![allow(dead_code)]
42
43use openai_ergonomic::Client;
44use std::collections::HashMap;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum EmbeddingModel {
49 TextEmbedding3Small,
51 TextEmbedding3Large,
53 Ada002,
55}
56
57impl EmbeddingModel {
58 pub const fn as_str(&self) -> &'static str {
60 match self {
61 Self::TextEmbedding3Small => "text-embedding-3-small",
62 Self::TextEmbedding3Large => "text-embedding-3-large",
63 Self::Ada002 => "text-embedding-ada-002",
64 }
65 }
66
67 pub const fn default_dimensions(&self) -> usize {
69 match self {
70 Self::TextEmbedding3Large => 3072,
71 Self::TextEmbedding3Small | Self::Ada002 => 1536,
72 }
73 }
74
75 pub const fn supports_dimensions(&self) -> bool {
77 matches!(self, Self::TextEmbedding3Small | Self::TextEmbedding3Large)
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct Embedding {
84 pub vector: Vec<f32>,
86 pub text: String,
88 pub model: EmbeddingModel,
90 pub token_count: Option<usize>,
92}
93
94impl Embedding {
95 pub const fn new(vector: Vec<f32>, text: String, model: EmbeddingModel) -> Self {
97 Self {
98 vector,
99 text,
100 model,
101 token_count: None,
102 }
103 }
104
105 pub fn dimensions(&self) -> usize {
107 self.vector.len()
108 }
109
110 pub fn cosine_similarity(&self, other: &Self) -> Result<f32, EmbeddingError> {
112 if self.vector.len() != other.vector.len() {
113 return Err(EmbeddingError::DimensionMismatch {
114 expected: self.vector.len(),
115 actual: other.vector.len(),
116 });
117 }
118
119 let dot_product: f32 = self
120 .vector
121 .iter()
122 .zip(&other.vector)
123 .map(|(a, b)| a * b)
124 .sum();
125
126 let norm_a: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
127 let norm_b: f32 = other.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
128
129 if norm_a == 0.0 || norm_b == 0.0 {
130 return Err(EmbeddingError::ZeroVector);
131 }
132
133 Ok(dot_product / (norm_a * norm_b))
134 }
135
136 pub fn euclidean_distance(&self, other: &Self) -> Result<f32, EmbeddingError> {
138 if self.vector.len() != other.vector.len() {
139 return Err(EmbeddingError::DimensionMismatch {
140 expected: self.vector.len(),
141 actual: other.vector.len(),
142 });
143 }
144
145 let distance: f32 = self
146 .vector
147 .iter()
148 .zip(&other.vector)
149 .map(|(a, b)| (a - b).powi(2))
150 .sum::<f32>()
151 .sqrt();
152
153 Ok(distance)
154 }
155}
156
157#[derive(Debug, thiserror::Error)]
159pub enum EmbeddingError {
160 #[error("Dimension mismatch: expected {expected}, got {actual}")]
162 DimensionMismatch {
163 expected: usize,
165 actual: usize,
167 },
168
169 #[error("Cannot calculate similarity with zero vector")]
171 ZeroVector,
172
173 #[error("Model {model} does not support dimension reduction")]
175 DimensionReductionNotSupported {
176 model: String,
178 },
179
180 #[error("Invalid dimensions: {dimensions} (must be between 1 and {max})")]
182 InvalidDimensions {
183 dimensions: usize,
185 max: usize,
187 },
188
189 #[error("Batch processing failed: {message}")]
191 BatchProcessingFailed {
192 message: String,
194 },
195}
196
197#[derive(Debug, Clone)]
199pub struct EmbeddingRequest {
200 pub inputs: Vec<String>,
202 pub model: EmbeddingModel,
204 pub dimensions: Option<usize>,
206 pub user: Option<String>,
208}
209
210impl EmbeddingRequest {
211 pub fn new(text: impl Into<String>, model: EmbeddingModel) -> Self {
213 Self {
214 inputs: vec![text.into()],
215 model,
216 dimensions: None,
217 user: None,
218 }
219 }
220
221 pub const fn batch(texts: Vec<String>, model: EmbeddingModel) -> Self {
223 Self {
224 inputs: texts,
225 model,
226 dimensions: None,
227 user: None,
228 }
229 }
230
231 pub fn with_dimensions(mut self, dimensions: usize) -> Result<Self, EmbeddingError> {
233 if !self.model.supports_dimensions() {
234 return Err(EmbeddingError::DimensionReductionNotSupported {
235 model: self.model.as_str().to_string(),
236 });
237 }
238
239 let max_dims = self.model.default_dimensions();
240 if dimensions == 0 || dimensions > max_dims {
241 return Err(EmbeddingError::InvalidDimensions {
242 dimensions,
243 max: max_dims,
244 });
245 }
246
247 self.dimensions = Some(dimensions);
248 Ok(self)
249 }
250
251 #[must_use]
253 pub fn with_user(mut self, user: impl Into<String>) -> Self {
254 self.user = Some(user.into());
255 self
256 }
257}
258
259#[derive(Debug, Clone)]
261pub struct EmbeddingResponse {
262 pub embeddings: Vec<Embedding>,
264 pub model: EmbeddingModel,
266 pub usage: EmbeddingUsage,
268}
269
270#[derive(Debug, Clone)]
272pub struct EmbeddingUsage {
273 pub prompt_tokens: usize,
275 pub total_tokens: usize,
277}
278
279#[derive(Debug, Clone)]
281pub struct SimilarityResult {
282 pub embedding: Embedding,
284 pub score: f32,
286 pub index: usize,
288}
289
290#[derive(Debug, Clone)]
292pub struct EmbeddingCollection {
293 embeddings: Vec<Embedding>,
294 metadata: HashMap<usize, serde_json::Value>,
295}
296
297impl EmbeddingCollection {
298 pub fn new() -> Self {
300 Self {
301 embeddings: Vec::new(),
302 metadata: HashMap::new(),
303 }
304 }
305
306 pub fn add(&mut self, embedding: Embedding) -> usize {
308 let index = self.embeddings.len();
309 self.embeddings.push(embedding);
310 index
311 }
312
313 pub fn add_with_metadata(
315 &mut self,
316 embedding: Embedding,
317 metadata: serde_json::Value,
318 ) -> usize {
319 let index = self.add(embedding);
320 self.metadata.insert(index, metadata);
321 index
322 }
323
324 pub fn find_similar(
326 &self,
327 query: &Embedding,
328 top_k: usize,
329 ) -> Result<Vec<SimilarityResult>, EmbeddingError> {
330 let mut results = Vec::new();
331
332 for (index, embedding) in self.embeddings.iter().enumerate() {
333 let score = query.cosine_similarity(embedding)?;
334 results.push(SimilarityResult {
335 embedding: embedding.clone(),
336 score,
337 index,
338 });
339 }
340
341 results.sort_by(|a, b| {
343 b.score
344 .partial_cmp(&a.score)
345 .unwrap_or(std::cmp::Ordering::Equal)
346 });
347
348 results.truncate(top_k);
350 Ok(results)
351 }
352
353 pub fn get_metadata(&self, index: usize) -> Option<&serde_json::Value> {
355 self.metadata.get(&index)
356 }
357
358 pub fn len(&self) -> usize {
360 self.embeddings.len()
361 }
362
363 pub fn is_empty(&self) -> bool {
365 self.embeddings.is_empty()
366 }
367}
368
369impl Default for EmbeddingCollection {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375#[tokio::main]
376async fn main() -> Result<(), Box<dyn std::error::Error>> {
377 println!(" OpenAI Ergonomic - Comprehensive Embeddings Example\n");
378
379 let client = match Client::from_env() {
381 Ok(client_builder) => {
382 println!(" Client initialized successfully");
383 client_builder.build()
384 }
385 Err(e) => {
386 eprintln!(" Failed to initialize client: {e}");
387 eprintln!(" Make sure OPENAI_API_KEY is set in your environment");
388 return Err(e.into());
389 }
390 };
391
392 println!("\n Example 1: Basic Embedding Generation");
394 println!("=========================================");
395
396 match basic_embedding_example(&client).await {
397 Ok(()) => println!(" Basic embedding example completed"),
398 Err(e) => {
399 eprintln!(" Basic embedding example failed: {e}");
400 handle_embedding_error(e.as_ref());
401 }
402 }
403
404 println!("\n Example 2: Model Comparison");
406 println!("===============================");
407
408 match model_comparison_example(&client).await {
409 Ok(()) => println!(" Model comparison example completed"),
410 Err(e) => {
411 eprintln!(" Model comparison example failed: {e}");
412 handle_embedding_error(e.as_ref());
413 }
414 }
415
416 println!("\n Example 3: Batch Processing");
418 println!("===============================");
419
420 match batch_processing_example(&client).await {
421 Ok(()) => println!(" Batch processing example completed"),
422 Err(e) => {
423 eprintln!(" Batch processing example failed: {e}");
424 handle_embedding_error(e.as_ref());
425 }
426 }
427
428 println!("\n Example 4: Dimension Reduction");
430 println!("==================================");
431
432 match dimension_reduction_example(&client).await {
433 Ok(()) => println!(" Dimension reduction example completed"),
434 Err(e) => {
435 eprintln!(" Dimension reduction example failed: {e}");
436 handle_embedding_error(e.as_ref());
437 }
438 }
439
440 println!("\n Example 5: Similarity Search");
442 println!("================================");
443
444 match similarity_search_example(&client).await {
445 Ok(()) => println!(" Similarity search example completed"),
446 Err(e) => {
447 eprintln!(" Similarity search example failed: {e}");
448 handle_embedding_error(e.as_ref());
449 }
450 }
451
452 println!("\n Example 6: Testing Patterns");
454 println!("===============================");
455
456 match testing_patterns_example().await {
457 Ok(()) => println!(" Testing patterns example completed"),
458 Err(e) => {
459 eprintln!(" Testing patterns example failed: {e}");
460 handle_embedding_error(e.as_ref());
461 }
462 }
463
464 println!("\n All examples completed! Check the console output above for results.");
465 println!("\nNote: This example simulates API responses. Swap the simulated sections with");
466 println!("real `client.embeddings().create(...)` calls when you're ready to hit the API.");
467
468 Ok(())
469}
470
471async fn basic_embedding_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
473 println!("Creating embeddings for a simple text...");
474
475 let text = "The quick brown fox jumps over the lazy dog";
483 let model = EmbeddingModel::TextEmbedding3Small;
484
485 println!(" Input text: \"{}\"", text);
486 println!(" Model: {}", model.as_str());
487 println!(" Expected dimensions: {}", model.default_dimensions());
488
489 let simulated_embedding = simulate_embedding(text, model);
491
492 println!(
493 " Generated embedding with {} dimensions",
494 simulated_embedding.dimensions()
495 );
496 println!(
497 " First 5 values: {:?}",
498 &simulated_embedding.vector[..5.min(simulated_embedding.vector.len())]
499 );
500
501 if let Some(token_count) = simulated_embedding.token_count {
502 println!(" Token count: {}", token_count);
503 }
504
505 Ok(())
506}
507
508async fn model_comparison_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
510 println!("Comparing embeddings across different models...");
511
512 let text = "Artificial intelligence is transforming the world";
513 let models = [
514 EmbeddingModel::TextEmbedding3Small,
515 EmbeddingModel::TextEmbedding3Large,
516 EmbeddingModel::Ada002,
517 ];
518
519 println!(" Input text: \"{}\"", text);
520 println!();
521
522 for model in models {
523 println!("Testing model: {}", model.as_str());
524
525 let embedding = simulate_embedding(text, model);
527
528 println!(" Dimensions: {}", embedding.dimensions());
529 println!(
530 " Supports dimension reduction: {}",
531 model.supports_dimensions()
532 );
533 println!(
534 " Vector norm: {:.6}",
535 calculate_vector_norm(&embedding.vector)
536 );
537 println!();
538 }
539
540 println!(" Different models produce embeddings with different characteristics:");
541 println!(" - text-embedding-3-small: Balanced performance and cost");
542 println!(" - text-embedding-3-large: Higher quality, more expensive");
543 println!(" - ada-002: Legacy model, still widely used");
544
545 Ok(())
546}
547
548async fn batch_processing_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
550 println!("Processing multiple texts in batch...");
551
552 let texts = vec![
553 "The weather is sunny today".to_string(),
554 "I love reading science fiction books".to_string(),
555 "Machine learning algorithms are fascinating".to_string(),
556 "Pizza is my favorite food".to_string(),
557 "The ocean is vast and mysterious".to_string(),
558 ];
559
560 println!(" Processing {} texts in batch:", texts.len());
561 for (i, text) in texts.iter().enumerate() {
562 println!(" {}. \"{}\"", i + 1, text);
563 }
564
565 let mut embeddings = Vec::new();
574 let mut total_tokens = 0;
575
576 for text in &texts {
577 let embedding = simulate_embedding(text, EmbeddingModel::TextEmbedding3Small);
578 if let Some(tokens) = embedding.token_count {
579 total_tokens += tokens;
580 }
581 embeddings.push(embedding);
582 }
583
584 println!("\n Generated {} embeddings", embeddings.len());
585 println!(" Total tokens used: {}", total_tokens);
586 #[allow(clippy::cast_precision_loss)]
587 {
588 println!(
589 " Average tokens per text: {:.1}",
590 total_tokens as f32 / texts.len() as f32
591 );
592 }
593
594 #[allow(clippy::cast_precision_loss)]
596 let avg_norm: f32 = embeddings
597 .iter()
598 .map(|e| calculate_vector_norm(&e.vector))
599 .sum::<f32>()
600 / embeddings.len() as f32;
601
602 println!(" Average vector norm: {:.6}", avg_norm);
603
604 println!("\n Batch processing is more efficient for multiple texts:");
605 println!(" - Reduced API calls and latency");
606 println!(" - Better throughput for large datasets");
607 println!(" - Cost-effective for bulk operations");
608
609 Ok(())
610}
611
612async fn dimension_reduction_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
614 println!("Demonstrating dimension reduction capabilities...");
615
616 let text = "Vector databases enable semantic search at scale";
617 let model = EmbeddingModel::TextEmbedding3Small;
618 let original_dims = model.default_dimensions();
619 let reduced_dims = [512, 256, 128];
620
621 println!(" Input text: \"{}\"", text);
622 println!(
623 " Model: {} (default: {} dimensions)",
624 model.as_str(),
625 original_dims
626 );
627
628 let original_embedding = simulate_embedding(text, model);
630 println!(
631 "\n Original embedding: {} dimensions",
632 original_embedding.dimensions()
633 );
634
635 for &dims in &reduced_dims {
637 let reduced_embedding = simulate_reduced_embedding(text, model, dims).unwrap();
646
647 println!(" Reduced to {} dimensions:", dims);
648
649 if let Ok(similarity) = original_embedding.cosine_similarity(&reduced_embedding) {
651 println!(" Similarity to original: {:.4}", similarity);
652 }
653
654 #[allow(clippy::cast_precision_loss)]
655 let compression_ratio = dims as f32 / original_dims as f32;
656 println!(" Compression ratio: {:.1}%", compression_ratio * 100.0);
657
658 let storage_savings = (1.0 - compression_ratio) * 100.0;
659 println!(" Storage savings: {:.1}%", storage_savings);
660 }
661
662 println!("\n Dimension reduction trade-offs:");
663 println!(" Pros: Reduced storage, faster search, lower memory usage");
664 println!(" Cons: Some semantic information loss");
665 println!(" Best practice: Test different dimensions for your use case");
666
667 Ok(())
668}
669
670async fn similarity_search_example(_client: &Client) -> Result<(), Box<dyn std::error::Error>> {
672 println!("Demonstrating similarity search and comparison...");
673
674 let documents = vec![
676 "The cat sat on the mat",
677 "A feline rested on the rug",
678 "Dogs are loyal companions",
679 "Canines make great pets",
680 "The weather is sunny today",
681 "It's a beautiful clear day",
682 "Machine learning is fascinating",
683 "AI algorithms are powerful tools",
684 ];
685
686 let model = EmbeddingModel::TextEmbedding3Small;
687 println!(" Document collection ({} items):", documents.len());
688 for (i, doc) in documents.iter().enumerate() {
689 println!(" {}. \"{}\"", i + 1, doc);
690 }
691
692 let mut collection = EmbeddingCollection::new();
694 for doc in &documents {
695 let embedding = simulate_embedding(doc, model);
696 let metadata = serde_json::json!({
697 "text": doc,
698 "length": doc.len(),
699 "word_count": doc.split_whitespace().count()
700 });
701 collection.add_with_metadata(embedding, metadata);
702 }
703
704 println!(
705 "\n Created embedding collection with {} items",
706 collection.len()
707 );
708
709 let queries = vec![
711 "A cat sitting down",
712 "Dog pets",
713 "Nice weather",
714 "Artificial intelligence",
715 ];
716
717 for query in queries {
718 println!("\n Query: \"{}\"", query);
719
720 let query_embedding = simulate_embedding(query, model);
721 let results = collection.find_similar(&query_embedding, 3)?;
722
723 println!(" Top 3 similar documents:");
724 for (rank, result) in results.iter().enumerate() {
725 println!(
726 " {}. \"{}\" (similarity: {:.4})",
727 rank + 1,
728 result.embedding.text,
729 result.score
730 );
731
732 if let Some(metadata) = collection.get_metadata(result.index) {
733 if let Some(word_count) = metadata["word_count"].as_u64() {
734 println!(" Words: {}", word_count);
735 }
736 }
737 }
738 }
739
740 println!("\n Similarity search applications:");
741 println!(" Semantic search engines");
742 println!(" Document retrieval systems");
743 println!(" Recommendation engines");
744 println!(" Content deduplication");
745
746 Ok(())
747}
748
749async fn testing_patterns_example() -> Result<(), Box<dyn std::error::Error>> {
751 println!("Demonstrating testing patterns for embeddings...");
752
753 println!("\n Test 1: Embedding Properties");
755 let text = "Test embedding generation";
756 let model = EmbeddingModel::TextEmbedding3Small;
757 let embedding = simulate_embedding(text, model);
758
759 assert_eq!(embedding.dimensions(), model.default_dimensions());
760 assert_eq!(embedding.text, text);
761 assert_eq!(embedding.model, model);
762 println!(" Embedding properties test passed");
763
764 println!("\n Test 2: Similarity Calculations");
766 let text1 = "Hello world";
767 let text2 = "Hello world"; let text3 = "Goodbye world"; let embed1 = simulate_embedding(text1, model);
771 let embed2 = simulate_embedding(text2, model);
772 let embed3 = simulate_embedding(text3, model);
773
774 let identical_similarity = embed1.cosine_similarity(&embed2)?;
775 let different_similarity = embed1.cosine_similarity(&embed3)?;
776
777 assert!(
778 identical_similarity > 0.99,
779 "Identical texts should have high similarity"
780 );
781 assert!(
782 different_similarity < identical_similarity,
783 "Different texts should have lower similarity"
784 );
785 println!(" Similarity calculation test passed");
786 println!(
787 " Identical texts similarity: {:.4}",
788 identical_similarity
789 );
790 println!(
791 " Different texts similarity: {:.4}",
792 different_similarity
793 );
794
795 println!("\n Test 3: Error Handling");
797 let small_embed =
798 simulate_reduced_embedding("test", EmbeddingModel::TextEmbedding3Small, 256).unwrap();
799 let large_embed = simulate_embedding("test", EmbeddingModel::TextEmbedding3Large);
800
801 match small_embed.cosine_similarity(&large_embed) {
802 Err(EmbeddingError::DimensionMismatch { expected, actual }) => {
803 println!(" Dimension mismatch error handled correctly");
804 println!(" Expected: {}, Actual: {}", expected, actual);
805 }
806 Ok(_) => panic!("Should have failed with dimension mismatch"),
807 Err(e) => panic!("Unexpected error: {}", e),
808 }
809
810 println!("\n Test 4: Collection Operations");
812 let mut collection = EmbeddingCollection::new();
813 assert!(collection.is_empty());
814
815 let test_embedding = simulate_embedding("test", model);
816 let index = collection.add(test_embedding);
817 assert_eq!(index, 0);
818 assert_eq!(collection.len(), 1);
819 assert!(!collection.is_empty());
820
821 println!(" Collection operations test passed");
822
823 println!("\n Test 5: Model Capabilities");
825 assert!(EmbeddingModel::TextEmbedding3Small.supports_dimensions());
826 assert!(EmbeddingModel::TextEmbedding3Large.supports_dimensions());
827 assert!(!EmbeddingModel::Ada002.supports_dimensions());
828
829 println!(" Model capabilities test passed");
830
831 println!("\n Testing best practices:");
832 println!(" Test embedding properties and dimensions");
833 println!(" Validate similarity calculations");
834 println!(" Test error conditions and edge cases");
835 println!(" Test with known similar/dissimilar text pairs");
836 println!(" Use deterministic test data for reproducible results");
837
838 Ok(())
839}
840
841fn simulate_embedding(text: &str, model: EmbeddingModel) -> Embedding {
843 use std::collections::hash_map::DefaultHasher;
844 use std::hash::{Hash, Hasher};
845
846 let dimensions = model.default_dimensions();
847
848 let mut hasher = DefaultHasher::new();
850 text.hash(&mut hasher);
851 model.as_str().hash(&mut hasher);
852 let seed = hasher.finish();
853
854 let mut rng = XorShift64Star::new(seed);
855 let mut vector = Vec::with_capacity(dimensions);
856
857 for _ in 0..dimensions {
859 vector.push(rng.next_f32() - 0.5);
860 }
861
862 let norm = calculate_vector_norm(&vector);
864 if norm > 0.0 {
865 for value in &mut vector {
866 *value /= norm;
867 }
868 }
869
870 let token_count = text.split_whitespace().count().max(1);
871
872 let mut embedding = Embedding::new(vector, text.to_string(), model);
873 embedding.token_count = Some(token_count);
874
875 embedding
876}
877
878fn simulate_reduced_embedding(
880 text: &str,
881 model: EmbeddingModel,
882 dimensions: usize,
883) -> Result<Embedding, Box<dyn std::error::Error>> {
884 if !model.supports_dimensions() {
885 return Err(EmbeddingError::DimensionReductionNotSupported {
886 model: model.as_str().to_string(),
887 }
888 .into());
889 }
890
891 let mut original = simulate_embedding(text, model);
892
893 original.vector.truncate(dimensions);
895 let norm = calculate_vector_norm(&original.vector);
896 if norm > 0.0 {
897 for value in &mut original.vector {
898 *value /= norm;
899 }
900 }
901
902 Ok(original)
903}
904
905fn calculate_vector_norm(vector: &[f32]) -> f32 {
907 vector.iter().map(|x| x * x).sum::<f32>().sqrt()
908}
909
910struct XorShift64Star {
912 state: u64,
913}
914
915impl XorShift64Star {
916 const fn new(seed: u64) -> Self {
917 Self {
918 state: if seed == 0 { 1 } else { seed },
919 }
920 }
921
922 fn next_u64(&mut self) -> u64 {
923 self.state ^= self.state >> 12;
924 self.state ^= self.state << 25;
925 self.state ^= self.state >> 27;
926 self.state.wrapping_mul(0x2545_F491_4F6C_DD1D)
927 }
928
929 #[allow(clippy::cast_precision_loss)]
930 fn next_f32(&mut self) -> f32 {
931 (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
932 }
933}
934
935fn handle_embedding_error(error: &dyn std::error::Error) {
937 eprintln!(" Embedding Error: {}", error);
939
940 if let Some(source) = error.source() {
941 eprintln!(" Caused by: {}", source);
942 }
943
944 eprintln!(" Troubleshooting tips:");
946 eprintln!(" - Check your API key and network connection");
947 eprintln!(" - Verify text input is not empty");
948 eprintln!(" - Ensure model supports requested features");
949 eprintln!(" - Check dimension parameters are valid");
950}