mod web_context;
use bytes::BytesMut;
use futures::task::AtomicWaker;
use futures::{future::Ready, io, prelude::*};
use js_sys::Array;
use libp2p_core::transport::DialOpts;
use libp2p_core::{
    multiaddr::{Multiaddr, Protocol},
    transport::{ListenerId, TransportError, TransportEvent},
};
use send_wrapper::SendWrapper;
use std::cmp::min;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::{pin::Pin, task::Context, task::Poll};
use wasm_bindgen::prelude::*;
use web_sys::{CloseEvent, Event, MessageEvent, WebSocket};
use crate::web_context::WebContext;
#[derive(Default)]
pub struct Transport {
    _private: (),
}
const MAX_BUFFER: usize = 1024 * 1024;
impl libp2p_core::Transport for Transport {
    type Output = Connection;
    type Error = Error;
    type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
    type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
    fn listen_on(
        &mut self,
        _: ListenerId,
        addr: Multiaddr,
    ) -> Result<(), TransportError<Self::Error>> {
        Err(TransportError::MultiaddrNotSupported(addr))
    }
    fn remove_listener(&mut self, _id: ListenerId) -> bool {
        false
    }
    fn dial(
        &mut self,
        addr: Multiaddr,
        dial_opts: DialOpts,
    ) -> Result<Self::Dial, TransportError<Self::Error>> {
        if dial_opts.role.is_listener() {
            return Err(TransportError::MultiaddrNotSupported(addr));
        }
        let url = extract_websocket_url(&addr)
            .ok_or_else(|| TransportError::MultiaddrNotSupported(addr))?;
        Ok(async move {
            let socket = match WebSocket::new(&url) {
                Ok(ws) => ws,
                Err(_) => return Err(Error::invalid_websocket_url(&url)),
            };
            Ok(Connection::new(socket))
        }
        .boxed())
    }
    fn poll(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> std::task::Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
        Poll::Pending
    }
}
fn extract_websocket_url(addr: &Multiaddr) -> Option<String> {
    let mut protocols = addr.iter();
    let host_port = match (protocols.next(), protocols.next()) {
        (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
            format!("{ip}:{port}")
        }
        (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
            format!("[{ip}]:{port}")
        }
        (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
        | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
        | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port)))
        | (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => {
            format!("{}:{}", &h, port)
        }
        _ => return None,
    };
    let (scheme, wspath) = match (protocols.next(), protocols.next()) {
        (Some(Protocol::Tls), Some(Protocol::Ws(path))) => ("wss", path.into_owned()),
        (Some(Protocol::Ws(path)), _) => ("ws", path.into_owned()),
        (Some(Protocol::Wss(path)), _) => ("wss", path.into_owned()),
        _ => return None,
    };
    Some(format!("{scheme}://{host_port}{wspath}"))
}
#[derive(thiserror::Error, Debug)]
#[error("{msg}")]
pub struct Error {
    msg: String,
}
impl Error {
    fn invalid_websocket_url(url: &str) -> Self {
        Self {
            msg: format!("Invalid websocket url: {url}"),
        }
    }
}
pub struct Connection {
    inner: SendWrapper<Inner>,
}
struct Inner {
    socket: WebSocket,
    new_data_waker: Rc<AtomicWaker>,
    read_buffer: Rc<Mutex<BytesMut>>,
    open_waker: Rc<AtomicWaker>,
    write_waker: Rc<AtomicWaker>,
    close_waker: Rc<AtomicWaker>,
    errored: Rc<AtomicBool>,
    _on_open_closure: Rc<Closure<dyn FnMut(Event)>>,
    _on_buffered_amount_low_closure: Rc<Closure<dyn FnMut(Event)>>,
    _on_close_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
    _on_error_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
    _on_message_closure: Rc<Closure<dyn FnMut(MessageEvent)>>,
    buffered_amount_low_interval: i32,
}
impl Inner {
    fn ready_state(&self) -> ReadyState {
        match self.socket.ready_state() {
            0 => ReadyState::Connecting,
            1 => ReadyState::Open,
            2 => ReadyState::Closing,
            3 => ReadyState::Closed,
            unknown => unreachable!("invalid `ReadyState` value: {unknown}"),
        }
    }
    fn poll_open(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
        match self.ready_state() {
            ReadyState::Connecting => {
                self.open_waker.register(cx.waker());
                Poll::Pending
            }
            ReadyState::Open => Poll::Ready(Ok(())),
            ReadyState::Closed | ReadyState::Closing => {
                Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
            }
        }
    }
    fn error_barrier(&self) -> io::Result<()> {
        if self.errored.load(Ordering::SeqCst) {
            return Err(io::ErrorKind::BrokenPipe.into());
        }
        Ok(())
    }
}
#[derive(PartialEq)]
enum ReadyState {
    Connecting,
    Open,
    Closing,
    Closed,
}
impl Connection {
    fn new(socket: WebSocket) -> Self {
        socket.set_binary_type(web_sys::BinaryType::Arraybuffer);
        let open_waker = Rc::new(AtomicWaker::new());
        let onopen_closure = Closure::<dyn FnMut(_)>::new({
            let open_waker = open_waker.clone();
            move |_| {
                open_waker.wake();
            }
        });
        socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
        let close_waker = Rc::new(AtomicWaker::new());
        let onclose_closure = Closure::<dyn FnMut(_)>::new({
            let close_waker = close_waker.clone();
            move |_| {
                close_waker.wake();
            }
        });
        socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
        let errored = Rc::new(AtomicBool::new(false));
        let onerror_closure = Closure::<dyn FnMut(_)>::new({
            let errored = errored.clone();
            move |_| {
                errored.store(true, Ordering::SeqCst);
            }
        });
        socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
        let read_buffer = Rc::new(Mutex::new(BytesMut::new()));
        let new_data_waker = Rc::new(AtomicWaker::new());
        let onmessage_closure = Closure::<dyn FnMut(_)>::new({
            let read_buffer = read_buffer.clone();
            let new_data_waker = new_data_waker.clone();
            let errored = errored.clone();
            move |e: MessageEvent| {
                let data = js_sys::Uint8Array::new(&e.data());
                let mut read_buffer = read_buffer.lock().unwrap();
                if read_buffer.len() + data.length() as usize > MAX_BUFFER {
                    tracing::warn!("Remote is overloading us with messages, closing connection");
                    errored.store(true, Ordering::SeqCst);
                    return;
                }
                read_buffer.extend_from_slice(&data.to_vec());
                new_data_waker.wake();
            }
        });
        socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
        let write_waker = Rc::new(AtomicWaker::new());
        let on_buffered_amount_low_closure = Closure::<dyn FnMut(_)>::new({
            let write_waker = write_waker.clone();
            let socket = socket.clone();
            move |_| {
                if socket.buffered_amount() == 0 {
                    write_waker.wake();
                }
            }
        });
        let buffered_amount_low_interval = WebContext::new()
            .expect("to have a window or worker context")
            .set_interval_with_callback_and_timeout_and_arguments(
                on_buffered_amount_low_closure.as_ref().unchecked_ref(),
                100, &Array::new(),
            )
            .expect("to be able to set an interval");
        Self {
            inner: SendWrapper::new(Inner {
                socket,
                new_data_waker,
                read_buffer,
                open_waker,
                write_waker,
                close_waker,
                errored,
                _on_open_closure: Rc::new(onopen_closure),
                _on_buffered_amount_low_closure: Rc::new(on_buffered_amount_low_closure),
                _on_close_closure: Rc::new(onclose_closure),
                _on_error_closure: Rc::new(onerror_closure),
                _on_message_closure: Rc::new(onmessage_closure),
                buffered_amount_low_interval,
            }),
        }
    }
    fn buffered_amount(&self) -> usize {
        self.inner.socket.buffered_amount() as usize
    }
}
impl AsyncRead for Connection {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<Result<usize, io::Error>> {
        let this = self.get_mut();
        this.inner.error_barrier()?;
        futures::ready!(this.inner.poll_open(cx))?;
        let mut read_buffer = this.inner.read_buffer.lock().unwrap();
        if read_buffer.is_empty() {
            this.inner.new_data_waker.register(cx.waker());
            return Poll::Pending;
        }
        let split_index = min(buf.len(), read_buffer.len());
        let bytes_to_return = read_buffer.split_to(split_index);
        let len = bytes_to_return.len();
        buf[..len].copy_from_slice(&bytes_to_return);
        Poll::Ready(Ok(len))
    }
}
impl AsyncWrite for Connection {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.get_mut();
        this.inner.error_barrier()?;
        futures::ready!(this.inner.poll_open(cx))?;
        debug_assert!(this.buffered_amount() <= MAX_BUFFER);
        let remaining_space = MAX_BUFFER - this.buffered_amount();
        if remaining_space == 0 {
            this.inner.write_waker.register(cx.waker());
            return Poll::Pending;
        }
        let bytes_to_send = min(buf.len(), remaining_space);
        if this
            .inner
            .socket
            .send_with_u8_array(&buf[..bytes_to_send])
            .is_err()
        {
            return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
        }
        Poll::Ready(Ok(bytes_to_send))
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        if self.buffered_amount() == 0 {
            return Poll::Ready(Ok(()));
        }
        self.inner.error_barrier()?;
        self.inner.write_waker.register(cx.waker());
        Poll::Pending
    }
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        const REGULAR_CLOSE: u16 = 1000; if self.inner.ready_state() == ReadyState::Closed {
            return Poll::Ready(Ok(()));
        }
        self.inner.error_barrier()?;
        if self.inner.ready_state() != ReadyState::Closing {
            let _ = self
                .inner
                .socket
                .close_with_code_and_reason(REGULAR_CLOSE, "user initiated");
        }
        self.inner.close_waker.register(cx.waker());
        Poll::Pending
    }
}
impl Drop for Connection {
    fn drop(&mut self) {
        self.inner.socket.set_onclose(None);
        self.inner.socket.set_onerror(None);
        self.inner.socket.set_onopen(None);
        self.inner.socket.set_onmessage(None);
        const REGULAR_CLOSE: u16 = 1000; if let ReadyState::Connecting | ReadyState::Open = self.inner.ready_state() {
            let _ = self
                .inner
                .socket
                .close_with_code_and_reason(REGULAR_CLOSE, "connection dropped");
        }
        WebContext::new()
            .expect("to have a window or worker context")
            .clear_interval_with_handle(self.inner.buffered_amount_low_interval);
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use libp2p_identity::PeerId;
    #[test]
    fn extract_url() {
        let peer_id = PeerId::random();
        let addr = "/dns4/example.com/tcp/2222/tls/ws"
            .parse::<Multiaddr>()
            .unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "wss://example.com:2222/");
        let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
            .parse()
            .unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "wss://example.com:2222/");
        let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
            .parse::<Multiaddr>()
            .unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "wss://127.0.0.1:2222/");
        let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "wss://[::1]:2222/");
        let addr = "/dns4/example.com/tcp/2222/wss"
            .parse::<Multiaddr>()
            .unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "wss://example.com:2222/");
        let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
            .parse()
            .unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "wss://example.com:2222/");
        let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "wss://127.0.0.1:2222/");
        let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "wss://[::1]:2222/");
        let addr = "/dns4/example.com/tcp/2222/ws"
            .parse::<Multiaddr>()
            .unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "ws://example.com:2222/");
        let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
            .parse()
            .unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "ws://example.com:2222/");
        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "ws://127.0.0.1:2222/");
        let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "ws://[::1]:2222/");
        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
        let url = extract_websocket_url(&addr).unwrap();
        assert_eq!(url, "ws://127.0.0.1:2222/");
        let addr = "/ip4/127.0.0.1/tcp/2222/tls/wss"
            .parse::<Multiaddr>()
            .unwrap();
        assert!(extract_websocket_url(&addr).is_none());
        let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
        assert!(extract_websocket_url(&addr).is_none());
    }
}