use crate::AletheiaDB;
use crate::core::error::{Result, StorageError};
use crate::core::hasher::IdentityHasher;
use crate::core::id::NodeId;
use rand::prelude::*;
use std::collections::HashMap;
use std::hash::BuildHasherDefault;
pub struct Oracle<'a> {
db: &'a AletheiaDB,
}
impl<'a> Oracle<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self { db }
}
pub fn personalized_page_rank(
&self,
seed: NodeId,
alpha: f32,
num_walks: usize,
max_steps: usize,
) -> Result<HashMap<NodeId, f32>> {
let mut visits: HashMap<NodeId, usize, BuildHasherDefault<IdentityHasher>> =
HashMap::default();
let mut rng = rand::thread_rng();
if self.db.get_node(seed).is_err() {
return Err(crate::core::error::Error::Storage(
StorageError::NodeNotFound(seed),
));
}
for _ in 0..num_walks {
let mut current = seed;
*visits.entry(current).or_default() += 1;
for _step in 0..max_steps {
if rng.r#gen::<f32>() < alpha {
break;
}
let edges = self.db.get_outgoing_edges(current);
if edges.is_empty() {
break; }
let edge_idx = rng.gen_range(0..edges.len());
let edge_id = edges[edge_idx];
match self.db.get_edge_target(edge_id) {
Ok(target) => {
current = target;
*visits.entry(current).or_default() += 1;
}
Err(_) => break,
}
}
}
let total_visits: usize = visits.values().sum();
let mut scores = HashMap::new();
if total_visits > 0 {
for (node, count) in visits {
scores.insert(node, count as f32 / total_visits as f32);
}
} else {
scores.insert(seed, 1.0);
}
Ok(scores)
}
pub fn reachability_probability(
&self,
start: NodeId,
target: NodeId,
max_steps: usize,
num_simulations: usize,
) -> Result<f32> {
if self.db.get_node(start).is_err() {
return Err(crate::core::error::Error::Storage(
StorageError::NodeNotFound(start),
));
}
if self.db.get_node(target).is_err() {
return Err(crate::core::error::Error::Storage(
StorageError::NodeNotFound(target),
));
}
if start == target {
return Ok(1.0);
}
let mut hits = 0;
let mut rng = rand::thread_rng();
for _ in 0..num_simulations {
let mut current = start;
let mut found = false;
for _ in 0..max_steps {
let edges = self.db.get_outgoing_edges(current);
if edges.is_empty() {
break;
}
let edge_idx = rng.gen_range(0..edges.len());
let edge_id = edges[edge_idx];
match self.db.get_edge_target(edge_id) {
Ok(next) => {
if next == target {
found = true;
break;
}
current = next;
}
Err(_) => break,
}
}
if found {
hits += 1;
}
}
Ok(hits as f32 / num_simulations as f32)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
#[test]
fn test_oracle_ppr_sink() {
let db = AletheiaDB::new().unwrap();
let props = PropertyMapBuilder::new().build();
let a = db.create_node("Node", props.clone()).unwrap();
let b = db.create_node("Node", props.clone()).unwrap();
db.create_edge(a, b, "NEXT", props.clone()).unwrap();
let oracle = Oracle::new(&db);
let ppr = oracle.personalized_page_rank(a, 0.5, 1000, 10).unwrap();
let score_a = *ppr.get(&a).unwrap();
let score_b = *ppr.get(&b).unwrap();
assert!(
score_a > score_b,
"Start node should have higher probability in sink graph"
);
assert!(
score_a > 0.6 && score_a < 0.75,
"P(A) should be around 0.66 (got {})",
score_a
);
}
#[test]
fn test_oracle_reachability() {
let db = AletheiaDB::new().unwrap();
let props = PropertyMapBuilder::new().build();
let a = db.create_node("Node", props.clone()).unwrap();
let b = db.create_node("Node", props.clone()).unwrap();
let c = db.create_node("Node", props.clone()).unwrap();
db.create_edge(a, b, "NEXT", props.clone()).unwrap();
db.create_edge(b, c, "NEXT", props.clone()).unwrap();
let oracle = Oracle::new(&db);
let prob = oracle.reachability_probability(a, c, 5, 100).unwrap();
assert!(prob > 0.95, "Should reach C with high probability");
let prob_short = oracle.reachability_probability(a, c, 1, 100).unwrap();
assert!(prob_short < 0.01, "Should not reach C in 1 step");
}
#[test]
fn test_oracle_cycle() {
let db = AletheiaDB::new().unwrap();
let props = PropertyMapBuilder::new().build();
let a = db.create_node("Node", props.clone()).unwrap();
let b = db.create_node("Node", props.clone()).unwrap();
db.create_edge(a, b, "NEXT", props.clone()).unwrap();
db.create_edge(b, a, "NEXT", props.clone()).unwrap();
let oracle = Oracle::new(&db);
let ppr = oracle.personalized_page_rank(a, 0.1, 2000, 100).unwrap();
let score_a = *ppr.get(&a).unwrap();
let score_b = *ppr.get(&b).unwrap();
assert!(
(score_a - score_b).abs() < 0.1,
"Scores should be roughly equal in symmetric cycle (A: {}, B: {})",
score_a,
score_b
);
}
}