1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
use std::{
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
};

use async_ws::{
    connection::WsConfig,
    http::{check_upgrade_response, is_upgrade_request, upgrade_request},
};
use futures::{AsyncReadExt, Stream};
use http::Response;

use crate::{RequestSend, Transport};

mod error;

use error::*;

pub type WsMessageKind = async_ws::message::WsMessageKind;
pub type WsSend = async_ws::connection::WsSend<Transport>;
pub type WsConnectionError = async_ws::connection::WsConnectionError;
pub type WsMessageReader = async_ws::connection::WsMessageReader<Transport>;
pub type WsMessageWriter = async_ws::connection::WsMessageWriter<Transport>;

pub struct WsConnection {
    inner: async_ws::connection::WsConnection<Transport>,
}

impl WsConnection {
    pub async fn connect_with_uri<T>(uri: T) -> Result<Self, WsConnectError>
    where
        http::Uri: TryFrom<T>,
        <http::Uri as TryFrom<T>>::Error: Into<http::uri::InvalidUri>,
    {
        let uri: http::Uri = uri.try_into().map_err(Into::into)?;
        let mut request = Self::connect_request_builder().body("").unwrap();
        *request.uri_mut() = uri;
        Self::connect(&request).await
    }
    pub async fn connect(request: &http::Request<impl AsRef<[u8]>>) -> Result<Self, WsConnectError> {
        if !is_upgrade_request(request) {
            return Err(WsConnectError::InvalidUpgradeRequest);
        }
        let response = RequestSend::new(request).await?;
        if !check_upgrade_response(request, &response) {
            let (head, body_reader) = response.into_parts();
            let mut buf = Vec::new();
            let result = body_reader.take(1 << 14).read_to_end(&mut buf).await;
            let result: Box<dyn std::fmt::Debug + Send + Sync> = match String::from_utf8(buf) {
                Ok(str) => Box::new(result.map(move |_| str)),
                Err(err) => Box::new(result.map(move |_| err.into_bytes())),
            };
            let response = Response::from_parts(head, result);
            return Err(WsConnectError::InvalidUpgradeResponse(response.into()));
        }
        let transport = response.into_body().into_inner()?.1;
        let inner = async_ws::connection::WsConnection::with_config(transport, WsConfig::client());
        Ok(Self { inner })
    }
    pub fn connect_request_builder() -> http::request::Builder {
        upgrade_request()
    }
    pub fn send(&self, kind: WsMessageKind) -> WsSend {
        self.inner.send(kind)
    }
    pub fn send_text(&self) -> WsSend {
        self.send(WsMessageKind::Text)
    }
    pub fn send_binary(&self) -> WsSend {
        self.send(WsMessageKind::Binary)
    }
    pub fn err(&self) -> Option<Arc<WsConnectionError>> {
        self.inner.err()
    }
}

impl Stream for WsConnection {
    type Item = WsMessageReader;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        Pin::new(&mut self.get_mut().inner).poll_next(cx)
    }
}