Documentation
use crate::http1::{SendRequest, Sender};
use cpool::backend::{Address, BackendState};
use cpool::strategy::LbStrategy;
use cpool::{Error, debug, instrument_current_span, tokio_spawn, trace};
use hyper::client::conn::http1;
use std::collections::HashMap;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::atomic::{AtomicBool, AtomicU32};
use std::sync::{Arc, LazyLock, Mutex};

/// sender id 分配
static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));

struct Inner {
    /// 空闲的sr
    free: HashMap<u32, Arc<SendRequest>>,
    /// 在用的sr
    work: HashMap<u32, Arc<SendRequest>>,
}

impl Inner {
    fn new() -> Self {
        Inner {
            free: HashMap::new(),
            work: HashMap::new(),
        }
    }

    fn add_work_sr(&mut self, sr: Arc<SendRequest>) -> u32 {
        let id = SENDER_ID.fetch_add(1, Relaxed);
        self.work.insert(id, sr);
        id
    }

    fn remove_sr(&mut self, id: u32) -> bool {
        // 不应该同时处于work和free中
        if self.work.remove(&id).is_some() {
            true
        } else if self.free.remove(&id).is_some() {
            true
        } else {
            false
        }
    }

    /// 返回值是被删除的元素个数, 及查找到的sr
    fn get_sr(&mut self) -> (usize, Option<Arc<SendRequest>>) {
        if self.free.is_empty() {
            // 挪动
            self.work.retain(|id, sr| {
                // 引用计数为2时,代表外界不存在引用, 队列里一份, run那里一份
                if Arc::strong_count(sr) == 2 {
                    self.free.insert(*id, sr.clone());
                    false
                } else {
                    true
                }
            });
        }

        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
        let mut del = 0;
        let ks: Vec<u32> = self.free.keys().cloned().collect();
        for k in ks {
            if let Some(sr) = self.free.remove(&k) {
                if sr.is_closed() {
                    del += 1;
                } else {
                    // 挪入work
                    self.work.insert(k, sr.clone());
                    return (del, Some(sr));
                }
            }
        }

        (del, None)
    }
}

/// http1连接池
/// 连接池中的连接受max_conn数据限制
/// 默认的keepalive是一分钟
/// 有三种行为会导致连接池中的连接被释放
///     1.对方主动关闭连接
///     2.remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
///     3.在外界没有引用sender达到keepalive时间后, 注意如果没有设置keepalive则此条无效
pub struct Pool {
    id: String,
    state: cpool::pool::BaseState,
    use_tls: AtomicBool,
    free_conn_map: Mutex<HashMap<u64, Inner>>,
}

impl Pool {
    pub fn new(id: String, strategy: Arc<dyn LbStrategy>) -> Self {
        let p = Pool {
            id,
            state: cpool::pool::BaseState::new(strategy),
            use_tls: AtomicBool::new(false),
            free_conn_map: Mutex::new(HashMap::new()),
        };
        <Pool as cpool::pool::Pool>::set_keepalive(
            &p,
            Some(std::time::Duration::from_secs(60 * 1)),
        );
        p
    }
}

impl Default for Pool {
    fn default() -> Self {
        Pool::new(
            "".to_string(),
            Arc::new(cpool::strategy::CHStrategy::default()),
        )
    }
}

impl cpool::pool::Pool for Pool {
    cpool::macros::base_pool_impl! {state}

    fn remove_backend(&self, addr: &Address) -> bool {
        if self.state.lb_strategy.remove_backend(addr) {
            // 清除缓存
            self.clear_bs_sr(addr.hash_code());
            true
        } else {
            false
        }
    }

    fn use_tls(&self, tls: bool) {
        self.use_tls.store(tls, Relaxed);
    }

    fn tls(&self) -> bool {
        self.use_tls.load(Relaxed)
    }
}

impl Pool {
    fn clear_bs_sr(&self, hash_code: u64) {
        let mut guard = self.free_conn_map.lock().unwrap();
        if let Some(inner) = guard.remove(&hash_code) {
            assert!(
                self.state
                    .cur_conn
                    .fetch_sub(inner.free.len() + inner.work.len(), Relaxed)
                    > 0
            );
            trace!(
                "[http/1.1 pool] [desc] current connection count: {}",
                self.state.cur_conn.load(Relaxed)
            );
        }
    }

    fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
        let mut guard = self.free_conn_map.lock().unwrap();
        let inner = guard.get_mut(&bs.hash_code())?;

        let sr = {
            let (del_cnt, sr) = inner.get_sr();
            if del_cnt > 0 {
                assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
                trace!(
                    "[http/1.1 pool] [desc] current connection count: {}",
                    self.state.cur_conn.load(Relaxed)
                );
            }
            sr
        };

        let tls = <Pool as cpool::pool::Pool>::tls(self);
        sr.map(|s| Sender::new(s, crate::utils::base_url(tls, bs.get_address())))
    }

    fn add_sr(&self, hash_code: u64, sr: Arc<SendRequest>) -> u32 {
        let mut guard = self.free_conn_map.lock().unwrap();
        if let Some(inner) = guard.get_mut(&hash_code) {
            inner.add_work_sr(sr)
        } else {
            let mut inner = Inner::new();
            let id = inner.add_work_sr(sr);
            guard.insert(hash_code, inner);
            id
        }
    }

    fn remove_sr(&self, hash_code: u64, id: u32) {
        let mut guard = self.free_conn_map.lock().unwrap();
        if let Some(inner) = guard.get_mut(&hash_code) {
            if inner.remove_sr(id) {
                assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
                trace!(
                    "[http/1.1 pool] [desc] current connection count: {}",
                    self.state.cur_conn.load(Relaxed)
                );
            }
        }
    }

    fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
        pool: Arc<Self>,
        c: C,
        sr: Arc<SendRequest>,
        bs: &BackendState,
    ) {
        let hash_code = bs.hash_code();

        // 加入work,且必须在下面执行super::run_conn之前加入
        let id = pool.add_sr(hash_code, sr.clone());
        let ka = <Pool as cpool::pool::Pool>::get_keepalive(&pool);

        tokio_spawn! {
            instrument_current_span! {
                async move {
                    let _r = crate::utils::run_conn(sr, c, ka).await;
                    debug!("[http/1.1 pool] connection closed: {:?}", _r);
                    pool.remove_sr(hash_code, id);
                }
            }
        };
    }

    async fn create_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
        let tls = <Pool as cpool::pool::Pool>::tls(&self);
        let addr = bs.get_address();

        // 获取连接
        let tcp = crate::utils::create_https_stream(addr).await?;
        let tls_tcp =
            crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP1_TLS_CLIENT_CFG.clone())
                .await?;

        // 握手建立 sender 和 connection
        let io = hyper_util::rt::TokioIo::new(tls_tcp);
        let pair = http1::handshake(io)
            .await
            .map_err(|e| Error::from_other(e))?;

        // 启动连接驱动
        let sr = Arc::new(pair.0);
        Pool::run_conn(self, pair.1, sr.clone(), &bs);

        // 返回一份
        Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
    }

    async fn create_non_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
        let addr = bs.get_address();
        let tls = <Pool as cpool::pool::Pool>::tls(&self);

        // 获取连接
        let tcp = crate::utils::create_http_stream(addr).await?;

        // 握手建立 sender 和 connection
        let io = hyper_util::rt::TokioIo::new(tcp);
        let pair = http1::handshake(io)
            .await
            .map_err(|e| Error::from_other(e))?;

        // 启动连接驱动
        let sr = Arc::new(pair.0);
        Pool::run_conn(self, pair.1, sr.clone(), &bs);

        // 返回一份
        Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
    }

    async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
        if <Pool as cpool::pool::Pool>::tls(&self) {
            self.create_tls_sender(bs).await
        } else {
            self.create_non_tls_sender(bs).await
        }
    }
}

pub trait Http1Pool {
    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;

    fn get_id(&self) -> &String;
}

impl Http1Pool for Pool {
    async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
        let bs = self
            .state
            .lb_strategy
            .get_backend(key)
            .ok_or(Error::NoBackend)?;

        if let Some(sender) = self.clone().get_sender(&bs) {
            return Ok(sender);
        }

        let sender = {
            // 预分配数量
            cpool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
            self.clone().create_sender(&bs).await.map(|s| {
                trace!(
                    "[http/1.1 pool] [incr] current connection count: {}",
                    self.state.cur_conn.load(Relaxed)
                );
                s
            })
        };

        if sender.is_err() {
            assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
        }

        sender
    }

    fn get_id(&self) -> &String {
        &self.id
    }
}

pub async fn get<P: Http1Pool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
    Http1Pool::get(pool, key).await
}