use std::collections::HashMap;
use crate::node::{NodeId, Weight};
use crate::walk::{WalkId, WalkIdGenerator};
#[derive(Clone)]
pub struct RandomWalk {
nodes: Vec<NodeId>,
walk_id: WalkId,
}
impl RandomWalk {
pub fn new() -> Self {
let walk_id = WalkIdGenerator::new().get_id();
RandomWalk {
nodes: Vec::new(),
walk_id,
}
}
pub fn from_nodes(nodes: Vec<NodeId>) -> Self {
let walk_id = WalkIdGenerator::new().get_id();
RandomWalk { nodes, walk_id }
}
pub fn _add_node(&mut self, node_id: NodeId) {
self.nodes.push(node_id);
}
pub fn get_nodes(&self) -> &[NodeId] {
&self.nodes
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn contains(&self, node_id: &NodeId) -> bool {
self.nodes.contains(node_id)
}
pub fn intersects_nodes(&self, nodes: &[NodeId]) -> bool {
nodes.iter().any(|&node| self.contains(&node))
}
pub fn _get_nodes_mut(&mut self) -> &mut Vec<NodeId> {
&mut self.nodes
}
pub fn first_node(&self) -> Option<NodeId> {
self.nodes.first().copied()
}
pub fn last_node(&self) -> Option<NodeId> {
self.nodes.last().copied()
}
pub fn get_walk_id(&self) -> WalkId {
self.walk_id
}
pub fn iter(&self) -> impl Iterator<Item = &NodeId> {
self.nodes.iter()
}
pub fn push(&mut self, node_id: NodeId) {
self.nodes.push(node_id);
}
pub fn extend(&mut self, new_segment: &[NodeId]) {
self.nodes.extend_from_slice(new_segment);
}
pub fn split_from(&mut self, pos: usize) -> RandomWalk {
let split_segment = self.nodes.split_off(pos);
RandomWalk::from_nodes(split_segment)
}
}
impl IntoIterator for RandomWalk {
type Item = NodeId;
type IntoIter = std::vec::IntoIter<NodeId>;
fn into_iter(self) -> Self::IntoIter {
self.nodes.into_iter()
}
}
impl RandomWalk {
pub fn calculate_penalties(
&self,
neg_weights: &HashMap<NodeId, Weight>,
) -> HashMap<NodeId, Weight> {
let mut penalties: HashMap<NodeId, Weight> = HashMap::new();
let mut negs = neg_weights.clone();
let mut accumulated_penalty = 0.0;
for &step in self.nodes.iter().rev() {
if let Some(penalty) = negs.remove(&step) {
accumulated_penalty += penalty;
}
if accumulated_penalty != 0.0 {
penalties.insert(step, accumulated_penalty);
}
}
penalties
}
}