use std::{
net::IpAddr,
sync::atomic::{AtomicUsize, Ordering},
};
use tokio::sync::RwLock;
use tonic::transport::Channel;
#[derive(Debug, Default)]
pub(crate) struct ReadyChannels {
channels: RwLock<Vec<(IpAddr, Channel)>>,
index: AtomicUsize,
}
impl ReadyChannels {
pub(crate) async fn find(&self, ip: IpAddr) -> Option<Channel> {
self.channels
.read()
.await
.iter()
.find_map(|(addr, channel)| {
if *addr == ip {
Some(channel.clone())
} else {
None
}
})
}
pub(crate) async fn get_any(&self) -> Option<(IpAddr, Channel)> {
let read_access = self.channels.read().await;
if read_access.is_empty() {
return None;
}
let index = self.index.fetch_add(1, Ordering::Relaxed);
Some(read_access[index % read_access.len()].clone())
}
pub(crate) async fn add(&self, ip: IpAddr, channel: Channel) {
self.channels.write().await.push((ip, channel));
}
pub(crate) async fn remove(&self, ip: IpAddr) {
let mut write_access = self.channels.write().await;
if let Some(index) = write_access.iter().position(|(addr, _)| *addr == ip) {
write_access.swap_remove(index);
}
}
pub(crate) async fn replace_with(&self, new: Vec<(IpAddr, Channel)>) {
*self.channels.write().await = new;
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use tonic::transport::Endpoint;
use super::*;
const LOCALHOST_V4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
const LOCALHOST_V6: IpAddr = IpAddr::V6(Ipv6Addr::LOCALHOST);
fn default_channel() -> Channel {
Endpoint::from_static("http://localhost:8080").connect_lazy()
}
#[tokio::test]
async fn find() {
let ready_channels = ReadyChannels::default();
assert!(ready_channels.find(LOCALHOST_V4).await.is_none());
ready_channels.add(LOCALHOST_V6, default_channel()).await;
assert!(ready_channels.find(LOCALHOST_V4).await.is_none());
ready_channels.add(LOCALHOST_V4, default_channel()).await;
assert!(ready_channels.find(LOCALHOST_V4).await.is_some());
}
#[tokio::test]
async fn get_any() {
let ready_channels = ReadyChannels::default();
assert!(ready_channels.get_any().await.is_none());
for i in 0..128 {
ready_channels
.add(Ipv4Addr::new(127, 0, 0, i).into(), default_channel())
.await;
}
let mut found = vec![];
for _ in 0..10 {
if let Some((ip, _)) = ready_channels.get_any().await {
found.push(ip);
} else {
panic!("No channels found");
}
}
found.sort();
found.dedup();
assert!(found.len() > 1);
}
#[tokio::test]
async fn add() {
let ready_channels = ReadyChannels::default();
assert!(ready_channels.channels.read().await.is_empty());
ready_channels.add(LOCALHOST_V4, default_channel()).await;
assert_eq!(ready_channels.channels.read().await.len(), 1);
ready_channels.add(LOCALHOST_V6, default_channel()).await;
assert_eq!(ready_channels.channels.read().await.len(), 2);
}
#[tokio::test]
async fn remove() {
let ready_channels = ReadyChannels::default();
ready_channels.add(LOCALHOST_V4, default_channel()).await;
ready_channels.add(LOCALHOST_V6, default_channel()).await;
assert_eq!(ready_channels.channels.read().await.len(), 2);
ready_channels.remove([127, 0, 0, 2].into()).await;
assert_eq!(ready_channels.channels.read().await.len(), 2);
ready_channels.remove(LOCALHOST_V4).await;
assert_eq!(ready_channels.channels.read().await.len(), 1);
assert!(ready_channels.find(LOCALHOST_V4).await.is_none());
ready_channels.remove(LOCALHOST_V6).await;
assert!(ready_channels.channels.read().await.is_empty());
assert!(ready_channels.find(LOCALHOST_V6).await.is_none());
}
#[tokio::test]
async fn replace_with() {
let ready_channels = ReadyChannels::default();
ready_channels.add(LOCALHOST_V4, default_channel()).await;
ready_channels.add(LOCALHOST_V6, default_channel()).await;
assert_eq!(ready_channels.channels.read().await.len(), 2);
let new = vec![
([127, 0, 0, 2].into(), default_channel()),
([0, 0, 0, 0, 0, 0, 0, 2].into(), default_channel()),
];
ready_channels.replace_with(new.clone()).await;
{
let guard = ready_channels.channels.read().await;
assert_eq!(
guard.iter().map(|(ip, _)| *ip).collect::<Vec<_>>(),
new.iter().map(|(ip, _)| *ip).collect::<Vec<_>>(),
);
}
}
}