use crate::{
core::state::compute_state::ComputeStateVec,
db::{
api::{
state::{GenericNodeState, TypedNodeState},
view::{graph::GraphViewOps, NodeViewOps, StaticGraphViewOps},
},
task::{
context::Context,
node::eval_node::EvalNodeView,
task::{ATask, Job, Step},
task_runner::TaskRunner,
},
},
};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct FastRPState {
pub embedding_state: Vec<f64>,
}
pub fn fast_rp<G>(
g: &G,
embedding_dim: usize,
normalization_strength: f64,
iter_weights: Vec<f64>,
seed: Option<u64>,
threads: Option<usize>,
) -> TypedNodeState<'static, FastRPState, G>
where
G: StaticGraphViewOps,
{
let ctx: Context<G, ComputeStateVec> = g.into();
let m = g.count_nodes() as f64;
let s = m.sqrt();
let beta = normalization_strength - 1.0;
let num_iters = iter_weights.len() - 1;
let weights = Arc::new(iter_weights);
let seed = seed.unwrap_or(rand::thread_rng().gen());
let step1 = {
let weights = Arc::clone(&weights);
ATask::new(move |vv| {
let l = ((vv.degree() as f64) / (m * 2.0)).powf(beta);
let choices = [
(l * s.sqrt(), 1.0 / (s * 2.0)),
(0.0, 1.0 - (1.0 / s)),
(-l * s.sqrt(), 1.0 / (s * 2.0)),
];
let mut rng = SmallRng::seed_from_u64(vv.node.0 as u64 ^ seed);
let state: &mut FastRPState = vv.get_mut();
state.embedding_state = (0..embedding_dim)
.map(|_| choices.choose_weighted(&mut rng, |item| item.1).unwrap().0 * weights[0])
.collect();
Step::Continue
})
};
let step2 = ATask::new(move |vv: &mut EvalNodeView<_, FastRPState>| {
let weights = Arc::clone(&weights);
let denom: f64 =
weights[vv.graph().ss] / ((vv.neighbours().iter().count() * (num_iters + 1)) as f64);
for neighbour in vv.neighbours() {
for i in 0..embedding_dim {
vv.get_mut().embedding_state[i] += neighbour.prev().embedding_state[i] * denom;
}
}
Step::Continue
});
let mut runner: TaskRunner<G, _> = TaskRunner::new(ctx);
runner.run(
vec![Job::new(step1)],
vec![Job::read_only(step2)],
None,
|_, _, _, local: Vec<FastRPState>| {
TypedNodeState::new(GenericNodeState::new_from_eval(g.clone(), local, None))
},
threads,
num_iters,
None,
None,
)
}