net_pool/
pool.rs

1use crate::backend::{Address, BackendState};
2use crate::error::Error;
3use crate::strategy::LbStrategy;
4use std::sync::Arc;
5use std::sync::atomic::Ordering::Relaxed;
6use std::sync::atomic::{AtomicU64, AtomicUsize};
7use std::time::Duration;
8
9pub trait Pool: Send + Sync {
10    /// 设置最大连接数
11    fn set_max_conn(&self, max: Option<usize>);
12
13    /// 获取最大连接数
14    fn get_max_conn(&self) -> Option<usize>;
15
16    /// 获取当前的连接数
17    fn get_cur_conn(&self) -> usize;
18
19    /// 设置空闲连接保留时长
20    fn set_keepalive(&self, _: Option<Duration>) {}
21
22    /// 获取空闲连接保留时长
23    fn get_keepalive(&self) -> Option<Duration> {
24        None
25    }
26
27    /// 获取转发策略
28    fn get_strategy(&self) -> Arc<dyn LbStrategy>;
29
30    /// 添加一个后端地址
31    fn add_backend(&self, addr: Address) {
32        self.get_strategy().add_backend(addr)
33    }
34
35    /// 移除一个后端地址
36    fn remove_backend(&self, addr: &Address) -> bool {
37        self.get_strategy().remove_backend(addr)
38    }
39
40    /// 获取所有后端地址切片
41    fn get_backends(&self) -> Vec<BackendState> {
42        self.get_strategy().get_backends()
43    }
44
45    // 是否使用tls
46    fn use_tls(&self, _: bool) {}
47
48    fn tls(&self) -> bool {
49        false
50    }
51}
52
53pub struct BaseState {
54    pub max_conn: AtomicUsize, // usize::MAX表示无设置
55    pub cur_conn: AtomicUsize,
56    pub keepalive: AtomicU64, // 精度只到秒级别, u64::MAX表示无设置
57    pub lb_strategy: Arc<dyn LbStrategy>,
58}
59
60impl BaseState {
61    pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
62        BaseState {
63            max_conn: AtomicUsize::new(usize::MAX),
64            cur_conn: AtomicUsize::new(0),
65            keepalive: AtomicU64::new(u64::MAX),
66            lb_strategy: strategy,
67        }
68    }
69}
70
71pub fn increase_current(max: &AtomicUsize, cur: &AtomicUsize) -> Result<(), Error> {
72    let m = max.load(Relaxed);
73    if m == usize::MAX {
74        cur.fetch_add(1, Relaxed);
75        Ok(())
76    } else {
77        let mut c = cur.load(Relaxed);
78        loop {
79            if c == m {
80                break Err(Error::PoolFull);
81            }
82
83            match cur.compare_exchange(c, c + 1, Relaxed, Relaxed) {
84                Ok(_) => break Ok(()),
85                Err(s) => c = s,
86            }
87        }
88    }
89}