#[cfg(test)]
mod endpoint_test;
use std::{
    collections::HashMap,
    fmt, iter,
    net::{IpAddr, SocketAddr},
    ops::{Index, IndexMut},
    sync::Arc,
    time::Instant,
};
use crate::association::Association;
use crate::chunk::chunk_type::CT_INIT;
use crate::config::{ClientConfig, EndpointConfig, ServerConfig, TransportConfig};
use crate::packet::PartialDecode;
use crate::shared::{
    AssociationEvent, AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner,
};
use crate::util::{AssociationIdGenerator, RandomAssociationIdGenerator};
use crate::{Payload, Transmit};
use shared::EcnCodepoint;
use bytes::Bytes;
use fxhash::FxHashMap;
use log::{debug, trace};
use rand::{rngs::StdRng, SeedableRng};
use slab::Slab;
use thiserror::Error;
pub struct Endpoint {
    rng: StdRng,
    association_ids_init: HashMap<AssociationId, AssociationHandle>,
    association_ids: FxHashMap<AssociationId, AssociationHandle>,
    associations: Slab<AssociationMeta>,
    local_cid_generator: Box<dyn AssociationIdGenerator>,
    endpoint_config: Arc<EndpointConfig>,
    server_config: Option<Arc<ServerConfig>>,
    reject_new_associations: bool,
}
impl fmt::Debug for Endpoint {
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt.debug_struct("Endpoint<T>")
            .field("rng", &self.rng)
            .field("association_ids_initial", &self.association_ids_init)
            .field("association_ids", &self.association_ids)
            .field("associations", &self.associations)
            .field("config", &self.endpoint_config)
            .field("server_config", &self.server_config)
            .field("reject_new_associations", &self.reject_new_associations)
            .finish()
    }
}
impl Endpoint {
    pub fn new(
        endpoint_config: Arc<EndpointConfig>,
        server_config: Option<Arc<ServerConfig>>,
    ) -> Self {
        Self {
            rng: StdRng::from_entropy(),
            association_ids_init: HashMap::default(),
            association_ids: FxHashMap::default(),
            associations: Slab::new(),
            local_cid_generator: (endpoint_config.aid_generator_factory.as_ref())(),
            reject_new_associations: false,
            endpoint_config,
            server_config,
        }
    }
    pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
        self.server_config = server_config;
    }
    pub fn handle_event(&mut self, ch: AssociationHandle, event: EndpointEvent) {
        match event.0 {
            EndpointEventInner::Drained => {
                let conn = self.associations.remove(ch.0);
                self.association_ids_init.remove(&conn.init_cid);
                for cid in conn.loc_cids.values() {
                    self.association_ids.remove(cid);
                }
            }
        }
    }
    pub fn handle(
        &mut self,
        now: Instant,
        remote: SocketAddr,
        local_ip: Option<IpAddr>,
        ecn: Option<EcnCodepoint>,
        data: Bytes,
    ) -> Option<(AssociationHandle, DatagramEvent)> {
        let partial_decode = match PartialDecode::unmarshal(&data) {
            Ok(x) => x,
            Err(err) => {
                trace!("malformed header: {}", err);
                return None;
            }
        };
        let dst_cid = partial_decode.common_header.verification_tag;
        let known_ch = if dst_cid > 0 {
            self.association_ids.get(&dst_cid).cloned()
        } else {
            if partial_decode.first_chunk_type == CT_INIT {
                if let Some(dst_cid) = partial_decode.initiate_tag {
                    self.association_ids.get(&dst_cid).cloned()
                } else {
                    None
                }
            } else {
                None
            }
        };
        if let Some(ch) = known_ch {
            return Some((
                ch,
                DatagramEvent::AssociationEvent(AssociationEvent(AssociationEventInner::Datagram(
                    Transmit {
                        now,
                        remote,
                        ecn,
                        payload: Payload::PartialDecode(partial_decode),
                        local_ip,
                    },
                ))),
            ));
        }
        self.handle_first_packet(now, remote, local_ip, ecn, partial_decode)
            .map(|(ch, a)| (ch, DatagramEvent::NewAssociation(a)))
    }
    pub fn connect(
        &mut self,
        config: ClientConfig,
        remote: SocketAddr,
    ) -> Result<(AssociationHandle, Association), ConnectError> {
        if self.is_full() {
            return Err(ConnectError::TooManyAssociations);
        }
        if remote.port() == 0 {
            return Err(ConnectError::InvalidRemoteAddress(remote));
        }
        let remote_aid = RandomAssociationIdGenerator::new().generate_aid();
        let local_aid = self.new_aid();
        let (ch, conn) = self.add_association(
            remote_aid,
            local_aid,
            remote,
            None,
            Instant::now(),
            None,
            config.transport,
        );
        Ok((ch, conn))
    }
    fn new_aid(&mut self) -> AssociationId {
        loop {
            let aid = self.local_cid_generator.generate_aid();
            if !self.association_ids.contains_key(&aid) {
                break aid;
            }
        }
    }
    fn handle_first_packet(
        &mut self,
        now: Instant,
        remote: SocketAddr,
        local_ip: Option<IpAddr>,
        ecn: Option<EcnCodepoint>,
        partial_decode: PartialDecode,
    ) -> Option<(AssociationHandle, Association)> {
        if partial_decode.first_chunk_type != CT_INIT
            || (partial_decode.first_chunk_type == CT_INIT && partial_decode.initiate_tag.is_none())
        {
            debug!("refusing first packet with Non-INIT or emtpy initial_tag INIT");
            return None;
        }
        let server_config = self.server_config.as_ref().unwrap();
        if self.associations.len() >= server_config.concurrent_associations as usize
            || self.reject_new_associations
            || self.is_full()
        {
            debug!("refusing association");
            return None;
        }
        let server_config = server_config.clone();
        let transport_config = server_config.transport.clone();
        let remote_aid = *partial_decode.initiate_tag.as_ref().unwrap();
        let local_aid = self.new_aid();
        let (ch, mut conn) = self.add_association(
            remote_aid,
            local_aid,
            remote,
            local_ip,
            now,
            Some(server_config),
            transport_config,
        );
        conn.handle_event(AssociationEvent(AssociationEventInner::Datagram(
            Transmit {
                now,
                remote,
                ecn,
                payload: Payload::PartialDecode(partial_decode),
                local_ip,
            },
        )));
        Some((ch, conn))
    }
    #[allow(clippy::too_many_arguments)]
    fn add_association(
        &mut self,
        remote_aid: AssociationId,
        local_aid: AssociationId,
        remote_addr: SocketAddr,
        local_ip: Option<IpAddr>,
        now: Instant,
        server_config: Option<Arc<ServerConfig>>,
        transport_config: Arc<TransportConfig>,
    ) -> (AssociationHandle, Association) {
        let conn = Association::new(
            server_config,
            transport_config,
            self.endpoint_config.get_max_payload_size(),
            local_aid,
            remote_addr,
            local_ip,
            now,
        );
        let id = self.associations.insert(AssociationMeta {
            init_cid: remote_aid,
            cids_issued: 0,
            loc_cids: iter::once((0, local_aid)).collect(),
            initial_remote: remote_addr,
        });
        let ch = AssociationHandle(id);
        self.association_ids.insert(local_aid, ch);
        (ch, conn)
    }
    pub fn reject_new_associations(&mut self) {
        self.reject_new_associations = true;
    }
    pub fn endpoint_config(&self) -> &EndpointConfig {
        &self.endpoint_config
    }
    fn is_full(&self) -> bool {
        (((u32::MAX >> 1) + (u32::MAX >> 2)) as usize) < self.association_ids.len()
    }
}
#[derive(Debug)]
pub(crate) struct AssociationMeta {
    init_cid: AssociationId,
    cids_issued: u64,
    loc_cids: FxHashMap<u64, AssociationId>,
    initial_remote: SocketAddr,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct AssociationHandle(pub usize);
impl From<AssociationHandle> for usize {
    fn from(x: AssociationHandle) -> usize {
        x.0
    }
}
impl Index<AssociationHandle> for Slab<AssociationMeta> {
    type Output = AssociationMeta;
    fn index(&self, ch: AssociationHandle) -> &AssociationMeta {
        &self[ch.0]
    }
}
impl IndexMut<AssociationHandle> for Slab<AssociationMeta> {
    fn index_mut(&mut self, ch: AssociationHandle) -> &mut AssociationMeta {
        &mut self[ch.0]
    }
}
#[allow(clippy::large_enum_variant)] pub enum DatagramEvent {
    AssociationEvent(AssociationEvent),
    NewAssociation(Association),
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectError {
    #[error("endpoint stopping")]
    EndpointStopping,
    #[error("too many associations")]
    TooManyAssociations,
    #[error("invalid DNS name: {0}")]
    InvalidDnsName(String),
    #[error("invalid remote address: {0}")]
    InvalidRemoteAddress(SocketAddr),
    #[error("no default client config")]
    NoDefaultClientConfig,
}