Skip to main content

phago_distributed/query/
distributed.rs

1//! Distributed hybrid query engine.
2//!
3//! Implements two-phase TF-IDF for distributed queries:
4//!
5//! 1. **Scatter (Phase 1)**: Collect local term frequencies from each shard
6//! 2. **Gather (Phase 2)**: Aggregate into global document frequencies
7//! 3. **Scatter (Phase 3)**: Execute local queries with global DF for accurate IDF
8//! 4. **Gather (Phase 4)**: Merge and rank top-k results from all shards
9
10use crate::query::tokenize;
11use crate::shard::ShardedColony;
12use crate::types::*;
13use phago_core::topology::TopologyGraph;
14use std::collections::HashMap;
15
16/// Configuration for distributed hybrid queries.
17#[derive(Debug, Clone)]
18pub struct DistributedHybridConfig {
19    /// Weight for TF-IDF component (0.0 to 1.0).
20    pub alpha: f64,
21    /// Maximum results per shard.
22    pub max_local_results: usize,
23    /// Maximum final results.
24    pub max_results: usize,
25    /// Candidate multiplier for TF-IDF.
26    pub candidate_multiplier: usize,
27}
28
29impl Default for DistributedHybridConfig {
30    fn default() -> Self {
31        Self {
32            alpha: 0.5,
33            max_local_results: 30,
34            max_results: 10,
35            candidate_multiplier: 3,
36        }
37    }
38}
39
40/// Distributed query engine implementing two-phase TF-IDF.
41///
42/// This engine executes queries across multiple shards by:
43/// 1. First collecting term frequencies from all shards
44/// 2. Computing global document frequencies
45/// 3. Re-executing queries with the global DF for accurate scoring
46/// 4. Merging and normalizing results across shards
47pub struct DistributedQueryEngine {
48    config: DistributedHybridConfig,
49}
50
51impl DistributedQueryEngine {
52    /// Create a new distributed query engine with the given configuration.
53    pub fn new(config: DistributedHybridConfig) -> Self {
54        Self { config }
55    }
56
57    /// Create a query engine with default configuration.
58    pub fn with_defaults() -> Self {
59        Self::new(DistributedHybridConfig::default())
60    }
61
62    /// Get the configuration.
63    pub fn config(&self) -> &DistributedHybridConfig {
64        &self.config
65    }
66
67    /// Phase 1: Get term frequencies from a shard.
68    ///
69    /// Collects how many documents in this shard contain each query term.
70    /// This is used to compute local document frequencies.
71    pub fn get_local_term_frequencies(
72        &self,
73        shard: &ShardedColony,
74        terms: &[String],
75    ) -> HashMap<String, u64> {
76        shard.get_term_frequencies(terms)
77    }
78
79    /// Phase 2: Aggregate global document frequencies.
80    ///
81    /// Combines local document frequencies from all shards to compute
82    /// the global DF for each term across the entire distributed graph.
83    pub fn aggregate_global_df(
84        &self,
85        local_dfs: Vec<HashMap<String, u64>>,
86    ) -> HashMap<String, u64> {
87        let mut global_df = HashMap::new();
88        for local in local_dfs {
89            for (term, count) in local {
90                *global_df.entry(term).or_insert(0) += count;
91            }
92        }
93        global_df
94    }
95
96    /// Phase 3: Execute local query with global DF.
97    ///
98    /// Computes TF-IDF scores for nodes in a single shard using the
99    /// global document frequencies for accurate IDF computation.
100    pub fn execute_local_query(
101        &self,
102        shard: &ShardedColony,
103        request: &LocalQueryRequest,
104    ) -> LocalQueryResult {
105        let graph = shard.local().substrate().graph();
106        let all_nodes = graph.all_nodes();
107        let total_docs = all_nodes.len().max(1) as f64;
108
109        // Compute TF-IDF for each node
110        let mut scored: Vec<ScoredNode> = Vec::new();
111
112        for nid in &all_nodes {
113            if let Some(node) = graph.get_node(nid) {
114                let label_lower = node.label.to_lowercase();
115                let label_terms: Vec<String> = label_lower
116                    .split(|c: char| !c.is_alphanumeric())
117                    .filter(|w| w.len() >= 3)
118                    .map(|w| w.to_string())
119                    .collect();
120
121                let mut score = 0.0;
122                for qt in &request.query_terms {
123                    let tf = label_terms.iter().filter(|t| *t == qt).count() as f64;
124                    if tf > 0.0 {
125                        // Use global DF if available, otherwise assume 1
126                        let df = *request.global_df.get(qt).unwrap_or(&1) as f64;
127                        let idf = (total_docs / df.max(1.0)).ln() + 1.0;
128                        score += tf * idf;
129                    }
130                }
131
132                // Exact match boost - if the entire label matches a query term
133                for qt in &request.query_terms {
134                    if label_lower == *qt {
135                        score += 10.0;
136                    }
137                }
138
139                if score > 0.0 {
140                    scored.push(ScoredNode {
141                        node_id: *nid,
142                        label: node.label.clone(),
143                        score,
144                        shard_id: shard.shard_id(),
145                    });
146                }
147            }
148        }
149
150        // Sort by score descending and truncate
151        scored.sort_by(|a, b| {
152            b.score
153                .partial_cmp(&a.score)
154                .unwrap_or(std::cmp::Ordering::Equal)
155        });
156        scored.truncate(request.max_results);
157
158        LocalQueryResult {
159            shard_id: shard.shard_id(),
160            results: scored,
161            term_frequencies: shard.get_term_frequencies(&request.query_terms),
162        }
163    }
164
165    /// Phase 4: Merge results from all shards.
166    ///
167    /// Combines results from multiple shards, normalizes scores across
168    /// shards, sorts by score, and returns the top-k results.
169    pub fn merge_results(&self, results: Vec<LocalQueryResult>) -> Vec<ScoredNode> {
170        let mut all: Vec<ScoredNode> = results.into_iter().flat_map(|r| r.results).collect();
171
172        // Normalize scores across shards
173        if let Some(max_score) = all
174            .iter()
175            .map(|s| s.score)
176            .max_by(|a, b| a.partial_cmp(b).unwrap())
177        {
178            if max_score > 0.0 {
179                for node in &mut all {
180                    node.score /= max_score;
181                }
182            }
183        }
184
185        // Sort and truncate to final result count
186        all.sort_by(|a, b| {
187            b.score
188                .partial_cmp(&a.score)
189                .unwrap_or(std::cmp::Ordering::Equal)
190        });
191        all.truncate(self.config.max_results);
192        all
193    }
194
195    /// Execute a full distributed query across multiple shards.
196    ///
197    /// This is the main entry point for distributed queries. It coordinates
198    /// all four phases of the two-phase TF-IDF algorithm:
199    ///
200    /// 1. Collects local term frequencies from each shard
201    /// 2. Aggregates them into global document frequencies
202    /// 3. Executes local queries on each shard with global DF
203    /// 4. Merges and normalizes results
204    ///
205    /// # Arguments
206    ///
207    /// * `shards` - Slice of shard references to query
208    /// * `query_text` - The raw query text to search for
209    ///
210    /// # Returns
211    ///
212    /// A vector of scored nodes, sorted by relevance (highest first).
213    pub fn distributed_query(
214        &self,
215        shards: &[&ShardedColony],
216        query_text: &str,
217    ) -> Vec<ScoredNode> {
218        let query_terms = tokenize(query_text);
219        if query_terms.is_empty() || shards.is_empty() {
220            return Vec::new();
221        }
222
223        // Phase 1: Get local term frequencies
224        let local_dfs: Vec<HashMap<String, u64>> = shards
225            .iter()
226            .map(|s| self.get_local_term_frequencies(s, &query_terms))
227            .collect();
228
229        // Phase 2: Aggregate global DF
230        let global_df = self.aggregate_global_df(local_dfs);
231
232        // Phase 3: Execute local queries with global DF
233        let request = LocalQueryRequest {
234            query_terms: query_terms.clone(),
235            max_results: self.config.max_local_results,
236            global_df,
237        };
238
239        let local_results: Vec<LocalQueryResult> = shards
240            .iter()
241            .map(|s| self.execute_local_query(s, &request))
242            .collect();
243
244        // Phase 4: Merge results
245        self.merge_results(local_results)
246    }
247
248    /// Execute a query on a single shard (for non-distributed use).
249    ///
250    /// This is useful for testing or when the data resides in a single shard.
251    pub fn local_query(&self, shard: &ShardedColony, query_text: &str) -> Vec<ScoredNode> {
252        self.distributed_query(&[shard], query_text)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::hashing::ConsistentHashRing;
260    use phago_core::types::Position;
261    use phago_runtime::colony::ColonyConfig;
262    use std::sync::Arc;
263    use tokio::sync::RwLock;
264
265    fn create_test_ring() -> Arc<RwLock<ConsistentHashRing>> {
266        Arc::new(RwLock::new(ConsistentHashRing::new(3)))
267    }
268
269    fn create_test_shard(id: u32) -> ShardedColony {
270        let ring = create_test_ring();
271        let mut shard = ShardedColony::new(ShardId::new(id), ColonyConfig::default(), ring);
272
273        // Add some test data directly to the colony
274        shard.local_mut().ingest_document(
275            "Test Doc",
276            "cell membrane protein transport",
277            Position::new(0.0, 0.0),
278        );
279
280        shard
281    }
282
283    #[test]
284    fn test_tokenize() {
285        let tokens = tokenize("The cell membrane");
286        assert!(tokens.contains(&"cell".to_string()));
287        assert!(tokens.contains(&"membrane".to_string()));
288        assert!(!tokens.contains(&"the".to_string())); // Stopword
289    }
290
291    #[test]
292    fn test_aggregate_global_df() {
293        let engine = DistributedQueryEngine::with_defaults();
294
295        let local_dfs = vec![
296            [("cell".to_string(), 5u64), ("membrane".to_string(), 3u64)]
297                .into_iter()
298                .collect(),
299            [("cell".to_string(), 2u64), ("protein".to_string(), 4u64)]
300                .into_iter()
301                .collect(),
302        ];
303
304        let global_df = engine.aggregate_global_df(local_dfs);
305
306        assert_eq!(global_df.get("cell"), Some(&7));
307        assert_eq!(global_df.get("membrane"), Some(&3));
308        assert_eq!(global_df.get("protein"), Some(&4));
309    }
310
311    #[test]
312    fn test_merge_results() {
313        let engine = DistributedQueryEngine::new(DistributedHybridConfig {
314            max_results: 10,
315            ..Default::default()
316        });
317
318        let results = vec![
319            LocalQueryResult {
320                shard_id: ShardId::new(0),
321                results: vec![ScoredNode {
322                    node_id: phago_core::types::NodeId::from_seed(1),
323                    label: "cell".to_string(),
324                    score: 1.0,
325                    shard_id: ShardId::new(0),
326                }],
327                term_frequencies: HashMap::new(),
328            },
329            LocalQueryResult {
330                shard_id: ShardId::new(1),
331                results: vec![ScoredNode {
332                    node_id: phago_core::types::NodeId::from_seed(2),
333                    label: "membrane".to_string(),
334                    score: 0.5,
335                    shard_id: ShardId::new(1),
336                }],
337                term_frequencies: HashMap::new(),
338            },
339        ];
340
341        let merged = engine.merge_results(results);
342        assert_eq!(merged.len(), 2);
343        // After normalization, highest score should be 1.0
344        assert!((merged[0].score - 1.0).abs() < 0.001);
345        // Second should be 0.5 / 1.0 = 0.5
346        assert!((merged[1].score - 0.5).abs() < 0.001);
347    }
348
349    #[test]
350    fn test_config_defaults() {
351        let config = DistributedHybridConfig::default();
352        assert_eq!(config.alpha, 0.5);
353        assert_eq!(config.max_local_results, 30);
354        assert_eq!(config.max_results, 10);
355        assert_eq!(config.candidate_multiplier, 3);
356    }
357
358    #[test]
359    fn test_engine_creation() {
360        let engine = DistributedQueryEngine::with_defaults();
361        assert_eq!(engine.config().max_results, 10);
362
363        let custom_engine = DistributedQueryEngine::new(DistributedHybridConfig {
364            max_results: 20,
365            ..Default::default()
366        });
367        assert_eq!(custom_engine.config().max_results, 20);
368    }
369
370    #[test]
371    fn test_empty_query() {
372        let engine = DistributedQueryEngine::with_defaults();
373        let shard = create_test_shard(0);
374
375        // Empty query text should return empty results
376        let results = engine.distributed_query(&[&shard], "");
377        assert!(results.is_empty());
378
379        // Query with only stopwords should also return empty
380        let results = engine.distributed_query(&[&shard], "the a an");
381        assert!(results.is_empty());
382    }
383
384    #[test]
385    fn test_empty_shards() {
386        let engine = DistributedQueryEngine::with_defaults();
387
388        // No shards should return empty results
389        let results = engine.distributed_query(&[], "cell membrane");
390        assert!(results.is_empty());
391    }
392
393    #[test]
394    fn test_local_query() {
395        let engine = DistributedQueryEngine::with_defaults();
396        let shard = create_test_shard(0);
397
398        // Run some ticks to process the document
399        // (Note: This test may not find results if the document hasn't been
400        // processed into graph nodes yet - depends on colony behavior)
401        let results = engine.local_query(&shard, "cell membrane");
402
403        // Results may be empty if document hasn't been digested into graph nodes
404        // This is expected behavior - the test validates the query path works
405        assert!(results.len() <= engine.config().max_results);
406    }
407}