oxirs_graphrag/graph/
subgraph.rs1use crate::{GraphRAGResult, ScoredEntity, Triple};
4use std::collections::HashSet;
5
6#[derive(Debug, Clone)]
8pub struct SubgraphConfig {
9 pub max_triples: usize,
11 pub include_internal_edges: bool,
13 pub include_external_edges: bool,
15 pub score_weighted: bool,
17}
18
19impl Default for SubgraphConfig {
20 fn default() -> Self {
21 Self {
22 max_triples: 100,
23 include_internal_edges: true,
24 include_external_edges: true,
25 score_weighted: true,
26 }
27 }
28}
29
30pub struct SubgraphExtractor {
32 config: SubgraphConfig,
33}
34
35impl Default for SubgraphExtractor {
36 fn default() -> Self {
37 Self::new(SubgraphConfig::default())
38 }
39}
40
41impl SubgraphExtractor {
42 pub fn new(config: SubgraphConfig) -> Self {
43 Self { config }
44 }
45
46 pub fn extract(
48 &self,
49 seeds: &[ScoredEntity],
50 expanded_triples: &[Triple],
51 ) -> GraphRAGResult<Vec<Triple>> {
52 let seed_uris: HashSet<String> = seeds.iter().map(|s| s.uri.clone()).collect();
53
54 let mut scored_triples: Vec<(f64, &Triple)> = expanded_triples
56 .iter()
57 .map(|triple| {
58 let score = self.score_triple(triple, seeds, &seed_uris);
59 (score, triple)
60 })
61 .filter(|(score, _)| *score > 0.0)
62 .collect();
63
64 scored_triples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
66
67 let result: Vec<Triple> = scored_triples
69 .into_iter()
70 .take(self.config.max_triples)
71 .map(|(_, t)| t.clone())
72 .collect();
73
74 Ok(result)
75 }
76
77 fn score_triple(
79 &self,
80 triple: &Triple,
81 seeds: &[ScoredEntity],
82 seed_uris: &HashSet<String>,
83 ) -> f64 {
84 let subject_is_seed = seed_uris.contains(&triple.subject);
85 let object_is_seed = seed_uris.contains(&triple.object);
86
87 if subject_is_seed && object_is_seed {
89 if !self.config.include_internal_edges {
90 return 0.0;
91 }
92
93 if self.config.score_weighted {
94 let subj_score = seeds
96 .iter()
97 .find(|s| s.uri == triple.subject)
98 .map(|s| s.score)
99 .unwrap_or(0.5);
100 let obj_score = seeds
101 .iter()
102 .find(|s| s.uri == triple.object)
103 .map(|s| s.score)
104 .unwrap_or(0.5);
105 return (subj_score + obj_score) / 2.0 * 1.5; }
107 return 1.5;
108 }
109
110 if subject_is_seed || object_is_seed {
112 if !self.config.include_external_edges {
113 return 0.0;
114 }
115
116 if self.config.score_weighted {
117 let seed_uri = if subject_is_seed {
118 &triple.subject
119 } else {
120 &triple.object
121 };
122 return seeds
123 .iter()
124 .find(|s| &s.uri == seed_uri)
125 .map(|s| s.score)
126 .unwrap_or(0.5);
127 }
128 return 1.0;
129 }
130
131 0.1
133 }
134
135 pub fn extract_steiner(
137 &self,
138 seeds: &[ScoredEntity],
139 all_triples: &[Triple],
140 ) -> GraphRAGResult<Vec<Triple>> {
141 use std::collections::HashMap;
143
144 let mut adjacency: HashMap<String, Vec<(String, Triple)>> = HashMap::new();
145 for triple in all_triples {
146 adjacency
147 .entry(triple.subject.clone())
148 .or_default()
149 .push((triple.object.clone(), triple.clone()));
150 adjacency
151 .entry(triple.object.clone())
152 .or_default()
153 .push((triple.subject.clone(), triple.clone()));
154 }
155
156 let seed_uris: Vec<String> = seeds.iter().map(|s| s.uri.clone()).collect();
157 let mut result_triples: HashSet<Triple> = HashSet::new();
158
159 for i in 0..seed_uris.len() {
161 for j in (i + 1)..seed_uris.len() {
162 if let Some(path) = self.bfs_path(&seed_uris[i], &seed_uris[j], &adjacency) {
163 for triple in path {
164 result_triples.insert(triple);
165 }
166 }
167 }
168 }
169
170 Ok(result_triples
171 .into_iter()
172 .take(self.config.max_triples)
173 .collect())
174 }
175
176 fn bfs_path(
178 &self,
179 start: &str,
180 end: &str,
181 adjacency: &std::collections::HashMap<String, Vec<(String, Triple)>>,
182 ) -> Option<Vec<Triple>> {
183 use std::collections::VecDeque;
184
185 if start == end {
186 return Some(vec![]);
187 }
188
189 let mut visited: HashSet<String> = HashSet::new();
190 let mut queue: VecDeque<(String, Vec<Triple>)> = VecDeque::new();
191
192 queue.push_back((start.to_string(), vec![]));
193 visited.insert(start.to_string());
194
195 while let Some((current, path)) = queue.pop_front() {
196 if let Some(neighbors) = adjacency.get(¤t) {
197 for (neighbor, triple) in neighbors {
198 if neighbor == end {
199 let mut result = path.clone();
200 result.push(triple.clone());
201 return Some(result);
202 }
203
204 if !visited.contains(neighbor) && path.len() < 5 {
205 visited.insert(neighbor.clone());
207 let mut new_path = path.clone();
208 new_path.push(triple.clone());
209 queue.push_back((neighbor.clone(), new_path));
210 }
211 }
212 }
213 }
214
215 None
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use std::collections::HashMap;
223
224 #[test]
225 fn test_subgraph_extraction() {
226 let extractor = SubgraphExtractor::default();
227
228 let seeds = vec![
229 ScoredEntity {
230 uri: "http://a".to_string(),
231 score: 0.9,
232 source: crate::ScoreSource::Vector,
233 metadata: HashMap::new(),
234 },
235 ScoredEntity {
236 uri: "http://b".to_string(),
237 score: 0.8,
238 source: crate::ScoreSource::Vector,
239 metadata: HashMap::new(),
240 },
241 ];
242
243 let triples = vec![
244 Triple::new("http://a", "http://rel", "http://b"),
245 Triple::new("http://a", "http://rel", "http://c"),
246 Triple::new("http://x", "http://rel", "http://y"),
247 ];
248
249 let result = extractor.extract(&seeds, &triples).unwrap();
250
251 assert!(!result.is_empty());
253 assert!(result
254 .iter()
255 .any(|t| t.subject == "http://a" && t.object == "http://b"));
256 }
257
258 #[test]
259 fn test_steiner_extraction() {
260 let extractor = SubgraphExtractor::default();
261
262 let seeds = vec![
263 ScoredEntity {
264 uri: "http://a".to_string(),
265 score: 0.9,
266 source: crate::ScoreSource::Vector,
267 metadata: HashMap::new(),
268 },
269 ScoredEntity {
270 uri: "http://c".to_string(),
271 score: 0.8,
272 source: crate::ScoreSource::Vector,
273 metadata: HashMap::new(),
274 },
275 ];
276
277 let triples = vec![
278 Triple::new("http://a", "http://rel", "http://b"),
279 Triple::new("http://b", "http://rel", "http://c"),
280 ];
281
282 let result = extractor.extract_steiner(&seeds, &triples).unwrap();
283
284 assert_eq!(result.len(), 2);
286 }
287}