Skip to main content

sh_layer1/
embeddings.rs

1//! 嵌入模型模块
2//!
3//! 文本嵌入、批量处理、缓存。
4//!
5//! 支持多种嵌入模型提供商:
6//! - OpenAI Embeddings API
7//! - HuggingFace Inference API
8//! - Cohere Embed API
9//! - 本地 SentenceTransformers 模型
10
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tokio::sync::RwLock;
19
20// ============================================================================
21// 常量定义
22// ============================================================================
23
24/// 默认嵌入模型
25pub const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-ada-002";
26
27/// 默认嵌入维度
28pub const DEFAULT_EMBEDDING_DIMENSION: usize = 1536;
29
30/// 缓存默认 TTL(秒)
31pub const DEFAULT_CACHE_TTL_SECS: u64 = 3600;
32
33/// 缓存默认最大条目数
34pub const DEFAULT_CACHE_MAX_ENTRIES: usize = 10000;
35
36// ============================================================================
37// 统一 EmbeddingModel Trait
38// ============================================================================
39
40/// 嵌入模型统一接口
41///
42/// 所有嵌入模型必须实现此 trait。
43#[async_trait]
44pub trait EmbeddingModel: Send + Sync {
45    /// 生成单个文本的嵌入向量
46    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
47
48    /// 批量生成嵌入向量
49    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
50
51    /// 获取向量维度
52    fn dimension(&self) -> usize;
53
54    /// 获取模型名称
55    fn model_name(&self) -> &str;
56
57    /// 获取提供商名称
58    fn provider(&self) -> &str;
59}
60
61// ============================================================================
62// 嵌入缓存
63// ============================================================================
64
65/// 缓存条目
66#[derive(Debug, Clone)]
67struct CacheEntry {
68    /// 嵌入向量
69    embedding: Vec<f32>,
70    /// 创建时间
71    created_at: Instant,
72    /// 访问计数
73    access_count: usize,
74}
75
76/// 嵌入缓存
77///
78/// 使用 LRU 策略管理缓存条目。
79#[derive(Debug)]
80pub struct EmbeddingCache {
81    /// 缓存存储
82    store: RwLock<HashMap<String, CacheEntry>>,
83    /// 最大条目数
84    max_entries: usize,
85    /// TTL(秒)
86    ttl_secs: u64,
87}
88
89impl EmbeddingCache {
90    /// 创建新的缓存实例
91    pub fn new(max_entries: usize, ttl_secs: u64) -> Self {
92        Self {
93            store: RwLock::new(HashMap::new()),
94            max_entries,
95            ttl_secs,
96        }
97    }
98
99    /// 使用默认配置创建缓存
100    pub fn default_cache() -> Self {
101        Self::new(DEFAULT_CACHE_MAX_ENTRIES, DEFAULT_CACHE_TTL_SECS)
102    }
103
104    /// 生成缓存键
105    fn cache_key(provider: &str, model: &str, text: &str) -> String {
106        use std::collections::hash_map::DefaultHasher;
107        use std::hash::{Hash, Hasher};
108
109        let mut hasher = DefaultHasher::new();
110        provider.hash(&mut hasher);
111        model.hash(&mut hasher);
112        text.hash(&mut hasher);
113        format!("{}:{}:{:016x}", provider, model, hasher.finish())
114    }
115
116    /// 获取缓存的嵌入向量
117    pub async fn get(&self, provider: &str, model: &str, text: &str) -> Option<Vec<f32>> {
118        let key = Self::cache_key(provider, model, text);
119        let mut store = self.store.write().await;
120
121        if let Some(entry) = store.get_mut(&key) {
122            // 检查是否过期
123            if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
124                store.remove(&key);
125                return None;
126            }
127
128            entry.access_count += 1;
129            return Some(entry.embedding.clone());
130        }
131
132        None
133    }
134
135    /// 存储嵌入向量到缓存
136    pub async fn put(&self, provider: &str, model: &str, text: &str, embedding: Vec<f32>) {
137        let key = Self::cache_key(provider, model, text);
138        let mut store = self.store.write().await;
139
140        // 如果达到最大条目数,移除最少访问的条目
141        if store.len() >= self.max_entries {
142            if let Some((lru_key, _)) = store
143                .iter()
144                .min_by_key(|(_, e)| e.access_count)
145                .map(|(k, v)| (k.clone(), v.access_count))
146            {
147                store.remove(&lru_key);
148            }
149        }
150
151        store.insert(
152            key,
153            CacheEntry {
154                embedding,
155                created_at: Instant::now(),
156                access_count: 0,
157            },
158        );
159    }
160
161    /// 批量获取缓存的嵌入向量
162    pub async fn get_batch(
163        &self,
164        provider: &str,
165        model: &str,
166        texts: &[String],
167    ) -> Vec<Option<Vec<f32>>> {
168        let mut results = Vec::with_capacity(texts.len());
169        let mut store = self.store.write().await;
170
171        for text in texts {
172            let key = Self::cache_key(provider, model, text);
173
174            if let Some(entry) = store.get_mut(&key) {
175                if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
176                    store.remove(&key);
177                    results.push(None);
178                } else {
179                    entry.access_count += 1;
180                    results.push(Some(entry.embedding.clone()));
181                }
182            } else {
183                results.push(None);
184            }
185        }
186
187        results
188    }
189
190    /// 清空缓存
191    pub async fn clear(&self) {
192        let mut store = self.store.write().await;
193        store.clear();
194    }
195
196    /// 获取缓存统计信息
197    pub async fn stats(&self) -> CacheStats {
198        let store = self.store.read().await;
199        let total_entries = store.len();
200        let total_access: usize = store.values().map(|e| e.access_count).sum();
201
202        CacheStats {
203            total_entries,
204            total_access,
205            max_entries: self.max_entries,
206            ttl_secs: self.ttl_secs,
207        }
208    }
209}
210
211/// 缓存统计信息
212#[derive(Debug, Clone)]
213pub struct CacheStats {
214    pub total_entries: usize,
215    pub total_access: usize,
216    pub max_entries: usize,
217    pub ttl_secs: u64,
218}
219
220// ============================================================================
221// 模型配置
222// ============================================================================
223
224/// 嵌入模型提供商类型
225#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum EmbeddingProvider {
227    OpenAI,
228    HuggingFace,
229    Cohere,
230    Local,
231    /// Mock 提供商,用于测试和安全默认值
232    Mock,
233}
234
235impl EmbeddingProvider {
236    pub fn as_str(&self) -> &'static str {
237        match self {
238            Self::OpenAI => "openai",
239            Self::HuggingFace => "huggingface",
240            Self::Cohere => "cohere",
241            Self::Local => "local",
242            Self::Mock => "mock",
243        }
244    }
245}
246
247/// 嵌入模型配置
248#[derive(Debug, Clone)]
249pub struct EmbeddingsConfig {
250    /// 提供商类型
251    pub provider: EmbeddingProvider,
252    /// API 密钥(本地模型可为空)
253    pub api_key: String,
254    /// API 基础 URL(可选)
255    pub base_url: Option<String>,
256    /// 模型名称
257    pub model: String,
258    /// 向量维度(可选,用于本地模型)
259    pub dimension: Option<usize>,
260}
261
262impl Default for EmbeddingsConfig {
263    fn default() -> Self {
264        // 安全默认值:使用 Mock 提供商
265        // 这样可以避免在未配置环境下意外调用外部 API
266        Self {
267            provider: EmbeddingProvider::Mock,
268            api_key: String::new(),
269            base_url: None,
270            model: "mock-embedding".to_string(),
271            dimension: Some(DEFAULT_EMBEDDING_DIMENSION),
272        }
273    }
274}
275
276impl EmbeddingsConfig {
277    /// 从环境变量创建 OpenAI 配置
278    pub fn openai_from_env() -> Result<Self> {
279        let api_key = std::env::var("OPENAI_API_KEY")
280            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
281
282        let base_url = std::env::var("OPENAI_BASE_URL")
283            .ok()
284            .or_else(|| Some("https://api.openai.com/v1".to_string()));
285
286        let model = std::env::var("OPENAI_EMBEDDING_MODEL")
287            .unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.to_string());
288
289        Ok(Self {
290            provider: EmbeddingProvider::OpenAI,
291            api_key,
292            base_url,
293            model,
294            dimension: None,
295        })
296    }
297
298    /// 从环境变量创建 HuggingFace 配置
299    pub fn huggingface_from_env() -> Result<Self> {
300        let api_key = std::env::var("HUGGINGFACE_API_KEY")
301            .map_err(|_| anyhow!("HUGGINGFACE_API_KEY environment variable not set"))?;
302
303        let model = std::env::var("HUGGINGFACE_EMBEDDING_MODEL")
304            .unwrap_or_else(|_| "sentence-transformers/all-MiniLM-L6-v2".to_string());
305
306        Ok(Self {
307            provider: EmbeddingProvider::HuggingFace,
308            api_key,
309            base_url: Some(
310                "https://api-inference.huggingface.co/pipeline/feature-extraction".to_string(),
311            ),
312            model,
313            dimension: None,
314        })
315    }
316
317    /// 从环境变量创建 Cohere 配置
318    pub fn cohere_from_env() -> Result<Self> {
319        let api_key = std::env::var("COHERE_API_KEY")
320            .map_err(|_| anyhow!("COHERE_API_KEY environment variable not set"))?;
321
322        let model = std::env::var("COHERE_EMBEDDING_MODEL")
323            .unwrap_or_else(|_| "embed-english-v3.0".to_string());
324
325        Ok(Self {
326            provider: EmbeddingProvider::Cohere,
327            api_key,
328            base_url: Some("https://api.cohere.ai/v1".to_string()),
329            model,
330            dimension: None,
331        })
332    }
333
334    /// 创建本地模型配置
335    pub fn local(model: impl Into<String>, dimension: Option<usize>) -> Self {
336        Self {
337            provider: EmbeddingProvider::Local,
338            api_key: String::new(),
339            base_url: None,
340            model: model.into(),
341            dimension,
342        }
343    }
344
345    /// 检查配置是否有效(本地模型和 Mock 不需要 API key)
346    pub fn is_valid(&self) -> bool {
347        matches!(
348            self.provider,
349            EmbeddingProvider::Local | EmbeddingProvider::Mock
350        ) || !self.api_key.is_empty()
351    }
352}
353
354// ============================================================================
355// OpenAI 实现
356// ============================================================================
357
358/// OpenAI 嵌入模型
359#[derive(Debug)]
360pub struct OpenAIEmbeddings {
361    client: Client,
362    config: EmbeddingsConfig,
363    cache: Option<Arc<EmbeddingCache>>,
364}
365
366impl OpenAIEmbeddings {
367    pub fn new(config: EmbeddingsConfig) -> Result<Self> {
368        if !config.is_valid() {
369            return Err(anyhow!("OpenAI Embeddings API not configured"));
370        }
371
372        Ok(Self {
373            client: Client::new(),
374            config,
375            cache: None,
376        })
377    }
378
379    pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
380        let mut embeddings = Self::new(config)?;
381        embeddings.cache = Some(cache);
382        Ok(embeddings)
383    }
384
385    fn base_url(&self) -> &str {
386        self.config
387            .base_url
388            .as_deref()
389            .unwrap_or("https://api.openai.com/v1")
390    }
391}
392
393#[async_trait]
394impl EmbeddingModel for OpenAIEmbeddings {
395    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
396        let embeddings = self.embed_batch(&[text.to_string()]).await?;
397        embeddings
398            .into_iter()
399            .next()
400            .ok_or_else(|| anyhow!("No embedding returned"))
401    }
402
403    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
404        if texts.is_empty() {
405            return Ok(Vec::new());
406        }
407
408        // 检查缓存
409        if let Some(cache) = &self.cache {
410            let cached = cache.get_batch("openai", &self.config.model, texts).await;
411            let all_cached = cached.iter().all(|c| c.is_some());
412            if all_cached {
413                return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
414            }
415        }
416
417        let url = format!("{}/embeddings", self.base_url());
418
419        let request_body = OpenAiEmbeddingRequest {
420            model: self.config.model.clone(),
421            input: texts.to_vec(),
422            encoding_format: Some("float".to_string()),
423        };
424
425        tracing::debug!("Sending OpenAI embedding request for {} texts", texts.len());
426
427        let response = self
428            .client
429            .post(&url)
430            .header("Authorization", format!("Bearer {}", self.config.api_key))
431            .header("Content-Type", "application/json")
432            .json(&request_body)
433            .send()
434            .await?;
435
436        let status = response.status();
437        let response_text = response.text().await?;
438
439        if !status.is_success() {
440            tracing::error!("OpenAI Embedding API error: {} - {}", status, response_text);
441            return Err(anyhow!(
442                "OpenAI Embedding API request failed with status {}: {}",
443                status,
444                response_text
445            ));
446        }
447
448        let response_body: OpenAiEmbeddingResponse =
449            serde_json::from_str(&response_text).map_err(|e| {
450                anyhow!(
451                    "Failed to parse OpenAI embedding response: {} - {}",
452                    e,
453                    response_text
454                )
455            })?;
456
457        // 按 index 排序并提取向量
458        let mut embeddings: Vec<(usize, Vec<f32>)> = response_body
459            .data
460            .into_iter()
461            .map(|item| (item.index, item.embedding))
462            .collect();
463        embeddings.sort_by_key(|(idx, _)| *idx);
464        let result: Vec<Vec<f32>> = embeddings.into_iter().map(|(_, emb)| emb).collect();
465
466        // 存入缓存
467        if let Some(cache) = &self.cache {
468            for (text, embedding) in texts.iter().zip(result.iter()) {
469                cache
470                    .put("openai", &self.config.model, text, embedding.clone())
471                    .await;
472            }
473        }
474
475        Ok(result)
476    }
477
478    fn dimension(&self) -> usize {
479        match self.config.model.as_str() {
480            "text-embedding-ada-002" => 1536,
481            "text-embedding-3-small" => 1536,
482            "text-embedding-3-large" => 3072,
483            _ => DEFAULT_EMBEDDING_DIMENSION,
484        }
485    }
486
487    fn model_name(&self) -> &str {
488        &self.config.model
489    }
490
491    fn provider(&self) -> &str {
492        "openai"
493    }
494}
495
496#[derive(Serialize)]
497struct OpenAiEmbeddingRequest {
498    model: String,
499    input: Vec<String>,
500    #[serde(skip_serializing_if = "Option::is_none")]
501    encoding_format: Option<String>,
502}
503
504#[derive(Deserialize)]
505struct OpenAiEmbeddingResponse {
506    data: Vec<OpenAiEmbeddingData>,
507    #[allow(dead_code)]
508    model: String,
509    #[allow(dead_code)]
510    usage: OpenAiEmbeddingUsage,
511}
512
513#[derive(Deserialize)]
514struct OpenAiEmbeddingData {
515    embedding: Vec<f32>,
516    index: usize,
517    #[allow(dead_code)]
518    object: String,
519}
520
521#[derive(Deserialize)]
522#[allow(dead_code)]
523struct OpenAiEmbeddingUsage {
524    prompt_tokens: u32,
525    total_tokens: u32,
526}
527
528// ============================================================================
529// HuggingFace 实现
530// ============================================================================
531
532/// HuggingFace 嵌入模型
533#[derive(Debug)]
534pub struct HuggingFaceEmbeddings {
535    client: Client,
536    config: EmbeddingsConfig,
537    cache: Option<Arc<EmbeddingCache>>,
538}
539
540impl HuggingFaceEmbeddings {
541    pub fn new(config: EmbeddingsConfig) -> Result<Self> {
542        if !config.is_valid() {
543            return Err(anyhow!("HuggingFace API not configured"));
544        }
545
546        Ok(Self {
547            client: Client::new(),
548            config,
549            cache: None,
550        })
551    }
552
553    pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
554        let mut embeddings = Self::new(config)?;
555        embeddings.cache = Some(cache);
556        Ok(embeddings)
557    }
558}
559
560#[async_trait]
561impl EmbeddingModel for HuggingFaceEmbeddings {
562    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
563        // HuggingFace API 返回格式取决于模型,通常需要单独调用
564        let embeddings = self.embed_batch(&[text.to_string()]).await?;
565        embeddings
566            .into_iter()
567            .next()
568            .ok_or_else(|| anyhow!("No embedding returned from HuggingFace"))
569    }
570
571    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
572        if texts.is_empty() {
573            return Ok(Vec::new());
574        }
575
576        // 检查缓存
577        if let Some(cache) = &self.cache {
578            let cached = cache
579                .get_batch("huggingface", &self.config.model, texts)
580                .await;
581            let all_cached = cached.iter().all(|c| c.is_some());
582            if all_cached {
583                return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
584            }
585        }
586
587        let url = format!(
588            "https://api-inference.huggingface.co/pipeline/feature-extraction/{}",
589            self.config.model
590        );
591
592        tracing::debug!(
593            "Sending HuggingFace embedding request for {} texts",
594            texts.len()
595        );
596
597        let response = self
598            .client
599            .post(&url)
600            .header("Authorization", format!("Bearer {}", self.config.api_key))
601            .header("Content-Type", "application/json")
602            .json(&serde_json::json!({ "inputs": texts }))
603            .send()
604            .await?;
605
606        let status = response.status();
607        let response_text = response.text().await?;
608
609        if !status.is_success() {
610            tracing::error!("HuggingFace API error: {} - {}", status, response_text);
611            return Err(anyhow!(
612                "HuggingFace API request failed with status {}: {}",
613                status,
614                response_text
615            ));
616        }
617
618        // HuggingFace 返回格式: [[f32, f32, ...], ...] 或 [[f32], [f32], ...]
619        let embeddings: Vec<Vec<f32>> = serde_json::from_str(&response_text).map_err(|e| {
620            anyhow!(
621                "Failed to parse HuggingFace response: {} - {}",
622                e,
623                response_text
624            )
625        })?;
626
627        // 存入缓存
628        if let Some(cache) = &self.cache {
629            for (text, embedding) in texts.iter().zip(embeddings.iter()) {
630                cache
631                    .put("huggingface", &self.config.model, text, embedding.clone())
632                    .await;
633            }
634        }
635
636        Ok(embeddings)
637    }
638
639    fn dimension(&self) -> usize {
640        // 常见模型的维度
641        match self.config.model.as_str() {
642            "sentence-transformers/all-MiniLM-L6-v2" => 384,
643            "sentence-transformers/all-mpnet-base-v2" => 768,
644            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" => 384,
645            _ => self.config.dimension.unwrap_or(768),
646        }
647    }
648
649    fn model_name(&self) -> &str {
650        &self.config.model
651    }
652
653    fn provider(&self) -> &str {
654        "huggingface"
655    }
656}
657
658// ============================================================================
659// Cohere 实现
660// ============================================================================
661
662/// Cohere 嵌入模型
663#[derive(Debug)]
664pub struct CohereEmbeddings {
665    client: Client,
666    config: EmbeddingsConfig,
667    cache: Option<Arc<EmbeddingCache>>,
668}
669
670impl CohereEmbeddings {
671    pub fn new(config: EmbeddingsConfig) -> Result<Self> {
672        if !config.is_valid() {
673            return Err(anyhow!("Cohere API not configured"));
674        }
675
676        Ok(Self {
677            client: Client::new(),
678            config,
679            cache: None,
680        })
681    }
682
683    pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
684        let mut embeddings = Self::new(config)?;
685        embeddings.cache = Some(cache);
686        Ok(embeddings)
687    }
688}
689
690#[async_trait]
691impl EmbeddingModel for CohereEmbeddings {
692    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
693        let embeddings = self.embed_batch(&[text.to_string()]).await?;
694        embeddings
695            .into_iter()
696            .next()
697            .ok_or_else(|| anyhow!("No embedding returned from Cohere"))
698    }
699
700    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
701        if texts.is_empty() {
702            return Ok(Vec::new());
703        }
704
705        // 检查缓存
706        if let Some(cache) = &self.cache {
707            let cached = cache.get_batch("cohere", &self.config.model, texts).await;
708            let all_cached = cached.iter().all(|c| c.is_some());
709            if all_cached {
710                return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
711            }
712        }
713
714        let url = "https://api.cohere.ai/v1/embed";
715
716        let request_body = CohereEmbeddingRequest {
717            model: self.config.model.clone(),
718            texts: texts.to_vec(),
719            input_type: "search_document",
720            embedding_types: Some(vec!["float".to_string()]),
721        };
722
723        tracing::debug!("Sending Cohere embedding request for {} texts", texts.len());
724
725        let response = self
726            .client
727            .post(url)
728            .header("Authorization", format!("Bearer {}", self.config.api_key))
729            .header("Content-Type", "application/json")
730            .json(&request_body)
731            .send()
732            .await?;
733
734        let status = response.status();
735        let response_text = response.text().await?;
736
737        if !status.is_success() {
738            tracing::error!("Cohere API error: {} - {}", status, response_text);
739            return Err(anyhow!(
740                "Cohere API request failed with status {}: {}",
741                status,
742                response_text
743            ));
744        }
745
746        let response_body: CohereEmbeddingResponse = serde_json::from_str(&response_text)
747            .map_err(|e| anyhow!("Failed to parse Cohere response: {} - {}", e, response_text))?;
748
749        let result = response_body.embeddings.float;
750
751        // 存入缓存
752        if let Some(cache) = &self.cache {
753            for (text, embedding) in texts.iter().zip(result.iter()) {
754                cache
755                    .put("cohere", &self.config.model, text, embedding.clone())
756                    .await;
757            }
758        }
759
760        Ok(result)
761    }
762
763    fn dimension(&self) -> usize {
764        match self.config.model.as_str() {
765            "embed-english-v3.0" | "embed-english-light-v3.0" => 1024,
766            "embed-multilingual-v3.0" => 1024,
767            "embed-english-v2.0" => 4096,
768            _ => self.config.dimension.unwrap_or(1024),
769        }
770    }
771
772    fn model_name(&self) -> &str {
773        &self.config.model
774    }
775
776    fn provider(&self) -> &str {
777        "cohere"
778    }
779}
780
781#[derive(Serialize)]
782struct CohereEmbeddingRequest {
783    model: String,
784    texts: Vec<String>,
785    input_type: &'static str,
786    #[serde(skip_serializing_if = "Option::is_none")]
787    embedding_types: Option<Vec<String>>,
788}
789
790#[derive(Deserialize)]
791struct CohereEmbeddingResponse {
792    embeddings: CohereEmbeddingsData,
793    #[allow(dead_code)]
794    id: String,
795    #[allow(dead_code)]
796    text_type: String,
797}
798
799#[derive(Deserialize)]
800struct CohereEmbeddingsData {
801    float: Vec<Vec<f32>>,
802}
803
804// ============================================================================
805// 本地模型实现 (SentenceTransformers)
806// ============================================================================
807
808/// 本地 SentenceTransformers 嵌入模型
809///
810/// 注意:此实现需要 `candle` 或 `ort` 特性启用。
811/// 在纯 Rust 环境下,使用占位实现。
812pub struct LocalEmbeddings {
813    config: EmbeddingsConfig,
814    cache: Option<Arc<EmbeddingCache>>,
815    #[cfg(feature = "local-embeddings")]
816    #[allow(dead_code)]
817    model: Option<std::sync::Mutex<Box<dyn LocalModelBackend>>>,
818}
819
820impl std::fmt::Debug for LocalEmbeddings {
821    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
822        f.debug_struct("LocalEmbeddings")
823            .field("config", &self.config)
824            .field("cache", &self.cache)
825            .field("model", &"<model>")
826            .finish()
827    }
828}
829
830impl LocalEmbeddings {
831    pub fn new(config: EmbeddingsConfig) -> Result<Self> {
832        Ok(Self {
833            config,
834            cache: None,
835            #[cfg(feature = "local-embeddings")]
836            model: None,
837        })
838    }
839
840    pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
841        let mut embeddings = Self::new(config)?;
842        embeddings.cache = Some(cache);
843        Ok(embeddings)
844    }
845
846    /// 加载本地模型
847    #[cfg(feature = "local-embeddings")]
848    pub fn load_model(&mut self) -> Result<()> {
849        // 使用 candle 或 ort 加载模型
850        // 这是一个占位实现
851        tracing::info!("Loading local embedding model: {}", self.config.model);
852        Ok(())
853    }
854}
855
856#[async_trait]
857impl EmbeddingModel for LocalEmbeddings {
858    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
859        // 检查缓存
860        if let Some(cache) = &self.cache {
861            if let Some(embedding) = cache.get("local", &self.config.model, text).await {
862                return Ok(embedding);
863            }
864        }
865
866        #[cfg(feature = "local-embeddings")]
867        {
868            // 实际实现使用 candle 或 ort
869            // 这里是占位代码
870            let embedding = vec![0.0f32; self.dimension()];
871
872            if let Some(cache) = &self.cache {
873                cache
874                    .put("local", &self.config.model, text, embedding.clone())
875                    .await;
876            }
877
878            Ok(embedding)
879        }
880
881        #[cfg(not(feature = "local-embeddings"))]
882        {
883            Err(anyhow!(
884                "Local embeddings require 'local-embeddings' feature. \
885                 Enable it in Cargo.toml and ensure candle or ort is available."
886            ))
887        }
888    }
889
890    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
891        // 检查缓存
892        if let Some(cache) = &self.cache {
893            let cached = cache.get_batch("local", &self.config.model, texts).await;
894            if cached.iter().all(|c| c.is_some()) {
895                return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
896            }
897        }
898
899        #[cfg(feature = "local-embeddings")]
900        {
901            let mut results = Vec::with_capacity(texts.len());
902            for text in texts {
903                results.push(self.embed(text).await?);
904            }
905
906            // 存入缓存
907            if let Some(cache) = &self.cache {
908                for (text, embedding) in texts.iter().zip(results.iter()) {
909                    cache
910                        .put("local", &self.config.model, text, embedding.clone())
911                        .await;
912                }
913            }
914
915            Ok(results)
916        }
917
918        #[cfg(not(feature = "local-embeddings"))]
919        {
920            Err(anyhow!(
921                "Local embeddings require 'local-embeddings' feature"
922            ))
923        }
924    }
925
926    fn dimension(&self) -> usize {
927        self.config.dimension.unwrap_or(384)
928    }
929
930    fn model_name(&self) -> &str {
931        &self.config.model
932    }
933
934    fn provider(&self) -> &str {
935        "local"
936    }
937}
938
939/// 本地模型后端 trait
940#[cfg(feature = "local-embeddings")]
941#[allow(dead_code)]
942trait LocalModelBackend: Send + Sync {
943    fn encode(&self, text: &str) -> Result<Vec<f32>>;
944}
945
946// ============================================================================
947// 统一 Embeddings 工厂
948// ============================================================================
949
950/// 嵌入模型工厂
951pub struct EmbeddingsFactory {
952    cache: Arc<EmbeddingCache>,
953}
954
955impl EmbeddingsFactory {
956    pub fn new() -> Self {
957        Self {
958            cache: Arc::new(EmbeddingCache::default_cache()),
959        }
960    }
961
962    pub fn with_cache(cache: Arc<EmbeddingCache>) -> Self {
963        Self { cache }
964    }
965
966    /// 创建嵌入模型实例
967    pub fn create(&self, config: EmbeddingsConfig) -> Result<Box<dyn EmbeddingModel>> {
968        match config.provider {
969            EmbeddingProvider::OpenAI => Ok(Box::new(OpenAIEmbeddings::with_cache(
970                config,
971                self.cache.clone(),
972            )?)),
973            EmbeddingProvider::HuggingFace => Ok(Box::new(HuggingFaceEmbeddings::with_cache(
974                config,
975                self.cache.clone(),
976            )?)),
977            EmbeddingProvider::Cohere => Ok(Box::new(CohereEmbeddings::with_cache(
978                config,
979                self.cache.clone(),
980            )?)),
981            EmbeddingProvider::Local => Ok(Box::new(LocalEmbeddings::with_cache(
982                config,
983                self.cache.clone(),
984            )?)),
985            EmbeddingProvider::Mock => {
986                let dimension = config.dimension.unwrap_or(DEFAULT_EMBEDDING_DIMENSION);
987                #[cfg(any(feature = "mock", test))]
988                {
989                    Ok(Box::new(MockEmbeddingModel::with_name(
990                        dimension,
991                        &config.model,
992                    )))
993                }
994                #[cfg(not(any(feature = "mock", test)))]
995                {
996                    // 当 mock feature 未启用时,使用 LocalEmbeddings 作为回退
997                    let local_config = EmbeddingsConfig::local(&config.model, Some(dimension));
998                    Ok(Box::new(LocalEmbeddings::new(local_config)?))
999                }
1000            }
1001        }
1002    }
1003
1004    /// 创建安全的嵌入模型实例
1005    ///
1006    /// 如果指定的配置无效,自动回退到 Mock 模型。
1007    /// 这确保了即使在未配置环境下也能安全返回一个可用实例。
1008    pub fn create_safe(&self, config: EmbeddingsConfig) -> Box<dyn EmbeddingModel> {
1009        if config.is_valid() {
1010            self.create(config)
1011                .unwrap_or_else(|_| self.create_mock_default())
1012        } else {
1013            self.create_mock_default()
1014        }
1015    }
1016
1017    /// 创建默认 Mock 模型
1018    fn create_mock_default(&self) -> Box<dyn EmbeddingModel> {
1019        #[cfg(any(feature = "mock", test))]
1020        {
1021            Box::new(MockEmbeddingModel::new(DEFAULT_EMBEDDING_DIMENSION))
1022        }
1023        #[cfg(not(any(feature = "mock", test)))]
1024        {
1025            // 如果没有 mock feature,使用 LocalEmbeddings 作为安全回退
1026            let config = EmbeddingsConfig::local("fallback", Some(DEFAULT_EMBEDDING_DIMENSION));
1027            Box::new(LocalEmbeddings::new(config).expect("Local embeddings should always work"))
1028        }
1029    }
1030
1031    /// 创建 OpenAI 嵌入模型
1032    pub fn openai(&self) -> Result<Box<dyn EmbeddingModel>> {
1033        let config = EmbeddingsConfig::openai_from_env()?;
1034        self.create(config)
1035    }
1036
1037    /// 创建 HuggingFace 嵌入模型
1038    pub fn huggingface(&self) -> Result<Box<dyn EmbeddingModel>> {
1039        let config = EmbeddingsConfig::huggingface_from_env()?;
1040        self.create(config)
1041    }
1042
1043    /// 创建 Cohere 嵌入模型
1044    pub fn cohere(&self) -> Result<Box<dyn EmbeddingModel>> {
1045        let config = EmbeddingsConfig::cohere_from_env()?;
1046        self.create(config)
1047    }
1048
1049    /// 创建本地嵌入模型
1050    pub fn local(&self, model: &str, dimension: Option<usize>) -> Result<Box<dyn EmbeddingModel>> {
1051        let config = EmbeddingsConfig::local(model, dimension);
1052        self.create(config)
1053    }
1054
1055    /// 创建 Mock 嵌入模型(仅测试/开发使用)
1056    ///
1057    /// **安全默认值**: Mock 模型返回零向量,不调用任何外部 API。
1058    /// 这是在未配置环境下的安全回退选项。
1059    #[cfg(any(feature = "mock", test))]
1060    pub fn mock(&self, dimension: usize) -> Box<dyn EmbeddingModel> {
1061        Box::new(MockEmbeddingModel::new(dimension))
1062    }
1063
1064    /// 获取缓存实例
1065    pub fn cache(&self) -> Arc<EmbeddingCache> {
1066        self.cache.clone()
1067    }
1068}
1069
1070impl Default for EmbeddingsFactory {
1071    fn default() -> Self {
1072        Self::new()
1073    }
1074}
1075
1076// ============================================================================
1077// Mock 嵌入模型(仅测试/开发使用)
1078// ============================================================================
1079
1080/// Mock 嵌入模型
1081///
1082/// 用于测试场景或作为回退,返回固定维度的零向量。
1083///
1084/// **注意**: 此类型仅在启用 `mock` feature 或测试配置下可用。
1085/// 生产代码不应使用此类型。
1086#[cfg(any(feature = "mock", test))]
1087pub struct MockEmbeddingModel {
1088    dimension: usize,
1089    model_name: String,
1090}
1091
1092#[cfg(any(feature = "mock", test))]
1093impl MockEmbeddingModel {
1094    /// 创建新的 Mock 模型
1095    pub fn new(dimension: usize) -> Self {
1096        Self {
1097            dimension,
1098            model_name: "mock-embedding".to_string(),
1099        }
1100    }
1101
1102    /// 使用自定义模型名创建
1103    pub fn with_name(dimension: usize, model_name: impl Into<String>) -> Self {
1104        Self {
1105            dimension,
1106            model_name: model_name.into(),
1107        }
1108    }
1109}
1110
1111#[cfg(any(feature = "mock", test))]
1112#[async_trait]
1113impl EmbeddingModel for MockEmbeddingModel {
1114    async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
1115        Ok(vec![0.0; self.dimension])
1116    }
1117
1118    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1119        Ok(texts.iter().map(|_| vec![0.0; self.dimension]).collect())
1120    }
1121
1122    fn dimension(&self) -> usize {
1123        self.dimension
1124    }
1125
1126    fn model_name(&self) -> &str {
1127        &self.model_name
1128    }
1129
1130    fn provider(&self) -> &str {
1131        "mock"
1132    }
1133}
1134
1135// ============================================================================
1136// 向后兼容:保留原有 Embeddings 类型别名
1137// ============================================================================
1138
1139/// 向后兼容的 Embeddings 类型
1140///
1141/// 默认使用 OpenAI。
1142pub type Embeddings = OpenAIEmbeddings;
1143
1144// ============================================================================
1145// 测试
1146// ============================================================================
1147
1148#[cfg(test)]
1149mod tests {
1150    use super::*;
1151
1152    // ==========================================================================
1153    // 缓存测试
1154    // ==========================================================================
1155
1156    #[tokio::test]
1157    async fn test_cache_basic_operations() {
1158        let cache = EmbeddingCache::new(100, 3600);
1159
1160        // 测试 put 和 get
1161        let embedding = vec![0.1f32, 0.2, 0.3];
1162        cache
1163            .put("openai", "test-model", "hello", embedding.clone())
1164            .await;
1165
1166        let cached = cache.get("openai", "test-model", "hello").await;
1167        assert!(cached.is_some());
1168        assert_eq!(cached.unwrap(), embedding);
1169
1170        // 测试未命中的情况
1171        let not_cached = cache.get("openai", "test-model", "not-exists").await;
1172        assert!(not_cached.is_none());
1173    }
1174
1175    #[tokio::test]
1176    async fn test_cache_batch_operations() {
1177        let cache = EmbeddingCache::new(100, 3600);
1178
1179        let texts: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
1180        let embeddings: Vec<Vec<f32>> = texts.iter().map(|t| vec![t.len() as f32]).collect();
1181
1182        for (text, emb) in texts.iter().zip(embeddings.iter()) {
1183            cache.put("test", "model", text, emb.clone()).await;
1184        }
1185
1186        let cached = cache.get_batch("test", "model", &texts).await;
1187        assert!(cached.iter().all(|c| c.is_some()));
1188    }
1189
1190    #[tokio::test]
1191    async fn test_cache_stats() {
1192        let cache = EmbeddingCache::new(100, 3600);
1193
1194        cache.put("test", "model", "a", vec![1.0f32]).await;
1195        cache.put("test", "model", "b", vec![2.0]).await;
1196
1197        let _ = cache.get("test", "model", "a").await;
1198        let _ = cache.get("test", "model", "a").await;
1199
1200        let stats = cache.stats().await;
1201        assert_eq!(stats.total_entries, 2);
1202        assert_eq!(stats.total_access, 2);
1203    }
1204
1205    // ==========================================================================
1206    // 配置测试
1207    // ==========================================================================
1208
1209    #[test]
1210    fn test_config_openai_from_env() {
1211        std::env::set_var("OPENAI_API_KEY", "test_key");
1212        std::env::remove_var("OPENAI_BASE_URL");
1213        std::env::remove_var("OPENAI_EMBEDDING_MODEL");
1214
1215        let config = EmbeddingsConfig::openai_from_env().unwrap();
1216        assert_eq!(config.api_key, "test_key");
1217        assert_eq!(config.model, DEFAULT_EMBEDDING_MODEL);
1218
1219        std::env::remove_var("OPENAI_API_KEY");
1220    }
1221
1222    #[test]
1223    fn test_config_huggingface_from_env() {
1224        std::env::set_var("HUGGINGFACE_API_KEY", "hf_test");
1225        std::env::remove_var("HUGGINGFACE_EMBEDDING_MODEL");
1226
1227        let config = EmbeddingsConfig::huggingface_from_env().unwrap();
1228        assert_eq!(config.api_key, "hf_test");
1229        assert!(config.model.contains("sentence-transformers"));
1230
1231        std::env::remove_var("HUGGINGFACE_API_KEY");
1232    }
1233
1234    #[test]
1235    fn test_config_cohere_from_env() {
1236        std::env::set_var("COHERE_API_KEY", "cohere_test");
1237        std::env::remove_var("COHERE_EMBEDDING_MODEL");
1238
1239        let config = EmbeddingsConfig::cohere_from_env().unwrap();
1240        assert_eq!(config.api_key, "cohere_test");
1241        assert!(config.model.starts_with("embed-"));
1242
1243        std::env::remove_var("COHERE_API_KEY");
1244    }
1245
1246    #[test]
1247    fn test_config_local() {
1248        let config = EmbeddingsConfig::local("all-MiniLM-L6-v2", Some(384));
1249        assert_eq!(config.provider, EmbeddingProvider::Local);
1250        assert!(config.api_key.is_empty());
1251        assert!(config.is_valid()); // 本地模型不需要 API key
1252    }
1253
1254    // ==========================================================================
1255    // 维度测试
1256    // ==========================================================================
1257
1258    #[test]
1259    fn test_openai_dimension() {
1260        let config = EmbeddingsConfig {
1261            provider: EmbeddingProvider::OpenAI,
1262            api_key: "test".to_string(),
1263            base_url: None,
1264            model: "text-embedding-ada-002".to_string(),
1265            dimension: None,
1266        };
1267        let embeddings = OpenAIEmbeddings::new(config).unwrap();
1268        assert_eq!(embeddings.dimension(), 1536);
1269
1270        let config = EmbeddingsConfig {
1271            provider: EmbeddingProvider::OpenAI,
1272            api_key: "test".to_string(),
1273            base_url: None,
1274            model: "text-embedding-3-large".to_string(),
1275            dimension: None,
1276        };
1277        let embeddings = OpenAIEmbeddings::new(config).unwrap();
1278        assert_eq!(embeddings.dimension(), 3072);
1279    }
1280
1281    #[test]
1282    fn test_huggingface_dimension() {
1283        let config = EmbeddingsConfig {
1284            provider: EmbeddingProvider::HuggingFace,
1285            api_key: "test".to_string(),
1286            base_url: None,
1287            model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
1288            dimension: None,
1289        };
1290        let embeddings = HuggingFaceEmbeddings::new(config).unwrap();
1291        assert_eq!(embeddings.dimension(), 384);
1292    }
1293
1294    #[test]
1295    fn test_cohere_dimension() {
1296        let config = EmbeddingsConfig {
1297            provider: EmbeddingProvider::Cohere,
1298            api_key: "test".to_string(),
1299            base_url: None,
1300            model: "embed-english-v3.0".to_string(),
1301            dimension: None,
1302        };
1303        let embeddings = CohereEmbeddings::new(config).unwrap();
1304        assert_eq!(embeddings.dimension(), 1024);
1305    }
1306
1307    // ==========================================================================
1308    // 工厂测试
1309    // ==========================================================================
1310
1311    #[test]
1312    fn test_factory_create_openai() {
1313        std::env::set_var("OPENAI_API_KEY", "test_key");
1314
1315        let factory = EmbeddingsFactory::new();
1316        let model = factory.openai().unwrap();
1317        assert_eq!(model.provider(), "openai");
1318
1319        std::env::remove_var("OPENAI_API_KEY");
1320    }
1321
1322    #[test]
1323    fn test_factory_create_local() {
1324        let factory = EmbeddingsFactory::new();
1325        let model = factory.local("test-model", Some(384)).unwrap();
1326        assert_eq!(model.provider(), "local");
1327        assert_eq!(model.dimension(), 384);
1328    }
1329
1330    #[test]
1331    fn test_factory_create_mock() {
1332        let factory = EmbeddingsFactory::new();
1333        let model = factory.mock(512);
1334        assert_eq!(model.provider(), "mock");
1335        assert_eq!(model.dimension(), 512);
1336    }
1337
1338    #[test]
1339    fn test_factory_create_safe_with_invalid_config() {
1340        let factory = EmbeddingsFactory::new();
1341        // 使用空 api_key 的 OpenAI 配置是无效的
1342        let config = EmbeddingsConfig {
1343            provider: EmbeddingProvider::OpenAI,
1344            api_key: String::new(),
1345            base_url: None,
1346            model: "test".to_string(),
1347            dimension: None,
1348        };
1349        let model = factory.create_safe(config);
1350        // 应该回退到 mock
1351        assert_eq!(model.provider(), "mock");
1352    }
1353
1354    #[test]
1355    fn test_factory_create_safe_with_valid_config() {
1356        std::env::set_var("OPENAI_API_KEY", "test_key");
1357        let factory = EmbeddingsFactory::new();
1358        let config = EmbeddingsConfig::openai_from_env().unwrap();
1359        let model = factory.create_safe(config);
1360        assert_eq!(model.provider(), "openai");
1361        std::env::remove_var("OPENAI_API_KEY");
1362    }
1363
1364    // ==========================================================================
1365    // 安全默认值测试
1366    // ==========================================================================
1367
1368    #[test]
1369    fn test_config_default_is_safe() {
1370        let config = EmbeddingsConfig::default();
1371        // 默认配置应该使用 Mock 提供商
1372        assert_eq!(config.provider, EmbeddingProvider::Mock);
1373        // Mock 提供商不需要 API key,所以应该有效
1374        assert!(config.is_valid());
1375    }
1376
1377    #[test]
1378    fn test_provider_mock_is_valid() {
1379        let config = EmbeddingsConfig {
1380            provider: EmbeddingProvider::Mock,
1381            api_key: String::new(),
1382            base_url: None,
1383            model: "mock-test".to_string(),
1384            dimension: Some(256),
1385        };
1386        assert!(config.is_valid());
1387    }
1388
1389    #[test]
1390    fn test_embeddings_factory_mock_default_dimension() {
1391        let factory = EmbeddingsFactory::new();
1392        let model = factory.mock(DEFAULT_EMBEDDING_DIMENSION);
1393        assert_eq!(model.dimension(), DEFAULT_EMBEDDING_DIMENSION);
1394    }
1395
1396    // ==========================================================================
1397    // 向后兼容测试
1398    // ==========================================================================
1399
1400    #[test]
1401    fn test_backward_compatible_embeddings() {
1402        std::env::set_var("OPENAI_API_KEY", "test_key");
1403
1404        let config = EmbeddingsConfig::openai_from_env().unwrap();
1405        let embeddings = Embeddings::new(config).unwrap();
1406        assert_eq!(embeddings.provider(), "openai");
1407
1408        std::env::remove_var("OPENAI_API_KEY");
1409    }
1410}