use crate::query::tokenize;
use crate::shard::ShardedColony;
use crate::types::*;
use phago_core::topology::TopologyGraph;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DistributedHybridConfig {
pub alpha: f64,
pub max_local_results: usize,
pub max_results: usize,
pub candidate_multiplier: usize,
}
impl Default for DistributedHybridConfig {
fn default() -> Self {
Self {
alpha: 0.5,
max_local_results: 30,
max_results: 10,
candidate_multiplier: 3,
}
}
}
pub struct DistributedQueryEngine {
config: DistributedHybridConfig,
}
impl DistributedQueryEngine {
pub fn new(config: DistributedHybridConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(DistributedHybridConfig::default())
}
pub fn config(&self) -> &DistributedHybridConfig {
&self.config
}
pub fn get_local_term_frequencies(
&self,
shard: &ShardedColony,
terms: &[String],
) -> HashMap<String, u64> {
shard.get_term_frequencies(terms)
}
pub fn aggregate_global_df(
&self,
local_dfs: Vec<HashMap<String, u64>>,
) -> HashMap<String, u64> {
let mut global_df = HashMap::new();
for local in local_dfs {
for (term, count) in local {
*global_df.entry(term).or_insert(0) += count;
}
}
global_df
}
pub fn execute_local_query(
&self,
shard: &ShardedColony,
request: &LocalQueryRequest,
) -> LocalQueryResult {
let graph = shard.local().substrate().graph();
let all_nodes = graph.all_nodes();
let total_docs = all_nodes.len().max(1) as f64;
let mut scored: Vec<ScoredNode> = Vec::new();
for nid in &all_nodes {
if let Some(node) = graph.get_node(nid) {
let label_lower = node.label.to_lowercase();
let label_terms: Vec<String> = label_lower
.split(|c: char| !c.is_alphanumeric())
.filter(|w| w.len() >= 3)
.map(|w| w.to_string())
.collect();
let mut score = 0.0;
for qt in &request.query_terms {
let tf = label_terms.iter().filter(|t| *t == qt).count() as f64;
if tf > 0.0 {
let df = *request.global_df.get(qt).unwrap_or(&1) as f64;
let idf = (total_docs / df.max(1.0)).ln() + 1.0;
score += tf * idf;
}
}
for qt in &request.query_terms {
if label_lower == *qt {
score += 10.0;
}
}
if score > 0.0 {
scored.push(ScoredNode {
node_id: *nid,
label: node.label.clone(),
score,
shard_id: shard.shard_id(),
});
}
}
}
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(request.max_results);
LocalQueryResult {
shard_id: shard.shard_id(),
results: scored,
term_frequencies: shard.get_term_frequencies(&request.query_terms),
}
}
pub fn merge_results(&self, results: Vec<LocalQueryResult>) -> Vec<ScoredNode> {
let mut all: Vec<ScoredNode> = results.into_iter().flat_map(|r| r.results).collect();
if let Some(max_score) = all
.iter()
.map(|s| s.score)
.max_by(|a, b| a.partial_cmp(b).unwrap())
{
if max_score > 0.0 {
for node in &mut all {
node.score /= max_score;
}
}
}
all.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all.truncate(self.config.max_results);
all
}
pub fn distributed_query(
&self,
shards: &[&ShardedColony],
query_text: &str,
) -> Vec<ScoredNode> {
let query_terms = tokenize(query_text);
if query_terms.is_empty() || shards.is_empty() {
return Vec::new();
}
let local_dfs: Vec<HashMap<String, u64>> = shards
.iter()
.map(|s| self.get_local_term_frequencies(s, &query_terms))
.collect();
let global_df = self.aggregate_global_df(local_dfs);
let request = LocalQueryRequest {
query_terms: query_terms.clone(),
max_results: self.config.max_local_results,
global_df,
};
let local_results: Vec<LocalQueryResult> = shards
.iter()
.map(|s| self.execute_local_query(s, &request))
.collect();
self.merge_results(local_results)
}
pub fn local_query(&self, shard: &ShardedColony, query_text: &str) -> Vec<ScoredNode> {
self.distributed_query(&[shard], query_text)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hashing::ConsistentHashRing;
use phago_core::types::Position;
use phago_runtime::colony::ColonyConfig;
use std::sync::Arc;
use tokio::sync::RwLock;
fn create_test_ring() -> Arc<RwLock<ConsistentHashRing>> {
Arc::new(RwLock::new(ConsistentHashRing::new(3)))
}
fn create_test_shard(id: u32) -> ShardedColony {
let ring = create_test_ring();
let mut shard = ShardedColony::new(ShardId::new(id), ColonyConfig::default(), ring);
shard.local_mut().ingest_document(
"Test Doc",
"cell membrane protein transport",
Position::new(0.0, 0.0),
);
shard
}
#[test]
fn test_tokenize() {
let tokens = tokenize("The cell membrane");
assert!(tokens.contains(&"cell".to_string()));
assert!(tokens.contains(&"membrane".to_string()));
assert!(!tokens.contains(&"the".to_string())); }
#[test]
fn test_aggregate_global_df() {
let engine = DistributedQueryEngine::with_defaults();
let local_dfs = vec![
[("cell".to_string(), 5u64), ("membrane".to_string(), 3u64)]
.into_iter()
.collect(),
[("cell".to_string(), 2u64), ("protein".to_string(), 4u64)]
.into_iter()
.collect(),
];
let global_df = engine.aggregate_global_df(local_dfs);
assert_eq!(global_df.get("cell"), Some(&7));
assert_eq!(global_df.get("membrane"), Some(&3));
assert_eq!(global_df.get("protein"), Some(&4));
}
#[test]
fn test_merge_results() {
let engine = DistributedQueryEngine::new(DistributedHybridConfig {
max_results: 10,
..Default::default()
});
let results = vec![
LocalQueryResult {
shard_id: ShardId::new(0),
results: vec![ScoredNode {
node_id: phago_core::types::NodeId::from_seed(1),
label: "cell".to_string(),
score: 1.0,
shard_id: ShardId::new(0),
}],
term_frequencies: HashMap::new(),
},
LocalQueryResult {
shard_id: ShardId::new(1),
results: vec![ScoredNode {
node_id: phago_core::types::NodeId::from_seed(2),
label: "membrane".to_string(),
score: 0.5,
shard_id: ShardId::new(1),
}],
term_frequencies: HashMap::new(),
},
];
let merged = engine.merge_results(results);
assert_eq!(merged.len(), 2);
assert!((merged[0].score - 1.0).abs() < 0.001);
assert!((merged[1].score - 0.5).abs() < 0.001);
}
#[test]
fn test_config_defaults() {
let config = DistributedHybridConfig::default();
assert_eq!(config.alpha, 0.5);
assert_eq!(config.max_local_results, 30);
assert_eq!(config.max_results, 10);
assert_eq!(config.candidate_multiplier, 3);
}
#[test]
fn test_engine_creation() {
let engine = DistributedQueryEngine::with_defaults();
assert_eq!(engine.config().max_results, 10);
let custom_engine = DistributedQueryEngine::new(DistributedHybridConfig {
max_results: 20,
..Default::default()
});
assert_eq!(custom_engine.config().max_results, 20);
}
#[test]
fn test_empty_query() {
let engine = DistributedQueryEngine::with_defaults();
let shard = create_test_shard(0);
let results = engine.distributed_query(&[&shard], "");
assert!(results.is_empty());
let results = engine.distributed_query(&[&shard], "the a an");
assert!(results.is_empty());
}
#[test]
fn test_empty_shards() {
let engine = DistributedQueryEngine::with_defaults();
let results = engine.distributed_query(&[], "cell membrane");
assert!(results.is_empty());
}
#[test]
fn test_local_query() {
let engine = DistributedQueryEngine::with_defaults();
let shard = create_test_shard(0);
let results = engine.local_query(&shard, "cell membrane");
assert!(results.len() <= engine.config().max_results);
}
}