use rand::Rng;
use rustsim_core::{
interaction::{PositionedAgent, SpaceInteraction},
space::Space,
types::{AgentId, NodeId},
};
use std::collections::{HashSet, VecDeque};
use thiserror::Error;
pub type GraphPos = NodeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
pub enum GraphSpaceError {
#[error("invalid graph node index {0}")]
InvalidNode(GraphPos),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NeighborType {
#[default]
Out,
In,
All,
}
#[derive(Debug, Clone)]
pub struct GraphSpace {
adj_out: Vec<Vec<GraphPos>>,
adj_in: Vec<Vec<GraphPos>>,
stored_ids: Vec<Vec<AgentId>>,
directed: bool,
}
impl GraphSpace {
pub fn new(n: usize) -> Self {
Self {
adj_out: vec![Vec::new(); n],
adj_in: vec![Vec::new(); n],
stored_ids: vec![Vec::new(); n],
directed: false,
}
}
pub fn new_directed(n: usize, directed: bool) -> Self {
Self {
adj_out: vec![Vec::new(); n],
adj_in: vec![Vec::new(); n],
stored_ids: vec![Vec::new(); n],
directed,
}
}
pub fn num_vertices(&self) -> usize {
self.adj_out.len()
}
pub fn num_edges(&self) -> usize {
let total: usize = self.adj_out.iter().map(|v| v.len()).sum();
if self.directed {
total
} else {
total / 2
}
}
pub fn is_directed(&self) -> bool {
self.directed
}
pub fn add_vertex(&mut self) -> GraphPos {
let idx = self.adj_out.len();
self.adj_out.push(Vec::new());
self.adj_in.push(Vec::new());
self.stored_ids.push(Vec::new());
idx
}
pub fn rem_vertex(&mut self, n: GraphPos) -> bool {
let nv = self.num_vertices();
if n >= nv {
return false;
}
let out_neighbors: Vec<GraphPos> = self.adj_out[n].clone();
for &neighbor in &out_neighbors {
self.rem_edge(n, neighbor);
}
let in_neighbors: Vec<GraphPos> = self.adj_in[n].clone();
for &neighbor in &in_neighbors {
self.rem_edge(neighbor, n);
}
let last = nv - 1;
if n != last {
self.adj_out.swap(n, last);
self.adj_in.swap(n, last);
self.stored_ids.swap(n, last);
for neighbors in &mut self.adj_out {
for pos in neighbors.iter_mut() {
if *pos == last {
*pos = n;
}
}
}
for neighbors in &mut self.adj_in {
for pos in neighbors.iter_mut() {
if *pos == last {
*pos = n;
}
}
}
}
self.adj_out.pop();
self.adj_in.pop();
self.stored_ids.pop();
true
}
pub fn add_edge(&mut self, a: GraphPos, b: GraphPos) -> bool {
let nv = self.num_vertices();
if a >= nv || b >= nv {
return false;
}
if self.adj_out[a].contains(&b) {
return false; }
self.adj_out[a].push(b);
self.adj_in[b].push(a);
if !self.directed {
self.adj_out[b].push(a);
self.adj_in[a].push(b);
}
true
}
pub fn rem_edge(&mut self, a: GraphPos, b: GraphPos) -> bool {
let nv = self.num_vertices();
if a >= nv || b >= nv {
return false;
}
let removed = remove_from_vec(&mut self.adj_out[a], b);
remove_from_vec(&mut self.adj_in[b], a);
if !self.directed {
remove_from_vec(&mut self.adj_out[b], a);
remove_from_vec(&mut self.adj_in[a], b);
}
removed
}
pub fn neighbors_out(&self, n: GraphPos) -> &[GraphPos] {
&self.adj_out[n]
}
pub fn neighbors_in(&self, n: GraphPos) -> &[GraphPos] {
&self.adj_in[n]
}
pub fn neighbors_all(&self, n: GraphPos) -> Vec<GraphPos> {
let mut set: HashSet<GraphPos> = HashSet::new();
set.extend(&self.adj_out[n]);
set.extend(&self.adj_in[n]);
set.into_iter().collect()
}
pub fn neighbors(&self, n: GraphPos, kind: NeighborType) -> Vec<GraphPos> {
match kind {
NeighborType::Out => self.adj_out[n].clone(),
NeighborType::In => self.adj_in[n].clone(),
NeighborType::All => self.neighbors_all(n),
}
}
pub fn ids_in_position(&self, n: GraphPos) -> &[AgentId] {
&self.stored_ids[n]
}
pub fn positions(&self) -> std::ops::Range<usize> {
0..self.num_vertices()
}
pub fn nearby_positions(&self, pos: GraphPos, r: usize, kind: NeighborType) -> Vec<GraphPos> {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
visited.insert(pos);
queue.push_back((pos, 0usize));
let mut result = Vec::new();
while let Some((node, dist)) = queue.pop_front() {
if dist > 0 {
result.push(node);
}
if dist < r {
match kind {
NeighborType::Out => {
for &neighbor in &self.adj_out[node] {
if visited.insert(neighbor) {
queue.push_back((neighbor, dist + 1));
}
}
}
NeighborType::In => {
for &neighbor in &self.adj_in[node] {
if visited.insert(neighbor) {
queue.push_back((neighbor, dist + 1));
}
}
}
NeighborType::All => {
for &neighbor in &self.adj_out[node] {
if visited.insert(neighbor) {
queue.push_back((neighbor, dist + 1));
}
}
for &neighbor in &self.adj_in[node] {
if visited.insert(neighbor) {
queue.push_back((neighbor, dist + 1));
}
}
}
}
}
}
result
}
pub fn nearby_agent_ids(&self, pos: GraphPos, r: usize, kind: NeighborType) -> Vec<AgentId> {
let mut ids = Vec::new();
ids.extend_from_slice(&self.stored_ids[pos]);
for neighbor in self.nearby_positions(pos, r, kind) {
ids.extend_from_slice(&self.stored_ids[neighbor]);
}
ids
}
}
fn remove_from_vec(v: &mut Vec<GraphPos>, val: GraphPos) -> bool {
if let Some(i) = v.iter().position(|&x| x == val) {
v.swap_remove(i);
true
} else {
false
}
}
impl Space for GraphSpace {}
impl<A> SpaceInteraction<A> for GraphSpace
where
A: PositionedAgent<Position = GraphPos>,
{
type Error = GraphSpaceError;
fn random_position<R: rand::RngCore>(&self, rng: &mut R) -> A::Position {
rng.gen_range(0..self.num_vertices())
}
fn add_agent(&mut self, agent: &A) -> Result<(), Self::Error> {
let pos = *agent.position();
if pos >= self.num_vertices() {
return Err(GraphSpaceError::InvalidNode(pos));
}
self.stored_ids[pos].push(agent.id());
Ok(())
}
fn remove_agent(&mut self, agent: &A) -> Result<(), Self::Error> {
let pos = *agent.position();
if pos >= self.num_vertices() {
return Err(GraphSpaceError::InvalidNode(pos));
}
if let Some(i) = self.stored_ids[pos].iter().position(|&id| id == agent.id()) {
self.stored_ids[pos].swap_remove(i);
}
Ok(())
}
fn nearby_ids(&self, position: &A::Position, radius: usize) -> Vec<AgentId> {
self.nearby_agent_ids(*position, radius, NeighborType::Out)
}
}