use crate::Instance;
use crate::lb::{
LoadBalance, LoadBalanceError, RandomLoadBalance, RoundRobinLoadBalance,
WeightRandomLoadBalance, WeightRoundRobinLoadBalance,
};
use dashmap::DashMap;
use reqwest::{Client, Method, RequestBuilder, Url};
use std::time::Duration;
#[derive(Debug, Default)]
pub enum LoadBalanceStrategy {
RoundRobin,
WeightedRoundRobin,
#[default]
Random,
WeightedRandom,
}
impl LoadBalanceStrategy {
pub fn as_schema(&self) -> &str {
match self {
LoadBalanceStrategy::RoundRobin => "lb-rr",
LoadBalanceStrategy::WeightedRoundRobin => "lb-wrr",
LoadBalanceStrategy::Random => "lb-r",
LoadBalanceStrategy::WeightedRandom => "lb-wr",
}
}
}
pub struct LoadBalanceClient {
client: Client,
strategies: DashMap<String, LoadBalanceStrategy>,
random_lb: RandomLoadBalance,
weight_random_lb: WeightRandomLoadBalance,
round_robin_lb: RoundRobinLoadBalance,
weight_round_robin_lb: WeightRoundRobinLoadBalance,
}
macro_rules! impl_parse_url {
($self:expr, $scheme:expr, $strategy:expr, $url:expr, $parsed_url:expr) => {{
let service_id = $parsed_url.host_str().unwrap();
let instance = $self.get_instance(service_id, $strategy).await?;
let res = $url.replace(
&format!("{}://{}", $scheme, service_id),
&format!(
"{}{}:{}",
LoadBalanceClient::HTTP_PREFIX,
instance.ip,
instance.port
),
);
Ok(res)
}};
}
impl LoadBalanceClient {
pub fn new() -> Self {
Self::new_with_connect_timeout(Duration::from_secs(5))
}
pub fn new_with_connect_timeout(timeout: Duration) -> Self {
let client = Client::builder()
.connect_timeout(timeout)
.build()
.expect("Failed to build HTTP client");
Self {
client,
strategies: Default::default(),
random_lb: RandomLoadBalance,
weight_random_lb: WeightRandomLoadBalance::default(),
round_robin_lb: RoundRobinLoadBalance::default(),
weight_round_robin_lb: WeightRoundRobinLoadBalance::default(),
}
}
pub fn set_strategy(&mut self, service_id: impl Into<String>, strategy: LoadBalanceStrategy) {
self.strategies.insert(service_id.into(), strategy);
}
async fn get_instance(
&self,
service_id: &str,
specify_strategy: Option<LoadBalanceStrategy>,
) -> Result<Instance, LoadBalanceError> {
if let Some(strategy) = specify_strategy {
return self.get_instance_(service_id, &strategy).await;
}
if let Some(strategy) = self.strategies.get(service_id) {
return self.get_instance_(service_id, &strategy).await;
}
let default_strategy = LoadBalanceStrategy::default();
let result = self.get_instance_(service_id, &default_strategy).await;
self.strategies
.insert(service_id.to_string(), default_strategy);
result
}
async fn get_instance_(
&self,
service_id: &str,
strategy: &LoadBalanceStrategy,
) -> Result<Instance, LoadBalanceError> {
match strategy {
LoadBalanceStrategy::Random => self.random_lb.get_instance(service_id).await,
LoadBalanceStrategy::WeightedRandom => {
self.weight_random_lb.get_instance(service_id).await
}
LoadBalanceStrategy::RoundRobin => self.round_robin_lb.get_instance(service_id).await,
LoadBalanceStrategy::WeightedRoundRobin => {
self.weight_round_robin_lb.get_instance(service_id).await
}
}
}
const HTTP_PREFIX: &'static str = "http://";
async fn parse_url(&self, url: &str) -> Result<String, LoadBalanceError> {
let parsed_url = Url::parse(url).unwrap();
let scheme = parsed_url.scheme();
match scheme {
"lb" => {
impl_parse_url!(self, "lb", None, url, parsed_url)
}
"lb-r" => impl_parse_url!(
self,
"lb-r",
Some(LoadBalanceStrategy::Random),
url,
parsed_url
),
"lb-wr" => impl_parse_url!(
self,
"lb-wr",
Some(LoadBalanceStrategy::WeightedRandom),
url,
parsed_url
),
"lb-rr" => impl_parse_url!(
self,
"lb-rr",
Some(LoadBalanceStrategy::RoundRobin),
url,
parsed_url
),
"lb-wrr" => impl_parse_url!(
self,
"lb-wrr",
Some(LoadBalanceStrategy::WeightedRoundRobin),
url,
parsed_url
),
_ => Ok(url.to_string()),
}
}
pub async fn get(&self, url: &str) -> Result<RequestBuilder, LoadBalanceError> {
Ok(self.client.get(self.parse_url(url).await?))
}
pub async fn post(&self, url: &str) -> Result<RequestBuilder, LoadBalanceError> {
Ok(self.client.post(self.parse_url(url).await?))
}
pub async fn put(&self, url: &str) -> Result<RequestBuilder, LoadBalanceError> {
Ok(self.client.put(self.parse_url(url).await?))
}
pub async fn delete(&self, url: &str) -> Result<RequestBuilder, LoadBalanceError> {
Ok(self.client.delete(self.parse_url(url).await?))
}
pub async fn patch(&self, url: &str) -> Result<RequestBuilder, LoadBalanceError> {
Ok(self.client.patch(self.parse_url(url).await?))
}
pub async fn head(&self, url: &str) -> Result<RequestBuilder, LoadBalanceError> {
Ok(self.client.head(self.parse_url(url).await?))
}
pub async fn request(
&self,
method: Method,
url: &str,
) -> Result<RequestBuilder, LoadBalanceError> {
Ok(self.client.request(method, self.parse_url(url).await?))
}
pub fn get_client(&self) -> &Client {
&self.client
}
}
impl Default for LoadBalanceClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::conf::{ClientConfigBuilder, ConRegConfigBuilder, DiscoveryConfigBuilder};
use crate::init_with;
#[tokio::test]
async fn test_load_balance_client() {
let _ = init_client().await;
let mut client = LoadBalanceClient::new();
client.set_strategy("test", LoadBalanceStrategy::WeightedRandom);
client.set_strategy("test", LoadBalanceStrategy::RoundRobin);
let response = client
.get("lb://test-server/hello")
.await
.unwrap()
.send()
.await;
println!("Response: {:?}", response.unwrap().text().await.unwrap());
}
async fn init_client() {
let config = ConRegConfigBuilder::default()
.client(ClientConfigBuilder::default().port(8001).build().unwrap())
.discovery(
DiscoveryConfigBuilder::default()
.server_addr("127.0.0.1:8000")
.build()
.unwrap(),
)
.build()
.unwrap();
init_with(config).await;
}
}