1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
pub mod round_robin;

use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;

use tokio::io;
use tokio::net::TcpStream;

#[async_trait::async_trait]
trait LoadBalancingStrategy {
    async fn entry(&self, ip: IpAddr) -> SocketAddr;
}

#[derive(Debug, Clone, Copy, clap::ValueEnum)]
pub enum Strategy {
    RoundRobinSimple,
    RoundRobinIpHash,
}

impl std::fmt::Display for Strategy {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Strategy::RoundRobinSimple => write!(f, "round_robin"),
            Strategy::RoundRobinIpHash => write!(f, "round_robin_ip_hash"),
        }
    }
}

impl FromStr for Strategy {
    type Err = &'static str;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "round_robin" => Ok(Self::RoundRobinSimple),
            "round_robin_ip_hash" => Ok(Self::RoundRobinIpHash),
            _ => Err("invalid strategy"),
        }
    }
}

type DynLoadBalancer = Arc<dyn LoadBalancingStrategy + Send + Sync>;

pub struct LoadBalancer {
    strategy: DynLoadBalancer,
}

impl LoadBalancer {
    pub fn new(strategy: Strategy, targets: impl AsRef<[SocketAddr]>) -> Self {
        let strategy: DynLoadBalancer = match strategy {
            Strategy::RoundRobinSimple => Arc::new(round_robin::RoundRobinSimple::new(targets)),
            Strategy::RoundRobinIpHash => Arc::new(round_robin::RoundRobinIpHash::new(targets)),
        };
        Self { strategy }
    }
    async fn entry(&self, socket: SocketAddr) -> SocketAddr {
        self.strategy.entry(socket.ip()).await
    }
    async fn connect(&self, socket: SocketAddr) -> io::Result<TcpStream> {
        TcpStream::connect(self.entry(socket).await).await
    }
    pub async fn bridge(&self, mut tcp: TcpStream, socket: SocketAddr) -> io::Result<()> {
        let mut target_tcp = self.connect(socket).await?;
        io::copy_bidirectional(&mut target_tcp, &mut tcp).await?;
        Ok(())
    }
}

impl Clone for LoadBalancer {
    fn clone(&self) -> Self {
        Self {
            strategy: Arc::clone(&self.strategy),
        }
    }
}