use std::ops::{Index, IndexMut};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::index::IndexBase;
use crate::{NodeIndex, PortIndex, UnmanagedDenseMap};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Weights<N, P, NI: IndexBase = u32, PI: IndexBase = u32> {
pub nodes: UnmanagedDenseMap<NodeIndex<NI>, N>,
pub ports: UnmanagedDenseMap<PortIndex<PI>, P>,
}
impl<N, P, NI: IndexBase, PI: IndexBase> Weights<N, P, NI, PI>
where
N: Clone + Default,
P: Clone + Default,
{
#[inline]
pub fn new() -> Self {
Self {
nodes: UnmanagedDenseMap::new(),
ports: UnmanagedDenseMap::new(),
}
}
#[inline]
pub fn with_capacity(nodes: usize, ports: usize) -> Self {
Self {
nodes: UnmanagedDenseMap::with_capacity(nodes),
ports: UnmanagedDenseMap::with_capacity(ports),
}
}
}
impl<N, P, NI: IndexBase, PI: IndexBase> Default for Weights<N, P, NI, PI>
where
N: Clone + Default,
P: Clone + Default,
{
#[inline]
fn default() -> Self {
Self {
nodes: UnmanagedDenseMap::new(),
ports: UnmanagedDenseMap::new(),
}
}
}
impl<N, P, NI: IndexBase, PI: IndexBase> Index<NodeIndex<NI>> for Weights<N, P, NI, PI>
where
N: Clone,
P: Clone,
{
type Output = N;
fn index(&self, key: NodeIndex<NI>) -> &Self::Output {
&self.nodes[key]
}
}
impl<N, P, NI: IndexBase, PI: IndexBase> IndexMut<NodeIndex<NI>> for Weights<N, P, NI, PI>
where
N: Clone,
P: Clone,
{
fn index_mut(&mut self, key: NodeIndex<NI>) -> &mut Self::Output {
&mut self.nodes[key]
}
}
impl<N, P, NI: IndexBase, PI: IndexBase> Index<PortIndex<PI>> for Weights<N, P, NI, PI>
where
N: Clone,
P: Clone,
{
type Output = P;
fn index(&self, key: PortIndex<PI>) -> &Self::Output {
&self.ports[key]
}
}
impl<N, P, NI: IndexBase, PI: IndexBase> IndexMut<PortIndex<PI>> for Weights<N, P, NI, PI>
where
N: Clone,
P: Clone,
{
fn index_mut(&mut self, key: PortIndex<PI>) -> &mut Self::Output {
&mut self.ports[key]
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_weights() {
let mut weights = Weights::<usize, isize, u32, u32>::new();
let node = NodeIndex::<u32>::new(0);
let port = PortIndex::<u32>::new(0);
assert_eq!(weights[node], 0);
assert_eq!(weights[port], 0);
weights[node] = 42;
weights[port] = -1;
assert_eq!(weights[node], 42);
assert_eq!(weights[port], -1);
}
}