Skip to main content

sh_layer3/
retriever_engine.rs

1//! # Retriever Engine
2//!
3//! 检索引擎:向量相似度检索和 RAG 支持。
4//!
5//! ## 功能
6//!
7//! - 文档索引与检索
8//! - 多种分块策略(固定大小、段落、代码)
9//! - 混合检索(向量 + 关键词)
10//! - RAG Pipeline(带重排序)
11//! - OpenAI Embeddings 集成
12
13use crate::types::Layer3Result;
14use async_trait::async_trait;
15use parking_lot::RwLock;
16use sh_layer2::generate_short_id;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20/// 检索引擎 trait
21///
22/// 提供向量相似度检索能力。
23#[async_trait]
24pub trait RetrieverEngine: Send + Sync {
25    /// 索引文档
26    async fn index(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>>;
27
28    /// 检索相似文档
29    async fn retrieve(&self, query: &str, top_k: usize) -> Layer3Result<Vec<RetrievalResult>>;
30
31    /// 混合检索(向量 + 关键词)
32    async fn hybrid_retrieve(
33        &self,
34        query: &str,
35        top_k: usize,
36    ) -> Layer3Result<Vec<RetrievalResult>>;
37
38    /// 带配置的混合检索
39    async fn hybrid_retrieve_with_config(
40        &self,
41        query: &str,
42        top_k: usize,
43        config: &HybridSearchConfig,
44    ) -> Layer3Result<Vec<RetrievalResult>> {
45        let _ = config;
46        self.hybrid_retrieve(query, top_k).await
47    }
48
49    /// 带过滤条件的检索
50    async fn retrieve_with_filter(
51        &self,
52        query: &str,
53        top_k: usize,
54        filter: Option<crate::vector_store::MetadataFilter>,
55    ) -> Layer3Result<Vec<RetrievalResult>> {
56        let _ = filter;
57        self.retrieve(query, top_k).await
58    }
59
60    /// 删除文档
61    async fn delete(&self, doc_ids: &[String]) -> Layer3Result<bool>;
62
63    /// 清空索引
64    async fn clear(&self) -> Layer3Result<bool>;
65
66    /// 获取文档数量
67    async fn count(&self) -> Layer3Result<usize>;
68}
69
70/// 文档结构
71#[derive(Debug, Clone)]
72pub struct Document {
73    /// 文档 ID(可选,自动生成)
74    pub id: Option<String>,
75    /// 文档内容
76    pub content: String,
77    /// 元数据
78    pub metadata: HashMap<String, serde_json::Value>,
79    /// 来源(文件路径、URL 等)
80    pub source: Option<String>,
81}
82
83impl Document {
84    pub fn new(content: impl Into<String>) -> Self {
85        Self {
86            id: None,
87            content: content.into(),
88            metadata: HashMap::new(),
89            source: None,
90        }
91    }
92
93    pub fn with_source(mut self, source: impl Into<String>) -> Self {
94        self.source = Some(source.into());
95        self
96    }
97
98    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
99        self.metadata.insert(key.into(), value);
100        self
101    }
102}
103
104/// 检索结果
105#[derive(Debug, Clone)]
106pub struct RetrievalResult {
107    /// 文档 ID
108    pub doc_id: String,
109    /// 文档内容
110    pub content: String,
111    /// 相似度分数 (0.0-1.0)
112    pub score: f32,
113    /// 元数据
114    pub metadata: HashMap<String, serde_json::Value>,
115    /// 来源
116    pub source: Option<String>,
117}
118
119// ============================================================================
120// Hybrid Search Configuration
121// ============================================================================
122
123/// 混合检索权重配置
124#[derive(Debug, Clone, Copy)]
125pub struct HybridWeights {
126    /// 向量搜索权重
127    pub vector: f32,
128    /// 关键词搜索权重
129    pub keyword: f32,
130}
131
132impl HybridWeights {
133    /// 创建新的权重配置
134    pub fn new(vector: f32, keyword: f32) -> Self {
135        let total = vector + keyword;
136        Self {
137            vector: vector / total,
138            keyword: keyword / total,
139        }
140    }
141
142    /// 默认权重:70% 向量 + 30% 关键词
143    pub fn default_weights() -> Self {
144        Self {
145            vector: 0.7,
146            keyword: 0.3,
147        }
148    }
149
150    /// 仅向量搜索
151    pub fn vector_only() -> Self {
152        Self {
153            vector: 1.0,
154            keyword: 0.0,
155        }
156    }
157
158    /// 仅关键词搜索
159    pub fn keyword_only() -> Self {
160        Self {
161            vector: 0.0,
162            keyword: 1.0,
163        }
164    }
165
166    /// 均衡权重
167    pub fn balanced() -> Self {
168        Self {
169            vector: 0.5,
170            keyword: 0.5,
171        }
172    }
173}
174
175impl Default for HybridWeights {
176    fn default() -> Self {
177        Self::default_weights()
178    }
179}
180
181/// 混合检索配置
182#[derive(Debug, Clone)]
183pub struct HybridSearchConfig {
184    /// 权重配置
185    pub weights: HybridWeights,
186    /// 是否启用短语匹配
187    pub phrase_matching: bool,
188    /// 是否启用 RRIF 重排序
189    pub use_rrif: bool,
190    /// RRIF 参数 K(控制排名衰减)
191    pub rrif_k: f32,
192    /// 候选结果数量倍数(top_k * candidates_multiplier)
193    pub candidates_multiplier: usize,
194}
195
196impl HybridSearchConfig {
197    pub fn new() -> Self {
198        Self {
199            weights: HybridWeights::default(),
200            phrase_matching: true,
201            use_rrif: true,
202            rrif_k: 60.0,
203            candidates_multiplier: 2,
204        }
205    }
206
207    pub fn with_weights(mut self, weights: HybridWeights) -> Self {
208        self.weights = weights;
209        self
210    }
211
212    pub fn with_phrase_matching(mut self, enabled: bool) -> Self {
213        self.phrase_matching = enabled;
214        self
215    }
216
217    pub fn with_rrif(mut self, enabled: bool, k: f32) -> Self {
218        self.use_rrif = enabled;
219        self.rrif_k = k;
220        self
221    }
222}
223
224impl Default for HybridSearchConfig {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230/// Embedding 模型 trait
231#[async_trait]
232pub trait EmbeddingModel: Send + Sync {
233    /// 生成文本嵌入向量
234    async fn embed(&self, text: &str) -> Layer3Result<Vec<f32>>;
235
236    /// 批量生成嵌入向量
237    async fn embed_batch(&self, texts: &[String]) -> Layer3Result<Vec<Vec<f32>>>;
238
239    /// 获取向量维度
240    fn dimension(&self) -> usize;
241
242    /// 模型名称
243    fn model_name(&self) -> &str;
244}
245
246/// 分块策略 trait
247pub trait ChunkingStrategy: Send + Sync {
248    /// 分块文档
249    fn chunk(&self, document: &Document) -> Vec<Chunk>;
250}
251
252/// 文档分块
253#[derive(Debug, Clone)]
254pub struct Chunk {
255    /// 分块 ID
256    pub id: String,
257    /// 文档 ID
258    pub doc_id: String,
259    /// 分块内容
260    pub content: String,
261    /// 在原文中的位置
262    pub position: ChunkPosition,
263    /// 元数据
264    pub metadata: HashMap<String, serde_json::Value>,
265}
266
267/// 分块位置
268#[derive(Debug, Clone, Copy)]
269pub struct ChunkPosition {
270    /// 起始字符位置
271    pub start: usize,
272    /// 结束字符位置
273    pub end: usize,
274    /// 分块索引
275    pub index: usize,
276    /// 总分块数
277    pub total: usize,
278}
279
280/// 固定大小分块策略
281#[derive(Debug, Clone)]
282pub struct FixedSizeChunker {
283    /// 分块大小(字符数)
284    pub chunk_size: usize,
285    /// 重叠大小
286    pub overlap: usize,
287}
288
289impl FixedSizeChunker {
290    pub fn new(chunk_size: usize, overlap: usize) -> Self {
291        Self {
292            chunk_size,
293            overlap,
294        }
295    }
296}
297
298impl Default for FixedSizeChunker {
299    fn default() -> Self {
300        Self {
301            chunk_size: 500,
302            overlap: 50,
303        }
304    }
305}
306
307impl ChunkingStrategy for FixedSizeChunker {
308    fn chunk(&self, document: &Document) -> Vec<Chunk> {
309        let content = &document.content;
310        if content.len() <= self.chunk_size {
311            return vec![Chunk {
312                id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
313                doc_id: document.id.clone().unwrap_or_default(),
314                content: content.clone(),
315                position: ChunkPosition {
316                    start: 0,
317                    end: content.len(),
318                    index: 0,
319                    total: 1,
320                },
321                metadata: document.metadata.clone(),
322            }];
323        }
324
325        let mut chunks = Vec::new();
326        let mut start = 0;
327        let mut index = 0;
328
329        while start < content.len() {
330            let end = (start + self.chunk_size).min(content.len());
331            chunks.push(Chunk {
332                id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
333                doc_id: document.id.clone().unwrap_or_default(),
334                content: content[start..end].to_string(),
335                position: ChunkPosition {
336                    start,
337                    end,
338                    index,
339                    total: 0, // 将在最后更新
340                },
341                metadata: document.metadata.clone(),
342            });
343            // 防止死循环:到达末尾时直接设置 start = end
344            start = if end < content.len() {
345                end.saturating_sub(self.overlap)
346            } else {
347                end
348            };
349            index += 1;
350        }
351
352        let total = chunks.len();
353        for chunk in &mut chunks {
354            chunk.position.total = total;
355        }
356
357        chunks
358    }
359}
360
361// ============================================================================
362// Paragraph Chunking Strategy
363// ============================================================================
364
365/// 段落分块策略
366///
367/// 按自然段落边界分块,保持语义完整性。
368#[derive(Debug, Clone)]
369pub struct ParagraphChunker {
370    max_chunk_size: usize,
371    min_chunk_size: usize,
372}
373
374impl ParagraphChunker {
375    pub fn new(max_chunk_size: usize, min_chunk_size: usize) -> Self {
376        Self {
377            max_chunk_size,
378            min_chunk_size,
379        }
380    }
381}
382
383impl Default for ParagraphChunker {
384    fn default() -> Self {
385        Self {
386            max_chunk_size: 1000,
387            min_chunk_size: 100,
388        }
389    }
390}
391
392impl ChunkingStrategy for ParagraphChunker {
393    fn chunk(&self, document: &Document) -> Vec<Chunk> {
394        let content = &document.content;
395        let paragraphs: Vec<&str> = content
396            .split('\n')
397            .filter(|p| !p.trim().is_empty())
398            .collect();
399
400        if paragraphs.is_empty() {
401            return vec![Chunk {
402                id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
403                doc_id: document.id.clone().unwrap_or_default(),
404                content: content.clone(),
405                position: ChunkPosition {
406                    start: 0,
407                    end: content.len(),
408                    index: 0,
409                    total: 1,
410                },
411                metadata: document.metadata.clone(),
412            }];
413        }
414
415        let mut chunks = Vec::new();
416        let mut current_chunk = String::new();
417        let mut start = 0;
418        let mut index = 0;
419
420        for paragraph in paragraphs {
421            if current_chunk.len() + paragraph.len() < self.max_chunk_size {
422                if !current_chunk.is_empty() {
423                    current_chunk.push('\n');
424                }
425                current_chunk.push_str(paragraph);
426            } else {
427                if current_chunk.len() >= self.min_chunk_size {
428                    let end = start + current_chunk.len();
429                    chunks.push(Chunk {
430                        id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
431                        doc_id: document.id.clone().unwrap_or_default(),
432                        content: current_chunk.clone(),
433                        position: ChunkPosition {
434                            start,
435                            end,
436                            index,
437                            total: 0,
438                        },
439                        metadata: document.metadata.clone(),
440                    });
441                    start = end;
442                    index += 1;
443                }
444                current_chunk = paragraph.to_string();
445            }
446        }
447
448        if current_chunk.len() >= self.min_chunk_size {
449            chunks.push(Chunk {
450                id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
451                doc_id: document.id.clone().unwrap_or_default(),
452                content: current_chunk,
453                position: ChunkPosition {
454                    start,
455                    end: content.len(),
456                    index,
457                    total: 0,
458                },
459                metadata: document.metadata.clone(),
460            });
461        }
462
463        let total = chunks.len().max(1);
464        for chunk in &mut chunks {
465            chunk.position.total = total;
466        }
467
468        if chunks.is_empty() {
469            vec![Chunk {
470                id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
471                doc_id: document.id.clone().unwrap_or_default(),
472                content: content.clone(),
473                position: ChunkPosition {
474                    start: 0,
475                    end: content.len(),
476                    index: 0,
477                    total: 1,
478                },
479                metadata: document.metadata.clone(),
480            }]
481        } else {
482            chunks
483        }
484    }
485}
486
487// ============================================================================
488// Recursive Chunking Strategy
489// ============================================================================
490
491/// 递归分块策略
492///
493/// 依次尝试多种分隔符,从大到小。
494#[derive(Debug, Clone)]
495pub struct RecursiveChunker {
496    max_chunk_size: usize,
497    separators: Vec<String>,
498}
499
500impl RecursiveChunker {
501    pub fn new(max_chunk_size: usize) -> Self {
502        Self {
503            max_chunk_size,
504            separators: vec![
505                "\n\n\n".to_string(),
506                "\n\n".to_string(),
507                "\n".to_string(),
508                ". ".to_string(),
509                " ".to_string(),
510                "".to_string(),
511            ],
512        }
513    }
514}
515
516impl Default for RecursiveChunker {
517    fn default() -> Self {
518        Self::new(1000)
519    }
520}
521
522impl ChunkingStrategy for RecursiveChunker {
523    fn chunk(&self, document: &Document) -> Vec<Chunk> {
524        self._recursive_split(document, &document.content, 0, 0)
525    }
526}
527
528impl RecursiveChunker {
529    fn _recursive_split(
530        &self,
531        document: &Document,
532        text: &str,
533        start_offset: usize,
534        initial_index: usize,
535    ) -> Vec<Chunk> {
536        if text.len() <= self.max_chunk_size {
537            return vec![Chunk {
538                id: format!(
539                    "{}-{}",
540                    document.id.as_deref().unwrap_or("doc"),
541                    initial_index
542                ),
543                doc_id: document.id.clone().unwrap_or_default(),
544                content: text.to_string(),
545                position: ChunkPosition {
546                    start: start_offset,
547                    end: start_offset + text.len(),
548                    index: initial_index,
549                    total: 1,
550                },
551                metadata: document.metadata.clone(),
552            }];
553        }
554
555        for separator in &self.separators {
556            if separator.is_empty() {
557                let mut chunks = Vec::new();
558                let mut start = 0;
559                let mut index = initial_index;
560
561                while start < text.len() {
562                    let end = (start + self.max_chunk_size).min(text.len());
563                    chunks.push(Chunk {
564                        id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
565                        doc_id: document.id.clone().unwrap_or_default(),
566                        content: text[start..end].to_string(),
567                        position: ChunkPosition {
568                            start: start_offset + start,
569                            end: start_offset + end,
570                            index,
571                            total: 0,
572                        },
573                        metadata: document.metadata.clone(),
574                    });
575                    start = end;
576                    index += 1;
577                }
578
579                let total = chunks.len();
580                for chunk in &mut chunks {
581                    chunk.position.total = total;
582                }
583                return chunks;
584            }
585
586            if text.contains(separator) {
587                let parts: Vec<&str> = text.split(separator).collect();
588                let mut chunks = Vec::new();
589                let mut current_chunk = String::new();
590                let mut current_start = start_offset;
591                let mut index = initial_index;
592
593                for (i, part) in parts.iter().enumerate() {
594                    let part_with_sep = if i < parts.len() - 1 {
595                        format!("{}{}", part, separator)
596                    } else {
597                        part.to_string()
598                    };
599
600                    if current_chunk.len() + part_with_sep.len() <= self.max_chunk_size {
601                        current_chunk.push_str(&part_with_sep);
602                    } else {
603                        if !current_chunk.is_empty() {
604                            chunks.push(Chunk {
605                                id: format!(
606                                    "{}-{}",
607                                    document.id.as_deref().unwrap_or("doc"),
608                                    index
609                                ),
610                                doc_id: document.id.clone().unwrap_or_default(),
611                                content: current_chunk.clone(),
612                                position: ChunkPosition {
613                                    start: current_start,
614                                    end: current_start + current_chunk.len(),
615                                    index,
616                                    total: 0,
617                                },
618                                metadata: document.metadata.clone(),
619                            });
620                            current_start += current_chunk.len();
621                            index += 1;
622                        }
623
624                        if part_with_sep.len() > self.max_chunk_size {
625                            let sub_chunks = self._recursive_split(
626                                document,
627                                &part_with_sep,
628                                current_start,
629                                index,
630                            );
631                            for sub in sub_chunks {
632                                current_start = sub.position.end;
633                                index += 1;
634                                chunks.push(sub);
635                            }
636                        } else {
637                            current_chunk = part_with_sep;
638                        }
639                    }
640                }
641
642                if !current_chunk.is_empty() {
643                    chunks.push(Chunk {
644                        id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
645                        doc_id: document.id.clone().unwrap_or_default(),
646                        content: current_chunk,
647                        position: ChunkPosition {
648                            start: current_start,
649                            end: start_offset + text.len(),
650                            index,
651                            total: 0,
652                        },
653                        metadata: document.metadata.clone(),
654                    });
655                }
656
657                let total = chunks.len().max(1);
658                for chunk in &mut chunks {
659                    chunk.position.total = total;
660                }
661                return chunks;
662            }
663        }
664
665        vec![Chunk {
666            id: format!(
667                "{}-{}",
668                document.id.as_deref().unwrap_or("doc"),
669                initial_index
670            ),
671            doc_id: document.id.clone().unwrap_or_default(),
672            content: text.to_string(),
673            position: ChunkPosition {
674                start: start_offset,
675                end: start_offset + text.len(),
676                index: initial_index,
677                total: 1,
678            },
679            metadata: document.metadata.clone(),
680        }]
681    }
682}
683
684// ============================================================================
685// Default Retriever Engine Implementation
686// ============================================================================
687
688use crate::vector_store::{VectorItem, VectorStore};
689
690/// 默认检索引擎实现
691///
692/// 结合 Embedding 模型、分块策略和向量存储提供完整的 RAG 功能。
693pub struct DefaultRetrieverEngine<VS, EM, CS>
694where
695    VS: VectorStore,
696    EM: EmbeddingModel,
697    CS: ChunkingStrategy,
698{
699    /// 向量存储
700    vector_store: VS,
701    /// Embedding 模型
702    embedding_model: EM,
703    /// 分块策略
704    chunking_strategy: CS,
705    /// 文档索引(文档 ID -> 分块 ID 列表)
706    doc_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
707    /// 分块内容缓存(分块 ID -> 内容)
708    chunk_cache: Arc<RwLock<HashMap<String, String>>>,
709}
710
711impl<VS, EM, CS> DefaultRetrieverEngine<VS, EM, CS>
712where
713    VS: VectorStore,
714    EM: EmbeddingModel,
715    CS: ChunkingStrategy,
716{
717    /// 创建新的检索引擎
718    pub fn new(vector_store: VS, embedding_model: EM, chunking_strategy: CS) -> Self {
719        Self {
720            vector_store,
721            embedding_model,
722            chunking_strategy,
723            doc_index: Arc::new(RwLock::new(HashMap::new())),
724            chunk_cache: Arc::new(RwLock::new(HashMap::new())),
725        }
726    }
727
728    /// 提取关键词(分词 + 去停用词)
729    fn extract_keywords(&self, query: &str) -> Vec<String> {
730        let words: Vec<String> = query
731            .to_lowercase()
732            .split_whitespace()
733            .map(|s| s.to_string())
734            .collect();
735
736        let stop_words = std::collections::HashSet::from([
737            "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
738            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
739            "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
740            "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
741            "below", "between", "under", "again", "further", "then", "once", "here", "there",
742            "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
743            "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s",
744            "t", "just", "and", "but", "if", "or", "because", "until", "while", "although",
745        ]);
746
747        words
748            .into_iter()
749            .filter(|w| !stop_words.contains(w.as_str()) && w.len() > 1)
750            .collect()
751    }
752
753    /// 计算关键词分数(BM25 风格)
754    fn compute_keyword_score(
755        &self,
756        query_keywords: &[String],
757        content: &str,
758        config: &HybridSearchConfig,
759    ) -> f32 {
760        if query_keywords.is_empty() {
761            return 0.0;
762        }
763
764        let content_lower = content.to_lowercase();
765
766        // 短语匹配奖励
767        let mut phrase_bonus: f32 = 0.0;
768        if config.phrase_matching {
769            for keyword in query_keywords {
770                if content_lower.contains(keyword) {
771                    phrase_bonus += 0.1;
772                }
773            }
774            phrase_bonus = phrase_bonus.min(0.3);
775        }
776
777        // 计算关键词匹配数量
778        let matched_keywords = query_keywords
779            .iter()
780            .filter(|kw| content_lower.contains(kw.as_str()))
781            .count();
782
783        // BM25 风格的饱和函数
784        let k1 = 1.2;
785        let content_len = content.len() as f32;
786        let avg_len = 500.0;
787        let len_norm = 1.0 - 0.75 + 0.75 * (content_len / avg_len);
788
789        let bm25_score =
790            (matched_keywords as f32 * (k1 + 1.0)) / (matched_keywords as f32 + k1 * len_norm);
791
792        // 归一化到 [0, 1]
793        let normalized_score = bm25_score / (query_keywords.len() as f32 + k1);
794        let normalized_score = normalized_score.min(1.0);
795
796        normalized_score + phrase_bonus
797    }
798
799    /// 仅关键词搜索
800    async fn keyword_only_search(
801        &self,
802        query: &str,
803        candidates: Vec<RetrievalResult>,
804        top_k: usize,
805        config: &HybridSearchConfig,
806    ) -> Layer3Result<Vec<RetrievalResult>> {
807        let query_keywords = self.extract_keywords(query);
808
809        let mut scored_results: Vec<RetrievalResult> = candidates
810            .into_iter()
811            .map(|r| {
812                let keyword_score = self.compute_keyword_score(&query_keywords, &r.content, config);
813                RetrievalResult {
814                    doc_id: r.doc_id,
815                    content: r.content,
816                    score: keyword_score,
817                    metadata: r.metadata,
818                    source: r.source,
819                }
820            })
821            .collect();
822
823        scored_results.sort_by(|a, b| {
824            b.score
825                .partial_cmp(&a.score)
826                .unwrap_or(std::cmp::Ordering::Equal)
827        });
828
829        scored_results.truncate(top_k);
830        Ok(scored_results)
831    }
832
833    /// 应用 RRIF (Reciprocal Rank Fusion) 重排序
834    fn apply_rrif(&self, results: Vec<RetrievalResult>, k: f32) -> Vec<RetrievalResult> {
835        if results.is_empty() {
836            return results;
837        }
838
839        results
840            .into_iter()
841            .enumerate()
842            .map(|(idx, mut r)| {
843                let rank = (idx + 1) as f32;
844                let rrif_score = 1.0 / (k + rank);
845                r.score = r.score * 0.5 + rrif_score * 0.5;
846                r
847            })
848            .collect()
849    }
850}
851
852#[async_trait]
853impl<VS, EM, CS> RetrieverEngine for DefaultRetrieverEngine<VS, EM, CS>
854where
855    VS: VectorStore,
856    EM: EmbeddingModel,
857    CS: ChunkingStrategy,
858{
859    async fn index(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>> {
860        let mut doc_ids = Vec::new();
861
862        for doc in documents {
863            // 生成分块
864            let doc_id = doc.id.clone().unwrap_or_else(generate_short_id);
865            let chunks = self.chunking_strategy.chunk(&Document {
866                id: Some(doc_id.clone()),
867                content: doc.content.clone(),
868                metadata: doc.metadata.clone(),
869                source: doc.source.clone(),
870            });
871
872            // 为每个分块生成 embedding 并存储
873            let chunk_ids: Vec<String> = chunks.iter().map(|c| c.id.clone()).collect();
874
875            let chunk_contents: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
876
877            // 批量生成 embeddings
878            let embeddings = self.embedding_model.embed_batch(&chunk_contents).await?;
879
880            // 构建向量项并存储
881            let vector_items: Vec<VectorItem> = chunks
882                .into_iter()
883                .zip(embeddings)
884                .map(|(chunk, embedding)| {
885                    let mut metadata = chunk.metadata.clone();
886                    metadata.insert("doc_id".to_string(), serde_json::json!(chunk.doc_id));
887                    metadata.insert(
888                        "chunk_index".to_string(),
889                        serde_json::json!(chunk.position.index),
890                    );
891                    if let Some(source) = doc.source.clone() {
892                        metadata.insert("source".to_string(), serde_json::json!(source));
893                    }
894
895                    VectorItem {
896                        id: chunk.id.clone(),
897                        vector: embedding,
898                        metadata,
899                        content: Some(chunk.content.clone()),
900                    }
901                })
902                .collect();
903
904            // 缓存分块内容
905            {
906                let mut cache = self.chunk_cache.write();
907                for item in &vector_items {
908                    cache.insert(item.id.clone(), item.content.clone().unwrap_or_default());
909                }
910            }
911
912            // 存储向量
913            self.vector_store.add_batch(vector_items).await?;
914
915            // 记录文档索引
916            {
917                let mut index = self.doc_index.write();
918                index.insert(doc_id.clone(), chunk_ids);
919            }
920
921            doc_ids.push(doc_id);
922        }
923
924        Ok(doc_ids)
925    }
926
927    async fn retrieve(&self, query: &str, top_k: usize) -> Layer3Result<Vec<RetrievalResult>> {
928        // 生成查询向量
929        let query_embedding = self.embedding_model.embed(query).await?;
930
931        // 搜索相似向量
932        let results = self.vector_store.query(query_embedding, top_k).await?;
933
934        // 补充内容(从缓存中获取完整内容)
935        let cache = self.chunk_cache.read();
936        let enriched_results: Vec<RetrievalResult> = results
937            .into_iter()
938            .map(|r| {
939                let content = cache.get(&r.doc_id).cloned().unwrap_or(r.content);
940                RetrievalResult {
941                    doc_id: r.doc_id,
942                    content,
943                    score: r.score,
944                    metadata: r.metadata,
945                    source: r.source,
946                }
947            })
948            .collect();
949
950        Ok(enriched_results)
951    }
952
953    async fn hybrid_retrieve(
954        &self,
955        query: &str,
956        top_k: usize,
957    ) -> Layer3Result<Vec<RetrievalResult>> {
958        self.hybrid_retrieve_with_config(query, top_k, &HybridSearchConfig::default())
959            .await
960    }
961
962    async fn hybrid_retrieve_with_config(
963        &self,
964        query: &str,
965        top_k: usize,
966        config: &HybridSearchConfig,
967    ) -> Layer3Result<Vec<RetrievalResult>> {
968        // 如果仅使用向量搜索,直接返回
969        if config.weights.keyword == 0.0 {
970            return self.retrieve(query, top_k).await;
971        }
972
973        // 1. 向量搜索:获取更多候选结果
974        let candidates_count = top_k * config.candidates_multiplier;
975        let vector_results = self.retrieve(query, candidates_count).await?;
976
977        // 如果仅使用关键词搜索
978        if config.weights.vector == 0.0 {
979            return self
980                .keyword_only_search(query, vector_results, top_k, config)
981                .await;
982        }
983
984        // 2. 提取查询关键词
985        let query_keywords = self.extract_keywords(query);
986
987        // 3. 计算混合分数
988        let mut scored_results: Vec<RetrievalResult> = vector_results
989            .into_iter()
990            .map(|r| {
991                let keyword_score = self.compute_keyword_score(&query_keywords, &r.content, config);
992
993                // 混合分数
994                let final_score =
995                    r.score * config.weights.vector + keyword_score * config.weights.keyword;
996
997                RetrievalResult {
998                    doc_id: r.doc_id,
999                    content: r.content,
1000                    score: final_score,
1001                    metadata: r.metadata,
1002                    source: r.source,
1003                }
1004            })
1005            .collect();
1006
1007        // 4. 按分数排序
1008        scored_results.sort_by(|a, b| {
1009            b.score
1010                .partial_cmp(&a.score)
1011                .unwrap_or(std::cmp::Ordering::Equal)
1012        });
1013
1014        // 5. 可选:RRIF 重排序
1015        if config.use_rrif {
1016            scored_results = self.apply_rrif(scored_results, config.rrif_k);
1017        }
1018
1019        // 6. 截断并返回
1020        scored_results.truncate(top_k);
1021        Ok(scored_results)
1022    }
1023
1024    async fn delete(&self, doc_ids: &[String]) -> Layer3Result<bool> {
1025        // 先收集需要删除的分块 ID,然后释放锁
1026        let all_chunk_ids: Vec<String> = {
1027            let mut index = self.doc_index.write();
1028            let mut cache = self.chunk_cache.write();
1029
1030            let mut ids_to_delete: Vec<String> = Vec::new();
1031            for doc_id in doc_ids {
1032                if let Some(chunk_ids) = index.remove(doc_id) {
1033                    for chunk_id in &chunk_ids {
1034                        cache.remove(chunk_id);
1035                    }
1036                    ids_to_delete.extend(chunk_ids);
1037                }
1038            }
1039            ids_to_delete
1040        };
1041
1042        if all_chunk_ids.is_empty() {
1043            return Ok(false);
1044        }
1045
1046        self.vector_store.delete_batch(&all_chunk_ids).await?;
1047        Ok(true)
1048    }
1049
1050    async fn clear(&self) -> Layer3Result<bool> {
1051        self.vector_store.clear().await?;
1052        let mut index = self.doc_index.write();
1053        index.clear();
1054        let mut cache = self.chunk_cache.write();
1055        cache.clear();
1056        Ok(true)
1057    }
1058
1059    async fn count(&self) -> Layer3Result<usize> {
1060        let index = self.doc_index.read();
1061        Ok(index.len())
1062    }
1063}
1064
1065/// Layer1 EmbeddingModel wrapper
1066///
1067/// Wraps a layer1 embedding model to implement layer3's EmbeddingModel trait.
1068pub struct Layer1EmbeddingAdapter {
1069    inner: Box<dyn sh_layer1::EmbeddingModel>,
1070}
1071
1072impl Layer1EmbeddingAdapter {
1073    /// Create a new adapter wrapping a layer1 embedding model
1074    pub fn new(model: Box<dyn sh_layer1::EmbeddingModel>) -> Self {
1075        Self { inner: model }
1076    }
1077}
1078
1079#[async_trait]
1080impl EmbeddingModel for Layer1EmbeddingAdapter {
1081    async fn embed(&self, text: &str) -> Layer3Result<Vec<f32>> {
1082        self.inner.embed(text).await
1083    }
1084
1085    async fn embed_batch(&self, texts: &[String]) -> Layer3Result<Vec<Vec<f32>>> {
1086        self.inner.embed_batch(texts).await
1087    }
1088
1089    fn dimension(&self) -> usize {
1090        self.inner.dimension()
1091    }
1092
1093    fn model_name(&self) -> &str {
1094        self.inner.model_name()
1095    }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100    use super::*;
1101    use crate::vector_store::InMemoryVectorStore;
1102
1103    /// 创建测试用的 Mock Embedding 模型
1104    /// 使用 Layer1 的 MockEmbeddingModel 通过适配器
1105    fn create_mock_embedding_model(dimension: usize) -> Layer1EmbeddingAdapter {
1106        Layer1EmbeddingAdapter::new(Box::new(sh_layer1::MockEmbeddingModel::new(dimension)))
1107    }
1108
1109    #[test]
1110    fn test_document_builder() {
1111        let doc = Document::new("test content")
1112            .with_source("test.txt")
1113            .with_metadata("key", serde_json::json!("value"));
1114        assert_eq!(doc.source, Some("test.txt".to_string()));
1115    }
1116
1117    #[test]
1118    fn test_fixed_size_chunker() {
1119        let chunker = FixedSizeChunker::new(100, 20);
1120        let doc = Document::new("a".repeat(250));
1121        let chunks = chunker.chunk(&doc);
1122        assert!(!chunks.is_empty());
1123    }
1124
1125    #[tokio::test]
1126    async fn test_default_retriever_engine_index() {
1127        let vector_store = InMemoryVectorStore::in_memory();
1128        let embedding_model = create_mock_embedding_model(128);
1129        let chunker = FixedSizeChunker::default();
1130
1131        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1132
1133        let doc = Document::new("This is a test document for RAG.").with_source("test.txt");
1134
1135        let doc_ids = engine.index(vec![doc]).await.unwrap();
1136        assert_eq!(doc_ids.len(), 1);
1137        assert_eq!(engine.count().await.unwrap(), 1);
1138    }
1139
1140    #[tokio::test]
1141    async fn test_default_retriever_engine_retrieve() {
1142        let vector_store = InMemoryVectorStore::in_memory();
1143        let embedding_model = create_mock_embedding_model(128);
1144        let chunker = FixedSizeChunker::default();
1145
1146        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1147
1148        // 索引文档
1149        let docs = vec![
1150            Document::new("Rust is a systems programming language."),
1151            Document::new("Python is great for data science."),
1152        ];
1153        engine.index(docs).await.unwrap();
1154
1155        // 检索
1156        let results = engine.retrieve("Rust programming", 5).await.unwrap();
1157        assert!(!results.is_empty());
1158    }
1159
1160    #[tokio::test]
1161    async fn test_default_retriever_engine_delete() {
1162        let vector_store = InMemoryVectorStore::in_memory();
1163        let embedding_model = create_mock_embedding_model(128);
1164        let chunker = FixedSizeChunker::default();
1165
1166        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1167
1168        let doc = Document::new("Test document");
1169        let doc_ids = engine.index(vec![doc]).await.unwrap();
1170
1171        let deleted = engine.delete(&doc_ids).await.unwrap();
1172        assert!(deleted);
1173        assert_eq!(engine.count().await.unwrap(), 0);
1174    }
1175
1176    #[tokio::test]
1177    async fn test_mock_embedding_model() {
1178        let model = create_mock_embedding_model(64);
1179
1180        let embedding = model.embed("test").await.unwrap();
1181        assert_eq!(embedding.len(), 64);
1182        assert_eq!(model.dimension(), 64);
1183        assert_eq!(model.model_name(), "mock-embedding");
1184
1185        let embeddings = model
1186            .embed_batch(&["test1".to_string(), "test2".to_string()])
1187            .await
1188            .unwrap();
1189        assert_eq!(embeddings.len(), 2);
1190    }
1191
1192    #[tokio::test]
1193    async fn test_hybrid_retrieve() {
1194        let vector_store = InMemoryVectorStore::in_memory();
1195        let embedding_model = create_mock_embedding_model(128);
1196        let chunker = FixedSizeChunker::default();
1197
1198        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1199
1200        // 索引文档
1201        let docs = vec![
1202            Document::new("Rust is a systems programming language designed for performance."),
1203            Document::new("Python is widely used for data science and machine learning."),
1204            Document::new("JavaScript runs in the browser for web development."),
1205        ];
1206        engine.index(docs).await.unwrap();
1207
1208        // 混合检索
1209        let results = engine
1210            .hybrid_retrieve("Rust programming language", 5)
1211            .await
1212            .unwrap();
1213        assert!(!results.is_empty());
1214        // Rust 相关文档应该在前面
1215        assert!(results[0].content.contains("Rust"));
1216    }
1217
1218    #[tokio::test]
1219    async fn test_hybrid_retrieve_with_config() {
1220        let vector_store = InMemoryVectorStore::in_memory();
1221        let embedding_model = create_mock_embedding_model(128);
1222        let chunker = FixedSizeChunker::default();
1223
1224        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1225
1226        // 索引文档
1227        let docs = vec![
1228            Document::new("Machine learning algorithms use neural networks."),
1229            Document::new("The database stores data for the application."),
1230        ];
1231        engine.index(docs).await.unwrap();
1232
1233        // 测试仅向量搜索
1234        let config_vector_only =
1235            HybridSearchConfig::new().with_weights(HybridWeights::vector_only());
1236        let results = engine
1237            .hybrid_retrieve_with_config("neural networks", 5, &config_vector_only)
1238            .await
1239            .unwrap();
1240        assert!(!results.is_empty());
1241
1242        // 测试仅关键词搜索
1243        let config_keyword_only =
1244            HybridSearchConfig::new().with_weights(HybridWeights::keyword_only());
1245        let results = engine
1246            .hybrid_retrieve_with_config("machine learning", 5, &config_keyword_only)
1247            .await
1248            .unwrap();
1249        assert!(!results.is_empty());
1250
1251        // 测试均衡权重
1252        let config_balanced = HybridSearchConfig::new()
1253            .with_weights(HybridWeights::balanced())
1254            .with_rrif(true, 60.0);
1255        let results = engine
1256            .hybrid_retrieve_with_config("database", 5, &config_balanced)
1257            .await
1258            .unwrap();
1259        assert!(!results.is_empty());
1260    }
1261
1262    #[test]
1263    fn test_hybrid_weights() {
1264        let weights = HybridWeights::default_weights();
1265        assert_eq!(weights.vector, 0.7);
1266        assert_eq!(weights.keyword, 0.3);
1267
1268        let vector_only = HybridWeights::vector_only();
1269        assert_eq!(vector_only.vector, 1.0);
1270        assert_eq!(vector_only.keyword, 0.0);
1271
1272        let balanced = HybridWeights::balanced();
1273        assert_eq!(balanced.vector, 0.5);
1274        assert_eq!(balanced.keyword, 0.5);
1275    }
1276
1277    #[test]
1278    fn test_extract_keywords() {
1279        let vector_store = InMemoryVectorStore::in_memory();
1280        let embedding_model = create_mock_embedding_model(128);
1281        let chunker = FixedSizeChunker::default();
1282
1283        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1284
1285        // 测试关键词提取
1286        let keywords = engine.extract_keywords("The Rust programming language");
1287        assert!(keywords.contains(&"rust".to_string()));
1288        assert!(keywords.contains(&"programming".to_string()));
1289        assert!(keywords.contains(&"language".to_string()));
1290        // 停用词应该被过滤
1291        assert!(!keywords.contains(&"the".to_string()));
1292    }
1293
1294    #[test]
1295    fn test_bm25_keyword_score() {
1296        let vector_store = InMemoryVectorStore::in_memory();
1297        let embedding_model = create_mock_embedding_model(128);
1298        let chunker = FixedSizeChunker::default();
1299
1300        let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1301        let config = HybridSearchConfig::new();
1302
1303        let keywords = vec!["rust".to_string(), "programming".to_string()];
1304
1305        // 高匹配内容
1306        let score_high = engine.compute_keyword_score(
1307            &keywords,
1308            "Rust programming language for systems",
1309            &config,
1310        );
1311
1312        // 低匹配内容
1313        let score_low =
1314            engine.compute_keyword_score(&keywords, "Python data science frameworks", &config);
1315
1316        assert!(score_high > score_low);
1317    }
1318}