1use crate::{Bridge, BridgeConfig, BridgeError, BridgeEvent, Result};
7use async_trait::async_trait;
8use clasp_core::{Message, SetMessage, Value};
9use futures::{SinkExt, StreamExt};
10use parking_lot::Mutex;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::{TcpListener, TcpStream};
16use tokio::sync::mpsc;
17use tokio_tungstenite::{
18 accept_async, connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream,
19 WebSocketStream,
20};
21use tracing::{debug, error, info, warn};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum WsMessageFormat {
27 #[default]
29 Json,
30 MsgPack,
32 Raw,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum WsMode {
40 #[default]
42 Client,
43 Server,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct WebSocketBridgeConfig {
50 #[serde(default)]
52 pub mode: WsMode,
53 pub url: String,
55 #[serde(default)]
57 pub path: Option<String>,
58 #[serde(default)]
60 pub format: WsMessageFormat,
61 #[serde(default = "default_ping_interval")]
63 pub ping_interval_secs: u32,
64 #[serde(default = "default_true")]
66 pub auto_reconnect: bool,
67 #[serde(default = "default_reconnect_delay")]
69 pub reconnect_delay_secs: u32,
70 #[serde(default)]
72 pub headers: HashMap<String, String>,
73 #[serde(default = "default_namespace")]
75 pub namespace: String,
76}
77
78fn default_true() -> bool {
79 true
80}
81
82fn default_ping_interval() -> u32 {
83 30
84}
85
86fn default_reconnect_delay() -> u32 {
87 5
88}
89
90fn default_namespace() -> String {
91 "/ws".to_string()
92}
93
94impl Default for WebSocketBridgeConfig {
95 fn default() -> Self {
96 Self {
97 mode: WsMode::Client,
98 url: "ws://localhost:8080".to_string(),
99 path: None,
100 format: WsMessageFormat::Json,
101 ping_interval_secs: 30,
102 auto_reconnect: true,
103 reconnect_delay_secs: 5,
104 headers: HashMap::new(),
105 namespace: "/ws".to_string(),
106 }
107 }
108}
109
110#[allow(dead_code)]
112type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
113
114#[allow(dead_code)]
116type WsServerStream = WebSocketStream<TcpStream>;
117
118#[allow(dead_code)]
120type WsSink = futures::stream::SplitSink<WsServerStream, WsMessage>;
121
122pub struct WebSocketBridge {
124 config: BridgeConfig,
125 ws_config: WebSocketBridgeConfig,
126 running: Arc<Mutex<bool>>,
127 send_tx: Option<mpsc::Sender<WsMessage>>,
128 shutdown_tx: Option<mpsc::Sender<()>>,
129}
130
131impl WebSocketBridge {
132 pub fn new(ws_config: WebSocketBridgeConfig) -> Self {
134 let config = BridgeConfig {
135 name: "WebSocket Bridge".to_string(),
136 protocol: "websocket".to_string(),
137 bidirectional: true,
138 ..Default::default()
139 };
140
141 Self {
142 config,
143 ws_config,
144 running: Arc::new(Mutex::new(false)),
145 send_tx: None,
146 shutdown_tx: None,
147 }
148 }
149
150 fn parse_message(msg: &WsMessage, format: WsMessageFormat, prefix: &str) -> Option<Message> {
152 match msg {
153 WsMessage::Text(text) => match format {
154 WsMessageFormat::Json | WsMessageFormat::Raw => {
155 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
157 let address = json
158 .get("address")
159 .and_then(|v| v.as_str())
160 .map(|s| s.to_string())
161 .unwrap_or_else(|| format!("{}/message", prefix));
162
163 let value = json
164 .get("value")
165 .map(|v| Self::json_to_value(v.clone()))
166 .or_else(|| json.get("data").map(|v| Self::json_to_value(v.clone())))
167 .unwrap_or_else(|| Self::json_to_value(json));
168
169 Some(Message::Set(SetMessage {
170 address,
171 value,
172 revision: None,
173 lock: false,
174 unlock: false,
175 ttl: None,
176 }))
177 } else {
178 Some(Message::Set(SetMessage {
180 address: format!("{}/text", prefix),
181 value: Value::String(text.clone()),
182 revision: None,
183 lock: false,
184 unlock: false,
185 ttl: None,
186 }))
187 }
188 }
189 WsMessageFormat::MsgPack => {
190 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
193 if let (Some(addr), Some(val)) = (
195 json.get("address").and_then(|a| a.as_str()),
196 json.get("value"),
197 ) {
198 Some(Message::Set(SetMessage {
199 address: addr.to_string(),
200 value: Self::json_to_value(val.clone()),
201 revision: None,
202 lock: false,
203 unlock: false,
204 ttl: None,
205 }))
206 } else {
207 Some(Message::Set(SetMessage {
209 address: format!("{}/text", prefix),
210 value: Value::String(text.clone()),
211 revision: None,
212 lock: false,
213 unlock: false,
214 ttl: None,
215 }))
216 }
217 } else {
218 Some(Message::Set(SetMessage {
220 address: format!("{}/text", prefix),
221 value: Value::String(text.clone()),
222 revision: None,
223 lock: false,
224 unlock: false,
225 ttl: None,
226 }))
227 }
228 }
229 },
230 WsMessage::Binary(data) => match format {
231 WsMessageFormat::MsgPack => {
232 if let Ok((msg, _)) = clasp_core::codec::decode(data) {
234 Some(msg)
235 } else {
236 Some(Message::Set(SetMessage {
238 address: format!("{}/binary", prefix),
239 value: Value::Bytes(data.clone()),
240 revision: None,
241 lock: false,
242 unlock: false,
243 ttl: None,
244 }))
245 }
246 }
247 WsMessageFormat::Raw | WsMessageFormat::Json => Some(Message::Set(SetMessage {
248 address: format!("{}/binary", prefix),
249 value: Value::Bytes(data.clone()),
250 revision: None,
251 lock: false,
252 unlock: false,
253 ttl: None,
254 })),
255 },
256 _ => None,
257 }
258 }
259
260 fn json_to_value(json: serde_json::Value) -> Value {
262 match json {
263 serde_json::Value::Null => Value::Null,
264 serde_json::Value::Bool(b) => Value::Bool(b),
265 serde_json::Value::Number(n) => {
266 if let Some(i) = n.as_i64() {
267 Value::Int(i)
268 } else if let Some(f) = n.as_f64() {
269 Value::Float(f)
270 } else {
271 Value::Null
272 }
273 }
274 serde_json::Value::String(s) => Value::String(s),
275 serde_json::Value::Array(arr) => {
276 Value::Array(arr.into_iter().map(Self::json_to_value).collect())
277 }
278 serde_json::Value::Object(obj) => {
279 let map: HashMap<String, Value> = obj
280 .into_iter()
281 .map(|(k, v)| (k, Self::json_to_value(v)))
282 .collect();
283 Value::Map(map)
284 }
285 }
286 }
287
288 fn message_to_ws(msg: &Message, format: WsMessageFormat) -> Option<WsMessage> {
290 let (address, value) = match msg {
291 Message::Set(set) => (Some(&set.address), Some(&set.value)),
292 Message::Publish(pub_msg) => (Some(&pub_msg.address), pub_msg.value.as_ref()),
293 _ => return None,
294 };
295
296 match format {
297 WsMessageFormat::Json => {
298 let json = serde_json::json!({
299 "address": address,
300 "value": value,
301 });
302 Some(WsMessage::Text(json.to_string()))
303 }
304 WsMessageFormat::MsgPack => {
305 if let Ok(encoded) = clasp_core::codec::encode(msg) {
306 Some(WsMessage::Binary(encoded.to_vec()))
307 } else {
308 None
309 }
310 }
311 WsMessageFormat::Raw => {
312 if let Some(val) = value {
313 match val {
314 Value::String(s) => Some(WsMessage::Text(s.clone())),
315 Value::Bytes(b) => Some(WsMessage::Binary(b.clone())),
316 _ => {
317 let json = serde_json::to_string(val).ok()?;
318 Some(WsMessage::Text(json))
319 }
320 }
321 } else {
322 None
323 }
324 }
325 }
326 }
327
328 #[allow(clippy::too_many_arguments)]
330 async fn run_client(
331 url: String,
332 format: WsMessageFormat,
333 namespace: String,
334 auto_reconnect: bool,
335 reconnect_delay: u32,
336 ping_interval_secs: u32,
337 event_tx: mpsc::Sender<BridgeEvent>,
338 mut send_rx: mpsc::Receiver<WsMessage>,
339 mut shutdown_rx: mpsc::Receiver<()>,
340 running: Arc<Mutex<bool>>,
341 ) {
342 loop {
343 info!("WebSocket connecting to {}", url);
344
345 match connect_async(&url).await {
346 Ok((ws_stream, _)) => {
347 info!("WebSocket connected");
348 *running.lock() = true;
349 let _ = event_tx.send(BridgeEvent::Connected).await;
350
351 let (mut write, mut read) = ws_stream.split();
352
353 let ping_duration = if ping_interval_secs > 0 {
355 Some(std::time::Duration::from_secs(ping_interval_secs as u64))
356 } else {
357 None
358 };
359 let mut ping_interval = ping_duration.map(tokio::time::interval);
360 let mut awaiting_pong = false;
361
362 loop {
363 tokio::select! {
364 msg = read.next() => {
366 match msg {
367 Some(Ok(ws_msg)) => {
368 match &ws_msg {
369 WsMessage::Pong(_) => {
370 awaiting_pong = false;
371 debug!("Received pong");
372 }
373 WsMessage::Ping(data) => {
374 if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
376 error!("Failed to send pong: {}", e);
377 break;
378 }
379 }
380 _ => {
381 if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
382 let _ = event_tx.send(BridgeEvent::ToClasp(Box::new(clasp_msg))).await;
383 }
384 }
385 }
386 }
387 Some(Err(e)) => {
388 error!("WebSocket error: {}", e);
389 break;
390 }
391 None => {
392 warn!("WebSocket connection closed");
393 break;
394 }
395 }
396 }
397 msg = send_rx.recv() => {
399 if let Some(ws_msg) = msg {
400 if let Err(e) = write.send(ws_msg).await {
401 error!("WebSocket send error: {}", e);
402 break;
403 }
404 }
405 }
406 _ = async {
408 if let Some(ref mut interval) = ping_interval {
409 interval.tick().await
410 } else {
411 std::future::pending::<tokio::time::Instant>().await
413 }
414 } => {
415 if awaiting_pong {
416 warn!("Ping timeout - no pong received");
417 break;
418 }
419 if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
420 error!("Failed to send ping: {}", e);
421 break;
422 }
423 awaiting_pong = true;
424 debug!("Sent ping");
425 }
426 _ = shutdown_rx.recv() => {
428 info!("WebSocket shutting down");
429 let _ = write.close().await;
430 *running.lock() = false;
431 return;
432 }
433 }
434 }
435
436 *running.lock() = false;
437 let _ = event_tx
438 .send(BridgeEvent::Disconnected {
439 reason: Some("Connection closed".to_string()),
440 })
441 .await;
442 }
443 Err(e) => {
444 error!("WebSocket connection failed: {}", e);
445 let _ = event_tx
446 .send(BridgeEvent::Error(format!("Connection failed: {}", e)))
447 .await;
448 }
449 }
450
451 if !auto_reconnect {
452 *running.lock() = false;
453 return;
454 }
455
456 info!("Reconnecting in {} seconds...", reconnect_delay);
457 tokio::time::sleep(std::time::Duration::from_secs(reconnect_delay as u64)).await;
458 }
459 }
460
461 #[allow(clippy::too_many_arguments)]
463 async fn run_server(
464 addr: SocketAddr,
465 format: WsMessageFormat,
466 namespace: String,
467 ping_interval_secs: u32,
468 event_tx: mpsc::Sender<BridgeEvent>,
469 mut send_rx: mpsc::Receiver<WsMessage>,
470 mut shutdown_rx: mpsc::Receiver<()>,
471 running: Arc<Mutex<bool>>,
472 ) {
473 use std::sync::atomic::{AtomicU64, Ordering};
474 use tokio::sync::RwLock;
475
476 let listener = match TcpListener::bind(addr).await {
477 Ok(l) => l,
478 Err(e) => {
479 error!("Failed to bind WebSocket server: {}", e);
480 let _ = event_tx
481 .send(BridgeEvent::Error(format!("Bind failed: {}", e)))
482 .await;
483 return;
484 }
485 };
486
487 info!("WebSocket server listening on {}", addr);
488 *running.lock() = true;
489 let _ = event_tx.send(BridgeEvent::Connected).await;
490
491 let clients: Arc<RwLock<HashMap<u64, mpsc::Sender<WsMessage>>>> =
493 Arc::new(RwLock::new(HashMap::new()));
494 let next_client_id = Arc::new(AtomicU64::new(0));
495
496 loop {
497 tokio::select! {
498 result = listener.accept() => {
499 match result {
500 Ok((stream, peer_addr)) => {
501 let client_id = next_client_id.fetch_add(1, Ordering::SeqCst);
502 info!("WebSocket client {} connected: {}", client_id, peer_addr);
503
504 let namespace = namespace.clone();
505 let event_tx = event_tx.clone();
506 let clients = clients.clone();
507 let ping_interval = ping_interval_secs;
508
509 let (client_tx, mut client_rx) = mpsc::channel::<WsMessage>(100);
511 clients.write().await.insert(client_id, client_tx);
512
513 tokio::spawn(async move {
514 if let Ok(ws_stream) = accept_async(stream).await {
515 let (mut write, mut read) = ws_stream.split();
516
517 let ping_duration = if ping_interval > 0 {
519 Some(std::time::Duration::from_secs(ping_interval as u64))
520 } else {
521 None
522 };
523 let mut ping_timer = ping_duration.map(tokio::time::interval);
524 let mut awaiting_pong = false;
525
526 loop {
527 tokio::select! {
528 msg = read.next() => {
530 match msg {
531 Some(Ok(ws_msg)) => {
532 match &ws_msg {
533 WsMessage::Pong(_) => {
534 awaiting_pong = false;
535 debug!("Client {} pong received", client_id);
536 }
537 WsMessage::Ping(data) => {
538 if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
539 debug!("Failed to send pong to client {}: {}", client_id, e);
540 break;
541 }
542 }
543 _ => {
544 if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
545 let _ = event_tx.send(BridgeEvent::ToClasp(Box::new(clasp_msg))).await;
546 }
547 }
548 }
549 }
550 Some(Err(e)) => {
551 debug!("WebSocket client {} error: {}", client_id, e);
552 break;
553 }
554 None => break,
555 }
556 }
557 msg = client_rx.recv() => {
559 match msg {
560 Some(ws_msg) => {
561 if let Err(e) = write.send(ws_msg).await {
562 debug!("Failed to send to client {}: {}", client_id, e);
563 break;
564 }
565 }
566 None => break,
567 }
568 }
569 _ = async {
571 if let Some(ref mut timer) = ping_timer {
572 timer.tick().await
573 } else {
574 std::future::pending::<tokio::time::Instant>().await
575 }
576 } => {
577 if awaiting_pong {
578 warn!("Client {} ping timeout", client_id);
579 break;
580 }
581 if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
582 debug!("Failed to send ping to client {}: {}", client_id, e);
583 break;
584 }
585 awaiting_pong = true;
586 }
587 }
588 }
589 }
590
591 clients.write().await.remove(&client_id);
593 info!("WebSocket client {} disconnected: {}", client_id, peer_addr);
594 });
595 }
596 Err(e) => {
597 error!("WebSocket accept error: {}", e);
598 }
599 }
600 }
601 msg = send_rx.recv() => {
603 if let Some(ws_msg) = msg {
604 let client_list: Vec<_> = clients.read().await.values().cloned().collect();
606 for tx in client_list {
607 let _ = tx.send(ws_msg.clone()).await;
608 }
609 }
610 }
611 _ = shutdown_rx.recv() => {
612 info!("WebSocket server shutting down");
613 break;
614 }
615 }
616 }
617
618 *running.lock() = false;
619 let _ = event_tx
620 .send(BridgeEvent::Disconnected {
621 reason: Some("Server stopped".to_string()),
622 })
623 .await;
624 }
625}
626
627#[async_trait]
628impl Bridge for WebSocketBridge {
629 fn config(&self) -> &BridgeConfig {
630 &self.config
631 }
632
633 async fn start(&mut self) -> Result<mpsc::Receiver<BridgeEvent>> {
634 if *self.running.lock() {
635 return Err(BridgeError::Other("Bridge already running".to_string()));
636 }
637
638 let (event_tx, event_rx) = mpsc::channel(100);
639 let (send_tx, send_rx) = mpsc::channel(100);
640 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
641
642 self.send_tx = Some(send_tx);
643 self.shutdown_tx = Some(shutdown_tx);
644
645 let running = self.running.clone();
646 let ws_config = self.ws_config.clone();
647
648 match ws_config.mode {
649 WsMode::Client => {
650 tokio::spawn(Self::run_client(
651 ws_config.url,
652 ws_config.format,
653 ws_config.namespace,
654 ws_config.auto_reconnect,
655 ws_config.reconnect_delay_secs,
656 ws_config.ping_interval_secs,
657 event_tx,
658 send_rx,
659 shutdown_rx,
660 running,
661 ));
662 }
663 WsMode::Server => {
664 let addr: SocketAddr = ws_config
665 .url
666 .parse()
667 .map_err(|e| BridgeError::Other(format!("Invalid address: {}", e)))?;
668
669 tokio::spawn(Self::run_server(
670 addr,
671 ws_config.format,
672 ws_config.namespace,
673 ws_config.ping_interval_secs,
674 event_tx,
675 send_rx,
676 shutdown_rx,
677 running,
678 ));
679 }
680 }
681
682 info!("WebSocket bridge started in {:?} mode", self.ws_config.mode);
683 Ok(event_rx)
684 }
685
686 async fn stop(&mut self) -> Result<()> {
687 *self.running.lock() = false;
688 if let Some(tx) = self.shutdown_tx.take() {
689 let _ = tx.send(()).await;
690 }
691 self.send_tx = None;
692 info!("WebSocket bridge stopped");
693 Ok(())
694 }
695
696 async fn send(&self, msg: Message) -> Result<()> {
697 let send_tx = self
698 .send_tx
699 .as_ref()
700 .ok_or_else(|| BridgeError::Other("Not connected".to_string()))?;
701
702 if let Some(ws_msg) = Self::message_to_ws(&msg, self.ws_config.format) {
703 send_tx
704 .send(ws_msg)
705 .await
706 .map_err(|e| BridgeError::Other(format!("WebSocket send failed: {}", e)))?;
707 }
708
709 Ok(())
710 }
711
712 fn is_running(&self) -> bool {
713 *self.running.lock()
714 }
715
716 fn namespace(&self) -> &str {
717 &self.ws_config.namespace
718 }
719}
720
721#[cfg(test)]
722mod tests {
723 use super::*;
724
725 #[test]
726 fn test_config_default() {
727 let config = WebSocketBridgeConfig::default();
728 assert_eq!(config.mode, WsMode::Client);
729 assert_eq!(config.namespace, "/ws");
730 }
731
732 #[test]
733 fn test_message_formats() {
734 let prefix = "/ws";
735
736 let ws_msg = WsMessage::Text(r#"{"address": "/test", "value": 42}"#.to_string());
738 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
739 assert!(clasp.is_some());
740
741 let ws_msg = WsMessage::Text("hello".to_string());
743 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
744 assert!(clasp.is_some());
745
746 let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
748 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Raw, prefix);
749 assert!(clasp.is_some());
750 }
751}