Skip to main content

sh_layer3/
retriever_engine.rs

1//! # Retriever Engine
2//!
3//! 检索引擎:向量相似度检索和 RAG 支持。
4//!
5//! ## 功能
6//!
7//! - 文档索引与检索
8//! - 多种分块策略(固定大小、段落、代码)
9//! - 混合检索(向量 + 关键词)
10//! - RAG Pipeline(带重排序)
11//! - OpenAI Embeddings 集成
12
13use crate::types::Layer3Result;
14use async_trait::async_trait;
15use parking_lot::RwLock;
16use sh_layer2::generate_short_id;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20/// 检索引擎 trait
21///
22/// 提供向量相似度检索能力。
23#[async_trait]
24pub trait RetrieverEngine: Send + Sync {
25    /// 索引文档
26    async fn index(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>>;
27
28    /// 检索相似文档
29    async fn retrieve(&self, query: &str, top_k: usize) -> Layer3Result<Vec<RetrievalResult>>;
30
31    /// 混合检索(向量 + 关键词)
32    async fn hybrid_retrieve(
33        &self,
34        query: &str,
35        top_k: usize,
36    ) -> Layer3Result<Vec<RetrievalResult>>;
37
38    /// 带配置的混合检索
39    async fn hybrid_retrieve_with_config(
40        &self,
41        query: &str,
42        top_k: usize,
43        config: &HybridSearchConfig,
44    ) -> Layer3Result<Vec<RetrievalResult>> {
45        let _ = config;
46        self.hybrid_retrieve(query, top_k).await
47    }
48
49    /// 带过滤条件的检索
50    async fn retrieve_with_filter(
51        &self,
52        query: &str,
53        top_k: usize,
54        filter: Option<crate::vector_store::MetadataFilter>,
55    ) -> Layer3Result<Vec<RetrievalResult>> {
56        let _ = filter;
57        self.retrieve(query, top_k).await
58    }
59
60    /// 删除文档
61    async fn delete(&self, doc_ids: &[String]) -> Layer3Result<bool>;
62
63    /// 清空索引
64    async fn clear(&self) -> Layer3Result<bool>;
65
66    /// 获取文档数量
67    async fn count(&self) -> Layer3Result<usize>;
68}
69
70/// 文档结构
71#[derive(Debug, Clone)]
72pub struct Document {
73    /// 文档 ID(可选,自动生成)
74    pub id: Option<String>,
75    /// 文档内容
76    pub content: String,
77    /// 元数据
78    pub metadata: HashMap<String, serde_json::Value>,
79    /// 来源(文件路径、URL 等)
80    pub source: Option<String>,
81}
82
83impl Document {
84    pub fn new(content: impl Into<String>) -> Self {
85        Self {
86            id: None,
87            content: content.into(),
88            metadata: HashMap::new(),
89            source: None,
90        }
91    }
92
93    pub fn with_source(mut self, source: impl Into<String>) -> Self {
94        self.source = Some(source.into());
95        self
96    }
97
98    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
99        self.metadata.insert(key.into(), value);
100        self
101    }
102}
103
104/// 检索结果
105#[derive(Debug, Clone)]
106pub struct RetrievalResult {
107    /// 文档 ID
108    pub doc_id: String,
109    /// 文档内容
110    pub content: String,
111    /// 相似度分数 (0.0-1.0)
112    pub score: f32,
113    /// 元数据
114    pub metadata: HashMap<String, serde_json::Value>,
115    /// 来源
116    pub source: Option<String>,
117}
118
119// ============================================================================
120// Hybrid Search Configuration
121// ============================================================================
122
123/// 混合检索权重配置
124#[derive(Debug, Clone, Copy)]
125pub struct HybridWeights {
126    /// 向量搜索权重
127    pub vector: f32,
128    /// 关键词搜索权重
129    pub keyword: f32,
130}
131
132impl HybridWeights {
133    /// 创建新的权重配置
134    pub fn new(vector: f32, keyword: f32) -> Self {
135        let total = vector + keyword;
136        Self {
137            vector: vector / total,
138            keyword: keyword / total,
139        }
140    }
141
142    /// 默认权重:70% 向量 + 30% 关键词
143    pub fn default_weights() -> Self {
144        Self {
145            vector: 0.7,
146            keyword: 0.3,
147        }
148    }
149
150    /// 仅向量搜索
151    pub fn vector_only() -> Self {
152        Self {
153            vector: 1.0,
154            keyword: 0.0,
155        }
156    }
157
158    /// 仅关键词搜索
159    pub fn keyword_only() -> Self {
160        Self {
161            vector: 0.0,
162            keyword: 1.0,
163        }
164    }
165
166    /// 均衡权重
167    pub fn balanced() -> Self {
168        Self {
169            vector: 0.5,
170            keyword: 0.5,
171        }
172    }
173}
174
175impl Default for HybridWeights {
176    fn default() -> Self {
177        Self::default_weights()
178    }
179}
180
181/// 混合检索配置
182#[derive(Debug, Clone)]
183pub struct HybridSearchConfig {
184    /// 权重配置
185    pub weights: HybridWeights,
186    /// 是否启用短语匹配
187    pub phrase_matching: bool,
188    /// 是否启用 RRIF 重排序
189    pub use_rrif: bool,
190    /// RRIF 参数 K(控制排名衰减)
191    pub rrif_k: f32,
192    /// 候选结果数量倍数(top_k * candidates_multiplier)
193    pub candidates_multiplier: usize,
194}
195
196impl HybridSearchConfig {
197    pub fn new() -> Self {
198        Self {
199            weights: HybridWeights::default(),
200            phrase_matching: true,
201            use_rrif: true,
202            rrif_k: 60.0,
203            candidates_multiplier: 2,
204        }
205    }
206
207    pub fn with_weights(mut self, weights: HybridWeights) -> Self {
208        self.weights = weights;
209        self
210    }
211
212    pub fn with_phrase_matching(mut self, enabled: bool) -> Self {
213        self.phrase_matching = enabled;
214        self
215    }
216
217    pub fn with_rrif(mut self, enabled: bool, k: f32) -> Self {
218        self.use_rrif = enabled;
219        self.rrif_k = k;
220        self
221    }
222}
223
224impl Default for HybridSearchConfig {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230/// Embedding 模型 trait
231#[async_trait]
232pub trait EmbeddingModel: Send + Sync {
233    /// 生成文本嵌入向量
234    async fn embed(&self, text: &str) -> Layer3Result<Vec<f32>>;
235
236    /// 批量生成嵌入向量
237    async fn embed_batch(&self, texts: &[String]) -> Layer3Result<Vec<Vec<f32>>>;
238
239    /// 获取向量维度
240    fn dimension(&self) -> usize;
241
242    /// 模型名称
243    fn model_name(&self) -> &str;
244}
245
246/// 分块策略 trait
247pub trait ChunkingStrategy: Send + Sync {
248    /// 分块文档
249    fn chunk(&self, document: &Document) -> Vec<Chunk>;
250}
251
252/// 文档分块
253#[derive(Debug, Clone)]
254pub struct Chunk {
255    /// 分块 ID
256    pub id: String,
257    /// 文档 ID
258    pub doc_id: String,
259    /// 分块内容
260    pub content: String,
261    /// 在原文中的位置
262    pub position: ChunkPosition,
263    /// 元数据
264    pub metadata: HashMap<String, serde_json::Value>,
265}
266
267/// 分块位置
268#[derive(Debug, Clone, Copy)]
269pub struct ChunkPosition {
270    /// 起始字符位置
271    pub start: usize,
272    /// 结束字符位置
273    pub end: usize,
274    /// 分块索引
275    pub index: usize,
276    /// 总分块数
277    pub total: usize,
278}
279
280/// 固定大小分块策略
281#[derive(Debug, Clone)]
282pub struct FixedSizeChunker {
283    /// 分块大小(字符数)
284    pub chunk_size: usize,
285    /// 重叠大小
286    pub overlap: usize,
287}
288
289impl FixedSizeChunker {
290    pub fn new(chunk_size: usize, overlap: usize) -> Self {
291        Self {
292            chunk_size,
293            overlap,
294        }
295    }
296}
297
298impl Default for FixedSizeChunker {
299    fn default() -> Self {
300        Self {
301            chunk_size: 500,
302            overlap: 50,
303        }
304    }
305}
306
307impl ChunkingStrategy for FixedSizeChunker {
308    fn chunk(&self, document: &Document) -> Vec<Chunk> {
309        let content = &document.content;
310        if content.len() <= self.chunk_size {
311            return vec![Chunk {
312                id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
313                doc_id: document.id.clone().unwrap_or_default(),
314                content: content.clone(),
315                position: ChunkPosition {
316                    start: 0,
317                    end: content.len(),
318                    index: 0,
319                    total: 1,
320                },
321                metadata: document.metadata.clone(),
322            }];
323        }
324
325        let mut chunks = Vec::new();
326        let mut start = 0;
327        let mut index = 0;
328
329        while start < content.len() {
330            let end = (start + self.chunk_size).min(content.len());
331            chunks.push(Chunk {
332                id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
333                doc_id: document.id.clone().unwrap_or_default(),
334                content: content[start..end].to_string(),
335                position: ChunkPosition {
336                    start,
337                    end,
338                    index,
339                    total: 0, // 将在最后更新
340                },
341                metadata: document.metadata.clone(),
342            });
343            // 防止死循环:到达末尾时直接设置 start = end
344            start = if end < content.len() {
345                end.saturating_sub(self.overlap)
346            } else {
347                end
348            };
349            index += 1;
350        }
351
352        let total = chunks.len();
353        for chunk in &mut chunks {
354            chunk.position.total = total;
355        }
356
357        chunks
358    }
359}
360
361// ============================================================================
362// Paragraph Chunking Strategy
363// ============================================================================
364
365/// 段落分块策略
366///
367/// 按自然段落边界分块,保持语义完整性。
368#[derive(Debug, Clone)]
369pub struct ParagraphChunker {
370    max_chunk_size: usize,
371    min_chunk_size: usize,
372}
373
374impl ParagraphChunker {
375    pub fn new(max_chunk_size: usize, min_chunk_size: usize) -> Self {
376        Self {
377            max_chunk_size,
378            min_chunk_size,
379        }
380    }
381}
382
383impl Default for ParagraphChunker {
384    fn default() -> Self {
385        Self {
386            max_chunk_size: 1000,
387            min_chunk_size: 100,
388        }
389    }
390}
391
392impl ChunkingStrategy for ParagraphChunker {
393    fn chunk(&self, document: &Document) -> Vec<Chunk> {
394        let content = &document.content;
395        let paragraphs: Vec<&str> = content
396            .split('\n')
397            .filter(|p| !p.trim().is_empty())
398            .collect();
399
400        if paragraphs.is_empty() {
401            return vec![Chunk {
402                id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
403                doc_id: document.id.clone().unwrap_or_default(),
404                content: content.clone(),
405                position: ChunkPosition {
406                    start: 0,
407                    end: content.len(),
408                    index: 0,
409                    total: 1,
410                },
411                metadata: document.metadata.clone(),
412            }];
413        }
414
415        let mut chunks = Vec::new();
416        let mut current_chunk = String::new();
417        let mut start = 0;
418        let mut index = 0;
419
420        for paragraph in paragraphs {
421            if current_chunk.len() + paragraph.len() < self.max_chunk_size {
422                if !current_chunk.is_empty() {
423                    current_chunk.push('\n');
424                }
425                current_chunk.push_str(paragraph);
426            } else {
427                if current_chunk.len() >= self.min_chunk_size {
428                    let end = start + current_chunk.len();
429                    chunks.push(Chunk {
430                        id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
431                        doc_id: document.id.clone().unwrap_or_default(),
432                        content: current_chunk.clone(),
433                        position: ChunkPosition {
434                            start,
435                            end,
436                            index,
437                            total: 0,
438                        },
439                        metadata: document.metadata.clone(),
440                    });
441                    start = end;
442                    index += 1;
443                }
444                current_chunk = paragraph.to_string();
445            }
446        }
447
448        if current_chunk.len() >= self.min_chunk_size {
449            chunks.push(Chunk {
450                id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
451                doc_id: document.id.clone().unwrap_or_default(),
452                content: current_chunk,
453                position: ChunkPosition {
454                    start,
455                    end: content.len(),
456                    index,
457                    total: 0,
458                },
459                metadata: document.metadata.clone(),
460            });
461        }
462
463        let total = chunks.len().max(1);
464        for chunk in &mut chunks {
465            chunk.position.total = total;
466        }
467
468        if chunks.is_empty() {
469            vec![Chunk {
470                id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
471                doc_id: document.id.clone().unwrap_or_default(),
472                content: content.clone(),
473                position: ChunkPosition {
474                    start: 0,
475                    end: content.len(),
476                    index: 0,
477                    total: 1,
478                },
479                metadata: document.metadata.clone(),
480            }]
481        } else {
482            chunks
483        }
484    }
485}
486
487// ============================================================================
488// Recursive Chunking Strategy
489// ============================================================================
490
491/// 递归分块策略
492///
493/// 依次尝试多种分隔符,从大到小。
494#[derive(Debug, Clone)]
495pub struct RecursiveChunker {
496    max_chunk_size: usize,
497    separators: Vec<String>,
498}
499
500impl RecursiveChunker {
501    pub fn new(max_chunk_size: usize) -> Self {
502        Self {
503            max_chunk_size,
504            separators: vec![
505                "\n\n\n".to_string(),
506                "\n\n".to_string(),
507                "\n".to_string(),
508                ". ".to_string(),
509                " ".to_string(),
510                "".to_string(),
511            ],
512        }
513    }
514}
515
516impl Default for RecursiveChunker {
517    fn default() -> Self {
518        Self::new(1000)
519    }
520}
521
522impl ChunkingStrategy for RecursiveChunker {
523    fn chunk(&self, document: &Document) -> Vec<Chunk> {
524        self._recursive_split(document, &document.content, 0, 0)
525    }
526}
527
528impl RecursiveChunker {
529    fn _recursive_split(
530        &self,
531        document: &Document,
532        text: &str,
533        start_offset: usize,
534        initial_index: usize,
535    ) -> Vec<Chunk> {
536        if text.len() <= self.max_chunk_size {
537            return vec![Chunk {
538                id: format!(
539                    "{}-{}",
540                    document.id.as_deref().unwrap_or("doc"),
541                    initial_index
542                ),
543                doc_id: document.id.clone().unwrap_or_default(),
544                content: text.to_string(),
545                position: ChunkPosition {
546                    start: start_offset,
547                    end: start_offset + text.len(),
548                    index: initial_index,
549                    total: 1,
550                },
551                metadata: document.metadata.clone(),
552            }];
553        }
554
555        for separator in &self.separators {
556            if separator.is_empty() {
557                let mut chunks = Vec::new();
558                let mut start = 0;
559                let mut index = initial_index;
560
561                while start < text.len() {
562                    let end = (start + self.max_chunk_size).min(text.len());
563                    chunks.push(Chunk {
564                        id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
565                        doc_id: document.id.clone().unwrap_or_default(),
566                        content: text[start..end].to_string(),
567                        position: ChunkPosition {
568                            start: start_offset + start,
569                            end: start_offset + end,
570                            index,
571                            total: 0,
572                        },
573                        metadata: document.metadata.clone(),
574                    });
575                    start = end;
576                    index += 1;
577                }
578
579                let total = chunks.len();
580                for chunk in &mut chunks {
581                    chunk.position.total = total;
582                }
583                return chunks;
584            }
585
586            if text.contains(separator) {
587                let parts: Vec<&str> = text.split(separator).collect();
588                let mut chunks = Vec::new();
589                let mut current_chunk = String::new();
590                let mut current_start = start_offset;
591                let mut index = initial_index;
592
593                for (i, part) in parts.iter().enumerate() {
594                    let part_with_sep = if i < parts.len() - 1 {
595                        format!("{}{}", part, separator)
596                    } else {
597                        part.to_string()
598                    };
599
600                    if current_chunk.len() + part_with_sep.len() <= self.max_chunk_size {
601                        current_chunk.push_str(&part_with_sep);
602                    } else {
603                        if !current_chunk.is_empty() {
604                            chunks.push(Chunk {
605                                id: format!(
606                                    "{}-{}",
607                                    document.id.as_deref().unwrap_or("doc"),
608                                    index
609                                ),
610                                doc_id: document.id.clone().unwrap_or_default(),
611                                content: current_chunk.clone(),
612                                position: ChunkPosition {
613                                    start: current_start,
614                                    end: current_start + current_chunk.len(),
615                                    index,
616                                    total: 0,
617                                },
618                                metadata: document.metadata.clone(),
619                            });
620                            current_start += current_chunk.len();
621                            index += 1;
622                        }
623
624                        if part_with_sep.len() > self.max_chunk_size {
625                            let sub_chunks = self._recursive_split(
626                                document,
627                                &part_with_sep,
628                                current_start,
629                                index,
630                            );
631                            for sub in sub_chunks {
632                                current_start = sub.position.end;
633                                index += 1;
634                                chunks.push(sub);
635                            }
636                        } else {
637                            current_chunk = part_with_sep;
638                        }
639                    }
640                }
641
642                if !current_chunk.is_empty() {
643                    chunks.push(Chunk {
644                        id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
645                        doc_id: document.id.clone().unwrap_or_default(),
646                        content: current_chunk,
647                        position: ChunkPosition {
648                            start: current_start,
649                            end: start_offset + text.len(),
650                            index,
651                            total: 0,
652                        },
653                        metadata: document.metadata.clone(),
654                    });
655                }
656
657                let total = chunks.len().max(1);
658                for chunk in &mut chunks {
659                    chunk.position.total = total;
660                }
661                return chunks;
662            }
663        }
664
665        vec![Chunk {
666            id: format!(
667                "{}-{}",
668                document.id.as_deref().unwrap_or("doc"),
669                initial_index
670            ),
671            doc_id: document.id.clone().unwrap_or_default(),
672            content: text.to_string(),
673            position: ChunkPosition {
674                start: start_offset,
675                end: start_offset + text.len(),
676                index: initial_index,
677                total: 1,
678            },
679            metadata: document.metadata.clone(),
680        }]
681    }
682}
683
684// ============================================================================
685// Default Retriever Engine Implementation
686// ============================================================================
687
688use crate::vector_store::{VectorItem, VectorStore};
689
690/// 默认检索引擎实现
691///
692/// 结合 Embedding 模型、分块策略和向量存储提供完整的 RAG 功能。
693pub struct DefaultRetrieverEngine<VS, EM, CS>
694where
695    VS: VectorStore,
696    EM: EmbeddingModel,
697    CS: ChunkingStrategy,
698{
699    /// 向量存储
700    vector_store: VS,
701    /// Embedding 模型
702    embedding_model: EM,
703    /// 分块策略
704    chunking_strategy: CS,
705    /// 文档索引(文档 ID -> 分块 ID 列表)
706    doc_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
707    /// 分块内容缓存(分块 ID -> 内容)
708    chunk_cache: Arc<RwLock<HashMap<String, String>>>,
709}
710
711impl<VS, EM, CS> DefaultRetrieverEngine<VS, EM, CS>
712where
713    VS: VectorStore,
714    EM: EmbeddingModel,
715    CS: ChunkingStrategy,
716{
717    /// 创建新的检索引擎
718    pub fn new(vector_store: VS, embedding_model: EM, chunking_strategy: CS) -> Self {
719        Self {
720            vector_store,
721            embedding_model,
722            chunking_strategy,
723            doc_index: Arc::new(RwLock::new(HashMap::new())),
724            chunk_cache: Arc::new(RwLock::new(HashMap::new())),
725        }
726    }
727
728    /// 提取关键词(分词 + 去停用词)
729    fn extract_keywords(&self, query: &str) -> Vec<String> {
730        let words: Vec<String> = query
731            .to_lowercase()
732            .split_whitespace()
733            .map(|s| s.to_string())
734            .collect();
735
736        let stop_words = std::collections::HashSet::from([
737            "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
738            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
739            "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
740            "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
741            "below", "between", "under", "again", "further", "then", "once", "here", "there",
742            "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
743            "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s",
744            "t", "just", "and", "but", "if", "or", "because", "until", "while", "although",
745        ]);
746
747        words
748            .into_iter()
749            .filter(|w| !stop_words.contains(w.as_str()) && w.len() > 1)
750            .collect()
751    }
752
753    /// 计算关键词分数(BM25 风格)
754    fn compute_keyword_score(
755        &self,
756        query_keywords: &[String],
757        content: &str,
758        config: &HybridSearchConfig,
759    ) -> f32 {
760        if query_keywords.is_empty() {
761            return 0.0;
762        }
763
764        let content_lower = content.to_lowercase();
765
766        // 短语匹配奖励
767        let mut phrase_bonus: f32 = 0.0;
768        if config.phrase_matching {
769            for keyword in query_keywords {
770                if content_lower.contains(keyword) {
771                    phrase_bonus += 0.1;
772                }
773            }
774            phrase_bonus = phrase_bonus.min(0.3);
775        }
776
777        // 计算关键词匹配数量
778        let matched_keywords = query_keywords
779            .iter()
780            .filter(|kw| content_lower.contains(kw.as_str()))
781            .count();
782
783        // BM25 风格的饱和函数
784        let k1 = 1.2;
785        let content_len = content.len() as f32;
786        let avg_len = 500.0;
787        let len_norm = 1.0 - 0.75 + 0.75 * (content_len / avg_len);
788
789        let bm25_score =
790            (matched_keywords as f32 * (k1 + 1.0)) / (matched_keywords as f32 + k1 * len_norm);
791
792        // 归一化到 [0, 1]
793        let normalized_score = bm25_score / (query_keywords.len() as f32 + k1);
794        let normalized_score = normalized_score.min(1.0);
795
796        normalized_score + phrase_bonus
797    }
798
799    /// 仅关键词搜索
800    async fn keyword_only_search(
801        &self,
802        query: &str,
803        candidates: Vec<RetrievalResult>,
804        top_k: usize,
805        config: &HybridSearchConfig,
806    ) -> Layer3Result<Vec<RetrievalResult>> {
807        let query_keywords = self.extract_keywords(query);
808
809        let mut scored_results: Vec<RetrievalResult> = candidates
810            .into_iter()
811            .map(|r| {
812                let keyword_score = self.compute_keyword_score(&query_keywords, &r.content, config);
813                RetrievalResult {
814                    doc_id: r.doc_id,
815                    content: r.content,
816                    score: keyword_score,
817                    metadata: r.metadata,
818                    source: r.source,
819                }
820            })
821            .collect();
822
823        scored_results.sort_by(|a, b| {
824            b.score
825                .partial_cmp(&a.score)
826                .unwrap_or(std::cmp::Ordering::Equal)
827        });
828
829        scored_results.truncate(top_k);
830        Ok(scored_results)
831    }
832
833    /// 应用 RRIF (Reciprocal Rank Fusion) 重排序
834    fn apply_rrif(&self, results: Vec<RetrievalResult>, k: f32) -> Vec<RetrievalResult> {
835        if results.is_empty() {
836            return results;
837        }
838
839        results
840            .into_iter()
841            .enumerate()
842            .map(|(idx, mut r)| {
843                let rank = (idx + 1) as f32;
844                let rrif_score = 1.0 / (k + rank);
845                r.score = r.score * 0.5 + rrif_score * 0.5;
846                r
847            })
848            .collect()
849    }
850}
851
852#[async_trait]
853impl<VS, EM, CS> RetrieverEngine for DefaultRetrieverEngine<VS, EM, CS>
854where
855    VS: VectorStore,
856    EM: EmbeddingModel,
857    CS: ChunkingStrategy,
858{
859    async fn index(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>> {
860        let mut doc_ids = Vec::new();
861
862        for doc in documents {
863            // 生成分块
864            let doc_id = doc.id.clone().unwrap_or_else(generate_short_id);
865            let chunks = self.chunking_strategy.chunk(&Document {
866                id: Some(doc_id.clone()),
867                content: doc.content.clone(),
868                metadata: doc.metadata.clone(),
869                source: doc.source.clone(),
870            });
871
872            // 为每个分块生成 embedding 并存储
873            let chunk_ids: Vec<String> = chunks.iter().map(|c| c.id.clone()).collect();
874
875            let chunk_contents: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
876
877            // 批量生成 embeddings
878            let embeddings = self.embedding_model.embed_batch(&chunk_contents).await?;
879
880            // 构建向量项并存储
881            let vector_items: Vec<VectorItem> = chunks
882                .into_iter()
883                .zip(embeddings)
884                .map(|(chunk, embedding)| {
885                    let mut metadata = chunk.metadata.clone();
886                    metadata.insert("doc_id".to_string(), serde_json::json!(chunk.doc_id));
887                    metadata.insert(
888                        "chunk_index".to_string(),
889                        serde_json::json!(chunk.position.index),
890                    );
891                    if let Some(source) = doc.source.clone() {
892                        metadata.insert("source".to_string(), serde_json::json!(source));
893                    }
894
895                    VectorItem {
896                        id: chunk.id.clone(),
897                        vector: embedding,
898                        metadata,
899                        content: Some(chunk.content.clone()),
900                    }
901                })
902                .collect();
903
904            // 缓存分块内容
905            {
906                let mut cache = self.chunk_cache.write();
907                for item in &vector_items {
908                    cache.insert(item.id.clone(), item.content.clone().unwrap_or_default());
909                }
910            }
911
912            // 存储向量
913            self.vector_store.add_batch(vector_items).await?;
914
915            // 记录文档索引
916            {
917                let mut index = self.doc_index.write();
918                index.insert(doc_id.clone(), chunk_ids);
919            }
920
921            doc_ids.push(doc_id);
922        }
923
924        Ok(doc_ids)
925    }
926
927    async fn retrieve(&self, query: &str, top_k: usize) -> Layer3Result<Vec<RetrievalResult>> {
928        // 生成查询向量
929        let query_embedding = self.embedding_model.embed(query).await?;
930
931        // 搜索相似向量
932        let results = self.vector_store.query(query_embedding, top_k).await?;
933
934        // 补充内容(从缓存中获取完整内容)
935        let cache = self.chunk_cache.read();
936        let enriched_results: Vec<RetrievalResult> = results
937            .into_iter()
938            .map(|r| {
939                let content = cache.get(&r.doc_id).cloned().unwrap_or(r.content);
940                RetrievalResult {
941                    doc_id: r.doc_id,
942                    content,
943                    score: r.score,
944                    metadata: r.metadata,
945                    source: r.source,
946                }
947            })
948            .collect();
949
950        Ok(enriched_results)
951    }
952
953    async fn hybrid_retrieve(
954        &self,
955        query: &str,
956        top_k: usize,
957    ) -> Layer3Result<Vec<RetrievalResult>> {
958        self.hybrid_retrieve_with_config(query, top_k, &HybridSearchConfig::default())
959            .await
960    }
961
962    async fn hybrid_retrieve_with_config(
963        &self,
964        query: &str,
965        top_k: usize,
966        config: &HybridSearchConfig,
967    ) -> Layer3Result<Vec<RetrievalResult>> {
968        // 如果仅使用向量搜索,直接返回
969        if config.weights.keyword == 0.0 {
970            return self.retrieve(query, top_k).await;
971        }
972
973        // 1. 向量搜索:获取更多候选结果
974        let candidates_count = top_k * config.candidates_multiplier;
975        let vector_results = self.retrieve(query, candidates_count).await?;
976
977        // 如果仅使用关键词搜索
978        if config.weights.vector == 0.0 {
979            return self
980                .keyword_only_search(query, vector_results, top_k, config)
981                .await;
982        }
983
984        // 2. 提取查询关键词
985        let query_keywords = self.extract_keywords(query);
986
987        // 3. 计算混合分数
988        let mut scored_results: Vec<RetrievalResult> = vector_results
989            .into_iter()
990            .map(|r| {
991                let keyword_score = self.compute_keyword_score(&query_keywords, &r.content, config);
992
993                // 混合分数
994                let final_score =
995                    r.score * config.weights.vector + keyword_score * config.weights.keyword;
996
997                RetrievalResult {
998                    doc_id: r.doc_id,
999                    content: r.content,
1000                    score: final_score,
1001                    metadata: r.metadata,
1002                    source: r.source,
1003                }
1004            })
1005            .collect();
1006
1007        // 4. 按分数排序
1008        scored_results.sort_by(|a, b| {
1009            b.score
1010                .partial_cmp(&a.score)
1011                .unwrap_or(std::cmp::Ordering::Equal)
1012        });
1013
1014        // 5. 可选:RRIF 重排序
1015        if config.use_rrif {
1016            scored_results = self.apply_rrif(scored_results, config.rrif_k);
1017        }
1018
1019        // 6. 截断并返回
1020        scored_results.truncate(top_k);
1021        Ok(scored_results)
1022    }
1023
1024    async fn delete(&self, doc_ids: &[String]) -> Layer3Result<bool> {
1025        // 先收集需要删除的分块 ID,然后释放锁
1026        let all_chunk_ids: Vec<String> = {
1027            let mut index = self.doc_index.write();
1028            let mut cache = self.chunk_cache.write();
1029
1030            let mut ids_to_delete: Vec<String> = Vec::new();
1031            for doc_id in doc_ids {
1032                if let Some(chunk_ids) = index.remove(doc_id) {
1033                    for chunk_id in &chunk_ids {
1034                        cache.remove(chunk_id);
1035                    }
1036                    ids_to_delete.extend(chunk_ids);
1037                }
1038            }
1039            ids_to_delete
1040        };
1041
1042        if all_chunk_ids.is_empty() {
1043            return Ok(false);
1044        }
1045
1046        self.vector_store.delete_batch(&all_chunk_ids).await?;
1047        Ok(true)
1048    }
1049
1050    async fn clear(&self) -> Layer3Result<bool> {
1051        self.vector_store.clear().await?;
1052        let mut index = self.doc_index.write();
1053        index.clear();
1054        let mut cache = self.chunk_cache.write();
1055        cache.clear();
1056        Ok(true)
1057    }
1058
1059    async fn count(&self) -> Layer3Result<usize> {
1060        let index = self.doc_index.read();
1061        Ok(index.len())
1062    }
1063}
1064
1065/// Layer1 EmbeddingModel wrapper
1066///
1067/// Wraps a layer1 embedding model to implement layer3's EmbeddingModel trait.
1068pub struct Layer1EmbeddingAdapter {
1069    inner: Box<dyn sh_layer1::EmbeddingModel>,
1070}
1071
1072impl Layer1EmbeddingAdapter {
1073    /// Create a new adapter wrapping a layer1 embedding model
1074    pub fn new(model: Box<dyn sh_layer1::EmbeddingModel>) -> Self {
1075        Self { inner: model }
1076    }
1077}
1078
1079#[async_trait]
1080impl EmbeddingModel for Layer1EmbeddingAdapter {
1081    async fn embed(&self, text: &str) -> Layer3Result<Vec<f32>> {
1082        self.inner.embed(text).await
1083    }
1084
1085    async fn embed_batch(&self, texts: &[String]) -> Layer3Result<Vec<Vec<f32>>> {
1086        self.inner.embed_batch(texts).await
1087    }
1088
1089    fn dimension(&self) -> usize {
1090        self.inner.dimension()
1091    }
1092
1093    fn model_name(&self) -> &str {
1094        self.inner.model_name()
1095    }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100    use super::*;
1101    use crate::vector_store::InMemoryVectorStore;
1102
1103    /// 创建测试用的 Mock Embedding 模型
1104    /// 使用 Layer1 的 MockEmbeddingModel 通过适配器
1105    fn create_mock_embedding_model(dimension: usize) -> Layer1EmbeddingAdapter {
1106        #[cfg(any(feature = "mock", test))]
1107        {
1108            Layer1EmbeddingAdapter::new(Box::new(sh_layer1::MockEmbeddingModel::new(dimension)))
1109        }
1110        #[cfg(not(any(feature = "mock", test)))]
1111        {
1112            // 在非测试配置下,这个分支不应该被执行
1113            // 使用 unreachable! 来确保编译时检查
1114            compile_error!("MockEmbeddingModel requires 'mock' feature or test configuration")
1115        }
1116    }
1117
1118    #[test]
1119    fn test_document_builder() {
1120        let doc = Document::new("test content")
1121            .with_source("test.txt")
1122            .with_metadata("key", serde_json::json!("value"));
1123        assert_eq!(doc.source, Some("test.txt".to_string()));
1124    }
1125
1126    #[test]
1127    fn test_fixed_size_chunker() {
1128        let chunker = FixedSizeChunker::new(100, 20);
1129        let doc = Document::new("a".repeat(250));
1130        let chunks = chunker.chunk(&doc);
1131        assert!(!chunks.is_empty());
1132    }
1133
1134    #[tokio::test]
1135    async fn test_default_retriever_engine_index() {
1136        let vector_store = InMemoryVectorStore::in_memory();
1137        let embedding_model = create_mock_embedding_model(128);
1138        let chunker = FixedSizeChunker::default();
1139
1140        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1141
1142        let doc = Document::new("This is a test document for RAG.").with_source("test.txt");
1143
1144        let doc_ids = engine.index(vec![doc]).await.unwrap();
1145        assert_eq!(doc_ids.len(), 1);
1146        assert_eq!(engine.count().await.unwrap(), 1);
1147    }
1148
1149    #[tokio::test]
1150    async fn test_default_retriever_engine_retrieve() {
1151        let vector_store = InMemoryVectorStore::in_memory();
1152        let embedding_model = create_mock_embedding_model(128);
1153        let chunker = FixedSizeChunker::default();
1154
1155        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1156
1157        // 索引文档
1158        let docs = vec![
1159            Document::new("Rust is a systems programming language."),
1160            Document::new("Python is great for data science."),
1161        ];
1162        engine.index(docs).await.unwrap();
1163
1164        // 检索
1165        let results = engine.retrieve("Rust programming", 5).await.unwrap();
1166        assert!(!results.is_empty());
1167    }
1168
1169    #[tokio::test]
1170    async fn test_default_retriever_engine_delete() {
1171        let vector_store = InMemoryVectorStore::in_memory();
1172        let embedding_model = create_mock_embedding_model(128);
1173        let chunker = FixedSizeChunker::default();
1174
1175        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1176
1177        let doc = Document::new("Test document");
1178        let doc_ids = engine.index(vec![doc]).await.unwrap();
1179
1180        let deleted = engine.delete(&doc_ids).await.unwrap();
1181        assert!(deleted);
1182        assert_eq!(engine.count().await.unwrap(), 0);
1183    }
1184
1185    #[tokio::test]
1186    async fn test_mock_embedding_model() {
1187        let model = create_mock_embedding_model(64);
1188
1189        let embedding = model.embed("test").await.unwrap();
1190        assert_eq!(embedding.len(), 64);
1191        assert_eq!(model.dimension(), 64);
1192        assert_eq!(model.model_name(), "mock-embedding");
1193
1194        let embeddings = model
1195            .embed_batch(&["test1".to_string(), "test2".to_string()])
1196            .await
1197            .unwrap();
1198        assert_eq!(embeddings.len(), 2);
1199    }
1200
1201    #[tokio::test]
1202    async fn test_hybrid_retrieve() {
1203        let vector_store = InMemoryVectorStore::in_memory();
1204        let embedding_model = create_mock_embedding_model(128);
1205        let chunker = FixedSizeChunker::default();
1206
1207        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1208
1209        // 索引文档
1210        let docs = vec![
1211            Document::new("Rust is a systems programming language designed for performance."),
1212            Document::new("Python is widely used for data science and machine learning."),
1213            Document::new("JavaScript runs in the browser for web development."),
1214        ];
1215        engine.index(docs).await.unwrap();
1216
1217        // 混合检索
1218        let results = engine
1219            .hybrid_retrieve("Rust programming language", 5)
1220            .await
1221            .unwrap();
1222        assert!(!results.is_empty());
1223        // Rust 相关文档应该在前面
1224        assert!(results[0].content.contains("Rust"));
1225    }
1226
1227    #[tokio::test]
1228    async fn test_hybrid_retrieve_with_config() {
1229        let vector_store = InMemoryVectorStore::in_memory();
1230        let embedding_model = create_mock_embedding_model(128);
1231        let chunker = FixedSizeChunker::default();
1232
1233        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1234
1235        // 索引文档
1236        let docs = vec![
1237            Document::new("Machine learning algorithms use neural networks."),
1238            Document::new("The database stores data for the application."),
1239        ];
1240        engine.index(docs).await.unwrap();
1241
1242        // 测试仅向量搜索
1243        let config_vector_only =
1244            HybridSearchConfig::new().with_weights(HybridWeights::vector_only());
1245        let results = engine
1246            .hybrid_retrieve_with_config("neural networks", 5, &config_vector_only)
1247            .await
1248            .unwrap();
1249        assert!(!results.is_empty());
1250
1251        // 测试仅关键词搜索
1252        let config_keyword_only =
1253            HybridSearchConfig::new().with_weights(HybridWeights::keyword_only());
1254        let results = engine
1255            .hybrid_retrieve_with_config("machine learning", 5, &config_keyword_only)
1256            .await
1257            .unwrap();
1258        assert!(!results.is_empty());
1259
1260        // 测试均衡权重
1261        let config_balanced = HybridSearchConfig::new()
1262            .with_weights(HybridWeights::balanced())
1263            .with_rrif(true, 60.0);
1264        let results = engine
1265            .hybrid_retrieve_with_config("database", 5, &config_balanced)
1266            .await
1267            .unwrap();
1268        assert!(!results.is_empty());
1269    }
1270
1271    #[test]
1272    fn test_hybrid_weights() {
1273        let weights = HybridWeights::default_weights();
1274        assert_eq!(weights.vector, 0.7);
1275        assert_eq!(weights.keyword, 0.3);
1276
1277        let vector_only = HybridWeights::vector_only();
1278        assert_eq!(vector_only.vector, 1.0);
1279        assert_eq!(vector_only.keyword, 0.0);
1280
1281        let balanced = HybridWeights::balanced();
1282        assert_eq!(balanced.vector, 0.5);
1283        assert_eq!(balanced.keyword, 0.5);
1284    }
1285
1286    #[test]
1287    fn test_extract_keywords() {
1288        let vector_store = InMemoryVectorStore::in_memory();
1289        let embedding_model = create_mock_embedding_model(128);
1290        let chunker = FixedSizeChunker::default();
1291
1292        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1293
1294        // 测试关键词提取
1295        let keywords = engine.extract_keywords("The Rust programming language");
1296        assert!(keywords.contains(&"rust".to_string()));
1297        assert!(keywords.contains(&"programming".to_string()));
1298        assert!(keywords.contains(&"language".to_string()));
1299        // 停用词应该被过滤
1300        assert!(!keywords.contains(&"the".to_string()));
1301    }
1302
1303    #[test]
1304    fn test_bm25_keyword_score() {
1305        let vector_store = InMemoryVectorStore::in_memory();
1306        let embedding_model = create_mock_embedding_model(128);
1307        let chunker = FixedSizeChunker::default();
1308
1309        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1310        let config = HybridSearchConfig::new();
1311
1312        let keywords = vec!["rust".to_string(), "programming".to_string()];
1313
1314        // 高匹配内容
1315        let score_high = engine.compute_keyword_score(
1316            &keywords,
1317            "Rust programming language for systems",
1318            &config,
1319        );
1320
1321        // 低匹配内容
1322        let score_low =
1323            engine.compute_keyword_score(&keywords, "Python data science frameworks", &config);
1324
1325        assert!(score_high > score_low);
1326    }
1327}