league_client/
connector.rs

1//! Establishes the connection through the websocket
2
3use futures_util::stream::{SplitSink, SplitStream};
4use futures_util::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_native_tls::TlsStream;
7use tokio_tungstenite::WebSocketStream;
8use tungstenite::Message;
9
10use crate::{core, Error, LCResult as Result};
11
12pub type Connected = WebSocketStream<TlsStream<TcpStream>>;
13
14/// Stores information of the subscription.
15///
16/// Once speaker is dropped, it will unsubscribe from the events and broadcast
17/// that it is finished to the read/write tasks.
18pub struct Speaker {
19    finish: tokio::sync::broadcast::Sender<bool>,
20    writer: flume::Sender<String>,
21    _handles: Vec<tokio::task::JoinHandle<()>>,
22
23    pub reader: flume::Receiver<core::Incoming>,
24}
25
26impl Speaker {
27    pub async fn send(&self, msg: String) -> Result<()> {
28        self.writer.send_async(msg).await.or(Err(Error::SendErr))
29    }
30
31    fn try_send(&self, msg: String) -> Result<()> {
32        self.writer.try_send(msg).or(Err(Error::SendErr))
33    }
34}
35
36impl Drop for Speaker {
37    fn drop(&mut self) {
38        let msg = (6, "OnJsonApiEvent");
39        if let Ok(msg) = serde_json::to_string(&msg) {
40            if let Err(e) = self.try_send(msg) {
41                tracing::error!("failed to unsubscribe: {e}");
42            }
43        };
44
45        if let Err(e) = self.finish.send(true) {
46            tracing::error!("failed to send broadcast: {e}");
47        };
48    }
49}
50
51/// Start a subscription to the socket.
52///
53/// Use the speaker to communicate with the socket.
54pub async fn subscribe(socket: Connected) -> Speaker {
55    let (cleanup_tx, cleanup_rx1) = tokio::sync::broadcast::channel(1);
56    let cleanup_rx2 = cleanup_tx.subscribe();
57
58    let (reader_tx, reader_rx) = flume::unbounded();
59    let (writer_tx, writer_rx) = flume::unbounded();
60
61    let (write, read) = socket.split();
62
63    let read_handle = tokio::task::spawn(read_from(cleanup_rx1, reader_tx, read));
64    let write_handle = tokio::task::spawn(write_to(cleanup_rx2, write, writer_rx));
65
66    Speaker {
67        finish: cleanup_tx,
68        reader: reader_rx,
69        writer: writer_tx,
70        _handles: vec![read_handle, write_handle],
71    }
72}
73
74async fn read_from(
75    mut end: tokio::sync::broadcast::Receiver<bool>,
76    tx: flume::Sender<core::Incoming>,
77    mut read: SplitStream<Connected>,
78) {
79    loop {
80        tokio::select! {
81            Some(msg) = read.next() => {
82                let msg = match msg {
83                    Ok(msg) => msg,
84                    Err(_) => {
85                        tracing::warn!("channel disconnect");
86                        break;
87                    }
88                };
89
90                let msg = msg.to_string();
91                if msg.is_empty() {
92                    continue;
93                }
94
95                let incoming = serde_json::from_str::<core::Incoming>(&msg);
96                let incoming = match incoming {
97                    Ok(incoming) => incoming,
98                    Err(_) => {
99                        tracing::warn!("failed to parse msg into incoming: {msg}");
100                        continue;
101                    },
102                };
103
104                let resp = tx.send_async(incoming).await;
105                if resp.is_err() {
106                    tracing::warn!("channel disconnect");
107                    break;
108                }
109            },
110            _ = end.recv() => { break },
111        };
112    }
113}
114
115async fn write_to(
116    mut end: tokio::sync::broadcast::Receiver<bool>,
117    mut tx: SplitSink<Connected, Message>,
118    read: flume::Receiver<String>,
119) {
120    loop {
121        tokio::select! {
122            msg = read.recv_async() => {
123                let msg = match msg {
124                    Ok(msg) => msg,
125                    Err(_) => {
126                        tracing::warn!("channel disconnect");
127                        break;
128                    }
129                };
130
131                let res = tx.send(Message::Text(msg)).await;
132                if res.is_err() {
133                    tracing::warn!("channel disconnect");
134                    break;
135                }
136            },
137            _ = end.recv() => { break },
138        };
139    }
140}
141
142/// Creates a connection to the wanted websocket. Use this if you want to set up
143/// the connection yourself.
144pub struct Connector {
145    tls: tokio_native_tls::TlsConnector,
146}
147
148impl Connector {
149    fn new(tls: tokio_native_tls::TlsConnector) -> Self {
150        Self { tls }
151    }
152
153    /// create a builder to set up the tls connection.
154    pub fn builder() -> ConnectorBuilder {
155        ConnectorBuilder::default()
156    }
157
158    /// creates a stream and wraps it with tls settings. It will then
159    /// create asocket from the said stream.
160    ///
161    /// the request must have a basic auth included or it will not complete.
162    pub async fn connect(&self, req: tungstenite::http::Request<()>) -> Result<Connected> {
163        let uri = req.uri();
164
165        let host = uri
166            .host()
167            .ok_or(Error::Websocket("host is missing".into()))?;
168        let port = uri
169            .port()
170            .ok_or(Error::Websocket("port is missing".into()))?;
171        let combo = format!("{host}:{port}");
172
173        let stream = tokio::net::TcpStream::connect(&combo)
174            .await
175            .map_err(Error::Stream)?;
176        let stream = self.tls.connect(&combo, stream).await.map_err(Error::Tls)?;
177
178        let (socket, _) = tokio_tungstenite::client_async(req, stream)
179            .await
180            .map_err(Error::Tungstenite)?;
181
182        Ok(socket)
183    }
184}
185
186#[derive(Default)]
187pub struct ConnectorBuilder {
188    insecure: bool,
189}
190
191impl ConnectorBuilder {
192    pub fn insecure(self, value: bool) -> Self {
193        Self { insecure: value }
194    }
195
196    pub fn build(self) -> Result<Connector> {
197        let mut connector = native_tls::TlsConnector::builder();
198
199        if self.insecure {
200            connector.danger_accept_invalid_certs(true);
201        } else {
202            // Work out cert
203            unimplemented!();
204        }
205
206        let connector = connector
207            .build()
208            .map_err(|e| Error::Websocket(e.to_string()))?;
209        let tls = tokio_native_tls::TlsConnector::from(connector);
210
211        Ok(Connector::new(tls))
212    }
213}