Skip to main content

oxirs_vec/embeddings/
openaiembeddinggenerator_traits.rs

1//! # OpenAIEmbeddingGenerator - Trait Implementations
2//!
3//! This module contains trait implementations for `OpenAIEmbeddingGenerator`.
4//!
5//! ## Implemented Traits
6//!
7//! - `EmbeddingGenerator`
8//! - `AsAny`
9//!
10//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
11
12use crate::Vector;
13use anyhow::{anyhow, Result};
14
15use super::functions::{AsAny, EmbeddingGenerator};
16use super::openaiembeddinggenerator_type::OpenAIEmbeddingGenerator;
17use super::types::{EmbeddableContent, EmbeddingConfig, RateLimiter};
18
19impl EmbeddingGenerator for OpenAIEmbeddingGenerator {
20    fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
21        if self.openai_config.enable_cache {
22            let hash = content.content_hash();
23            if let Ok(mut cache) = self.request_cache.lock() {
24                if let Some(cached) = cache.get(&hash) {
25                    return Ok(cached.vector.clone());
26                }
27            }
28        }
29        let rt = tokio::runtime::Runtime::new()
30            .map_err(|e| anyhow!("Failed to create async runtime: {}", e))?;
31        let mut temp_generator = OpenAIEmbeddingGenerator {
32            config: self.config.clone(),
33            openai_config: self.openai_config.clone(),
34            client: self.client.clone(),
35            rate_limiter: RateLimiter::new(self.openai_config.requests_per_minute),
36            request_cache: self.request_cache.clone(),
37            metrics: self.metrics.clone(),
38        };
39        rt.block_on(temp_generator.generate_async(content))
40    }
41    fn dimensions(&self) -> usize {
42        self.config.dimensions
43    }
44    fn config(&self) -> &EmbeddingConfig {
45        &self.config
46    }
47}
48
49impl AsAny for OpenAIEmbeddingGenerator {
50    fn as_any(&self) -> &dyn std::any::Any {
51        self
52    }
53    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
54        self
55    }
56}