1use std::{collections::HashMap, fmt::Debug, sync::{atomic::{AtomicI32, Ordering}, Arc}};
2
3use async_tungstenite::tungstenite::{Message, self};
4use futures::{prelude::*, channel::mpsc::{Sender, self}, stream::{SplitSink, SplitStream}, lock::Mutex};
5use lighthouse_protocol::{Authentication, ClientMessage, DirectoryTree, Frame, InputEvent, LaserMetrics, Model, ServerMessage, Value, Verb};
6use serde::{Deserialize, Serialize};
7use stream_guard::GuardStreamExt;
8use tracing::{warn, error, debug, info};
9use crate::{Check, Error, Result, Spawner};
10
11pub struct Lighthouse<S> {
13 ws_sink: Arc<Mutex<SplitSink<S, Message>>>,
15 slots: Arc<Mutex<HashMap<i32, Slot<ServerMessage<Value>>>>>,
17 authentication: Authentication,
19 request_id: Arc<AtomicI32>,
21}
22
23enum Slot<M> {
26 EarlyMessages(Vec<M>),
31 WaitForMessages(Sender<M>),
37}
38
39impl<S> Lighthouse<S>
40 where S: Stream<Item = tungstenite::Result<Message>>
41 + Sink<Message, Error = tungstenite::Error>
42 + Send
43 + 'static {
44 pub fn new<W>(web_socket: S, authentication: Authentication) -> Result<Self> where W: Spawner {
47 let (ws_sink, ws_stream) = web_socket.split();
48 let slots = Arc::new(Mutex::new(HashMap::new()));
49 let lh = Self {
50 ws_sink: Arc::new(Mutex::new(ws_sink)),
51 slots: slots.clone(),
52 authentication,
53 request_id: Arc::new(AtomicI32::new(0)),
54 };
55 W::spawn(Self::run_receive_loop(ws_stream, slots));
56 Ok(lh)
57 }
58
59 #[tracing::instrument(skip(ws_stream, slots))]
61 async fn run_receive_loop(mut ws_stream: SplitStream<S>, slots: Arc<Mutex<HashMap<i32, Slot<ServerMessage<Value>>>>>) {
62 loop {
63 match Self::receive_message_from(&mut ws_stream).await {
64 Ok(msg) => {
65 let mut slots = slots.lock().await;
66 if let Some(request_id) = msg.request_id {
67 if let Some(slot) = slots.get_mut(&request_id) {
68 match slot {
69 Slot::EarlyMessages(msgs) => msgs.push(msg),
70 Slot::WaitForMessages(tx) => {
71 if let Err(e) = tx.send(msg).await {
72 if e.is_disconnected() {
73 info!("Receiver for request id {} disconnected, removing the sender...", request_id);
74 slots.remove(&request_id);
75 } else {
76 warn!("Could not send message for request id {} via channel: {:?}", request_id, e);
77 }
78 }
79 }
80 }
81 } else {
82 slots.insert(request_id, Slot::EarlyMessages(vec![msg]));
83 }
84 } else {
85 warn!("Got message without request id from server: {:?}", msg);
86 }
87 },
88 Err(Error::NoNextMessage) => {
89 info!("No next message available, closing receive loop");
90 break
91 },
92 Err(e) => error!("Bad message: {:?}", e),
93 }
94 }
95 }
96
97 #[tracing::instrument(skip(ws_stream))]
99 async fn receive_message_from<P>(ws_stream: &mut SplitStream<S>) -> Result<ServerMessage<P>>
100 where
101 P: for<'de> Deserialize<'de> {
102 let bytes = Self::receive_raw_from(ws_stream).await?;
103 let message = rmp_serde::from_slice(&bytes)?;
104 Ok(message)
105 }
106
107 #[tracing::instrument(skip(ws_stream))]
109 async fn receive_raw_from(ws_stream: &mut SplitStream<S>) -> Result<Vec<u8>> {
110 loop {
111 let message = ws_stream.next().await.ok_or_else(|| Error::NoNextMessage)??;
112 match message {
113 Message::Binary(bytes) => break Ok(bytes),
114 Message::Ping(_) => {}, Message::Close(_) => break Err(Error::ConnectionClosed),
116 _ => warn!("Got non-binary message: {:?}", message),
117 }
118 }
119 }
120
121 pub async fn put_model(&self, frame: Frame) -> Result<ServerMessage<()>> {
123 let username = self.authentication.username.clone();
124 self.put(&["user".into(), username, "model".into()], Model::Frame(frame)).await
125 }
126
127 pub async fn stream_model(&self) -> Result<impl Stream<Item = Result<ServerMessage<Model>>>> {
129 let username = self.authentication.username.clone();
130 self.stream(&["user".into(), username, "model".into()], ()).await
131 }
132
133 pub async fn put_input(&self, payload: InputEvent) -> Result<ServerMessage<()>> {
137 let username = self.authentication.username.clone();
138 self.put(&["user".into(), username, "input".into()], payload).await
139 }
140
141 pub async fn stream_input(&self) -> Result<impl Stream<Item = Result<ServerMessage<InputEvent>>>> {
148 let username = self.authentication.username.clone();
149 Ok(
150 self.stream(&["user".into(), username, "input".into()], ()).await?
151 .skip(1) )
153 }
154
155 pub async fn get_laser_metrics(&self) -> Result<ServerMessage<LaserMetrics>> {
157 self.get(&["metrics", "laser"]).await
158 }
159
160 pub async fn post<P>(&self, path: &[impl AsRef<str> + Debug], payload: P) -> Result<ServerMessage<()>>
162 where
163 P: Serialize {
164 self.perform(&Verb::Post, path, payload).await
165 }
166
167 pub async fn put<P>(&self, path: &[impl AsRef<str> + Debug], payload: P) -> Result<ServerMessage<()>>
169 where
170 P: Serialize {
171 self.perform(&Verb::Put, path, payload).await
172 }
173
174 pub async fn create(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
176 self.perform(&Verb::Create, path, ()).await
177 }
178
179 pub async fn delete(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
181 self.perform(&Verb::Delete, path, ()).await
182 }
183
184 pub async fn mkdir(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
186 self.perform(&Verb::Mkdir, path, ()).await
187 }
188
189 pub async fn list(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<DirectoryTree>> {
191 self.perform(&Verb::List, path, ()).await
192 }
193
194 pub async fn get<R>(&self, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<R>>
196 where
197 R: for<'de> Deserialize<'de> {
198 self.perform(&Verb::Get, path, ()).await
199 }
200
201 pub async fn link(&self, src_path: &[impl AsRef<str> + Debug], dest_path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
203 self.perform(&Verb::Link, dest_path, src_path.iter().map(|s| s.as_ref().to_owned()).collect::<Vec<_>>()).await
204 }
205
206 pub async fn unlink(&self, src_path: &[impl AsRef<str> + Debug], dest_path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
208 self.perform(&Verb::Unlink, dest_path, src_path.iter().map(|s| s.as_ref().to_owned()).collect::<Vec<_>>()).await
209 }
210
211 pub async fn stop(&self, request_id: i32, path: &[impl AsRef<str> + Debug]) -> Result<ServerMessage<()>> {
214 self.perform_with_id(request_id, &Verb::Stop, path, ()).await
215 }
216
217 #[tracing::instrument(skip(self, payload))]
219 pub async fn perform<P, R>(&self, verb: &Verb, path: &[impl AsRef<str> + Debug], payload: P) -> Result<ServerMessage<R>>
220 where
221 P: Serialize,
222 R: for<'de> Deserialize<'de> {
223 let request_id = self.next_request_id();
224 self.perform_with_id(request_id, verb, path, payload).await
225 }
226
227 #[tracing::instrument(skip(self, payload))]
229 async fn perform_with_id<P, R>(&self, request_id: i32, verb: &Verb, path: &[impl AsRef<str> + Debug], payload: P) -> Result<ServerMessage<R>>
230 where
231 P: Serialize,
232 R: for<'de> Deserialize<'de> {
233 assert_ne!(verb, &Verb::Stream, "Lighthouse::perform may only be used for one-off requests, use Lighthouse::stream for streaming.");
234 self.send_request(request_id, verb, path, payload).await?;
235 let response = self.receive_single(request_id).await?.check()?.decode_payload()?;
236 Ok(response)
237 }
238
239 #[tracing::instrument(skip(self, payload))]
242 pub async fn stream<P, R>(&self, path: &[impl AsRef<str> + Debug], payload: P) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
243 where
244 P: Serialize,
245 R: for<'de> Deserialize<'de> {
246 let request_id = self.next_request_id();
247 let path: Vec<String> = path.into_iter().map(|s| s.as_ref().to_string()).collect();
248 self.send_request(request_id, &Verb::Stream, &path, payload).await?;
249 let stream = self.receive_streaming(request_id).await?;
250 Ok(stream.map(|m| Ok(m?.check()?.decode_payload()?)).guard({
251 let this = (*self).clone();
253 move || {
254 tokio::spawn(async move {
255 if let Err(error) = this.stop(request_id, &path).await {
256 error! { ?path, %error, "Could not STOP stream" };
257 }
258 });
259 }
260 }))
261 }
262
263 async fn send_request<P>(&self, request_id: i32, verb: &Verb, path: &[impl AsRef<str> + Debug], payload: P) -> Result<i32>
265 where
266 P: Serialize {
267 let path = path.into_iter().map(|s| s.as_ref().to_string()).collect();
268 debug! { %request_id, "Sending request" };
269 self.send_message(&ClientMessage {
270 request_id,
271 authentication: self.authentication.clone(),
272 path,
273 meta: HashMap::new(),
274 verb: verb.clone(),
275 payload
276 }).await?;
277 Ok(request_id)
278 }
279
280 async fn send_message<P>(&self, message: &ClientMessage<P>) -> Result<()>
282 where
283 P: Serialize {
284 self.send_raw(rmp_serde::to_vec_named(message)?).await
285 }
286
287 #[tracing::instrument(skip(self))]
289 async fn receive_single<R>(&self, request_id: i32) -> Result<ServerMessage<R>>
290 where
291 R: for<'de> Deserialize<'de> {
292 let mut rx = self.receive(request_id).await?;
293 rx.next().await.ok_or_else(|| Error::Custom(format!("No response for {}", request_id)))?
294 }
295
296 #[tracing::instrument(skip(self))]
298 async fn receive_streaming<R>(&self, request_id: i32) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
299 where
300 R: for<'de> Deserialize<'de> {
301 self.receive(request_id).await
302 }
303
304 async fn receive<R>(&self, request_id: i32) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
305 where
306 R: for<'de> Deserialize<'de> {
307 let rx = {
308 let capacity = 4;
309 let (tx, rx) = {
310 let mut slots = self.slots.lock().await;
311 if let Some(Slot::EarlyMessages(msgs)) = slots.get_mut(&request_id) {
312 let (mut tx, rx) = mpsc::channel(capacity.min(msgs.len()));
313 for msg in msgs.drain(..) {
314 tx.feed(msg).await.map_err(|e| Error::Custom(format!("Could not feed tx with early message: {}", e)))?;
315 }
316 tx.flush().await.map_err(|e| Error::Custom(format!("Could not flush tx with early messages: {}", e)))?;
317 (tx, rx)
318 } else {
319 mpsc::channel(capacity)
320 }
321 };
322 self.slots.lock().await.insert(request_id, Slot::WaitForMessages(tx));
323 rx
324 };
325 Ok(rx.map(|s| Ok(s.decode_payload()?)).guard({
326 let slots = self.slots.clone();
327 move || {
328 tokio::spawn(async move {
329 let mut slots = slots.lock().await;
330 slots.remove(&request_id);
331 });
332 }
333 }))
334 }
335
336 async fn send_raw(&self, bytes: impl Into<Vec<u8>> + Debug) -> Result<()> {
338 Ok(self.ws_sink.lock().await.send(Message::Binary(bytes.into())).await?)
339 }
340
341 fn next_request_id(&self) -> i32 {
343 self.request_id.fetch_add(1, Ordering::Relaxed)
344 }
345
346 pub fn authentication(&self) -> &Authentication {
348 &self.authentication
349 }
350
351 pub async fn close(&self) -> Result<()> {
355 Ok(self.ws_sink.lock().await.close().await?)
356 }
357}
358
359impl<S> Clone for Lighthouse<S> {
364 fn clone(&self) -> Self {
365 Self {
366 ws_sink: self.ws_sink.clone(),
367 slots: self.slots.clone(),
368 authentication: self.authentication.clone(),
369 request_id: self.request_id.clone(),
370 }
371 }
372}