irelia_cli/
ws.rs

1//! Module containing all the data on the websocket LCU bindings
2
3use std::borrow::Cow;
4use std::{collections::HashSet, ops::ControlFlow, sync::Arc};
5
6use futures_util::{
7    stream::{SplitSink, SplitStream},
8    SinkExt, StreamExt,
9};
10use rustls::ClientConfig;
11use serde_json::Value;
12use tokio::{
13    net::TcpStream,
14    sync::mpsc::{self, UnboundedSender},
15    task::JoinHandle,
16};
17use tokio_tungstenite::{
18    connect_async_tls_with_config,
19    tungstenite::{client::IntoClientRequest, http::HeaderValue, Message},
20    Connector, MaybeTlsStream, WebSocketStream,
21};
22
23use crate::{
24    utils::{process_info::get_running_client, setup_tls::setup_tls_connector},
25    Error,
26};
27
28/// Different LCU websocket request types
29#[derive(PartialEq, Clone)]
30pub enum RequestType {
31    Welcome = 0,
32    Prefix = 1,
33    Call = 2,
34    CallResult = 3,
35    CallError = 4,
36    Subscribe = 5,
37    Unsubscribe = 6,
38    Publish = 7,
39    Event = 8,
40}
41
42#[derive(Eq, Hash, PartialEq, Clone)]
43/// Different event types that can be passed to the
44/// subscribe and unsubscribe methods.
45pub enum EventType {
46    OnJsonApiEvent,
47    OnLcdsEvent,
48    OnLog,
49    OnRegionLocaleChanged,
50    OnServiceProxyAsyncEvent,
51    OnServiceProxyMethodEvent,
52    OnServiceProxyUuidEvent,
53    OnJsonApiEventCallback(String),
54    OnLcdsEventCallback(String),
55}
56
57impl EventType {
58    fn to_string(&self) -> Cow<'static, str> {
59        match self {
60            EventType::OnJsonApiEvent => "OnJsonApiEvent".into(),
61            EventType::OnLcdsEvent => "OnLcdsEvent".into(),
62            EventType::OnLog => "OnLog".into(),
63            EventType::OnRegionLocaleChanged => "OnRegionLocaleChanged".into(),
64            EventType::OnServiceProxyAsyncEvent => "OnServiceProxyAsyncEvent".into(),
65            EventType::OnServiceProxyMethodEvent => "OnServiceProxyMethodEvent".into(),
66            EventType::OnServiceProxyUuidEvent => "OnServiceProxyUuidEvent".into(),
67            EventType::OnJsonApiEventCallback(callback) => {
68                format!("OnJsonApiEvent{}", callback.replace('/', "_")).into()
69            }
70            EventType::OnLcdsEventCallback(callback) => {
71                format!("OnLcdsEvent{}", callback.replace('/', "_")).into()
72            }
73        }
74    }
75}
76
77/// Struct representing a connection to the LCU websocket
78pub struct LCUWebSocket {
79    ws_sender: UnboundedSender<(RequestType, EventType)>,
80    handle: JoinHandle<()>,
81    url: String,
82    auth_header: String,
83}
84
85#[derive(PartialEq)]
86pub enum Flow {
87    TryReconnect,
88    Continue,
89}
90
91impl LCUWebSocket {
92    /// Creates a new connection to the LCU websocket
93    ///
94    /// # Errors
95    /// This function will return an error if the LCU is not running,
96    /// or if it cannot connect to the websocket
97    ///
98    /// # Panics
99    ///
100    /// If the auth header returned is somehow invalid (though I have not seen this in practice)
101    pub async fn new(
102        f: impl Fn(Result<&[Value], Error>) -> ControlFlow<(), Flow> + Send + Sync + 'static,
103    ) -> Result<Self, Error> {
104        let tls = setup_tls_connector();
105        let tls = Arc::new(tls);
106        let connector = Connector::Rustls(tls.clone());
107        let (url, auth_header) = get_running_client(false)?;
108        let str_req = format!("wss://{url}");
109        let mut request = str_req
110            .as_str()
111            .into_client_request()
112            .map_err(Error::WebsocketError)?;
113        request.headers_mut().insert(
114            "Authorization",
115            HeaderValue::from_str(&auth_header).expect("This is always a valid header"),
116        );
117
118        let (stream, _) = connect_async_tls_with_config(request, None, false, Some(connector))
119            .await
120            .map_err(Error::WebsocketError)?;
121
122        let (ws_sender, mut ws_receiver) = mpsc::unbounded_channel::<(RequestType, EventType)>();
123
124        let handle = tokio::spawn(async move {
125            let mut active_commands = HashSet::new();
126            let (mut write, mut read) = stream.split();
127
128            loop {
129                if let Ok((code, endpoint)) = ws_receiver.try_recv() {
130                    let endpoint = endpoint.to_string();
131
132                    let command = format!("[{}, \"{endpoint}\"]", code.clone() as u8);
133
134                    if code == RequestType::Subscribe {
135                        active_commands.insert(endpoint.clone());
136                    } else if code == RequestType::Unsubscribe {
137                        active_commands.remove(&endpoint);
138                    };
139
140                    if write.send(command.into()).await.is_err() {
141                        let mut c = f(Err(Error::LCUProcessNotRunning));
142                        if !budget_recursive(&mut c, &str_req, &tls, &f, &mut write, &mut read)
143                            .await
144                        {
145                            break;
146                        };
147                    };
148                };
149
150                if let Some(Ok(data)) = read.next().await {
151                    if let Ok(json) = &serde_json::from_slice::<Vec<Value>>(&data.into_data()) {
152                        let json = if let Some(endpoint) = json[1].as_str() {
153                            if active_commands.contains(endpoint) {
154                                json
155                            } else {
156                                continue;
157                            }
158                        } else {
159                            json
160                        };
161
162                        let mut c = f(Ok(json));
163                        if !budget_recursive(&mut c, &str_req, &tls, &f, &mut write, &mut read)
164                            .await
165                        {
166                            break;
167                        };
168                    };
169                };
170            }
171        });
172
173        Ok(Self {
174            ws_sender,
175            handle,
176            url,
177            auth_header,
178        })
179    }
180
181    #[must_use]
182    /// Returns a reference to the URL in use
183    pub fn url(&self) -> &str {
184        &self.url
185    }
186
187    #[must_use]
188    /// Returns a reference to the auth header in use
189    pub fn auth_header(&self) -> &str {
190        &self.auth_header
191    }
192
193    /// Subscribe to a new API event
194    pub fn subscribe(&mut self, endpoint: EventType) {
195        self.request(RequestType::Subscribe, endpoint);
196    }
197
198    /// Unsubscribe to a new API event
199    pub fn unsubscribe(&mut self, endpoint: EventType) {
200        self.request(RequestType::Unsubscribe, endpoint);
201    }
202
203    /// Terminate the event loop
204    pub fn terminate(&self) {
205        self.handle.abort();
206    }
207
208    #[must_use]
209    pub fn is_finished(&self) -> bool {
210        self.handle.is_finished()
211    }
212
213    /// Allows you to make a generic
214    /// request to the websocket socket
215    pub fn request(&mut self, code: RequestType, endpoint: EventType) {
216        let _ = &self.ws_sender.send((code, endpoint));
217    }
218}
219
220async fn budget_recursive(
221    c: &mut ControlFlow<(), Flow>,
222    str_req: &str,
223    tls: &Arc<ClientConfig>,
224    f: &(impl Fn(Result<&[Value], Error>) -> ControlFlow<(), Flow> + Sync + Send + 'static),
225    write: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
226    read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
227) -> bool {
228    while *c != ControlFlow::Continue(Flow::Continue) {
229        if *c == ControlFlow::Continue(Flow::TryReconnect) {
230            let tls = tls.clone();
231            let rec = reconnect(str_req, tls, write, read).await;
232            if let Err(e) = rec {
233                *c = f(Err(e));
234            } else {
235                break;
236            }
237        } else {
238            return false;
239        }
240    }
241
242    true
243}
244
245async fn reconnect(
246    str_req: &str,
247    tls: Arc<ClientConfig>,
248    write: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
249    read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
250) -> Result<(), Error> {
251    let req = str_req.into_client_request().unwrap();
252    let connector = Connector::Rustls(tls.clone());
253    let (stream, _) = connect_async_tls_with_config(req, None, false, Some(connector))
254        .await
255        .map_err(Error::WebsocketError)?;
256    (*write, *read) = stream.split();
257    Ok(())
258}
259
260#[cfg(test)]
261mod test {
262    use super::LCUWebSocket;
263    use std::time::Duration;
264
265    // #[ignore = "This does not need to be run often"]
266    #[tokio::test]
267    async fn it_inits() {
268        let mut ws_client = LCUWebSocket::new(|values| {
269            println!("{values:?}");
270            std::ops::ControlFlow::Continue(crate::ws::Flow::Continue)
271        })
272        .await
273        .unwrap();
274        ws_client.subscribe(crate::ws::EventType::OnJsonApiEvent);
275
276        while !ws_client.is_finished() {
277            tokio::time::sleep(Duration::from_secs(1)).await;
278        }
279    }
280}