use std::{
borrow::Cow,
collections::{HashMap, HashSet},
convert::Infallible,
future::Future,
hash::Hash,
net::SocketAddr,
sync::Arc,
};
use async_broadcast::Receiver;
use crate::{context::Endpoint, loadbalance::error::LoadBalanceError, net::Address};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Instance {
pub address: Address,
pub weight: u32,
pub tags: HashMap<Cow<'static, str>, Cow<'static, str>>,
}
pub trait Discover: Send + Sync + 'static {
type Key: Hash + PartialEq + Eq + Send + Sync + Clone + 'static;
type Error: Into<LoadBalanceError>;
fn discover<'s>(
&'s self,
endpoint: &'s Endpoint,
) -> impl Future<Output = Result<Vec<Arc<Instance>>, Self::Error>> + Send;
fn key(&self, endpoint: &Endpoint) -> Self::Key;
fn watch(&self, keys: Option<&[Self::Key]>) -> Option<Receiver<Change<Self::Key>>>;
}
#[derive(Debug, Clone)]
pub struct Change<K> {
pub key: K,
pub all: Vec<Arc<Instance>>,
pub added: Vec<Arc<Instance>>,
pub updated: Vec<Arc<Instance>>,
pub removed: Vec<Arc<Instance>>,
}
pub fn diff_address<K>(
key: K,
prev: Vec<Arc<Instance>>,
next: Vec<Arc<Instance>>,
) -> (Change<K>, bool)
where
K: Hash + PartialEq + Eq + Send + Sync + 'static,
{
let mut added = Vec::new();
let updated = Vec::new();
let mut removed = Vec::new();
let mut prev_set = HashSet::with_capacity(prev.len());
let mut next_set = HashSet::with_capacity(next.len());
for i in &prev {
prev_set.insert(i.address.clone());
}
for i in &next {
next_set.insert(i.address.clone());
}
for i in &next {
if !prev_set.contains(&i.address) {
added.push(i.clone());
}
}
for i in &prev {
if !next_set.contains(&i.address) {
removed.push(i.clone());
}
}
let changed = !added.is_empty() || !removed.is_empty();
(
Change {
key,
all: next,
added,
updated,
removed,
},
changed,
)
}
#[derive(Clone)]
pub struct StaticDiscover {
instances: Vec<Arc<Instance>>,
}
impl StaticDiscover {
pub fn new(instances: Vec<Arc<Instance>>) -> Self {
Self { instances }
}
}
impl From<Vec<SocketAddr>> for StaticDiscover {
fn from(addrs: Vec<SocketAddr>) -> Self {
let instances = addrs
.into_iter()
.map(|addr| {
Arc::new(Instance {
address: Address::Ip(addr),
weight: 1,
tags: Default::default(),
})
})
.collect();
Self { instances }
}
}
impl Discover for StaticDiscover {
type Key = ();
type Error = Infallible;
async fn discover<'s>(&'s self, _: &'s Endpoint) -> Result<Vec<Arc<Instance>>, Self::Error> {
Ok(self.instances.clone())
}
fn key(&self, _: &Endpoint) -> Self::Key {}
fn watch(&self, _keys: Option<&[Self::Key]>) -> Option<Receiver<Change<Self::Key>>> {
None
}
}
#[derive(Clone)]
pub struct WeightedStaticDiscover {
instances: Vec<Arc<Instance>>,
}
impl WeightedStaticDiscover {
pub fn new(instances: Vec<Arc<Instance>>) -> Self {
Self { instances }
}
}
impl From<Vec<(SocketAddr, u32)>> for WeightedStaticDiscover {
fn from(addrs: Vec<(SocketAddr, u32)>) -> Self {
let instances = addrs
.into_iter()
.map(|addr| {
Arc::new(Instance {
address: Address::Ip(addr.0),
weight: addr.1,
tags: Default::default(),
})
})
.collect();
Self { instances }
}
}
impl Discover for WeightedStaticDiscover {
type Key = ();
type Error = Infallible;
async fn discover<'s>(&'s self, _: &'s Endpoint) -> Result<Vec<Arc<Instance>>, Self::Error> {
Ok(self.instances.clone())
}
fn key(&self, _: &Endpoint) -> Self::Key {}
fn watch(&self, _keys: Option<&[Self::Key]>) -> Option<Receiver<Change<Self::Key>>> {
None
}
}
#[derive(Clone)]
pub struct DummyDiscover;
impl Discover for DummyDiscover {
type Key = ();
type Error = Infallible;
async fn discover<'s>(&'s self, _: &'s Endpoint) -> Result<Vec<Arc<Instance>>, Self::Error> {
Ok(vec![])
}
fn key(&self, _: &Endpoint) {}
fn watch(&self, _keys: Option<&[Self::Key]>) -> Option<Receiver<Change<Self::Key>>> {
None
}
}
impl From<Infallible> for LoadBalanceError {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::{Discover, Instance, StaticDiscover, WeightedStaticDiscover};
use crate::{context::Endpoint, net::Address};
#[test]
fn test_static_discover() {
let empty = Endpoint::new("".into());
let discover = StaticDiscover::from(vec![
"127.0.0.1:8000".parse().unwrap(),
"127.0.0.2:9000".parse().unwrap(),
]);
let resp = futures::executor::block_on(async { discover.discover(&empty).await }).unwrap();
let expected = vec![
Arc::new(Instance {
address: Address::Ip("127.0.0.1:8000".parse().unwrap()),
weight: 1,
tags: Default::default(),
}),
Arc::new(Instance {
address: Address::Ip("127.0.0.2:9000".parse().unwrap()),
weight: 1,
tags: Default::default(),
}),
];
assert_eq!(resp, expected);
}
#[test]
fn test_weighted_static_discover() {
let empty = Endpoint::new("".into());
let discover = WeightedStaticDiscover::from(vec![
("127.0.0.1:8000".parse().unwrap(), 2),
("127.0.0.2:9000".parse().unwrap(), 3),
("127.0.0.3:9000".parse().unwrap(), 4),
]);
let resp = futures::executor::block_on(async { discover.discover(&empty).await }).unwrap();
let expected = vec![
Arc::new(Instance {
address: Address::Ip("127.0.0.1:8000".parse().unwrap()),
weight: 2,
tags: Default::default(),
}),
Arc::new(Instance {
address: Address::Ip("127.0.0.2:9000".parse().unwrap()),
weight: 3,
tags: Default::default(),
}),
Arc::new(Instance {
address: Address::Ip("127.0.0.3:9000".parse().unwrap()),
weight: 4,
tags: Default::default(),
}),
];
assert_eq!(resp, expected);
}
}