1use std::{
51 collections::HashMap,
52 io::{Cursor, Write},
53 mem::size_of,
54 net::SocketAddr,
55 sync::{
56 atomic::{AtomicUsize, Ordering},
57 Arc,
58 },
59};
60
61use anyhow::anyhow;
62use futures_util::{stream::SplitSink, SinkExt, StreamExt, TryFutureExt};
63use serde::{Deserialize, Serialize};
64use tokio::sync::{mpsc, RwLock};
65use tokio_stream::wrappers::ReceiverStream;
66use uuid::Uuid;
67use warp::{
68 ws::{Message, WebSocket},
69 Filter,
70};
71
72#[derive(Clone, Debug, Serialize)]
73#[serde(rename_all = "camelCase")]
74struct ServerChannelMessage {
75 id: usize,
76 topic: String,
77 encoding: String,
78 schema_name: String,
79 schema: String,
80 schema_encoding: String,
81}
82
83#[derive(Clone, Debug, Serialize)]
84#[serde(tag = "op", rename_all = "camelCase")]
85enum ServerMessage {
86 #[serde(rename_all = "camelCase")]
87 ServerInfo {
88 name: String,
89 capabilities: Vec<String>,
90 supported_encodings: Vec<String>,
91 metadata: HashMap<String, String>,
92 session_id: String,
93 },
94 #[serde(rename_all = "camelCase")]
95 Advertise { channels: Vec<ServerChannelMessage> },
96}
97
98type ClientChannelId = u32;
99
100#[derive(Clone, Debug, Deserialize)]
101#[serde(rename_all = "camelCase")]
102struct ClientSubscriptionMessage {
103 id: ClientChannelId,
104 channel_id: usize,
105}
106
107#[derive(Clone, Debug, Deserialize)]
108#[serde(tag = "op", rename_all = "camelCase")]
109enum ClientMessage {
110 #[serde(rename_all = "camelCase")]
111 Subscribe {
112 subscriptions: Vec<ClientSubscriptionMessage>,
113 },
114 #[serde(rename_all = "camelCase")]
115 Unsubscribe {
116 subscription_ids: Vec<ClientChannelId>,
117 },
118}
119
120#[derive(Debug)]
121struct Client {
122 id: Uuid,
123 tx: mpsc::Sender<Message>,
124 subscriptions: HashMap<usize, ClientChannelId>,
125}
126
127type Clients = RwLock<HashMap<Uuid, Client>>;
128
129#[derive(Debug, Default)]
130struct ClientState {
131 clients: Clients,
132}
133
134#[derive(Debug)]
135struct MessageData {
136 timestamp_ns: u64,
137 data: Vec<u8>,
138}
139
140impl MessageData {
141 fn build_message(&self, subscription_id: u32) -> anyhow::Result<Message> {
142 let mut buffer =
143 vec![0; size_of::<u8>() + size_of::<u32>() + size_of::<u64>() + self.data.len()];
144 {
145 let mut w = Cursor::new(&mut buffer);
146 w.write(&(1 as u8).to_le_bytes())?;
148 w.write(&subscription_id.to_le_bytes())?;
150 w.write(&self.timestamp_ns.to_le_bytes())?;
151 w.write(&self.data)?;
152 }
153 Ok(Message::binary(buffer))
154 }
155}
156
157#[derive(Debug)]
159pub struct Channel {
160 id: usize,
161 topic: String,
162 is_latching: bool,
163
164 clients: Arc<ClientState>,
165 pinned_message: Arc<RwLock<Option<MessageData>>>,
166}
167
168impl Channel {
169 pub async fn send(&self, timestamp_ns: u64, data: &[u8]) -> anyhow::Result<()> {
176 let message_data = MessageData {
177 timestamp_ns,
178 data: data.to_vec(),
179 };
180 for client in self.clients.clients.read().await.values() {
181 if let Some(subscription_id) = client.subscriptions.get(&self.id) {
182 log::debug!(
183 "Send message on {} to client {} ({}).",
184 self.topic,
185 client.id,
186 client.tx.capacity()
187 );
188 client
189 .tx
190 .try_send(message_data.build_message(*subscription_id)?)?;
191 }
192 }
193
194 if self.is_latching {
195 *self.pinned_message.write().await = Some(message_data);
196 }
197
198 Ok(())
199 }
200}
201
202#[derive(Debug)]
203struct ChannelMetadata {
204 channel_message: ServerChannelMessage,
205 pinned_message: Arc<RwLock<Option<MessageData>>>,
206}
207
208type Channels = RwLock<HashMap<usize, ChannelMetadata>>;
209
210#[derive(Debug, Default)]
211struct ChannelState {
212 next_channel_id: AtomicUsize,
213 channels: Channels,
214}
215
216#[derive(Clone, Debug, Default)]
218pub struct FoxgloveWebSocket {
219 clients: Arc<ClientState>,
220 channels: Arc<ChannelState>,
221}
222
223async fn initialize_client(
224 user_ws_tx: &mut SplitSink<WebSocket, Message>,
225 channels: &Channels,
226 client_id: &Uuid,
227) -> anyhow::Result<()> {
228 user_ws_tx
229 .send(Message::text(
230 serde_json::to_string(&ServerMessage::ServerInfo {
231 name: "test_server".to_string(),
232 capabilities: vec![],
233 supported_encodings: vec![],
234 metadata: HashMap::default(),
235 session_id: client_id.as_hyphenated().to_string(),
236 })
237 .unwrap(),
238 ))
239 .await?;
240
241 let channel_messages = channels
242 .read()
243 .await
244 .values()
245 .map(|metadata| metadata.channel_message.clone())
246 .collect();
247
248 user_ws_tx
249 .send(Message::text(
250 serde_json::to_string(&ServerMessage::Advertise {
251 channels: channel_messages,
252 })
253 .unwrap(),
254 ))
255 .await?;
256
257 Ok(())
258}
259
260async fn handle_client_msg(
261 tx: &mpsc::Sender<Message>,
262 clients: &Arc<ClientState>,
263 channels: &Arc<ChannelState>,
264 client_id: &Uuid,
265 ws_msg: &Message,
266) -> anyhow::Result<()> {
267 let msg = if ws_msg.is_text() {
268 serde_json::from_str::<ClientMessage>(ws_msg.to_str().unwrap())?
269 } else if ws_msg.is_binary() {
270 return Err(anyhow!("Got binary message: unhandled at the moment."));
271 } else if ws_msg.is_close() {
272 return Ok(());
275 } else {
276 return Err(anyhow!(
277 "Got strage message, neither text nor binary: unhandled at the moment. {:?}",
278 ws_msg
279 ));
280 };
281
282 let mut clients = clients.clients.write().await;
283
284 let channels = channels.channels.read().await;
285
286 match msg {
287 ClientMessage::Subscribe { ref subscriptions } => {
288 let client = clients
289 .get_mut(client_id)
290 .ok_or(anyhow!("Client gone from client map?"))?;
291 for ClientSubscriptionMessage { id, channel_id } in subscriptions {
292 log::debug!(
293 "Client {} subscribed to {} with its own {}.",
294 client_id,
295 channel_id,
296 id
297 );
298
299 if let Some(ref channel_metadata) = channels.get(channel_id) {
300 client.subscriptions.insert(*channel_id, *id);
301 if let Some(message_data) =
302 channel_metadata.pinned_message.read().await.as_ref()
303 {
304 log::debug!("Sending latched: client {}.", client_id);
305 tx.send(message_data.build_message(*id)?).await?;
306 }
307 }
308 }
309 }
310 ClientMessage::Unsubscribe {
311 ref subscription_ids,
312 } => {
313 let client = clients
314 .get_mut(client_id)
315 .ok_or(anyhow!("Client gone from client map?"))?;
316 log::debug!("Client {} unsubscribes {:?}.", client_id, subscription_ids);
317 client
318 .subscriptions
319 .retain(|_, subscription_id| !subscription_ids.contains(subscription_id));
320 }
321 }
322 Ok(())
323}
324
325async fn client_connected(ws: WebSocket, clients: Arc<ClientState>, channels: Arc<ChannelState>) {
326 let (mut user_ws_tx, mut user_ws_rx) = ws.split();
328
329 let client_id = Uuid::new_v4();
330 log::info!("Client {} connected.", client_id);
331
332 if let Err(err) = initialize_client(&mut user_ws_tx, &channels.channels, &client_id).await {
334 log::error!("Failed to initialize client: {}.", err);
335 return;
336 }
337
338 let (tx, rx) = mpsc::channel(10);
340 let mut rx = ReceiverStream::new(rx);
341
342 tokio::task::spawn(async move {
344 while let Some(message) = rx.next().await {
345 user_ws_tx
346 .send(message)
347 .unwrap_or_else(|e| {
348 log::error!("Failed websocket send: {}.", e);
349 })
350 .await;
351 }
352 });
353
354 clients.clients.write().await.insert(
356 client_id,
357 Client {
358 id: client_id,
359 tx: tx.clone(),
360 subscriptions: HashMap::new(),
361 },
362 );
363
364 while let Some(result) = user_ws_rx.next().await {
365 let ws_msg = match result {
366 Ok(ws_msg) => ws_msg,
367 Err(err) => {
368 log::error!("Failed receiving, websocket error: {}.", err);
369 break;
370 }
371 };
372 if let Err(err) = handle_client_msg(&tx, &clients, &channels, &client_id, &ws_msg).await {
373 log::error!("Failed handling client message: {}.", err);
374 break;
375 }
376 }
377
378 log::debug!("Client {} closed.", client_id);
379 clients.clients.write().await.remove(&client_id);
380}
381
382impl FoxgloveWebSocket {
383 pub fn new() -> Self {
385 FoxgloveWebSocket::default()
386 }
387
388 pub async fn serve(&self, addr: impl Into<SocketAddr>) {
394 let clients = self.clients.clone();
395 let clients = warp::any().map(move || clients.clone());
396 let channels = self.channels.clone();
397 let channels = warp::any().map(move || channels.clone());
398 let foxglove_ws = warp::path::end().and(
399 warp::ws()
400 .and(clients)
401 .and(channels)
402 .map(|ws: warp::ws::Ws, clients, channels| {
403 ws.on_upgrade(move |socket| client_connected(socket, clients, channels))
404 })
405 .map(|reply| {
406 warp::reply::with_header(
407 reply,
408 "Sec-WebSocket-Protocol",
409 "foxglove.websocket.v1",
410 )
411 }),
412 );
413 warp::serve(foxglove_ws).run(addr).await;
414 }
415
416 pub async fn publish(
432 &self,
433 topic: String,
434 encoding: String,
435 schema_name: String,
436 schema: String,
437 schema_encoding: String,
438 is_latching: bool,
439 ) -> anyhow::Result<Channel> {
440 let channel_id = self
441 .channels
442 .next_channel_id
443 .fetch_add(1, Ordering::Relaxed);
444 log::info!("Publish new channel {}: {}.", topic, channel_id);
445 let channel = Channel {
446 id: channel_id,
447 topic: topic.clone(),
448 is_latching,
449 clients: self.clients.clone(),
450 pinned_message: Arc::default(),
451 };
452 let channel_message = ServerChannelMessage {
453 id: channel_id,
454 topic,
455 encoding,
456 schema_name,
457 schema,
458 schema_encoding,
459 };
460
461 for client in self.clients.clients.read().await.values() {
463 client
464 .tx
465 .send(Message::text(
466 serde_json::to_string(&ServerMessage::Advertise {
467 channels: vec![channel_message.clone()],
468 })
469 .unwrap(),
470 ))
471 .await?;
472 }
473
474 self.channels.channels.write().await.insert(
475 channel_id,
476 ChannelMetadata {
477 channel_message,
478 pinned_message: channel.pinned_message.clone(),
479 },
480 );
481
482 Ok(channel)
483 }
484}