oxios_kernel/memory/
graph.rs1use std::collections::HashMap;
8
9#[derive(Debug, Clone, Default)]
14pub struct MemoryGraph {
15 edges: HashMap<u64, Vec<u64>>,
17 incoming: HashMap<u64, Vec<u64>>,
19 node_count: usize,
21}
22
23impl MemoryGraph {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn add_edge(&mut self, from: u64, to: u64) {
33 if from == to {
34 return; }
36 self.edges.entry(from).or_default();
37 self.edges.entry(to).or_default();
38 self.incoming.entry(from).or_default();
39 self.incoming.entry(to).or_default();
40
41 let neighbors = self.edges.get_mut(&from).unwrap();
42 if !neighbors.contains(&to) {
43 neighbors.push(to);
44 self.incoming.get_mut(&to).unwrap().push(from);
45 }
46
47 self.node_count = self.edges.len();
48 }
49
50 pub fn link(&mut self, a: u64, b: u64) {
52 self.add_edge(a, b);
53 self.add_edge(b, a);
54 }
55
56 pub fn node_count(&self) -> usize {
58 self.node_count
59 }
60
61 pub fn neighbors(&self, node: u64) -> &[u64] {
63 self.edges.get(&node).map(|v| v.as_slice()).unwrap_or(&[])
64 }
65
66 pub fn pagerank(
79 &self,
80 damping: f64,
81 iterations: usize,
82 initial_scores: Option<&HashMap<u64, f64>>,
83 ) -> HashMap<u64, f64> {
84 if self.node_count == 0 {
85 return HashMap::new();
86 }
87
88 let n = self.node_count as f64;
89 let base = 1.0 / n;
90
91 let mut scores: HashMap<u64, f64> = self
93 .edges
94 .keys()
95 .map(|&k| {
96 let init = initial_scores
97 .and_then(|m| m.get(&k))
98 .copied()
99 .unwrap_or(base);
100 (k, init)
101 })
102 .collect();
103
104 let out_degree: HashMap<u64, usize> =
106 self.edges.iter().map(|(&k, v)| (k, v.len())).collect();
107
108 for _ in 0..iterations {
110 let mut new_scores = HashMap::with_capacity(self.node_count);
111
112 let sink_sum: f64 = scores
113 .iter()
114 .filter(|(&k, _)| out_degree.get(&k).copied().unwrap_or(0) == 0)
115 .map(|(_, &s)| s)
116 .sum();
117
118 for &node in self.edges.keys() {
119 let incoming_sum: f64 = self
121 .incoming
122 .get(&node)
123 .map(|neighbors| {
124 neighbors
125 .iter()
126 .map(|&src| {
127 let src_out = out_degree.get(&src).copied().unwrap_or(1) as f64;
128 scores.get(&src).copied().unwrap_or(0.0) / src_out
129 })
130 .sum()
131 })
132 .unwrap_or(0.0);
133
134 let rank = (1.0 - damping) / n + damping * (incoming_sum + sink_sum / n);
135 new_scores.insert(node, rank);
136 }
137
138 scores = new_scores;
139 }
140
141 scores
142 }
143
144 pub fn from_co_access(sessions: &[Vec<u64>]) -> Self {
150 let mut graph = Self::new();
151 for session in sessions {
152 for i in 0..session.len() {
154 for j in (i + 1)..session.len() {
155 graph.link(session[i], session[j]);
156 }
157 }
158 }
159 graph
160 }
161}
162
163#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_empty_graph() {
173 let graph = MemoryGraph::new();
174 let scores = graph.pagerank(0.85, 20, None);
175 assert!(scores.is_empty());
176 }
177
178 #[test]
179 fn test_single_node() {
180 let mut graph = MemoryGraph::new();
181 graph.add_edge(1, 1); let scores = graph.pagerank(0.85, 20, None);
183 assert!(scores.is_empty() || scores.values().all(|&v| v > 0.0));
184 }
185
186 #[test]
187 fn test_two_nodes() {
188 let mut graph = MemoryGraph::new();
189 graph.link(1, 2);
190
191 let scores = graph.pagerank(0.85, 50, None);
192 assert_eq!(scores.len(), 2);
193
194 let s1 = scores.get(&1).unwrap();
196 let s2 = scores.get(&2).unwrap();
197 assert!(
198 (s1 - s2).abs() < 0.01,
199 "Symmetric graph should have equal scores"
200 );
201 }
202
203 #[test]
204 fn test_hub_authority() {
205 let mut graph = MemoryGraph::new();
209 graph.add_edge(1, 2);
210 graph.add_edge(1, 3);
211 graph.add_edge(1, 4);
212 graph.add_edge(2, 1);
213 graph.add_edge(3, 1);
214 graph.add_edge(4, 1);
215
216 let scores = graph.pagerank(0.85, 50, None);
217 let s1 = scores.get(&1).unwrap();
218
219 for &node in &[2u64, 3, 4] {
221 let sn = scores.get(&node).unwrap();
222 assert!(*s1 >= *sn, "Hub node should have >= score than leaf");
223 }
224 }
225
226 #[test]
227 fn test_from_co_access() {
228 let sessions = vec![
229 vec![1, 2, 3], vec![2, 4], ];
232
233 let graph = MemoryGraph::from_co_access(&sessions);
234 assert_eq!(graph.node_count(), 4);
235
236 let scores = graph.pagerank(0.85, 50, None);
238 let s2 = scores.get(&2).unwrap();
239 for &node in &[1u64, 3, 4] {
240 let sn = scores.get(&node).unwrap();
241 assert!(*s2 >= *sn, "Node 2 should have highest score");
242 }
243 }
244
245 #[test]
246 fn test_initial_scores_influence() {
247 let mut graph = MemoryGraph::new();
249 graph.add_edge(1, 2);
250
251 let initial = HashMap::from([(1u64, 10.0), (2u64, 0.1)]);
253 let scores = graph.pagerank(0.85, 5, Some(&initial));
254
255 let s1 = scores.get(&1).unwrap();
257 let s2 = scores.get(&2).unwrap();
258 assert!(*s1 > 0.0, "Node 1 should have positive score");
260 assert!(*s2 > 0.0, "Node 2 should have positive score");
261 }
262}