use crate::{GraphRAGResult, ScoredEntity, Triple};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct SubgraphConfig {
pub max_triples: usize,
pub include_internal_edges: bool,
pub include_external_edges: bool,
pub score_weighted: bool,
}
impl Default for SubgraphConfig {
fn default() -> Self {
Self {
max_triples: 100,
include_internal_edges: true,
include_external_edges: true,
score_weighted: true,
}
}
}
pub struct SubgraphExtractor {
config: SubgraphConfig,
}
impl Default for SubgraphExtractor {
fn default() -> Self {
Self::new(SubgraphConfig::default())
}
}
impl SubgraphExtractor {
pub fn new(config: SubgraphConfig) -> Self {
Self { config }
}
pub fn extract(
&self,
seeds: &[ScoredEntity],
expanded_triples: &[Triple],
) -> GraphRAGResult<Vec<Triple>> {
let seed_uris: HashSet<String> = seeds.iter().map(|s| s.uri.clone()).collect();
let mut scored_triples: Vec<(f64, &Triple)> = expanded_triples
.iter()
.map(|triple| {
let score = self.score_triple(triple, seeds, &seed_uris);
(score, triple)
})
.filter(|(score, _)| *score > 0.0)
.collect();
scored_triples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let result: Vec<Triple> = scored_triples
.into_iter()
.take(self.config.max_triples)
.map(|(_, t)| t.clone())
.collect();
Ok(result)
}
fn score_triple(
&self,
triple: &Triple,
seeds: &[ScoredEntity],
seed_uris: &HashSet<String>,
) -> f64 {
let subject_is_seed = seed_uris.contains(&triple.subject);
let object_is_seed = seed_uris.contains(&triple.object);
if subject_is_seed && object_is_seed {
if !self.config.include_internal_edges {
return 0.0;
}
if self.config.score_weighted {
let subj_score = seeds
.iter()
.find(|s| s.uri == triple.subject)
.map(|s| s.score)
.unwrap_or(0.5);
let obj_score = seeds
.iter()
.find(|s| s.uri == triple.object)
.map(|s| s.score)
.unwrap_or(0.5);
return (subj_score + obj_score) / 2.0 * 1.5; }
return 1.5;
}
if subject_is_seed || object_is_seed {
if !self.config.include_external_edges {
return 0.0;
}
if self.config.score_weighted {
let seed_uri = if subject_is_seed {
&triple.subject
} else {
&triple.object
};
return seeds
.iter()
.find(|s| &s.uri == seed_uri)
.map(|s| s.score)
.unwrap_or(0.5);
}
return 1.0;
}
0.1
}
pub fn extract_steiner(
&self,
seeds: &[ScoredEntity],
all_triples: &[Triple],
) -> GraphRAGResult<Vec<Triple>> {
use std::collections::HashMap;
let mut adjacency: HashMap<String, Vec<(String, Triple)>> = HashMap::new();
for triple in all_triples {
adjacency
.entry(triple.subject.clone())
.or_default()
.push((triple.object.clone(), triple.clone()));
adjacency
.entry(triple.object.clone())
.or_default()
.push((triple.subject.clone(), triple.clone()));
}
let seed_uris: Vec<String> = seeds.iter().map(|s| s.uri.clone()).collect();
let mut result_triples: HashSet<Triple> = HashSet::new();
for i in 0..seed_uris.len() {
for j in (i + 1)..seed_uris.len() {
if let Some(path) = self.bfs_path(&seed_uris[i], &seed_uris[j], &adjacency) {
for triple in path {
result_triples.insert(triple);
}
}
}
}
Ok(result_triples
.into_iter()
.take(self.config.max_triples)
.collect())
}
fn bfs_path(
&self,
start: &str,
end: &str,
adjacency: &std::collections::HashMap<String, Vec<(String, Triple)>>,
) -> Option<Vec<Triple>> {
use std::collections::VecDeque;
if start == end {
return Some(vec![]);
}
let mut visited: HashSet<String> = HashSet::new();
let mut queue: VecDeque<(String, Vec<Triple>)> = VecDeque::new();
queue.push_back((start.to_string(), vec![]));
visited.insert(start.to_string());
while let Some((current, path)) = queue.pop_front() {
if let Some(neighbors) = adjacency.get(¤t) {
for (neighbor, triple) in neighbors {
if neighbor == end {
let mut result = path.clone();
result.push(triple.clone());
return Some(result);
}
if !visited.contains(neighbor) && path.len() < 5 {
visited.insert(neighbor.clone());
let mut new_path = path.clone();
new_path.push(triple.clone());
queue.push_back((neighbor.clone(), new_path));
}
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_subgraph_extraction() {
let extractor = SubgraphExtractor::default();
let seeds = vec![
ScoredEntity {
uri: "http://a".to_string(),
score: 0.9,
source: crate::ScoreSource::Vector,
metadata: HashMap::new(),
},
ScoredEntity {
uri: "http://b".to_string(),
score: 0.8,
source: crate::ScoreSource::Vector,
metadata: HashMap::new(),
},
];
let triples = vec![
Triple::new("http://a", "http://rel", "http://b"),
Triple::new("http://a", "http://rel", "http://c"),
Triple::new("http://x", "http://rel", "http://y"),
];
let result = extractor.extract(&seeds, &triples).expect("should succeed");
assert!(!result.is_empty());
assert!(result
.iter()
.any(|t| t.subject == "http://a" && t.object == "http://b"));
}
#[test]
fn test_steiner_extraction() {
let extractor = SubgraphExtractor::default();
let seeds = vec![
ScoredEntity {
uri: "http://a".to_string(),
score: 0.9,
source: crate::ScoreSource::Vector,
metadata: HashMap::new(),
},
ScoredEntity {
uri: "http://c".to_string(),
score: 0.8,
source: crate::ScoreSource::Vector,
metadata: HashMap::new(),
},
];
let triples = vec![
Triple::new("http://a", "http://rel", "http://b"),
Triple::new("http://b", "http://rel", "http://c"),
];
let result = extractor
.extract_steiner(&seeds, &triples)
.expect("should succeed");
assert_eq!(result.len(), 2);
}
}