oxirs_graphrag/graph/
subgraph.rs

1//! Subgraph extraction for context building
2
3use crate::{GraphRAGResult, ScoredEntity, Triple};
4use std::collections::HashSet;
5
6/// Subgraph extraction configuration
7#[derive(Debug, Clone)]
8pub struct SubgraphConfig {
9    /// Maximum number of triples to include
10    pub max_triples: usize,
11    /// Include all edges between selected nodes
12    pub include_internal_edges: bool,
13    /// Include edges to/from external nodes
14    pub include_external_edges: bool,
15    /// Prioritize triples with higher-scored entities
16    pub score_weighted: bool,
17}
18
19impl Default for SubgraphConfig {
20    fn default() -> Self {
21        Self {
22            max_triples: 100,
23            include_internal_edges: true,
24            include_external_edges: true,
25            score_weighted: true,
26        }
27    }
28}
29
30/// Subgraph extractor
31pub struct SubgraphExtractor {
32    config: SubgraphConfig,
33}
34
35impl Default for SubgraphExtractor {
36    fn default() -> Self {
37        Self::new(SubgraphConfig::default())
38    }
39}
40
41impl SubgraphExtractor {
42    pub fn new(config: SubgraphConfig) -> Self {
43        Self { config }
44    }
45
46    /// Extract relevant subgraph for LLM context
47    pub fn extract(
48        &self,
49        seeds: &[ScoredEntity],
50        expanded_triples: &[Triple],
51    ) -> GraphRAGResult<Vec<Triple>> {
52        let seed_uris: HashSet<String> = seeds.iter().map(|s| s.uri.clone()).collect();
53
54        // Score triples based on relevance to seeds
55        let mut scored_triples: Vec<(f64, &Triple)> = expanded_triples
56            .iter()
57            .map(|triple| {
58                let score = self.score_triple(triple, seeds, &seed_uris);
59                (score, triple)
60            })
61            .filter(|(score, _)| *score > 0.0)
62            .collect();
63
64        // Sort by score (descending)
65        scored_triples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
66
67        // Take top triples
68        let result: Vec<Triple> = scored_triples
69            .into_iter()
70            .take(self.config.max_triples)
71            .map(|(_, t)| t.clone())
72            .collect();
73
74        Ok(result)
75    }
76
77    /// Score a triple based on relevance to seeds
78    fn score_triple(
79        &self,
80        triple: &Triple,
81        seeds: &[ScoredEntity],
82        seed_uris: &HashSet<String>,
83    ) -> f64 {
84        let subject_is_seed = seed_uris.contains(&triple.subject);
85        let object_is_seed = seed_uris.contains(&triple.object);
86
87        // Internal edges (both endpoints are seeds)
88        if subject_is_seed && object_is_seed {
89            if !self.config.include_internal_edges {
90                return 0.0;
91            }
92
93            if self.config.score_weighted {
94                // Average score of both seed entities
95                let subj_score = seeds
96                    .iter()
97                    .find(|s| s.uri == triple.subject)
98                    .map(|s| s.score)
99                    .unwrap_or(0.5);
100                let obj_score = seeds
101                    .iter()
102                    .find(|s| s.uri == triple.object)
103                    .map(|s| s.score)
104                    .unwrap_or(0.5);
105                return (subj_score + obj_score) / 2.0 * 1.5; // Boost internal edges
106            }
107            return 1.5;
108        }
109
110        // External edges (one endpoint is seed)
111        if subject_is_seed || object_is_seed {
112            if !self.config.include_external_edges {
113                return 0.0;
114            }
115
116            if self.config.score_weighted {
117                let seed_uri = if subject_is_seed {
118                    &triple.subject
119                } else {
120                    &triple.object
121                };
122                return seeds
123                    .iter()
124                    .find(|s| &s.uri == seed_uri)
125                    .map(|s| s.score)
126                    .unwrap_or(0.5);
127            }
128            return 1.0;
129        }
130
131        // Neither endpoint is seed (context edges)
132        0.1
133    }
134
135    /// Extract minimal subgraph connecting seeds
136    pub fn extract_steiner(
137        &self,
138        seeds: &[ScoredEntity],
139        all_triples: &[Triple],
140    ) -> GraphRAGResult<Vec<Triple>> {
141        // Build adjacency for path finding
142        use std::collections::HashMap;
143
144        let mut adjacency: HashMap<String, Vec<(String, Triple)>> = HashMap::new();
145        for triple in all_triples {
146            adjacency
147                .entry(triple.subject.clone())
148                .or_default()
149                .push((triple.object.clone(), triple.clone()));
150            adjacency
151                .entry(triple.object.clone())
152                .or_default()
153                .push((triple.subject.clone(), triple.clone()));
154        }
155
156        let seed_uris: Vec<String> = seeds.iter().map(|s| s.uri.clone()).collect();
157        let mut result_triples: HashSet<Triple> = HashSet::new();
158
159        // Find shortest paths between all pairs of seeds
160        for i in 0..seed_uris.len() {
161            for j in (i + 1)..seed_uris.len() {
162                if let Some(path) = self.bfs_path(&seed_uris[i], &seed_uris[j], &adjacency) {
163                    for triple in path {
164                        result_triples.insert(triple);
165                    }
166                }
167            }
168        }
169
170        Ok(result_triples
171            .into_iter()
172            .take(self.config.max_triples)
173            .collect())
174    }
175
176    /// BFS to find shortest path between two nodes
177    fn bfs_path(
178        &self,
179        start: &str,
180        end: &str,
181        adjacency: &std::collections::HashMap<String, Vec<(String, Triple)>>,
182    ) -> Option<Vec<Triple>> {
183        use std::collections::VecDeque;
184
185        if start == end {
186            return Some(vec![]);
187        }
188
189        let mut visited: HashSet<String> = HashSet::new();
190        let mut queue: VecDeque<(String, Vec<Triple>)> = VecDeque::new();
191
192        queue.push_back((start.to_string(), vec![]));
193        visited.insert(start.to_string());
194
195        while let Some((current, path)) = queue.pop_front() {
196            if let Some(neighbors) = adjacency.get(&current) {
197                for (neighbor, triple) in neighbors {
198                    if neighbor == end {
199                        let mut result = path.clone();
200                        result.push(triple.clone());
201                        return Some(result);
202                    }
203
204                    if !visited.contains(neighbor) && path.len() < 5 {
205                        // Limit path length
206                        visited.insert(neighbor.clone());
207                        let mut new_path = path.clone();
208                        new_path.push(triple.clone());
209                        queue.push_back((neighbor.clone(), new_path));
210                    }
211                }
212            }
213        }
214
215        None
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use std::collections::HashMap;
223
224    #[test]
225    fn test_subgraph_extraction() {
226        let extractor = SubgraphExtractor::default();
227
228        let seeds = vec![
229            ScoredEntity {
230                uri: "http://a".to_string(),
231                score: 0.9,
232                source: crate::ScoreSource::Vector,
233                metadata: HashMap::new(),
234            },
235            ScoredEntity {
236                uri: "http://b".to_string(),
237                score: 0.8,
238                source: crate::ScoreSource::Vector,
239                metadata: HashMap::new(),
240            },
241        ];
242
243        let triples = vec![
244            Triple::new("http://a", "http://rel", "http://b"),
245            Triple::new("http://a", "http://rel", "http://c"),
246            Triple::new("http://x", "http://rel", "http://y"),
247        ];
248
249        let result = extractor.extract(&seeds, &triples).unwrap();
250
251        // Should prioritize a->b (internal) over a->c (external)
252        assert!(!result.is_empty());
253        assert!(result
254            .iter()
255            .any(|t| t.subject == "http://a" && t.object == "http://b"));
256    }
257
258    #[test]
259    fn test_steiner_extraction() {
260        let extractor = SubgraphExtractor::default();
261
262        let seeds = vec![
263            ScoredEntity {
264                uri: "http://a".to_string(),
265                score: 0.9,
266                source: crate::ScoreSource::Vector,
267                metadata: HashMap::new(),
268            },
269            ScoredEntity {
270                uri: "http://c".to_string(),
271                score: 0.8,
272                source: crate::ScoreSource::Vector,
273                metadata: HashMap::new(),
274            },
275        ];
276
277        let triples = vec![
278            Triple::new("http://a", "http://rel", "http://b"),
279            Triple::new("http://b", "http://rel", "http://c"),
280        ];
281
282        let result = extractor.extract_steiner(&seeds, &triples).unwrap();
283
284        // Should find path a->b->c
285        assert_eq!(result.len(), 2);
286    }
287}