1use crate::query::tokenize;
11use crate::shard::ShardedColony;
12use crate::types::*;
13use phago_core::topology::TopologyGraph;
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub struct DistributedHybridConfig {
19 pub alpha: f64,
21 pub max_local_results: usize,
23 pub max_results: usize,
25 pub candidate_multiplier: usize,
27}
28
29impl Default for DistributedHybridConfig {
30 fn default() -> Self {
31 Self {
32 alpha: 0.5,
33 max_local_results: 30,
34 max_results: 10,
35 candidate_multiplier: 3,
36 }
37 }
38}
39
40pub struct DistributedQueryEngine {
48 config: DistributedHybridConfig,
49}
50
51impl DistributedQueryEngine {
52 pub fn new(config: DistributedHybridConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn with_defaults() -> Self {
59 Self::new(DistributedHybridConfig::default())
60 }
61
62 pub fn config(&self) -> &DistributedHybridConfig {
64 &self.config
65 }
66
67 pub fn get_local_term_frequencies(
72 &self,
73 shard: &ShardedColony,
74 terms: &[String],
75 ) -> HashMap<String, u64> {
76 shard.get_term_frequencies(terms)
77 }
78
79 pub fn aggregate_global_df(
84 &self,
85 local_dfs: Vec<HashMap<String, u64>>,
86 ) -> HashMap<String, u64> {
87 let mut global_df = HashMap::new();
88 for local in local_dfs {
89 for (term, count) in local {
90 *global_df.entry(term).or_insert(0) += count;
91 }
92 }
93 global_df
94 }
95
96 pub fn execute_local_query(
101 &self,
102 shard: &ShardedColony,
103 request: &LocalQueryRequest,
104 ) -> LocalQueryResult {
105 let graph = shard.local().substrate().graph();
106 let all_nodes = graph.all_nodes();
107 let total_docs = all_nodes.len().max(1) as f64;
108
109 let mut scored: Vec<ScoredNode> = Vec::new();
111
112 for nid in &all_nodes {
113 if let Some(node) = graph.get_node(nid) {
114 let label_lower = node.label.to_lowercase();
115 let label_terms: Vec<String> = label_lower
116 .split(|c: char| !c.is_alphanumeric())
117 .filter(|w| w.len() >= 3)
118 .map(|w| w.to_string())
119 .collect();
120
121 let mut score = 0.0;
122 for qt in &request.query_terms {
123 let tf = label_terms.iter().filter(|t| *t == qt).count() as f64;
124 if tf > 0.0 {
125 let df = *request.global_df.get(qt).unwrap_or(&1) as f64;
127 let idf = (total_docs / df.max(1.0)).ln() + 1.0;
128 score += tf * idf;
129 }
130 }
131
132 for qt in &request.query_terms {
134 if label_lower == *qt {
135 score += 10.0;
136 }
137 }
138
139 if score > 0.0 {
140 scored.push(ScoredNode {
141 node_id: *nid,
142 label: node.label.clone(),
143 score,
144 shard_id: shard.shard_id(),
145 });
146 }
147 }
148 }
149
150 scored.sort_by(|a, b| {
152 b.score
153 .partial_cmp(&a.score)
154 .unwrap_or(std::cmp::Ordering::Equal)
155 });
156 scored.truncate(request.max_results);
157
158 LocalQueryResult {
159 shard_id: shard.shard_id(),
160 results: scored,
161 term_frequencies: shard.get_term_frequencies(&request.query_terms),
162 }
163 }
164
165 pub fn merge_results(&self, results: Vec<LocalQueryResult>) -> Vec<ScoredNode> {
170 let mut all: Vec<ScoredNode> = results.into_iter().flat_map(|r| r.results).collect();
171
172 if let Some(max_score) = all
174 .iter()
175 .map(|s| s.score)
176 .max_by(|a, b| a.partial_cmp(b).unwrap())
177 {
178 if max_score > 0.0 {
179 for node in &mut all {
180 node.score /= max_score;
181 }
182 }
183 }
184
185 all.sort_by(|a, b| {
187 b.score
188 .partial_cmp(&a.score)
189 .unwrap_or(std::cmp::Ordering::Equal)
190 });
191 all.truncate(self.config.max_results);
192 all
193 }
194
195 pub fn distributed_query(
214 &self,
215 shards: &[&ShardedColony],
216 query_text: &str,
217 ) -> Vec<ScoredNode> {
218 let query_terms = tokenize(query_text);
219 if query_terms.is_empty() || shards.is_empty() {
220 return Vec::new();
221 }
222
223 let local_dfs: Vec<HashMap<String, u64>> = shards
225 .iter()
226 .map(|s| self.get_local_term_frequencies(s, &query_terms))
227 .collect();
228
229 let global_df = self.aggregate_global_df(local_dfs);
231
232 let request = LocalQueryRequest {
234 query_terms: query_terms.clone(),
235 max_results: self.config.max_local_results,
236 global_df,
237 };
238
239 let local_results: Vec<LocalQueryResult> = shards
240 .iter()
241 .map(|s| self.execute_local_query(s, &request))
242 .collect();
243
244 self.merge_results(local_results)
246 }
247
248 pub fn local_query(&self, shard: &ShardedColony, query_text: &str) -> Vec<ScoredNode> {
252 self.distributed_query(&[shard], query_text)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::hashing::ConsistentHashRing;
260 use phago_core::types::Position;
261 use phago_runtime::colony::ColonyConfig;
262 use std::sync::Arc;
263 use tokio::sync::RwLock;
264
265 fn create_test_ring() -> Arc<RwLock<ConsistentHashRing>> {
266 Arc::new(RwLock::new(ConsistentHashRing::new(3)))
267 }
268
269 fn create_test_shard(id: u32) -> ShardedColony {
270 let ring = create_test_ring();
271 let mut shard = ShardedColony::new(ShardId::new(id), ColonyConfig::default(), ring);
272
273 shard.local_mut().ingest_document(
275 "Test Doc",
276 "cell membrane protein transport",
277 Position::new(0.0, 0.0),
278 );
279
280 shard
281 }
282
283 #[test]
284 fn test_tokenize() {
285 let tokens = tokenize("The cell membrane");
286 assert!(tokens.contains(&"cell".to_string()));
287 assert!(tokens.contains(&"membrane".to_string()));
288 assert!(!tokens.contains(&"the".to_string())); }
290
291 #[test]
292 fn test_aggregate_global_df() {
293 let engine = DistributedQueryEngine::with_defaults();
294
295 let local_dfs = vec![
296 [("cell".to_string(), 5u64), ("membrane".to_string(), 3u64)]
297 .into_iter()
298 .collect(),
299 [("cell".to_string(), 2u64), ("protein".to_string(), 4u64)]
300 .into_iter()
301 .collect(),
302 ];
303
304 let global_df = engine.aggregate_global_df(local_dfs);
305
306 assert_eq!(global_df.get("cell"), Some(&7));
307 assert_eq!(global_df.get("membrane"), Some(&3));
308 assert_eq!(global_df.get("protein"), Some(&4));
309 }
310
311 #[test]
312 fn test_merge_results() {
313 let engine = DistributedQueryEngine::new(DistributedHybridConfig {
314 max_results: 10,
315 ..Default::default()
316 });
317
318 let results = vec![
319 LocalQueryResult {
320 shard_id: ShardId::new(0),
321 results: vec![ScoredNode {
322 node_id: phago_core::types::NodeId::from_seed(1),
323 label: "cell".to_string(),
324 score: 1.0,
325 shard_id: ShardId::new(0),
326 }],
327 term_frequencies: HashMap::new(),
328 },
329 LocalQueryResult {
330 shard_id: ShardId::new(1),
331 results: vec![ScoredNode {
332 node_id: phago_core::types::NodeId::from_seed(2),
333 label: "membrane".to_string(),
334 score: 0.5,
335 shard_id: ShardId::new(1),
336 }],
337 term_frequencies: HashMap::new(),
338 },
339 ];
340
341 let merged = engine.merge_results(results);
342 assert_eq!(merged.len(), 2);
343 assert!((merged[0].score - 1.0).abs() < 0.001);
345 assert!((merged[1].score - 0.5).abs() < 0.001);
347 }
348
349 #[test]
350 fn test_config_defaults() {
351 let config = DistributedHybridConfig::default();
352 assert_eq!(config.alpha, 0.5);
353 assert_eq!(config.max_local_results, 30);
354 assert_eq!(config.max_results, 10);
355 assert_eq!(config.candidate_multiplier, 3);
356 }
357
358 #[test]
359 fn test_engine_creation() {
360 let engine = DistributedQueryEngine::with_defaults();
361 assert_eq!(engine.config().max_results, 10);
362
363 let custom_engine = DistributedQueryEngine::new(DistributedHybridConfig {
364 max_results: 20,
365 ..Default::default()
366 });
367 assert_eq!(custom_engine.config().max_results, 20);
368 }
369
370 #[test]
371 fn test_empty_query() {
372 let engine = DistributedQueryEngine::with_defaults();
373 let shard = create_test_shard(0);
374
375 let results = engine.distributed_query(&[&shard], "");
377 assert!(results.is_empty());
378
379 let results = engine.distributed_query(&[&shard], "the a an");
381 assert!(results.is_empty());
382 }
383
384 #[test]
385 fn test_empty_shards() {
386 let engine = DistributedQueryEngine::with_defaults();
387
388 let results = engine.distributed_query(&[], "cell membrane");
390 assert!(results.is_empty());
391 }
392
393 #[test]
394 fn test_local_query() {
395 let engine = DistributedQueryEngine::with_defaults();
396 let shard = create_test_shard(0);
397
398 let results = engine.local_query(&shard, "cell membrane");
402
403 assert!(results.len() <= engine.config().max_results);
406 }
407}