use std::collections::HashMap;
use crate::{
network::Network,
ospf::{LinkWeight, OspfImpl},
types::{Prefix, RouterId, ASN},
};
#[cfg(feature = "rand")]
use rand::{distributions::Uniform, prelude::*, rngs::StdRng};
pub trait WeightSampler {
fn sample<P: Prefix, Q, Ospf: OspfImpl>(
&mut self,
net: &Network<P, Q, Ospf>,
asn: ASN,
src: RouterId,
dst: RouterId,
) -> LinkWeight;
}
impl WeightSampler for LinkWeight {
fn sample<P: Prefix, Q, Ospf: OspfImpl>(
&mut self,
_net: &Network<P, Q, Ospf>,
_asn: ASN,
_src: RouterId,
_dst: RouterId,
) -> LinkWeight {
*self
}
}
#[derive(Clone, Debug)]
pub struct Lookup<S> {
lut: HashMap<(RouterId, RouterId), LinkWeight>,
default: S,
}
impl<S> Lookup<S> {
pub fn new(default: S) -> Self {
Self {
lut: Default::default(),
default,
}
}
pub fn from(
with: impl IntoIterator<Item = ((RouterId, RouterId), LinkWeight)>,
default: S,
) -> Self {
Self {
lut: with.into_iter().collect(),
default,
}
}
pub fn with(self, from: RouterId, to: RouterId, weight: LinkWeight) -> Self {
let Self { mut lut, default } = self;
lut.insert((from, to), weight);
Self { lut, default }
}
pub fn with_bidirectional(self, from: RouterId, to: RouterId, weight: LinkWeight) -> Self {
let Self { mut lut, default } = self;
lut.insert((from, to), weight);
lut.insert((to, from), weight);
Self { lut, default }
}
}
impl<S: WeightSampler> WeightSampler for Lookup<S> {
fn sample<P: Prefix, Q, Ospf: OspfImpl>(
&mut self,
net: &Network<P, Q, Ospf>,
asn: ASN,
src: RouterId,
dst: RouterId,
) -> LinkWeight {
self.lut
.get(&(src, dst))
.copied()
.unwrap_or_else(|| self.default.sample(net, asn, src, dst))
}
}
#[derive(Debug, Clone)]
#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
pub struct UniformWeights<R> {
a: LinkWeight,
b: LinkWeight,
round: bool,
rng: R,
}
#[cfg(feature = "rand")]
impl<R> UniformWeights<R> {
pub fn round(self) -> Self {
Self {
round: true,
..self
}
}
}
#[cfg(feature = "rand")]
impl UniformWeights<ThreadRng> {
pub fn new(a: LinkWeight, b: LinkWeight) -> Self {
Self {
rng: thread_rng(),
a,
b,
round: false,
}
}
}
#[cfg(feature = "rand")]
impl UniformWeights<StdRng> {
pub fn seeded(seed: u64, a: LinkWeight, b: LinkWeight) -> Self {
Self {
rng: StdRng::seed_from_u64(seed),
a,
b,
round: false,
}
}
}
#[cfg(feature = "rand")]
impl<R> UniformWeights<R> {
pub fn from_rng(rng: R, a: LinkWeight, b: LinkWeight) -> Self {
Self {
rng,
a,
b,
round: false,
}
}
}
#[cfg(feature = "rand")]
impl<R: RngCore> WeightSampler for UniformWeights<R> {
fn sample<P: Prefix, Q, Ospf: OspfImpl>(
&mut self,
_net: &Network<P, Q, Ospf>,
_asn: ASN,
_src: RouterId,
_dst: RouterId,
) -> LinkWeight {
let dist = Uniform::from(self.a..self.b);
let x = dist.sample(&mut self.rng);
if self.round {
x.round()
} else {
x
}
}
}