1use crate::Vector;
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10
11use super::functions::EmbeddingGenerator;
12use super::openaiembeddinggenerator_type::OpenAIEmbeddingGenerator;
13use super::sentencetransformergenerator_type::SentenceTransformerGenerator;
14
15pub struct EmbeddingCache {
17 cache: HashMap<u64, Vector>,
18 max_size: usize,
19 access_order: Vec<u64>,
20}
21impl EmbeddingCache {
22 pub fn new(max_size: usize) -> Self {
23 Self {
24 cache: HashMap::new(),
25 max_size,
26 access_order: Vec::new(),
27 }
28 }
29 pub fn get(&mut self, content: &EmbeddableContent) -> Option<&Vector> {
30 let hash = content.content_hash();
31 if let Some(vector) = self.cache.get(&hash) {
32 if let Some(pos) = self.access_order.iter().position(|&x| x == hash) {
33 self.access_order.remove(pos);
34 }
35 self.access_order.push(hash);
36 Some(vector)
37 } else {
38 None
39 }
40 }
41 pub fn insert(&mut self, content: &EmbeddableContent, vector: Vector) {
42 let hash = content.content_hash();
43 if self.cache.len() >= self.max_size && !self.cache.contains_key(&hash) {
44 if let Some(&lru_hash) = self.access_order.first() {
45 self.cache.remove(&lru_hash);
46 self.access_order.remove(0);
47 }
48 }
49 self.cache.insert(hash, vector);
50 self.access_order.push(hash);
51 }
52 pub fn clear(&mut self) {
53 self.cache.clear();
54 self.access_order.clear();
55 }
56 pub fn size(&self) -> usize {
57 self.cache.len()
58 }
59}
60#[derive(Debug, Clone)]
62pub struct ModelDetails {
63 pub vocab_size: usize,
64 pub num_layers: usize,
65 pub num_attention_heads: usize,
66 pub hidden_size: usize,
67 pub intermediate_size: usize,
68 pub max_position_embeddings: usize,
69 pub supports_languages: Vec<String>,
70 pub model_size_mb: usize,
71 pub typical_inference_time_ms: u64,
72}
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
75pub enum RetryStrategy {
76 Fixed,
78 ExponentialBackoff,
80 LinearBackoff,
82}
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EmbeddingConfig {
86 pub model_name: String,
87 pub dimensions: usize,
88 pub max_sequence_length: usize,
89 pub normalize: bool,
90}
91#[cfg(test)]
93pub struct MockEmbeddingGenerator {
94 pub(super) config: EmbeddingConfig,
95}
96#[cfg(test)]
97impl MockEmbeddingGenerator {
98 pub fn new() -> Self {
99 Self {
100 config: EmbeddingConfig {
101 dimensions: 128,
102 ..Default::default()
103 },
104 }
105 }
106 pub fn with_dimensions(dimensions: usize) -> Self {
107 Self {
108 config: EmbeddingConfig {
109 dimensions,
110 ..Default::default()
111 },
112 }
113 }
114}
115#[derive(Debug, Clone)]
117pub enum EmbeddableContent {
118 Text(String),
120 RdfResource {
122 uri: String,
123 label: Option<String>,
124 description: Option<String>,
125 properties: HashMap<String, Vec<String>>,
126 },
127 SparqlQuery(String),
129 GraphPattern(String),
131}
132impl EmbeddableContent {
133 pub fn to_text(&self) -> String {
135 match self {
136 EmbeddableContent::Text(text) => text.clone(),
137 EmbeddableContent::RdfResource {
138 uri,
139 label,
140 description,
141 properties,
142 } => {
143 let mut text_parts = vec![uri.clone()];
144 if let Some(label) = label {
145 text_parts.push(format!("label: {label}"));
146 }
147 if let Some(desc) = description {
148 text_parts.push(format!("description: {desc}"));
149 }
150 for (prop, values) in properties {
151 text_parts.push(format!("{prop}: {}", values.join(", ")));
152 }
153 text_parts.join(" ")
154 }
155 EmbeddableContent::SparqlQuery(query) => query.clone(),
156 EmbeddableContent::GraphPattern(pattern) => pattern.clone(),
157 }
158 }
159 pub fn content_hash(&self) -> u64 {
161 let mut hasher = std::collections::hash_map::DefaultHasher::new();
162 self.to_text().hash(&mut hasher);
163 hasher.finish()
164 }
165}
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub enum EmbeddingStrategy {
169 TfIdf,
171 SentenceTransformer,
173 Transformer(TransformerModelType),
175 Word2Vec(crate::word2vec::Word2VecConfig),
177 OpenAI(OpenAIConfig),
179 Custom(String),
181}
182pub struct EmbeddingManager {
184 generator: Box<dyn EmbeddingGenerator>,
185 cache: EmbeddingCache,
186 strategy: EmbeddingStrategy,
187}
188impl EmbeddingManager {
189 pub fn new(strategy: EmbeddingStrategy, cache_size: usize) -> Result<Self> {
190 let generator: Box<dyn EmbeddingGenerator> = match &strategy {
191 EmbeddingStrategy::TfIdf => {
192 let config = EmbeddingConfig::default();
193 Box::new(TfIdfEmbeddingGenerator::new(config))
194 }
195 EmbeddingStrategy::SentenceTransformer => {
196 let config = EmbeddingConfig::default();
197 Box::new(SentenceTransformerGenerator::new(config))
198 }
199 EmbeddingStrategy::Transformer(model_type) => {
200 let config = EmbeddingConfig {
201 model_name: format!("{model_type:?}"),
202 dimensions: match model_type {
203 TransformerModelType::DistilBERT => 384,
204 _ => 768,
205 },
206 max_sequence_length: 512,
207 normalize: true,
208 };
209 Box::new(SentenceTransformerGenerator::with_model_type(
210 config,
211 model_type.clone(),
212 ))
213 }
214 EmbeddingStrategy::Word2Vec(word2vec_config) => {
215 let embedding_config = EmbeddingConfig {
216 model_name: "word2vec".to_string(),
217 dimensions: word2vec_config.dimensions,
218 max_sequence_length: 512,
219 normalize: word2vec_config.normalize,
220 };
221 Box::new(crate::word2vec::Word2VecEmbeddingGenerator::new(
222 word2vec_config.clone(),
223 embedding_config,
224 )?)
225 }
226 EmbeddingStrategy::OpenAI(openai_config) => {
227 Box::new(OpenAIEmbeddingGenerator::new(openai_config.clone())?)
228 }
229 EmbeddingStrategy::Custom(_model_path) => {
230 let config = EmbeddingConfig::default();
231 Box::new(SentenceTransformerGenerator::new(config))
232 }
233 };
234 Ok(Self {
235 generator,
236 cache: EmbeddingCache::new(cache_size),
237 strategy,
238 })
239 }
240 pub fn get_embedding(&mut self, content: &EmbeddableContent) -> Result<Vector> {
242 if let Some(cached) = self.cache.get(content) {
243 return Ok(cached.clone());
244 }
245 let embedding = self.generator.generate(content)?;
246 self.cache.insert(content, embedding.clone());
247 Ok(embedding)
248 }
249 pub fn precompute_embeddings(&mut self, contents: &[EmbeddableContent]) -> Result<()> {
251 let embeddings = self.generator.generate_batch(contents)?;
252 for (content, embedding) in contents.iter().zip(embeddings) {
253 self.cache.insert(content, embedding);
254 }
255 Ok(())
256 }
257 pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
259 if let EmbeddingStrategy::TfIdf = self.strategy {
260 if let Some(tfidf_gen) = self
261 .generator
262 .as_any_mut()
263 .downcast_mut::<TfIdfEmbeddingGenerator>()
264 {
265 tfidf_gen.build_vocabulary(documents)?;
266 }
267 }
268 Ok(())
269 }
270 pub fn dimensions(&self) -> usize {
271 self.generator.dimensions()
272 }
273 pub fn cache_stats(&self) -> (usize, usize) {
274 (self.cache.size(), self.cache.max_size)
275 }
276}
277#[derive(Debug, Clone, Serialize, Deserialize, Default)]
279pub enum TransformerModelType {
280 #[default]
282 BERT,
283 RoBERTa,
285 DistilBERT,
287 MultiBERT,
289 Custom(String),
291}
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct OpenAIConfig {
295 pub api_key: String,
297 pub model: String,
299 pub base_url: String,
301 pub timeout_seconds: u64,
303 pub requests_per_minute: u32,
305 pub batch_size: usize,
307 pub enable_cache: bool,
309 pub cache_size: usize,
311 pub cache_ttl_seconds: u64,
313 pub max_retries: u32,
315 pub retry_delay_ms: u64,
317 pub retry_strategy: RetryStrategy,
319 pub track_costs: bool,
321 pub enable_metrics: bool,
323 pub user_agent: String,
325}
326impl OpenAIConfig {
327 pub fn production() -> Self {
329 Self {
330 requests_per_minute: 1000,
331 cache_size: 50000,
332 cache_ttl_seconds: 7200,
333 max_retries: 5,
334 retry_strategy: RetryStrategy::ExponentialBackoff,
335 ..Default::default()
336 }
337 }
338 pub fn development() -> Self {
340 Self {
341 requests_per_minute: 100,
342 cache_size: 1000,
343 cache_ttl_seconds: 300,
344 max_retries: 2,
345 ..Default::default()
346 }
347 }
348 pub fn validate(&self) -> Result<()> {
350 if self.api_key.is_empty() {
351 return Err(anyhow!("OpenAI API key is required"));
352 }
353 if self.requests_per_minute == 0 {
354 return Err(anyhow!("requests_per_minute must be greater than 0"));
355 }
356 if self.batch_size == 0 {
357 return Err(anyhow!("batch_size must be greater than 0"));
358 }
359 if self.timeout_seconds == 0 {
360 return Err(anyhow!("timeout_seconds must be greater than 0"));
361 }
362 Ok(())
363 }
364}
365pub struct RateLimiter {
367 requests_per_minute: u32,
368 request_times: std::collections::VecDeque<std::time::Instant>,
369}
370impl RateLimiter {
371 pub fn new(requests_per_minute: u32) -> Self {
372 Self {
373 requests_per_minute,
374 request_times: std::collections::VecDeque::new(),
375 }
376 }
377 pub async fn wait_if_needed(&mut self) {
378 let now = std::time::Instant::now();
379 let minute_ago = now - std::time::Duration::from_secs(60);
380 while let Some(&front_time) = self.request_times.front() {
381 if front_time < minute_ago {
382 self.request_times.pop_front();
383 } else {
384 break;
385 }
386 }
387 if self.request_times.len() >= self.requests_per_minute as usize {
388 if let Some(&oldest) = self.request_times.front() {
389 let wait_time = oldest + std::time::Duration::from_secs(60) - now;
390 if !wait_time.is_zero() {
391 tokio::time::sleep(wait_time).await;
392 }
393 }
394 }
395 self.request_times.push_back(now);
396 }
397}
398#[derive(Debug, Clone, Default)]
400pub struct OpenAIMetrics {
401 pub total_requests: u64,
402 pub successful_requests: u64,
403 pub failed_requests: u64,
404 pub total_tokens_processed: u64,
405 pub cache_hits: u64,
406 pub cache_misses: u64,
407 pub total_cost_usd: f64,
408 pub retry_count: u64,
409 pub rate_limit_waits: u64,
410 pub average_response_time_ms: f64,
411 pub last_request_time: Option<std::time::SystemTime>,
412 pub requests_by_model: HashMap<String, u64>,
413 pub errors_by_type: HashMap<String, u64>,
414}
415impl OpenAIMetrics {
416 pub fn cache_hit_ratio(&self) -> f64 {
418 if self.cache_hits + self.cache_misses == 0 {
419 0.0
420 } else {
421 self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
422 }
423 }
424 pub fn success_rate(&self) -> f64 {
426 if self.total_requests == 0 {
427 0.0
428 } else {
429 self.successful_requests as f64 / self.total_requests as f64
430 }
431 }
432 pub fn average_cost_per_request(&self) -> f64 {
434 if self.successful_requests == 0 {
435 0.0
436 } else {
437 self.total_cost_usd / self.successful_requests as f64
438 }
439 }
440 pub fn report(&self) -> String {
442 format!(
443 "OpenAI Metrics Report:\n\
444 Total Requests: {}\n\
445 Success Rate: {:.2}%\n\
446 Cache Hit Ratio: {:.2}%\n\
447 Total Cost: ${:.4}\n\
448 Avg Cost/Request: ${:.6}\n\
449 Avg Response Time: {:.2}ms\n\
450 Retries: {}\n\
451 Rate Limit Waits: {}",
452 self.total_requests,
453 self.success_rate() * 100.0,
454 self.cache_hit_ratio() * 100.0,
455 self.total_cost_usd,
456 self.average_cost_per_request(),
457 self.average_response_time_ms,
458 self.retry_count,
459 self.rate_limit_waits
460 )
461 }
462}
463#[derive(Debug, Clone)]
465pub struct CachedEmbedding {
466 pub vector: Vector,
467 pub cached_at: std::time::SystemTime,
468 pub model: String,
469 pub cost_usd: f64,
470}
471pub struct TfIdfEmbeddingGenerator {
473 pub(super) config: EmbeddingConfig,
474 pub(super) vocabulary: HashMap<String, usize>,
475 idf_scores: HashMap<String, f32>,
476}
477impl TfIdfEmbeddingGenerator {
478 pub fn new(config: EmbeddingConfig) -> Self {
479 Self {
480 config,
481 vocabulary: HashMap::new(),
482 idf_scores: HashMap::new(),
483 }
484 }
485 pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
487 let mut word_counts: HashMap<String, usize> = HashMap::new();
488 let mut doc_counts: HashMap<String, usize> = HashMap::new();
489 for doc in documents {
490 let words: Vec<String> = self.tokenize(doc);
491 let unique_words: std::collections::HashSet<_> = words.iter().collect();
492 for word in &words {
493 *word_counts.entry(word.clone()).or_insert(0) += 1;
494 }
495 for word in unique_words {
496 *doc_counts.entry(word.clone()).or_insert(0) += 1;
497 }
498 }
499 let mut word_freq: Vec<(String, usize)> = word_counts.into_iter().collect();
500 word_freq.sort_by(|a, b| b.1.cmp(&a.1));
501 self.vocabulary = word_freq
502 .into_iter()
503 .take(self.config.dimensions)
504 .enumerate()
505 .map(|(idx, (word, _))| (word, idx))
506 .collect();
507 let total_docs = documents.len() as f32;
508 for word in self.vocabulary.keys() {
509 let doc_freq = doc_counts.get(word).unwrap_or(&0);
510 let idf = (total_docs / (*doc_freq as f32 + 1.0)).ln();
511 self.idf_scores.insert(word.clone(), idf);
512 }
513 Ok(())
514 }
515 fn tokenize(&self, text: &str) -> Vec<String> {
516 text.to_lowercase()
517 .split_whitespace()
518 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
519 .filter(|s| !s.is_empty())
520 .map(String::from)
521 .collect()
522 }
523 pub(super) fn calculate_tf_idf(&self, text: &str) -> Vector {
524 let words = self.tokenize(text);
525 let mut tf_counts: HashMap<String, usize> = HashMap::new();
526 for word in &words {
527 *tf_counts.entry(word.clone()).or_insert(0) += 1;
528 }
529 let total_words = words.len() as f32;
530 let mut embedding = vec![0.0; self.config.dimensions];
531 for (word, count) in tf_counts {
532 if let Some(&idx) = self.vocabulary.get(&word) {
533 let tf = count as f32 / total_words;
534 let idf = self.idf_scores.get(&word).unwrap_or(&0.0);
535 embedding[idx] = tf * idf;
536 }
537 }
538 if self.config.normalize {
539 self.normalize_vector(&mut embedding);
540 }
541 Vector::new(embedding)
542 }
543 fn normalize_vector(&self, vector: &mut [f32]) {
544 let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
545 if magnitude > 0.0 {
546 for value in vector {
547 *value /= magnitude;
548 }
549 }
550 }
551}