oxirs_graphrag/gnn_encoder/
adjacency.rs1use std::collections::HashMap;
7
8pub type EdgeList = Vec<(String, String, String)>;
14
15#[derive(Debug, Clone)]
22pub struct AdjacencyGraph {
23 pub entity_to_idx: HashMap<String, usize>,
25 pub idx_to_entity: Vec<String>,
27 pub adjacency: Vec<Vec<usize>>,
29 pub edge_features: Vec<Vec<f64>>,
32}
33
34impl AdjacencyGraph {
35 pub fn from_triples(triples: &[(String, String, String)]) -> Self {
40 let mut entity_to_idx: HashMap<String, usize> = HashMap::new();
41 let mut idx_to_entity: Vec<String> = Vec::new();
42
43 let ensure_entity = |name: &str,
45 entity_to_idx: &mut HashMap<String, usize>,
46 idx_to_entity: &mut Vec<String>|
47 -> usize {
48 if let Some(&idx) = entity_to_idx.get(name) {
49 return idx;
50 }
51 let idx = idx_to_entity.len();
52 entity_to_idx.insert(name.to_string(), idx);
53 idx_to_entity.push(name.to_string());
54 idx
55 };
56
57 for (s, _p, o) in triples {
58 ensure_entity(s, &mut entity_to_idx, &mut idx_to_entity);
59 ensure_entity(o, &mut entity_to_idx, &mut idx_to_entity);
60 }
61
62 let n = idx_to_entity.len();
63 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
64 let mut edge_features: Vec<Vec<f64>> = vec![Vec::new(); n];
65
66 for (s, p, o) in triples {
68 let si = *entity_to_idx
69 .get(s.as_str())
70 .expect("entity must exist after first pass");
71 let oi = *entity_to_idx
72 .get(o.as_str())
73 .expect("entity must exist after first pass");
74
75 adjacency[si].push(oi);
76 edge_features[si].push(predicate_feature(p));
77
78 if si != oi {
80 adjacency[oi].push(si);
81 edge_features[oi].push(predicate_feature(p));
82 }
83 }
84
85 Self {
86 entity_to_idx,
87 idx_to_entity,
88 adjacency,
89 edge_features,
90 }
91 }
92
93 pub fn neighbors(&self, idx: usize) -> &[usize] {
96 if idx < self.adjacency.len() {
97 &self.adjacency[idx]
98 } else {
99 &[]
100 }
101 }
102
103 pub fn entity_count(&self) -> usize {
105 self.idx_to_entity.len()
106 }
107
108 pub fn entity_name(&self, idx: usize) -> Option<&str> {
111 self.idx_to_entity.get(idx).map(|s| s.as_str())
112 }
113
114 pub fn entity_index(&self, name: &str) -> Option<usize> {
117 self.entity_to_idx.get(name).copied()
118 }
119}
120
121fn predicate_feature(predicate: &str) -> f64 {
128 let hash: u64 = predicate
129 .bytes()
130 .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
131 (hash % 100_007) as f64 / 100_007.0
133}
134
135#[cfg(test)]
140mod tests {
141 use super::*;
142
143 fn s(x: &str) -> String {
144 x.to_string()
145 }
146
147 #[test]
148 fn test_empty_graph() {
149 let g = AdjacencyGraph::from_triples(&[]);
150 assert_eq!(g.entity_count(), 0);
151 assert!(g.neighbors(0).is_empty());
152 }
153
154 #[test]
155 fn test_single_triple() {
156 let triples = vec![(s("Alice"), s("knows"), s("Bob"))];
157 let g = AdjacencyGraph::from_triples(&triples);
158 assert_eq!(g.entity_count(), 2);
159 let alice_idx = g.entity_index("Alice").expect("Alice must be present");
161 let bob_idx = g.entity_index("Bob").expect("Bob must be present");
162 assert_eq!(g.neighbors(alice_idx), &[bob_idx]);
163 assert_eq!(g.neighbors(bob_idx), &[alice_idx]);
164 }
165
166 #[test]
167 fn test_entity_deduplication() {
168 let triples = vec![
169 (s("Alice"), s("knows"), s("Bob")),
170 (s("Alice"), s("worksAt"), s("Acme")),
171 (s("Bob"), s("worksAt"), s("Acme")),
172 ];
173 let g = AdjacencyGraph::from_triples(&triples);
174 assert_eq!(g.entity_count(), 3);
176 }
177
178 #[test]
179 fn test_neighbor_lookup() {
180 let triples = vec![(s("A"), s("r"), s("B")), (s("A"), s("r"), s("C"))];
181 let g = AdjacencyGraph::from_triples(&triples);
182 let a = g.entity_index("A").expect("A present");
183 let neighbors = g.neighbors(a);
184 assert_eq!(neighbors.len(), 2);
185 let b = g.entity_index("B").expect("B present");
186 let c = g.entity_index("C").expect("C present");
187 assert!(neighbors.contains(&b));
188 assert!(neighbors.contains(&c));
189 }
190
191 #[test]
192 fn test_round_trip_entity_names() {
193 let triples = vec![(s("X"), s("p"), s("Y"))];
194 let g = AdjacencyGraph::from_triples(&triples);
195 for (name, &idx) in &g.entity_to_idx {
196 assert_eq!(g.entity_name(idx), Some(name.as_str()));
197 assert_eq!(g.entity_index(name), Some(idx));
198 }
199 }
200}