Skip to main content

phago_distributed/query/
mod.rs

1//! Distributed query implementation.
2//!
3//! This module handles queries that span multiple shards using a
4//! two-phase TF-IDF approach:
5//!
6//! 1. **Scatter (Phase 1)**: Get local term frequencies from all shards
7//! 2. **Gather (Phase 2)**: Aggregate to global document frequencies
8//! 3. **Scatter (Phase 3)**: Compute TF-IDF with global DF
9//! 4. **Gather (Phase 4)**: Merge top-k results
10
11mod distributed;
12
13pub use distributed::{DistributedHybridConfig, DistributedQueryEngine};
14
15use crate::types::*;
16
17/// Tokenizer matching phago-rag's tokenizer.
18///
19/// Performs case-insensitive tokenization with stopword removal
20/// and minimum token length filtering.
21pub fn tokenize(text: &str) -> Vec<String> {
22    let stopwords: std::collections::HashSet<&str> = [
23        "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
24        "do", "does", "did", "will", "would", "could", "should", "may", "might", "shall", "can",
25        "need", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
26        "during", "before", "after", "above", "below", "between", "out", "off", "over", "under",
27        "again", "further", "then", "once", "and", "but", "or", "if", "while", "what", "which",
28        "who", "this", "that", "these", "those", "it", "its", "how",
29    ]
30    .iter()
31    .cloned()
32    .collect();
33
34    text.to_lowercase()
35        .split_whitespace()
36        .filter(|w| w.len() >= 3 && !stopwords.contains(w))
37        .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
38        .filter(|w| w.len() >= 3)
39        .collect()
40}
41
42/// Merge scored results from multiple shards.
43///
44/// Combines results from multiple shards, sorts by score (descending),
45/// and truncates to the specified maximum.
46pub fn merge_results(results: Vec<Vec<ScoredNode>>, max_results: usize) -> Vec<ScoredNode> {
47    let mut all: Vec<ScoredNode> = results.into_iter().flatten().collect();
48    all.sort_by(|a, b| {
49        b.score
50            .partial_cmp(&a.score)
51            .unwrap_or(std::cmp::Ordering::Equal)
52    });
53    all.truncate(max_results);
54    all
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use phago_core::types::NodeId;
61
62    #[test]
63    fn test_tokenize_basic() {
64        let tokens = tokenize("The cell membrane");
65        assert!(tokens.contains(&"cell".to_string()));
66        assert!(tokens.contains(&"membrane".to_string()));
67        // "The" should be filtered as stopword
68        assert!(!tokens.contains(&"the".to_string()));
69    }
70
71    #[test]
72    fn test_tokenize_filters_short_words() {
73        let tokens = tokenize("a is the on by");
74        assert!(tokens.is_empty());
75    }
76
77    #[test]
78    fn test_tokenize_trims_punctuation() {
79        let tokens = tokenize("cell, membrane.");
80        assert!(tokens.contains(&"cell".to_string()));
81        assert!(tokens.contains(&"membrane".to_string()));
82    }
83
84    #[test]
85    fn test_tokenize_lowercase() {
86        let tokens = tokenize("CELL Membrane");
87        assert!(tokens.contains(&"cell".to_string()));
88        assert!(tokens.contains(&"membrane".to_string()));
89    }
90
91    #[test]
92    fn test_merge_results_empty() {
93        let results: Vec<Vec<ScoredNode>> = vec![];
94        let merged = merge_results(results, 10);
95        assert!(merged.is_empty());
96    }
97
98    #[test]
99    fn test_merge_results_sorting() {
100        let results = vec![
101            vec![ScoredNode {
102                node_id: NodeId::from_seed(1),
103                label: "low".to_string(),
104                score: 0.3,
105                shard_id: ShardId::new(0),
106            }],
107            vec![ScoredNode {
108                node_id: NodeId::from_seed(2),
109                label: "high".to_string(),
110                score: 0.9,
111                shard_id: ShardId::new(1),
112            }],
113        ];
114        let merged = merge_results(results, 10);
115        assert_eq!(merged.len(), 2);
116        assert_eq!(merged[0].label, "high");
117        assert_eq!(merged[1].label, "low");
118    }
119
120    #[test]
121    fn test_merge_results_truncates() {
122        let results = vec![vec![
123            ScoredNode {
124                node_id: NodeId::from_seed(1),
125                label: "a".to_string(),
126                score: 0.9,
127                shard_id: ShardId::new(0),
128            },
129            ScoredNode {
130                node_id: NodeId::from_seed(2),
131                label: "b".to_string(),
132                score: 0.8,
133                shard_id: ShardId::new(0),
134            },
135            ScoredNode {
136                node_id: NodeId::from_seed(3),
137                label: "c".to_string(),
138                score: 0.7,
139                shard_id: ShardId::new(0),
140            },
141        ]];
142        let merged = merge_results(results, 2);
143        assert_eq!(merged.len(), 2);
144        assert_eq!(merged[0].label, "a");
145        assert_eq!(merged[1].label, "b");
146    }
147}