lighthouse_client/
lighthouse.rs

1use std::{collections::HashMap, fmt::Debug, sync::{atomic::{AtomicI32, Ordering}, Arc}};
2
3use async_tungstenite::tungstenite::{Message, self};
4use futures::{prelude::*, channel::mpsc::{Sender, self}, stream::{SplitSink, SplitStream}, lock::Mutex};
5use lighthouse_protocol::{Authentication, ClientMessage, DirectoryTree, Frame, InputEvent, LaserMetrics, Model, ServerMessage, Value, Verb};
6use serde::{Deserialize, Serialize};
7use stream_guard::GuardStreamExt;
8use tracing::{warn, error, debug, info};
9use crate::{Check, Error, Result, Spawner};
10
11/// A connection to the lighthouse server for sending requests and receiving events.
12pub struct Lighthouse<S> {
13    /// The sink-part of the WebSocket connection.
14    ws_sink: Arc<Mutex<SplitSink<S, Message>>>,
15    /// The response/event slots, keyed by request id.
16    slots: Arc<Mutex<HashMap<i32, Slot<ServerMessage<Value>>>>>,
17    /// The credentials used to authenticate with the lighthouse.
18    authentication: Authentication,
19    /// The next request id. Incremented on every request.
20    request_id: Arc<AtomicI32>,
21}
22
23/// A facility for coordinating asynchronous responses to a request between a
24/// requesting task and a receive loop task.
25enum Slot<M> {
26    /// Indicates that messages were received before the requesting task
27    /// registered the slot. **The receive loop** will construct this variant in
28    /// that case, i.e. store the already received messages in a
29    /// [`Slot::EarlyMessages`].
30    EarlyMessages(Vec<M>),
31    /// Indicates that no messages were received before the requesting task
32    /// registered the slot. **The requesting thread** will construct this
33    /// variant in that case, i.e. create a channel, store the sender in a
34    /// [`Slot::WaitForMessages`] for the receive loop and then return the
35    /// receiver.
36    WaitForMessages(Sender<M>),
37}
38
39impl<S> Lighthouse<S>
40    where S: Stream<Item = tungstenite::Result<Message>>
41           + Sink<Message, Error = tungstenite::Error>
42           + Send
43           + 'static {
44    /// Connects to the lighthouse using the given credentials.
45    /// Asynchronously runs a receive loop using the provided spawner.
46    pub fn new<W>(web_socket: S, authentication: Authentication) -> Result<Self> where W: Spawner {
47        let (ws_sink, ws_stream) = web_socket.split();
48        let slots = Arc::new(Mutex::new(HashMap::new()));
49        let lh = Self {
50            ws_sink: Arc::new(Mutex::new(ws_sink)),
51            slots: slots.clone(),
52            authentication,
53            request_id: Arc::new(AtomicI32::new(0)),
54        };
55        W::spawn(Self::run_receive_loop(ws_stream, slots));
56        Ok(lh)
57    }
58
59    /// Runs a loop that continuously receives events.
60    #[tracing::instrument(skip(ws_stream, slots))]
61    async fn run_receive_loop(mut ws_stream: SplitStream<S>, slots: Arc<Mutex<HashMap<i32, Slot<ServerMessage<Value>>>>>) {
62        loop {
63            match Self::receive_message_from(&mut ws_stream).await {
64                Ok(msg) => {
65                    let mut slots = slots.lock().await;
66                    if let Some(request_id) = msg.request_id {
67                        if let Some(slot) = slots.get_mut(&request_id) {
68                            match slot {
69                                Slot::EarlyMessages(msgs) => msgs.push(msg),
70                                Slot::WaitForMessages(tx) => {
71                                    if let Err(e) = tx.send(msg).await {
72                                        if e.is_disconnected() {
73                                            info!("Receiver for request id {} disconnected, removing the sender...", request_id);
74                                            slots.remove(&request_id);
75                                        } else {
76                                            warn!("Could not send message for request id {} via channel: {:?}", request_id, e);
77                                        }
78                                    }
79                                }
80                            }
81                        } else {
82                            slots.insert(request_id, Slot::EarlyMessages(vec![msg]));
83                        }
84                    } else {
85                        warn!("Got message without request id from server: {:?}", msg);
86                    }
87                },
88                Err(Error::NoNextMessage) => {
89                    info!("No next message available, closing receive loop");
90                    break
91                },
92                Err(e) => error!("Bad message: {:?}", e),
93            }
94        }
95    }
96
97    /// Receives a ServerMessage from the lighthouse.
98    #[tracing::instrument(skip(ws_stream))]
99    async fn receive_message_from<P>(ws_stream: &mut SplitStream<S>) -> Result<ServerMessage<P>>
100    where
101        P: for<'de> Deserialize<'de> {
102        let bytes = Self::receive_raw_from(ws_stream).await?;
103        let message = rmp_serde::from_slice(&bytes)?;
104        Ok(message)
105    }
106
107    /// Receives raw bytes from the lighthouse via the WebSocket connection.
108    #[tracing::instrument(skip(ws_stream))]
109    async fn receive_raw_from(ws_stream: &mut SplitStream<S>) -> Result<Vec<u8>> {
110        loop {
111            let message = ws_stream.next().await.ok_or_else(|| Error::NoNextMessage)??;
112            match message {
113                Message::Binary(bytes) => break Ok(bytes),
114                Message::Ping(_) => {}, // Ignore pings for now
115                Message::Close(_) => break Err(Error::ConnectionClosed),
116                _ => warn!("Got non-binary message: {:?}", message),
117            }
118        }
119    }
120
121    /// Replaces the user's lighthouse model with the given frame.
122    pub async fn put_model(&self, frame: Frame) -> Result<ServerMessage<()>> {
123        let username = self.authentication.username.clone();
124        self.put(&["user".into(), username, "model".into()], Model::Frame(frame)).await
125    }
126
127    /// Requests a stream of events (including key/controller events) for the user's lighthouse model.
128    pub async fn stream_model(&self) -> Result<impl Stream<Item = Result<ServerMessage<Model>>>> {
129        let username = self.authentication.username.clone();
130        self.stream(&["user".into(), username, "model".into()], ()).await
131    }
132
133    /// Sends an input event to the user's input endpoint.
134    /// 
135    /// Note that this is the new API which not all clients may support.
136    pub async fn put_input(&self, payload: InputEvent) -> Result<ServerMessage<()>> {
137        let username = self.authentication.username.clone();
138        self.put(&["user".into(), username, "input".into()], payload).await
139    }
140
141    /// Streams input events from the user's input endpoint.
142    /// 
143    /// Note that this is the new API which not all clients may support (in LUNA
144    /// disabling the legacy mode will send events to this endpoint).  If your
145    /// client or library does not support this, you may need to `stream_model`
146    /// and parse `LegacyInputEvent`s from there.
147    pub async fn stream_input(&self) -> Result<impl Stream<Item = Result<ServerMessage<InputEvent>>>> {
148        let username = self.authentication.username.clone();
149        Ok(
150            self.stream(&["user".into(), username, "input".into()], ()).await?
151                .skip(1) // Skip the persisted input (TODO: Should we handle this at the server level via some form of passthrough resources?)
152        )
153    }
154
155    /// Fetches lamp server metrics.
156    pub async fn get_laser_metrics(&self) -> Result<ServerMessage<LaserMetrics>> {
157        self.get(&["metrics", "laser"]).await
158    }
159
160    /// Combines PUT and CREATE. Requires CREATE and WRITE permission.
161    pub async fn post<P>(&self, path: &[impl AsRef<str> + Debug], payload: P) -> Result<ServerMessage<()>>
162    where
163        P: Serialize {
164        self.perform(&Verb::Post, path, payload).await
165    }
166
167    /// Updates the resource at the given path with the given payload. Requires WRITE permission.
168    pub async fn put<P>(&self, path: &[impl AsRef<str> + Debug], payload: P) -> Result<ServerMessage<()>>
169    where
170        P: Serialize {
171        self.perform(&Verb::Put, path, payload).await
172    }
173
174    /// Creates a resource at the given path. Requires CREATE permission.
175    pub async fn create(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
176        self.perform(&Verb::Create, path, ()).await
177    }
178
179    /// Deletes a resource at the given path. Requires DELETE permission.
180    pub async fn delete(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
181        self.perform(&Verb::Delete, path, ()).await
182    }
183
184    /// Creates a directory at the given path. Requires CREATE permission.
185    pub async fn mkdir(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
186        self.perform(&Verb::Mkdir, path, ()).await
187    }
188
189    /// Lists the directory tree at the given path. Requires READ permission.
190    pub async fn list(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<DirectoryTree>> {
191        self.perform(&Verb::List, path, ()).await
192    }
193
194    /// Gets the resource at the given path. Requires READ permission.
195    pub async fn get<R>(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<R>>
196    where
197        R: for<'de> Deserialize<'de> {
198        self.perform(&Verb::Get, path, ()).await
199    }
200
201    /// Links the given source to the given destination path.
202    pub async fn link(&self, src_path: &[impl AsRef<str> + Debug], dest_path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
203        self.perform(&Verb::Link, dest_path, src_path.iter().map(|s| s.as_ref().to_owned()).collect::<Vec<_>>()).await
204    }
205
206    /// Unlinks the given source from the given destination path.
207    pub async fn unlink(&self, src_path: &[impl AsRef<str> + Debug], dest_path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
208        self.perform(&Verb::Unlink, dest_path, src_path.iter().map(|s| s.as_ref().to_owned()).collect::<Vec<_>>()).await
209    }
210
211    /// Stops the given stream. **Should generally not be called manually**,
212    /// since streams will automatically be stopped once dropped.
213    pub async fn stop(&self, request_id: i32, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
214        self.perform_with_id(request_id, &Verb::Stop, path, ()).await
215    }
216
217    /// Performs a single request to the given path with the given payload.
218    #[tracing::instrument(skip(self, payload))]
219    pub async fn perform<P, R>(&self, verb: &Verb, path: &[impl AsRef<str> + Debug], payload: P) -> Result<ServerMessage<R>>
220    where
221        P: Serialize,
222        R: for<'de> Deserialize<'de> {
223        let request_id = self.next_request_id();
224        self.perform_with_id(request_id, verb, path, payload).await
225    }
226
227    /// Performs a single request to the given path with the given request id.
228    #[tracing::instrument(skip(self, payload))]
229    async fn perform_with_id<P, R>(&self, request_id: i32, verb: &Verb, path: &[impl AsRef<str> + Debug], payload: P) -> Result<ServerMessage<R>>
230    where
231        P: Serialize,
232        R: for<'de> Deserialize<'de> {
233        assert_ne!(verb, &Verb::Stream, "Lighthouse::perform may only be used for one-off requests, use Lighthouse::stream for streaming.");
234        self.send_request(request_id, verb, path, payload).await?;
235        let response = self.receive_single(request_id).await?.check()?.decode_payload()?;
236        Ok(response)
237    }
238    
239    /// Performs a STREAM request to the given path with the given payload.
240    /// Automatically sends a STOP once dropped.
241    #[tracing::instrument(skip(self, payload))]
242    pub async fn stream<P, R>(&self, path: &[impl AsRef<str> + Debug], payload: P) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
243    where
244        P: Serialize,
245        R: for<'de> Deserialize<'de> {
246        let request_id = self.next_request_id();
247        let path: Vec<String> = path.into_iter().map(|s| s.as_ref().to_string()).collect();
248        self.send_request(request_id, &Verb::Stream, &path, payload).await?;
249        let stream = self.receive_streaming(request_id).await?;
250        Ok(stream.map(|m| Ok(m?.check()?.decode_payload()?)).guard({
251            // Stop the stream on drop
252            let this = (*self).clone();
253            move || {
254                tokio::spawn(async move {
255                    if let Err(error) = this.stop(request_id, &path).await {
256                        error! { ?path, %error, "Could not STOP stream" };
257                    }
258                });
259            }
260        }))
261    }
262
263    /// Sends a request to the given path with the given payload.
264    async fn send_request<P>(&self, request_id: i32, verb: &Verb, path: &[impl AsRef<str> + Debug], payload: P) -> Result<i32>
265    where
266        P: Serialize {
267        let path = path.into_iter().map(|s| s.as_ref().to_string()).collect();
268        debug! { %request_id, "Sending request" };
269        self.send_message(&ClientMessage {
270            request_id,
271            authentication: self.authentication.clone(),
272            path,
273            meta: HashMap::new(),
274            verb: verb.clone(),
275            payload
276        }).await?;
277        Ok(request_id)
278    }
279
280    /// Sends a generic message to the lighthouse.
281    async fn send_message<P>(&self, message: &ClientMessage<P>) -> Result<()>
282    where
283        P: Serialize {
284        self.send_raw(rmp_serde::to_vec_named(message)?).await
285    }
286
287    /// Receives a single response for the given request id.
288    #[tracing::instrument(skip(self))]
289    async fn receive_single<R>(&self, request_id: i32) -> Result<ServerMessage<R>>
290    where
291        R: for<'de> Deserialize<'de> {
292        let mut rx = self.receive(request_id).await?;
293        rx.next().await.ok_or_else(|| Error::Custom(format!("No response for {}", request_id)))?
294    }
295
296    /// Receives a stream of responses for the given request id.
297    #[tracing::instrument(skip(self))]
298    async fn receive_streaming<R>(&self, request_id: i32) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
299    where
300        R: for<'de> Deserialize<'de> {
301        self.receive(request_id).await
302    }
303
304    async fn receive<R>(&self, request_id: i32) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
305    where
306        R: for<'de> Deserialize<'de> {
307        let rx = {
308            let capacity = 4;
309            let (tx, rx) = {
310                let mut slots = self.slots.lock().await;
311                if let Some(Slot::EarlyMessages(msgs)) = slots.get_mut(&request_id) {
312                    let (mut tx, rx) = mpsc::channel(capacity.min(msgs.len()));
313                    for msg in msgs.drain(..) {
314                        tx.feed(msg).await.map_err(|e| Error::Custom(format!("Could not feed tx with early message: {}", e)))?;
315                    } 
316                    tx.flush().await.map_err(|e| Error::Custom(format!("Could not flush tx with early messages: {}", e)))?;
317                    (tx, rx)
318                } else {
319                    mpsc::channel(capacity)
320                }
321            };
322            self.slots.lock().await.insert(request_id, Slot::WaitForMessages(tx));
323            rx
324        };
325        Ok(rx.map(|s| Ok(s.decode_payload()?)).guard({
326            let slots = self.slots.clone();
327            move || {
328                tokio::spawn(async move {
329                    let mut slots = slots.lock().await;
330                    slots.remove(&request_id);
331                });
332            }
333        }))
334    }
335
336    /// Sends raw bytes to the lighthouse via the WebSocket connection.
337    async fn send_raw(&self, bytes: impl Into<Vec<u8>> + Debug) -> Result<()> {
338        Ok(self.ws_sink.lock().await.send(Message::Binary(bytes.into())).await?)
339    }
340
341    /// Fetches the next request id.
342    fn next_request_id(&self) -> i32 {
343        self.request_id.fetch_add(1, Ordering::Relaxed)
344    }
345
346    /// Fetches the credentials used to authenticate with the lighthouse.
347    pub fn authentication(&self) -> &Authentication {
348        &self.authentication
349    }
350
351    /// Closes the WebSocket connection gracefully with a close message. While
352    /// the server will usually also handle abruptly closed connections
353    /// properly, it is recommended to always close the [``Lighthouse``].
354    pub async fn close(&self) -> Result<()> {
355        Ok(self.ws_sink.lock().await.close().await?)
356    }
357}
358
359// For some reason `#[derive(Clone)]` adds the trait bound `S: Clone`, despite
360// not actually being needed since the WebSocket sink is already wrapped in an
361// `Arc`, therefore we implement `Clone` manually.
362
363impl<S> Clone for Lighthouse<S> {
364    fn clone(&self) -> Self {
365        Self {
366            ws_sink: self.ws_sink.clone(),
367            slots: self.slots.clone(),
368            authentication: self.authentication.clone(),
369            request_id: self.request_id.clone(),
370        }
371    }
372}