1use crate::{Bridge, BridgeConfig, BridgeError, BridgeEvent, Result};
7use async_trait::async_trait;
8use clasp_core::{Message, SetMessage, Value};
9use futures::{SinkExt, StreamExt};
10use parking_lot::Mutex;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::{TcpListener, TcpStream};
16use tokio::sync::mpsc;
17use tokio_tungstenite::{
18 accept_async, connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream,
19 WebSocketStream,
20};
21use tracing::{debug, error, info, warn};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum WsMessageFormat {
27 #[default]
29 Json,
30 MsgPack,
32 Raw,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum WsMode {
40 #[default]
42 Client,
43 Server,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct WebSocketBridgeConfig {
50 #[serde(default)]
52 pub mode: WsMode,
53 pub url: String,
55 #[serde(default)]
57 pub path: Option<String>,
58 #[serde(default)]
60 pub format: WsMessageFormat,
61 #[serde(default = "default_ping_interval")]
63 pub ping_interval_secs: u32,
64 #[serde(default = "default_true")]
66 pub auto_reconnect: bool,
67 #[serde(default = "default_reconnect_delay")]
69 pub reconnect_delay_secs: u32,
70 #[serde(default)]
72 pub headers: HashMap<String, String>,
73 #[serde(default = "default_namespace")]
75 pub namespace: String,
76}
77
78fn default_true() -> bool {
79 true
80}
81
82fn default_ping_interval() -> u32 {
83 30
84}
85
86fn default_reconnect_delay() -> u32 {
87 5
88}
89
90fn default_namespace() -> String {
91 "/ws".to_string()
92}
93
94impl Default for WebSocketBridgeConfig {
95 fn default() -> Self {
96 Self {
97 mode: WsMode::Client,
98 url: "ws://localhost:8080".to_string(),
99 path: None,
100 format: WsMessageFormat::Json,
101 ping_interval_secs: 30,
102 auto_reconnect: true,
103 reconnect_delay_secs: 5,
104 headers: HashMap::new(),
105 namespace: "/ws".to_string(),
106 }
107 }
108}
109
110#[allow(dead_code)]
112type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
113
114#[allow(dead_code)]
116type WsServerStream = WebSocketStream<TcpStream>;
117
118#[allow(dead_code)]
120type WsSink = futures::stream::SplitSink<WsServerStream, WsMessage>;
121
122pub struct WebSocketBridge {
124 config: BridgeConfig,
125 ws_config: WebSocketBridgeConfig,
126 running: Arc<Mutex<bool>>,
127 send_tx: Option<mpsc::Sender<WsMessage>>,
128 shutdown_tx: Option<mpsc::Sender<()>>,
129}
130
131impl WebSocketBridge {
132 pub fn new(ws_config: WebSocketBridgeConfig) -> Self {
134 let config = BridgeConfig {
135 name: "WebSocket Bridge".to_string(),
136 protocol: "websocket".to_string(),
137 bidirectional: true,
138 ..Default::default()
139 };
140
141 Self {
142 config,
143 ws_config,
144 running: Arc::new(Mutex::new(false)),
145 send_tx: None,
146 shutdown_tx: None,
147 }
148 }
149
150 fn parse_message(msg: &WsMessage, format: WsMessageFormat, prefix: &str) -> Option<Message> {
152 match msg {
153 WsMessage::Text(text) => match format {
154 WsMessageFormat::Json | WsMessageFormat::Raw => {
155 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
157 let address = json
158 .get("address")
159 .and_then(|v| v.as_str())
160 .map(|s| s.to_string())
161 .unwrap_or_else(|| format!("{}/message", prefix));
162
163 let value = json
164 .get("value")
165 .map(|v| Self::json_to_value(v.clone()))
166 .or_else(|| json.get("data").map(|v| Self::json_to_value(v.clone())))
167 .unwrap_or_else(|| Self::json_to_value(json));
168
169 Some(Message::Set(SetMessage {
170 address,
171 value,
172 revision: None,
173 lock: false,
174 unlock: false,
175 }))
176 } else {
177 Some(Message::Set(SetMessage {
179 address: format!("{}/text", prefix),
180 value: Value::String(text.clone()),
181 revision: None,
182 lock: false,
183 unlock: false,
184 }))
185 }
186 }
187 WsMessageFormat::MsgPack => {
188 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
191 if let (Some(addr), Some(val)) = (
193 json.get("address").and_then(|a| a.as_str()),
194 json.get("value"),
195 ) {
196 Some(Message::Set(SetMessage {
197 address: addr.to_string(),
198 value: Self::json_to_value(val.clone()),
199 revision: None,
200 lock: false,
201 unlock: false,
202 }))
203 } else {
204 Some(Message::Set(SetMessage {
206 address: format!("{}/text", prefix),
207 value: Value::String(text.clone()),
208 revision: None,
209 lock: false,
210 unlock: false,
211 }))
212 }
213 } else {
214 Some(Message::Set(SetMessage {
216 address: format!("{}/text", prefix),
217 value: Value::String(text.clone()),
218 revision: None,
219 lock: false,
220 unlock: false,
221 }))
222 }
223 }
224 },
225 WsMessage::Binary(data) => match format {
226 WsMessageFormat::MsgPack => {
227 if let Ok((msg, _)) = clasp_core::codec::decode(data) {
229 Some(msg)
230 } else {
231 Some(Message::Set(SetMessage {
233 address: format!("{}/binary", prefix),
234 value: Value::Bytes(data.clone()),
235 revision: None,
236 lock: false,
237 unlock: false,
238 }))
239 }
240 }
241 WsMessageFormat::Raw | WsMessageFormat::Json => Some(Message::Set(SetMessage {
242 address: format!("{}/binary", prefix),
243 value: Value::Bytes(data.clone()),
244 revision: None,
245 lock: false,
246 unlock: false,
247 })),
248 },
249 _ => None,
250 }
251 }
252
253 fn json_to_value(json: serde_json::Value) -> Value {
255 match json {
256 serde_json::Value::Null => Value::Null,
257 serde_json::Value::Bool(b) => Value::Bool(b),
258 serde_json::Value::Number(n) => {
259 if let Some(i) = n.as_i64() {
260 Value::Int(i)
261 } else if let Some(f) = n.as_f64() {
262 Value::Float(f)
263 } else {
264 Value::Null
265 }
266 }
267 serde_json::Value::String(s) => Value::String(s),
268 serde_json::Value::Array(arr) => {
269 Value::Array(arr.into_iter().map(Self::json_to_value).collect())
270 }
271 serde_json::Value::Object(obj) => {
272 let map: HashMap<String, Value> = obj
273 .into_iter()
274 .map(|(k, v)| (k, Self::json_to_value(v)))
275 .collect();
276 Value::Map(map)
277 }
278 }
279 }
280
281 fn message_to_ws(msg: &Message, format: WsMessageFormat) -> Option<WsMessage> {
283 let (address, value) = match msg {
284 Message::Set(set) => (Some(&set.address), Some(&set.value)),
285 Message::Publish(pub_msg) => (Some(&pub_msg.address), pub_msg.value.as_ref()),
286 _ => return None,
287 };
288
289 match format {
290 WsMessageFormat::Json => {
291 let json = serde_json::json!({
292 "address": address,
293 "value": value,
294 });
295 Some(WsMessage::Text(json.to_string()))
296 }
297 WsMessageFormat::MsgPack => {
298 if let Ok(encoded) = clasp_core::codec::encode(msg) {
299 Some(WsMessage::Binary(encoded.to_vec()))
300 } else {
301 None
302 }
303 }
304 WsMessageFormat::Raw => {
305 if let Some(val) = value {
306 match val {
307 Value::String(s) => Some(WsMessage::Text(s.clone())),
308 Value::Bytes(b) => Some(WsMessage::Binary(b.clone())),
309 _ => {
310 let json = serde_json::to_string(val).ok()?;
311 Some(WsMessage::Text(json))
312 }
313 }
314 } else {
315 None
316 }
317 }
318 }
319 }
320
321 #[allow(clippy::too_many_arguments)]
323 async fn run_client(
324 url: String,
325 format: WsMessageFormat,
326 namespace: String,
327 auto_reconnect: bool,
328 reconnect_delay: u32,
329 ping_interval_secs: u32,
330 event_tx: mpsc::Sender<BridgeEvent>,
331 mut send_rx: mpsc::Receiver<WsMessage>,
332 mut shutdown_rx: mpsc::Receiver<()>,
333 running: Arc<Mutex<bool>>,
334 ) {
335 loop {
336 info!("WebSocket connecting to {}", url);
337
338 match connect_async(&url).await {
339 Ok((ws_stream, _)) => {
340 info!("WebSocket connected");
341 *running.lock() = true;
342 let _ = event_tx.send(BridgeEvent::Connected).await;
343
344 let (mut write, mut read) = ws_stream.split();
345
346 let ping_duration = if ping_interval_secs > 0 {
348 Some(std::time::Duration::from_secs(ping_interval_secs as u64))
349 } else {
350 None
351 };
352 let mut ping_interval = ping_duration.map(tokio::time::interval);
353 let mut awaiting_pong = false;
354
355 loop {
356 tokio::select! {
357 msg = read.next() => {
359 match msg {
360 Some(Ok(ws_msg)) => {
361 match &ws_msg {
362 WsMessage::Pong(_) => {
363 awaiting_pong = false;
364 debug!("Received pong");
365 }
366 WsMessage::Ping(data) => {
367 if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
369 error!("Failed to send pong: {}", e);
370 break;
371 }
372 }
373 _ => {
374 if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
375 let _ = event_tx.send(BridgeEvent::ToClasp(Box::new(clasp_msg))).await;
376 }
377 }
378 }
379 }
380 Some(Err(e)) => {
381 error!("WebSocket error: {}", e);
382 break;
383 }
384 None => {
385 warn!("WebSocket connection closed");
386 break;
387 }
388 }
389 }
390 msg = send_rx.recv() => {
392 if let Some(ws_msg) = msg {
393 if let Err(e) = write.send(ws_msg).await {
394 error!("WebSocket send error: {}", e);
395 break;
396 }
397 }
398 }
399 _ = async {
401 if let Some(ref mut interval) = ping_interval {
402 interval.tick().await
403 } else {
404 std::future::pending::<tokio::time::Instant>().await
406 }
407 } => {
408 if awaiting_pong {
409 warn!("Ping timeout - no pong received");
410 break;
411 }
412 if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
413 error!("Failed to send ping: {}", e);
414 break;
415 }
416 awaiting_pong = true;
417 debug!("Sent ping");
418 }
419 _ = shutdown_rx.recv() => {
421 info!("WebSocket shutting down");
422 let _ = write.close().await;
423 *running.lock() = false;
424 return;
425 }
426 }
427 }
428
429 *running.lock() = false;
430 let _ = event_tx
431 .send(BridgeEvent::Disconnected {
432 reason: Some("Connection closed".to_string()),
433 })
434 .await;
435 }
436 Err(e) => {
437 error!("WebSocket connection failed: {}", e);
438 let _ = event_tx
439 .send(BridgeEvent::Error(format!("Connection failed: {}", e)))
440 .await;
441 }
442 }
443
444 if !auto_reconnect {
445 *running.lock() = false;
446 return;
447 }
448
449 info!("Reconnecting in {} seconds...", reconnect_delay);
450 tokio::time::sleep(std::time::Duration::from_secs(reconnect_delay as u64)).await;
451 }
452 }
453
454 #[allow(clippy::too_many_arguments)]
456 async fn run_server(
457 addr: SocketAddr,
458 format: WsMessageFormat,
459 namespace: String,
460 ping_interval_secs: u32,
461 event_tx: mpsc::Sender<BridgeEvent>,
462 mut send_rx: mpsc::Receiver<WsMessage>,
463 mut shutdown_rx: mpsc::Receiver<()>,
464 running: Arc<Mutex<bool>>,
465 ) {
466 use std::sync::atomic::{AtomicU64, Ordering};
467 use tokio::sync::RwLock;
468
469 let listener = match TcpListener::bind(addr).await {
470 Ok(l) => l,
471 Err(e) => {
472 error!("Failed to bind WebSocket server: {}", e);
473 let _ = event_tx
474 .send(BridgeEvent::Error(format!("Bind failed: {}", e)))
475 .await;
476 return;
477 }
478 };
479
480 info!("WebSocket server listening on {}", addr);
481 *running.lock() = true;
482 let _ = event_tx.send(BridgeEvent::Connected).await;
483
484 let clients: Arc<RwLock<HashMap<u64, mpsc::Sender<WsMessage>>>> =
486 Arc::new(RwLock::new(HashMap::new()));
487 let next_client_id = Arc::new(AtomicU64::new(0));
488
489 loop {
490 tokio::select! {
491 result = listener.accept() => {
492 match result {
493 Ok((stream, peer_addr)) => {
494 let client_id = next_client_id.fetch_add(1, Ordering::SeqCst);
495 info!("WebSocket client {} connected: {}", client_id, peer_addr);
496
497 let namespace = namespace.clone();
498 let event_tx = event_tx.clone();
499 let clients = clients.clone();
500 let ping_interval = ping_interval_secs;
501
502 let (client_tx, mut client_rx) = mpsc::channel::<WsMessage>(100);
504 clients.write().await.insert(client_id, client_tx);
505
506 tokio::spawn(async move {
507 if let Ok(ws_stream) = accept_async(stream).await {
508 let (mut write, mut read) = ws_stream.split();
509
510 let ping_duration = if ping_interval > 0 {
512 Some(std::time::Duration::from_secs(ping_interval as u64))
513 } else {
514 None
515 };
516 let mut ping_timer = ping_duration.map(tokio::time::interval);
517 let mut awaiting_pong = false;
518
519 loop {
520 tokio::select! {
521 msg = read.next() => {
523 match msg {
524 Some(Ok(ws_msg)) => {
525 match &ws_msg {
526 WsMessage::Pong(_) => {
527 awaiting_pong = false;
528 debug!("Client {} pong received", client_id);
529 }
530 WsMessage::Ping(data) => {
531 if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
532 debug!("Failed to send pong to client {}: {}", client_id, e);
533 break;
534 }
535 }
536 _ => {
537 if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
538 let _ = event_tx.send(BridgeEvent::ToClasp(Box::new(clasp_msg))).await;
539 }
540 }
541 }
542 }
543 Some(Err(e)) => {
544 debug!("WebSocket client {} error: {}", client_id, e);
545 break;
546 }
547 None => break,
548 }
549 }
550 msg = client_rx.recv() => {
552 match msg {
553 Some(ws_msg) => {
554 if let Err(e) = write.send(ws_msg).await {
555 debug!("Failed to send to client {}: {}", client_id, e);
556 break;
557 }
558 }
559 None => break,
560 }
561 }
562 _ = async {
564 if let Some(ref mut timer) = ping_timer {
565 timer.tick().await
566 } else {
567 std::future::pending::<tokio::time::Instant>().await
568 }
569 } => {
570 if awaiting_pong {
571 warn!("Client {} ping timeout", client_id);
572 break;
573 }
574 if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
575 debug!("Failed to send ping to client {}: {}", client_id, e);
576 break;
577 }
578 awaiting_pong = true;
579 }
580 }
581 }
582 }
583
584 clients.write().await.remove(&client_id);
586 info!("WebSocket client {} disconnected: {}", client_id, peer_addr);
587 });
588 }
589 Err(e) => {
590 error!("WebSocket accept error: {}", e);
591 }
592 }
593 }
594 msg = send_rx.recv() => {
596 if let Some(ws_msg) = msg {
597 let client_list: Vec<_> = clients.read().await.values().cloned().collect();
599 for tx in client_list {
600 let _ = tx.send(ws_msg.clone()).await;
601 }
602 }
603 }
604 _ = shutdown_rx.recv() => {
605 info!("WebSocket server shutting down");
606 break;
607 }
608 }
609 }
610
611 *running.lock() = false;
612 let _ = event_tx
613 .send(BridgeEvent::Disconnected {
614 reason: Some("Server stopped".to_string()),
615 })
616 .await;
617 }
618}
619
620#[async_trait]
621impl Bridge for WebSocketBridge {
622 fn config(&self) -> &BridgeConfig {
623 &self.config
624 }
625
626 async fn start(&mut self) -> Result<mpsc::Receiver<BridgeEvent>> {
627 if *self.running.lock() {
628 return Err(BridgeError::Other("Bridge already running".to_string()));
629 }
630
631 let (event_tx, event_rx) = mpsc::channel(100);
632 let (send_tx, send_rx) = mpsc::channel(100);
633 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
634
635 self.send_tx = Some(send_tx);
636 self.shutdown_tx = Some(shutdown_tx);
637
638 let running = self.running.clone();
639 let ws_config = self.ws_config.clone();
640
641 match ws_config.mode {
642 WsMode::Client => {
643 tokio::spawn(Self::run_client(
644 ws_config.url,
645 ws_config.format,
646 ws_config.namespace,
647 ws_config.auto_reconnect,
648 ws_config.reconnect_delay_secs,
649 ws_config.ping_interval_secs,
650 event_tx,
651 send_rx,
652 shutdown_rx,
653 running,
654 ));
655 }
656 WsMode::Server => {
657 let addr: SocketAddr = ws_config
658 .url
659 .parse()
660 .map_err(|e| BridgeError::Other(format!("Invalid address: {}", e)))?;
661
662 tokio::spawn(Self::run_server(
663 addr,
664 ws_config.format,
665 ws_config.namespace,
666 ws_config.ping_interval_secs,
667 event_tx,
668 send_rx,
669 shutdown_rx,
670 running,
671 ));
672 }
673 }
674
675 info!("WebSocket bridge started in {:?} mode", self.ws_config.mode);
676 Ok(event_rx)
677 }
678
679 async fn stop(&mut self) -> Result<()> {
680 *self.running.lock() = false;
681 if let Some(tx) = self.shutdown_tx.take() {
682 let _ = tx.send(()).await;
683 }
684 self.send_tx = None;
685 info!("WebSocket bridge stopped");
686 Ok(())
687 }
688
689 async fn send(&self, msg: Message) -> Result<()> {
690 let send_tx = self
691 .send_tx
692 .as_ref()
693 .ok_or_else(|| BridgeError::Other("Not connected".to_string()))?;
694
695 if let Some(ws_msg) = Self::message_to_ws(&msg, self.ws_config.format) {
696 send_tx
697 .send(ws_msg)
698 .await
699 .map_err(|e| BridgeError::Other(format!("WebSocket send failed: {}", e)))?;
700 }
701
702 Ok(())
703 }
704
705 fn is_running(&self) -> bool {
706 *self.running.lock()
707 }
708
709 fn namespace(&self) -> &str {
710 &self.ws_config.namespace
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717
718 #[test]
719 fn test_config_default() {
720 let config = WebSocketBridgeConfig::default();
721 assert_eq!(config.mode, WsMode::Client);
722 assert_eq!(config.namespace, "/ws");
723 }
724
725 #[test]
726 fn test_message_formats() {
727 let prefix = "/ws";
728
729 let ws_msg = WsMessage::Text(r#"{"address": "/test", "value": 42}"#.to_string());
731 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
732 assert!(clasp.is_some());
733
734 let ws_msg = WsMessage::Text("hello".to_string());
736 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
737 assert!(clasp.is_some());
738
739 let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
741 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Raw, prefix);
742 assert!(clasp.is_some());
743 }
744}