1use std::borrow::Cow;
4use std::{collections::HashSet, ops::ControlFlow, sync::Arc};
5
6use futures_util::{
7 stream::{SplitSink, SplitStream},
8 SinkExt, StreamExt,
9};
10use rustls::ClientConfig;
11use serde_json::Value;
12use tokio::{
13 net::TcpStream,
14 sync::mpsc::{self, UnboundedSender},
15 task::JoinHandle,
16};
17use tokio_tungstenite::{
18 connect_async_tls_with_config,
19 tungstenite::{client::IntoClientRequest, http::HeaderValue, Message},
20 Connector, MaybeTlsStream, WebSocketStream,
21};
22
23use crate::{
24 utils::{process_info::get_running_client, setup_tls::setup_tls_connector},
25 Error,
26};
27
28#[derive(PartialEq, Clone)]
30pub enum RequestType {
31 Welcome = 0,
32 Prefix = 1,
33 Call = 2,
34 CallResult = 3,
35 CallError = 4,
36 Subscribe = 5,
37 Unsubscribe = 6,
38 Publish = 7,
39 Event = 8,
40}
41
42#[derive(Eq, Hash, PartialEq, Clone)]
43pub enum EventType {
46 OnJsonApiEvent,
47 OnLcdsEvent,
48 OnLog,
49 OnRegionLocaleChanged,
50 OnServiceProxyAsyncEvent,
51 OnServiceProxyMethodEvent,
52 OnServiceProxyUuidEvent,
53 OnJsonApiEventCallback(String),
54 OnLcdsEventCallback(String),
55}
56
57impl EventType {
58 fn to_string(&self) -> Cow<'static, str> {
59 match self {
60 EventType::OnJsonApiEvent => "OnJsonApiEvent".into(),
61 EventType::OnLcdsEvent => "OnLcdsEvent".into(),
62 EventType::OnLog => "OnLog".into(),
63 EventType::OnRegionLocaleChanged => "OnRegionLocaleChanged".into(),
64 EventType::OnServiceProxyAsyncEvent => "OnServiceProxyAsyncEvent".into(),
65 EventType::OnServiceProxyMethodEvent => "OnServiceProxyMethodEvent".into(),
66 EventType::OnServiceProxyUuidEvent => "OnServiceProxyUuidEvent".into(),
67 EventType::OnJsonApiEventCallback(callback) => {
68 format!("OnJsonApiEvent{}", callback.replace('/', "_")).into()
69 }
70 EventType::OnLcdsEventCallback(callback) => {
71 format!("OnLcdsEvent{}", callback.replace('/', "_")).into()
72 }
73 }
74 }
75}
76
77pub struct LCUWebSocket {
79 ws_sender: UnboundedSender<(RequestType, EventType)>,
80 handle: JoinHandle<()>,
81 url: String,
82 auth_header: String,
83}
84
85#[derive(PartialEq)]
86pub enum Flow {
87 TryReconnect,
88 Continue,
89}
90
91impl LCUWebSocket {
92 pub async fn new(
102 f: impl Fn(Result<&[Value], Error>) -> ControlFlow<(), Flow> + Send + Sync + 'static,
103 ) -> Result<Self, Error> {
104 let tls = setup_tls_connector();
105 let tls = Arc::new(tls);
106 let connector = Connector::Rustls(tls.clone());
107 let (url, auth_header) = get_running_client(false)?;
108 let str_req = format!("wss://{url}");
109 let mut request = str_req
110 .as_str()
111 .into_client_request()
112 .map_err(Error::WebsocketError)?;
113 request.headers_mut().insert(
114 "Authorization",
115 HeaderValue::from_str(&auth_header).expect("This is always a valid header"),
116 );
117
118 let (stream, _) = connect_async_tls_with_config(request, None, false, Some(connector))
119 .await
120 .map_err(Error::WebsocketError)?;
121
122 let (ws_sender, mut ws_receiver) = mpsc::unbounded_channel::<(RequestType, EventType)>();
123
124 let handle = tokio::spawn(async move {
125 let mut active_commands = HashSet::new();
126 let (mut write, mut read) = stream.split();
127
128 loop {
129 if let Ok((code, endpoint)) = ws_receiver.try_recv() {
130 let endpoint = endpoint.to_string();
131
132 let command = format!("[{}, \"{endpoint}\"]", code.clone() as u8);
133
134 if code == RequestType::Subscribe {
135 active_commands.insert(endpoint.clone());
136 } else if code == RequestType::Unsubscribe {
137 active_commands.remove(&endpoint);
138 };
139
140 if write.send(command.into()).await.is_err() {
141 let mut c = f(Err(Error::LCUProcessNotRunning));
142 if !budget_recursive(&mut c, &str_req, &tls, &f, &mut write, &mut read)
143 .await
144 {
145 break;
146 };
147 };
148 };
149
150 if let Some(Ok(data)) = read.next().await {
151 if let Ok(json) = &serde_json::from_slice::<Vec<Value>>(&data.into_data()) {
152 let json = if let Some(endpoint) = json[1].as_str() {
153 if active_commands.contains(endpoint) {
154 json
155 } else {
156 continue;
157 }
158 } else {
159 json
160 };
161
162 let mut c = f(Ok(json));
163 if !budget_recursive(&mut c, &str_req, &tls, &f, &mut write, &mut read)
164 .await
165 {
166 break;
167 };
168 };
169 };
170 }
171 });
172
173 Ok(Self {
174 ws_sender,
175 handle,
176 url,
177 auth_header,
178 })
179 }
180
181 #[must_use]
182 pub fn url(&self) -> &str {
184 &self.url
185 }
186
187 #[must_use]
188 pub fn auth_header(&self) -> &str {
190 &self.auth_header
191 }
192
193 pub fn subscribe(&mut self, endpoint: EventType) {
195 self.request(RequestType::Subscribe, endpoint);
196 }
197
198 pub fn unsubscribe(&mut self, endpoint: EventType) {
200 self.request(RequestType::Unsubscribe, endpoint);
201 }
202
203 pub fn terminate(&self) {
205 self.handle.abort();
206 }
207
208 #[must_use]
209 pub fn is_finished(&self) -> bool {
210 self.handle.is_finished()
211 }
212
213 pub fn request(&mut self, code: RequestType, endpoint: EventType) {
216 let _ = &self.ws_sender.send((code, endpoint));
217 }
218}
219
220async fn budget_recursive(
221 c: &mut ControlFlow<(), Flow>,
222 str_req: &str,
223 tls: &Arc<ClientConfig>,
224 f: &(impl Fn(Result<&[Value], Error>) -> ControlFlow<(), Flow> + Sync + Send + 'static),
225 write: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
226 read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
227) -> bool {
228 while *c != ControlFlow::Continue(Flow::Continue) {
229 if *c == ControlFlow::Continue(Flow::TryReconnect) {
230 let tls = tls.clone();
231 let rec = reconnect(str_req, tls, write, read).await;
232 if let Err(e) = rec {
233 *c = f(Err(e));
234 } else {
235 break;
236 }
237 } else {
238 return false;
239 }
240 }
241
242 true
243}
244
245async fn reconnect(
246 str_req: &str,
247 tls: Arc<ClientConfig>,
248 write: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
249 read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
250) -> Result<(), Error> {
251 let req = str_req.into_client_request().unwrap();
252 let connector = Connector::Rustls(tls.clone());
253 let (stream, _) = connect_async_tls_with_config(req, None, false, Some(connector))
254 .await
255 .map_err(Error::WebsocketError)?;
256 (*write, *read) = stream.split();
257 Ok(())
258}
259
260#[cfg(test)]
261mod test {
262 use super::LCUWebSocket;
263 use std::time::Duration;
264
265 #[tokio::test]
267 async fn it_inits() {
268 let mut ws_client = LCUWebSocket::new(|values| {
269 println!("{values:?}");
270 std::ops::ControlFlow::Continue(crate::ws::Flow::Continue)
271 })
272 .await
273 .unwrap();
274 ws_client.subscribe(crate::ws::EventType::OnJsonApiEvent);
275
276 while !ws_client.is_finished() {
277 tokio::time::sleep(Duration::from_secs(1)).await;
278 }
279 }
280}