use crate::generated::quilkin::filters::load_balancer::v1alpha1 as proto;
mod config;
mod endpoint_chooser;
use crate::filters::prelude::*;
use endpoint_chooser::EndpointChooser;
pub use config::{Config, Policy};
pub struct LoadBalancer {
endpoint_chooser: Box<dyn EndpointChooser>,
}
impl LoadBalancer {
fn new(config: Config) -> Self {
Self {
endpoint_chooser: config.policy.as_endpoint_chooser(),
}
}
pub fn testing(config: Config) -> Self {
Self::new(config)
}
}
impl Filter for LoadBalancer {
fn read<P: PacketMut>(&self, ctx: &mut ReadContext<'_, P>) -> Result<(), FilterError> {
self.endpoint_chooser
.choose_endpoints(ctx.destinations, ctx.endpoints, &ctx.source);
Ok(())
}
}
impl StaticFilter for LoadBalancer {
const NAME: &'static str = "quilkin.filters.load_balancer.v1alpha1.LoadBalancer";
type Configuration = Config;
type BinaryConfiguration = proto::LoadBalancer;
fn try_from_config(config: Option<Self::Configuration>) -> Result<Self, CreationError> {
Ok(LoadBalancer::new(Self::ensure_config_exists(config)?))
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashSet, net::Ipv4Addr};
use super::*;
use crate::{
net::endpoint::{Endpoint, EndpointAddress},
test::alloc_buffer,
};
fn get_response_addresses(
filter: &LoadBalancer,
input_addresses: &[EndpointAddress],
source: EndpointAddress,
) -> Vec<EndpointAddress> {
let endpoints = input_addresses
.iter()
.cloned()
.map(Endpoint::new)
.collect::<std::collections::BTreeSet<_>>();
let endpoints = crate::net::cluster::ClusterMap::new_default(endpoints);
let mut dest = Vec::new();
{
let mut context = ReadContext::new(&endpoints, source, alloc_buffer([]), &mut dest);
filter.read(&mut context).unwrap();
}
dest
}
#[tokio::test]
async fn round_robin_load_balancer_policy() {
let addresses: Vec<EndpointAddress> = vec![
([127, 0, 0, 1], 8080).into(),
([127, 0, 0, 2], 8080).into(),
([127, 0, 0, 3], 8080).into(),
];
let yaml = "policy: ROUND_ROBIN";
let filter = LoadBalancer::from_config(serde_yaml::from_str(yaml).unwrap());
let expected_sequence = addresses
.iter()
.map(|addr| vec![addr.clone()])
.collect::<Vec<_>>();
for _ in 0..10 {
assert_eq!(expected_sequence, {
let mut responses = Vec::new();
for _ in 0..addresses.len() {
responses.push(get_response_addresses(
&filter,
&addresses,
"127.0.0.1:8080".parse().unwrap(),
));
}
responses
});
}
}
#[tokio::test]
async fn random_load_balancer_policy() {
let addresses = vec![
"127.0.0.1:8080".parse().unwrap(),
"127.0.0.2:8080".parse().unwrap(),
"127.0.0.3:8080".parse().unwrap(),
];
let yaml = "
policy: RANDOM
";
let filter = LoadBalancer::from_config(serde_yaml::from_str(yaml).unwrap());
let mut result_sequences = vec![];
for _ in 0..10 {
for _ in 0..addresses.len() {
result_sequences.push(get_response_addresses(
&filter,
&addresses,
"127.0.0.1:8080".parse().unwrap(),
));
}
}
assert_eq!(
addresses.into_iter().collect::<HashSet<_>>(),
result_sequences
.clone()
.into_iter()
.flatten()
.collect::<HashSet<_>>(),
);
assert!(
&result_sequences[1..]
.iter()
.any(|seq| seq != &result_sequences[0]),
"the same sequence of addresses were chosen for random load balancer"
);
}
#[tokio::test]
async fn hash_load_balancer_policy() {
let addresses: Vec<EndpointAddress> = vec![
([127, 0, 0, 1], 8080).into(),
([127, 0, 0, 2], 8080).into(),
([127, 0, 0, 3], 8080).into(),
];
let source_ips = vec![[127u8, 1, 1, 1], [127, 2, 2, 2], [127, 3, 3, 3]];
let source_ports = [11111u16, 22222, 33333, 44444, 55555];
let yaml = "policy: HASH";
let filter = LoadBalancer::from_config(serde_yaml::from_str(yaml).unwrap());
let mut result_sequences = vec![];
for _ in 0..10 {
for _ in 0..addresses.len() {
result_sequences.push(get_response_addresses(
&filter,
&addresses,
(Ipv4Addr::LOCALHOST, 8080).into(),
));
}
}
assert_eq!(
1,
result_sequences
.into_iter()
.flatten()
.collect::<HashSet<_>>()
.len(),
);
let mut result_sequences = vec![];
for port in source_ports.iter().copied() {
result_sequences.push(vec![
get_response_addresses(
&filter,
&addresses,
(Ipv4Addr::LOCALHOST, port).into()
);
addresses.len()
]);
}
assert_ne!(
1,
result_sequences
.into_iter()
.flatten()
.flatten()
.collect::<HashSet<_>>()
.len(),
);
let mut result_sequences = vec![];
for ip in source_ips {
for port in source_ports.iter().copied() {
result_sequences.push(vec![
get_response_addresses(
&filter,
&addresses,
(ip, port).into()
);
addresses.len()
]);
}
}
assert_eq!(
addresses.into_iter().collect::<HashSet<_>>(),
result_sequences
.clone()
.into_iter()
.flatten()
.flatten()
.collect::<HashSet<_>>(),
);
assert!(
&result_sequences[1..]
.iter()
.any(|seq| seq != &result_sequences[0]),
"the same sequence of addresses were chosen for hash load balancer"
);
}
}