use std::collections::{HashMap, HashSet};
use crate::core::classifier::QueryIntent;
use crate::core::mmr::cosine_similarity;
use super::super::{CodeIndexer, KG_EXPAND_HOPS};
use super::KG_REFINE_THRESHOLD;
impl CodeIndexer {
pub(super) async fn kg_expand(
&self,
seeds: &[(String, f32)],
intent: QueryIntent,
) -> Vec<(String, f32)> {
let graph = self.symbol_graph().await;
if graph.node_count() == 0 || seeds.is_empty() {
return Vec::new();
}
let edge_kinds = Self::edge_kinds_for_intent(intent);
let seed_ids: HashSet<&String> = seeds.iter().map(|(id, _)| id).collect();
let mut best: HashMap<String, f32> = HashMap::new();
for (seed_id, seed_score) in seeds {
let Some(symbol) = graph.symbol_for_chunk(seed_id) else {
continue;
};
for (_, neighbour_id, edge_kind) in
graph.neighbors_by_edge(symbol, &edge_kinds, KG_EXPAND_HOPS)
{
if seed_ids.contains(&neighbour_id) {
continue;
}
let derived = seed_score * edge_kind.score_multiplier();
best.entry(neighbour_id)
.and_modify(|s| {
if derived > *s {
*s = derived;
}
})
.or_insert(derived);
}
}
let mut out: Vec<(String, f32)> = best.into_iter().collect();
out.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
out
}
pub(super) async fn expand_with_kg(
&self,
fused: Vec<(String, f32)>,
intent: &QueryIntent,
use_kg_first: bool,
expand_graph: bool,
refine_embedding: Option<&[f32]>,
) -> (Vec<(String, f32)>, HashSet<String>) {
let mut all = fused.clone();
if !(use_kg_first && expand_graph) {
return (all, HashSet::new());
}
let mut expanded = self.kg_expand(&fused, intent.clone()).await;
if let Some(refine_emb) = refine_embedding {
let mut scored: Vec<(String, f32)> = Vec::with_capacity(expanded.len());
for (id, _kg_score) in &expanded {
let cos = self
.get_embedding(id)
.map(|emb| cosine_similarity(refine_emb, &emb))
.unwrap_or(0.0);
if cos >= KG_REFINE_THRESHOLD {
scored.push((id.clone(), cos));
}
}
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
expanded = scored;
}
let kg_ids: HashSet<String> = expanded.iter().map(|(id, _)| id.clone()).collect();
all.extend(expanded);
(all, kg_ids)
}
#[cfg(test)]
pub(crate) async fn expand_with_kg_for_test(
&self,
fused: Vec<(String, f32)>,
intent: &QueryIntent,
use_kg_first: bool,
expand_graph: bool,
refine_embedding: Option<&[f32]>,
) -> (Vec<(String, f32)>, HashSet<String>) {
self.expand_with_kg(fused, intent, use_kg_first, expand_graph, refine_embedding)
.await
}
}