1use futures::{Sink, Stream, StreamExt};
2use futures_util::lock::Mutex;
3use futures_util::stream::SplitSink;
4use futures_util::{
5 future::{join_all, ready},
6 SinkExt,
7};
8use log::debug;
9use serde_json::Value;
10use std::{collections::HashMap, sync::Arc};
11use futures::channel::oneshot;
12use futures::channel::oneshot::{Receiver, Sender};
13use tokio_tungstenite::tungstenite::Error;
14use tokio_tungstenite::{
15 connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream,
16};
17
18use super::command::{create_command, Command, CommandResponse};
19use crate::command::CommandRequest;
20
21use serde::{Deserialize, Serialize};
22
23pub struct WebosClient<T> {
25 write: Box<Mutex<T>>,
26 next_command_id: Arc<Mutex<u64>>,
27 callbacks: Arc<Mutex<HashMap<String, Sender<CommandResponse>>>>,
28 pub key: Option<String>,
29}
30
31#[derive(Debug)]
32pub enum ClientError {
33 MalformedUrl,
34 ConnectionError,
35 CommandSendError,
36}
37
38#[derive(Serialize, Deserialize)]
39pub struct WebOsClientConfig {
40 pub address: String,
41 pub key: Option<String>,
42}
43
44impl Default for WebOsClientConfig {
45 fn default() -> Self {
46 WebOsClientConfig::new("ws://lgwebostv:3000/", None)
47 }
48}
49
50impl WebOsClientConfig {
51 pub fn new(addr: &str, key: Option<String>) -> WebOsClientConfig {
53 let address = String::from(addr);
54 WebOsClientConfig { address, key }
55 }
56}
57
58impl Clone for WebOsClientConfig {
59 fn clone(&self) -> Self {
60 let addr = self.address.clone();
61 let key = self.key.clone();
62 WebOsClientConfig { address: addr, key }
63 }
64}
65
66impl WebosClient<SplitSink<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, Message>> {
67 pub async fn new(config: WebOsClientConfig) -> Result<Self, ClientError> {
69 let url = url::Url::parse(&config.address).map_err(|_| ClientError::MalformedUrl)?;
70 let (ws_stream, _) = connect_async(url)
71 .await
72 .map_err(|_| ClientError::ConnectionError)?;
73 debug!("WebSocket handshake has been successfully completed");
74 let (write, read) = ws_stream.split();
75 WebosClient::from_stream_and_sink(read, write, config).await
76 }
77}
78
79impl<T> WebosClient<T>
80where
81 T: Sink<Message, Error = Error> + Unpin,
82{
83 pub async fn from_stream_and_sink<S>(
85 stream: S,
86 mut sink: T,
87 config: WebOsClientConfig,
88 ) -> Result<Self, ClientError>
89 where
90 S: Stream<Item = Result<Message, Error>> + Send + 'static,
91 {
92 let command_id_generator = Arc::from(Mutex::from(0));
93 let callbacks: Arc<Mutex<HashMap<String, Sender<CommandResponse>>>> =
94 Arc::from(Mutex::from(HashMap::new()));
95 let callbacks_copy = callbacks.clone();
96 let (sender, receiver) = oneshot::channel::<CommandResponse>();
97 tokio::spawn(async move { process_messages_from_server(stream, callbacks_copy).await });
98 let mut handshake = get_handshake();
99 if let Some(key) = config.key {
101 handshake["payload"]["client-key"] = Value::from(key);
102 }
103 let registration_id = 0.to_string();
104 handshake["id"] = Value::from(registration_id.to_string());
105 callbacks.lock().await.insert(registration_id, sender);
106 let formatted_handshake = format!("{handshake}");
107 sink.send(Message::text(formatted_handshake))
108 .await
109 .map_err(|_| ClientError::CommandSendError)?;
110 let key = Some(receiver.await.unwrap().payload.unwrap().to_string());
111 Ok(WebosClient {
112 write: Box::new(Mutex::new(sink)),
113 next_command_id: command_id_generator,
114 callbacks,
115 key,
116 })
117 }
118 pub async fn send_command(&self, cmd: Command) -> Result<CommandResponse, ClientError> {
120 let (message, promise) = self
121 .prepare_command_to_send(cmd)
122 .await
123 .map_err(|_| ClientError::CommandSendError)?;
124 self.write
125 .lock()
126 .await
127 .send(message)
128 .await
129 .map_err(|_| ClientError::CommandSendError)?;
130 promise.await.map_err(|_| ClientError::CommandSendError)
131 }
132
133 pub async fn send_all_commands(
135 self,
136 cmds: Vec<Command>,
137 ) -> Result<Vec<CommandResponse>, ClientError> {
138 let mut promises: Vec<Receiver<CommandResponse>> = vec![];
139 let commands = join_all(
140 cmds.into_iter()
141 .map(|cmd| async { self.prepare_command_to_send(cmd).await }),
142 )
143 .await;
144 let messages: Vec<Result<Message, Error>> = commands
145 .into_iter()
146 .map(|command| {
147 let x = command.unwrap();
148 promises.push(x.1);
149 Ok(x.0)
150 })
151 .collect();
152
153 let mut iter = futures_util::stream::iter(messages);
154 self.write
155 .lock()
156 .await
157 .send_all(&mut iter)
158 .await
159 .map_err(|_| ClientError::CommandSendError)?;
160 Ok(join_all(promises)
161 .await
162 .into_iter()
163 .map(|resp| resp.unwrap())
164 .collect())
165 }
166
167 async fn prepare_command_to_send(
168 &self,
169 cmd: Command,
170 ) -> Result<(Message, Receiver<CommandResponse>), ()> {
171 let id = self.generate_next_id().await;
172 let (sender, receiver) = oneshot::channel::<CommandResponse>();
173
174 if let Some(mut lock) = self.callbacks.try_lock() {
175 lock.insert(id.clone(), sender);
176 let message = Message::from(&create_command(id, cmd));
177 Ok((message, receiver))
178 } else {
179 Err(())
180 }
181 }
182
183 async fn generate_next_id(&self) -> String {
184 let mut guard = self.next_command_id.lock().await;
185 *guard += 1;
186 guard.to_string()
187 }
188}
189
190async fn process_messages_from_server<T>(
191 stream: T,
192 pending_requests: Arc<Mutex<HashMap<String, Sender<CommandResponse>>>>,
193) where
194 T: Stream<Item = Result<Message, Error>> + Send,
195{
196 stream
197 .for_each(|message| match message {
198 Ok(_message) => {
199 if let Ok(text_message) = _message.into_text() {
200 if let Ok(json) = serde_json::from_str::<Value>(&text_message) {
201 if let Some(r) = json["id"].as_str() {
202 if json["payload"]["pairingType"] != "PROMPT" {
204 let response = CommandResponse {
205 id: Some(r.to_string()),
206 payload: Some(json["payload"].clone()),
207 };
208 if let Some(mut requests) = pending_requests.try_lock() {
209 if let Some(id) = response.id.clone() {
210 if let Some(sender) = requests.remove(&id) {
211 sender.send(response).unwrap();
212 }
213 }
214 }
215 }
216 }
217 }
218 }
219 ready(())
220 }
221 Err(_) => ready(()),
222 })
223 .await
224}
225
226impl From<&CommandRequest> for Message {
227 fn from(request: &CommandRequest) -> Self {
228 Message::text(serde_json::to_string(request).unwrap())
229 }
230}
231
232fn get_handshake() -> serde_json::Value {
238 serde_json::from_str(include_str!("../handshake.json")).expect("Could not parse handshake json")
239}
240
241#[cfg(test)]
242mod tests {
243
244 struct LgDevice {
245 registered: bool,
246 responses: HashMap<String, Message>,
247 queue: VecDeque<Message>,
248 }
249 impl LgDevice {
250 pub fn new(responses: HashMap<String, Message>) -> Self {
251 LgDevice {
252 registered: false,
253 responses,
254 queue: VecDeque::new(),
255 }
256 }
257 }
258
259 impl Sink<Message> for LgDevice {
260 type Error = Error;
261
262 fn poll_ready(
263 self: Pin<&mut Self>,
264 _cx: &mut Context<'_>,
265 ) -> Poll<Result<(), Self::Error>> {
266 Poll::Ready(Ok(()))
267 }
268
269 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
270 if let Ok(text_message) = item.into_text() {
271 if let Ok(json) = serde_json::from_str::<Value>(&text_message) {
272 let id = json["id"].as_str().unwrap();
273 let mut _self = self.get_mut();
274 if let Some(response) = _self.responses.remove(id) {
275 _self.queue.push_front(response);
276 }
277 }
278 }
279
280 Ok(())
281 }
282
283 fn poll_flush(
284 self: Pin<&mut Self>,
285 _cx: &mut Context<'_>,
286 ) -> Poll<Result<(), Self::Error>> {
287 Poll::Ready(Ok(()))
288 }
289
290 fn poll_close(
291 self: Pin<&mut Self>,
292 _cx: &mut Context<'_>,
293 ) -> Poll<Result<(), Self::Error>> {
294 Poll::Ready(Ok(()))
295 }
296 }
297
298 impl Stream for LgDevice {
299 type Item = Result<Message, Error>;
300
301 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302 cx.waker().wake_by_ref();
303 if !self.registered {
304 self.get_mut().registered = true;
305 return Poll::Ready(Some(Ok(Message::Text(
306 r#"{
307 "id": "0",
308 "type": "registered",
309 "payload": {
310 "client-key": "key"
311 }
312 }"#
313 .to_owned(),
314 ))));
315 } else {
316 return if let Some(message) = self.get_mut().queue.pop_front() {
317 Poll::Ready(Some(Ok(message)))
318 } else {
319 Poll::Pending
320 };
321 }
322 }
323 }
324
325 use crate::client::{WebOsClientConfig, WebosClient};
326 use crate::command::Command;
327 use futures_util::{Sink, Stream, StreamExt};
328 use serde_json::Value;
329 use std::collections::{HashMap, VecDeque};
330 use std::pin::Pin;
331 use std::task::{Context, Poll};
332 use tokio_tungstenite::tungstenite::{Error, Message};
333
334 #[tokio::test]
335 async fn create_client() {
336 let device = LgDevice::new(HashMap::new());
337 let (sink, stream) = device.split();
338 assert!(
339 WebosClient::from_stream_and_sink(stream, sink, WebOsClientConfig::default())
340 .await
341 .is_ok()
342 );
343 }
344
345 #[tokio::test]
346 async fn send_command() {
347 let mut responses = HashMap::new();
348 responses.insert(
349 "1".to_owned(),
350 Message::Text(
351 r#"
352 {
353 "id": "1",
354 "payload": {
355 "returnValue": true
356 },
357 "type":"response"
358 }"#
359 .to_owned(),
360 ),
361 );
362
363 let device = LgDevice::new(responses);
364 let (sink, stream) = device.split();
365 let client = WebosClient::from_stream_and_sink(stream, sink, WebOsClientConfig::default())
366 .await
367 .unwrap();
368 client.send_command(Command::ChannelUp).await.unwrap();
369 }
370}