Documentation
use crate::http2::{SendRequest, Sender};
use hyper::client::conn::http2;
use net_pool::backend::{Address, BackendState};
use net_pool::strategy::LbStrategy;
use net_pool::{Error, debug, instrument_current_span, tokio_spawn, trace};
use std::collections::HashMap;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize};
use std::sync::{Arc, LazyLock, Mutex};

/// sender id 分配
static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));

struct Inner(SendRequest, Arc<()>, u32);

impl Inner {
    fn new(sender: SendRequest) -> Self {
        Inner(sender, Arc::new(()), SENDER_ID.fetch_add(1, Relaxed))
    }

    fn id(&self) -> u32 {
        self.2
    }

    fn new_sender(&self, base_url: String) -> Sender {
        Sender::new(self.0.clone(), base_url, self.1.clone())
    }

    /// 引用数量
    fn ref_count(&self) -> usize {
        Arc::strong_count(&self.1)
    }

    /// 是否达到最大数
    fn limited(&self, max_streams: Option<usize>) -> bool {
        if let Some(max) = max_streams {
            if self.ref_count() >= max + 1 {
                return true;
            }
        }
        return false;
    }

    fn is_closed(&self) -> bool {
        self.0.is_closed()
    }

    fn reference(&self) -> Arc<()> {
        self.1.clone()
    }
}

/// http2连接池
/// 连接池中的连接受max_conn数据限制
/// 默认的keepalive是一分钟
/// 连接池中的每个连接可并发流数受max_streams影响, 默认是20个
/// 有三种行为会导致连接池中的连接被释放
///     1.对方主动关闭连接
///     2.remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
///     3.在外界没有引用sender达到keepalive时间后, 注意如果没有设置keepalive则此条无效
pub struct Pool {
    id: String,
    state: net_pool::pool::BaseState,
    max_streams: AtomicUsize,
    free_conn_map: Mutex<HashMap<u64, Vec<Inner>>>,
    use_tls: AtomicBool,
}

impl Pool {
    pub fn new(id: String, strategy: Arc<dyn LbStrategy>, mut max_streams: Option<usize>) -> Self {
        if let Some(0) = max_streams {
            max_streams = None;
        }

        let p = Pool {
            id,
            state: net_pool::pool::BaseState::new(strategy),
            max_streams: AtomicUsize::new(usize::MAX),
            free_conn_map: Mutex::new(HashMap::new()),
            use_tls: AtomicBool::new(false),
        };

        p.set_max_streams(max_streams);
        <Pool as net_pool::pool::Pool>::set_keepalive(
            &p,
            Some(std::time::Duration::from_secs(60 * 1)),
        );
        p
    }

    pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
        if let Some(0) = max_streams {
            max_streams = None;
        }

        match max_streams {
            None => self.max_streams.store(usize::MAX, Relaxed),
            Some(v) => self.max_streams.store(v, Relaxed),
        }
    }
    pub fn get_max_streams(&self) -> Option<usize> {
        match self.max_streams.load(Relaxed) {
            usize::MAX => None,
            v => Some(v),
        }
    }
}

impl Default for Pool {
    fn default() -> Self {
        Pool::new(
            "".to_string(),
            Arc::new(net_pool::strategy::CHStrategy::default()),
            Some(200),
        )
    }
}

impl net_pool::pool::Pool for Pool {
    net_pool::macros::base_pool_impl! {state}

    fn remove_backend(&self, addr: &Address) -> bool {
        if self.state.lb_strategy.remove_backend(addr) {
            // 清除缓存
            self.clear_bs_inner(addr);
            true
        } else {
            false
        }
    }

    fn use_tls(&self, tls: bool) {
        self.use_tls.store(tls, Relaxed);
    }

    fn tls(&self) -> bool {
        self.use_tls.load(Relaxed)
    }
}

impl Pool {
    fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
        let mut guard = self.free_conn_map.lock().unwrap();

        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
        let inners = guard.get_mut(&bs.hash_code())?;
        for idx in (0..inners.len()).rev() {
            let inner = &mut inners[idx];
            if inner.is_closed() {
                // 无效了
                assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
                trace!(
                    "[http/2.0 pool] [desc] current connection count: {}",
                    self.state.cur_conn.load(Relaxed)
                );
                inners.remove(idx);
                continue;
            }

            if inner.limited(self.get_max_streams()) {
                continue;
            }

            let tls = <Pool as net_pool::pool::Pool>::tls(self);
            return Some(inner.new_sender(crate::utils::base_url(tls, bs.get_address())));
        }

        None
    }

    fn add_inner(&self, hash_code: u64, inner: Inner) {
        let mut guard = self.free_conn_map.lock().unwrap();
        if let Some(inners) = guard.get_mut(&hash_code) {
            inners.push(inner);
        } else {
            guard.insert(hash_code, vec![inner]);
        }
    }

    fn clear_bs_inner(&self, addr: &Address) {
        let mut guard = self.free_conn_map.lock().unwrap();
        if let Some(inners) = guard.remove(&addr.hash_code()) {
            assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
            trace!(
                "[http/2.0 pool] [desc] current connection count: {}",
                self.state.cur_conn.load(Relaxed)
            );
        }
    }

    fn remove_inner(&self, hash_code: u64, id: u32) {
        let mut guard = self.free_conn_map.lock().unwrap();
        if let Some(inners) = guard.get_mut(&hash_code) {
            let mut del = false;
            inners.retain(|inner| {
                if inner.id() == id {
                    del = true;
                    assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
                    false
                } else {
                    true
                }
            });
            if del {
                trace!(
                    "[http/2.0 pool] [desc] current connection count: {}",
                    self.state.cur_conn.load(Relaxed)
                );
            }
        }
    }

    fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
        pool: Arc<Self>,
        c: C,
        sr: SendRequest,
        bs: &BackendState,
    ) -> Sender {
        let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
        let inner = Inner::new(sr);
        let sender = inner.new_sender(crate::utils::base_url(tls, bs.get_address()));

        let tuple = (
            bs.hash_code(),
            <Pool as net_pool::pool::Pool>::get_keepalive(&pool),
            inner.reference(),
            inner.id(),
        );

        // 加入
        pool.add_inner(tuple.0, inner);

        // 驱动
        tokio_spawn! {
            instrument_current_span! {
                async move {
                    let _r = crate::utils::run_conn(tuple.2, c, tuple.1).await;
                    debug!("[http/2.0 pool] connection closed: {:?}", _r);
                    pool.remove_inner(tuple.0, tuple.3);
                }
            }
        };

        sender
    }

    async fn create_tls_sender(
        self: Arc<Self>,
        bs: &BackendState,
        exec: hyper_util::rt::TokioExecutor,
    ) -> Result<Sender, Error> {
        let addr = bs.get_address();

        // 获取连接
        let tcp = crate::utils::create_https_stream(addr).await?;
        let tls_tcp =
            crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
                .await?;

        // 握手建立 sender 和 connection
        let max_streams = self.get_max_streams();
        let io = hyper_util::rt::TokioIo::new(tls_tcp);
        let pair = http2::Builder::new(exec)
            .max_concurrent_streams(max_streams.map(|m| m as u32))
            .handshake(io)
            .await
            .map_err(|e| Error::from_other(e))?;

        // 启动连接驱动
        let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
        Ok(sender)
    }

    async fn create_non_tls_sender(
        self: Arc<Self>,
        bs: &BackendState,
        exec: hyper_util::rt::TokioExecutor,
    ) -> Result<Sender, Error> {
        let tcp = crate::utils::create_http_stream(bs.get_address()).await?;

        // 握手建立 sender 和 connection
        let io = hyper_util::rt::TokioIo::new(tcp);
        let pair = http2::handshake(exec, io)
            .await
            .map_err(|e| Error::from_other(e))?;

        // 启动连接驱动
        let sender = Pool::run_conn(self, pair.1, pair.0, bs);
        Ok(sender)
    }

    async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
        let exec = hyper_util::rt::tokio::TokioExecutor::new();
        if <Pool as net_pool::pool::Pool>::tls(&self) {
            self.create_tls_sender(bs, exec).await
        } else {
            self.create_non_tls_sender(bs, exec).await
        }
    }
}

pub trait HttpPool {
    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;

    fn get_id(&self) -> &String;
}

impl HttpPool for Pool {
    async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
        let bs = self
            .state
            .lb_strategy
            .get_backend(key)
            .ok_or(Error::NoBackend)?;

        let sender = match self.get_sender(&bs) {
            Some(s) => Ok(s),
            None => {
                // 预分配数量
                net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
                self.clone().create_sender(&bs).await.map(|s| {
                    trace!(
                        "[http/2.0 pool] [incr] current connection count: {}",
                        self.state.cur_conn.load(Relaxed)
                    );
                    s
                })
            }
        };

        if sender.is_err() {
            assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
        }

        sender
    }

    fn get_id(&self) -> &String {
        &self.id
    }
}

pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
    HttpPool::get(pool, key).await
}