1use crate::{Bridge, BridgeConfig, BridgeError, BridgeEvent, Result};
7use async_trait::async_trait;
8use clasp_core::{Message, PublishMessage, SetMessage, SignalType, Value};
9use futures::{SinkExt, StreamExt};
10use parking_lot::Mutex;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::{TcpListener, TcpStream};
16use tokio::sync::mpsc;
17use tokio_tungstenite::{
18 accept_async, connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream,
19 WebSocketStream,
20};
21use tracing::{debug, error, info, warn};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum WsMessageFormat {
27 #[default]
29 Json,
30 MsgPack,
32 Raw,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum WsMode {
40 #[default]
42 Client,
43 Server,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct WebSocketBridgeConfig {
50 #[serde(default)]
52 pub mode: WsMode,
53 pub url: String,
55 #[serde(default)]
57 pub path: Option<String>,
58 #[serde(default)]
60 pub format: WsMessageFormat,
61 #[serde(default = "default_ping_interval")]
63 pub ping_interval_secs: u32,
64 #[serde(default = "default_true")]
66 pub auto_reconnect: bool,
67 #[serde(default = "default_reconnect_delay")]
69 pub reconnect_delay_secs: u32,
70 #[serde(default)]
72 pub headers: HashMap<String, String>,
73 #[serde(default = "default_namespace")]
75 pub namespace: String,
76}
77
78fn default_true() -> bool {
79 true
80}
81
82fn default_ping_interval() -> u32 {
83 30
84}
85
86fn default_reconnect_delay() -> u32 {
87 5
88}
89
90fn default_namespace() -> String {
91 "/ws".to_string()
92}
93
94impl Default for WebSocketBridgeConfig {
95 fn default() -> Self {
96 Self {
97 mode: WsMode::Client,
98 url: "ws://localhost:8080".to_string(),
99 path: None,
100 format: WsMessageFormat::Json,
101 ping_interval_secs: 30,
102 auto_reconnect: true,
103 reconnect_delay_secs: 5,
104 headers: HashMap::new(),
105 namespace: "/ws".to_string(),
106 }
107 }
108}
109
110type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
112
113type WsServerStream = WebSocketStream<TcpStream>;
115
116type WsSink = futures::stream::SplitSink<WsServerStream, WsMessage>;
118
119pub struct WebSocketBridge {
121 config: BridgeConfig,
122 ws_config: WebSocketBridgeConfig,
123 running: Arc<Mutex<bool>>,
124 send_tx: Option<mpsc::Sender<WsMessage>>,
125 shutdown_tx: Option<mpsc::Sender<()>>,
126}
127
128impl WebSocketBridge {
129 pub fn new(ws_config: WebSocketBridgeConfig) -> Self {
131 let config = BridgeConfig {
132 name: "WebSocket Bridge".to_string(),
133 protocol: "websocket".to_string(),
134 bidirectional: true,
135 ..Default::default()
136 };
137
138 Self {
139 config,
140 ws_config,
141 running: Arc::new(Mutex::new(false)),
142 send_tx: None,
143 shutdown_tx: None,
144 }
145 }
146
147 fn parse_message(msg: &WsMessage, format: WsMessageFormat, prefix: &str) -> Option<Message> {
149 match msg {
150 WsMessage::Text(text) => match format {
151 WsMessageFormat::Json | WsMessageFormat::Raw => {
152 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
154 let address = json
155 .get("address")
156 .and_then(|v| v.as_str())
157 .map(|s| s.to_string())
158 .unwrap_or_else(|| format!("{}/message", prefix));
159
160 let value = json
161 .get("value")
162 .map(|v| Self::json_to_value(v.clone()))
163 .or_else(|| json.get("data").map(|v| Self::json_to_value(v.clone())))
164 .unwrap_or_else(|| Self::json_to_value(json));
165
166 Some(Message::Set(SetMessage {
167 address,
168 value,
169 revision: None,
170 lock: false,
171 unlock: false,
172 }))
173 } else {
174 Some(Message::Set(SetMessage {
176 address: format!("{}/text", prefix),
177 value: Value::String(text.clone()),
178 revision: None,
179 lock: false,
180 unlock: false,
181 }))
182 }
183 }
184 WsMessageFormat::MsgPack => {
185 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
188 if let (Some(addr), Some(val)) = (
190 json.get("address").and_then(|a| a.as_str()),
191 json.get("value"),
192 ) {
193 Some(Message::Set(SetMessage {
194 address: addr.to_string(),
195 value: Self::json_to_value(val.clone()),
196 revision: None,
197 lock: false,
198 unlock: false,
199 }))
200 } else {
201 Some(Message::Set(SetMessage {
203 address: format!("{}/text", prefix),
204 value: Value::String(text.clone()),
205 revision: None,
206 lock: false,
207 unlock: false,
208 }))
209 }
210 } else {
211 Some(Message::Set(SetMessage {
213 address: format!("{}/text", prefix),
214 value: Value::String(text.clone()),
215 revision: None,
216 lock: false,
217 unlock: false,
218 }))
219 }
220 }
221 },
222 WsMessage::Binary(data) => match format {
223 WsMessageFormat::MsgPack => {
224 if let Ok((msg, _)) = clasp_core::codec::decode(data) {
226 Some(msg)
227 } else {
228 Some(Message::Set(SetMessage {
230 address: format!("{}/binary", prefix),
231 value: Value::Bytes(data.clone()),
232 revision: None,
233 lock: false,
234 unlock: false,
235 }))
236 }
237 }
238 WsMessageFormat::Raw | WsMessageFormat::Json => Some(Message::Set(SetMessage {
239 address: format!("{}/binary", prefix),
240 value: Value::Bytes(data.clone()),
241 revision: None,
242 lock: false,
243 unlock: false,
244 })),
245 },
246 _ => None,
247 }
248 }
249
250 fn json_to_value(json: serde_json::Value) -> Value {
252 match json {
253 serde_json::Value::Null => Value::Null,
254 serde_json::Value::Bool(b) => Value::Bool(b),
255 serde_json::Value::Number(n) => {
256 if let Some(i) = n.as_i64() {
257 Value::Int(i)
258 } else if let Some(f) = n.as_f64() {
259 Value::Float(f)
260 } else {
261 Value::Null
262 }
263 }
264 serde_json::Value::String(s) => Value::String(s),
265 serde_json::Value::Array(arr) => {
266 Value::Array(arr.into_iter().map(Self::json_to_value).collect())
267 }
268 serde_json::Value::Object(obj) => {
269 let map: HashMap<String, Value> = obj
270 .into_iter()
271 .map(|(k, v)| (k, Self::json_to_value(v)))
272 .collect();
273 Value::Map(map)
274 }
275 }
276 }
277
278 fn message_to_ws(msg: &Message, format: WsMessageFormat) -> Option<WsMessage> {
280 let (address, value) = match msg {
281 Message::Set(set) => (Some(&set.address), Some(&set.value)),
282 Message::Publish(pub_msg) => (Some(&pub_msg.address), pub_msg.value.as_ref()),
283 _ => return None,
284 };
285
286 match format {
287 WsMessageFormat::Json => {
288 let json = serde_json::json!({
289 "address": address,
290 "value": value,
291 });
292 Some(WsMessage::Text(json.to_string()))
293 }
294 WsMessageFormat::MsgPack => {
295 if let Ok(encoded) = clasp_core::codec::encode(msg) {
296 Some(WsMessage::Binary(encoded.to_vec()))
297 } else {
298 None
299 }
300 }
301 WsMessageFormat::Raw => {
302 if let Some(val) = value {
303 match val {
304 Value::String(s) => Some(WsMessage::Text(s.clone())),
305 Value::Bytes(b) => Some(WsMessage::Binary(b.clone())),
306 _ => {
307 let json = serde_json::to_string(val).ok()?;
308 Some(WsMessage::Text(json))
309 }
310 }
311 } else {
312 None
313 }
314 }
315 }
316 }
317
318 async fn run_client(
320 url: String,
321 format: WsMessageFormat,
322 namespace: String,
323 auto_reconnect: bool,
324 reconnect_delay: u32,
325 ping_interval_secs: u32,
326 event_tx: mpsc::Sender<BridgeEvent>,
327 mut send_rx: mpsc::Receiver<WsMessage>,
328 mut shutdown_rx: mpsc::Receiver<()>,
329 running: Arc<Mutex<bool>>,
330 ) {
331 loop {
332 info!("WebSocket connecting to {}", url);
333
334 match connect_async(&url).await {
335 Ok((ws_stream, _)) => {
336 info!("WebSocket connected");
337 *running.lock() = true;
338 let _ = event_tx.send(BridgeEvent::Connected).await;
339
340 let (mut write, mut read) = ws_stream.split();
341
342 let ping_duration = if ping_interval_secs > 0 {
344 Some(std::time::Duration::from_secs(ping_interval_secs as u64))
345 } else {
346 None
347 };
348 let mut ping_interval = ping_duration.map(tokio::time::interval);
349 let mut awaiting_pong = false;
350
351 loop {
352 tokio::select! {
353 msg = read.next() => {
355 match msg {
356 Some(Ok(ws_msg)) => {
357 match &ws_msg {
358 WsMessage::Pong(_) => {
359 awaiting_pong = false;
360 debug!("Received pong");
361 }
362 WsMessage::Ping(data) => {
363 if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
365 error!("Failed to send pong: {}", e);
366 break;
367 }
368 }
369 _ => {
370 if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
371 let _ = event_tx.send(BridgeEvent::ToClasp(clasp_msg)).await;
372 }
373 }
374 }
375 }
376 Some(Err(e)) => {
377 error!("WebSocket error: {}", e);
378 break;
379 }
380 None => {
381 warn!("WebSocket connection closed");
382 break;
383 }
384 }
385 }
386 msg = send_rx.recv() => {
388 if let Some(ws_msg) = msg {
389 if let Err(e) = write.send(ws_msg).await {
390 error!("WebSocket send error: {}", e);
391 break;
392 }
393 }
394 }
395 _ = async {
397 if let Some(ref mut interval) = ping_interval {
398 interval.tick().await
399 } else {
400 std::future::pending::<tokio::time::Instant>().await
402 }
403 } => {
404 if awaiting_pong {
405 warn!("Ping timeout - no pong received");
406 break;
407 }
408 if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
409 error!("Failed to send ping: {}", e);
410 break;
411 }
412 awaiting_pong = true;
413 debug!("Sent ping");
414 }
415 _ = shutdown_rx.recv() => {
417 info!("WebSocket shutting down");
418 let _ = write.close().await;
419 *running.lock() = false;
420 return;
421 }
422 }
423 }
424
425 *running.lock() = false;
426 let _ = event_tx
427 .send(BridgeEvent::Disconnected {
428 reason: Some("Connection closed".to_string()),
429 })
430 .await;
431 }
432 Err(e) => {
433 error!("WebSocket connection failed: {}", e);
434 let _ = event_tx
435 .send(BridgeEvent::Error(format!("Connection failed: {}", e)))
436 .await;
437 }
438 }
439
440 if !auto_reconnect {
441 *running.lock() = false;
442 return;
443 }
444
445 info!("Reconnecting in {} seconds...", reconnect_delay);
446 tokio::time::sleep(std::time::Duration::from_secs(reconnect_delay as u64)).await;
447 }
448 }
449
450 async fn run_server(
452 addr: SocketAddr,
453 format: WsMessageFormat,
454 namespace: String,
455 ping_interval_secs: u32,
456 event_tx: mpsc::Sender<BridgeEvent>,
457 mut send_rx: mpsc::Receiver<WsMessage>,
458 mut shutdown_rx: mpsc::Receiver<()>,
459 running: Arc<Mutex<bool>>,
460 ) {
461 use std::sync::atomic::{AtomicU64, Ordering};
462 use tokio::sync::RwLock;
463
464 let listener = match TcpListener::bind(addr).await {
465 Ok(l) => l,
466 Err(e) => {
467 error!("Failed to bind WebSocket server: {}", e);
468 let _ = event_tx
469 .send(BridgeEvent::Error(format!("Bind failed: {}", e)))
470 .await;
471 return;
472 }
473 };
474
475 info!("WebSocket server listening on {}", addr);
476 *running.lock() = true;
477 let _ = event_tx.send(BridgeEvent::Connected).await;
478
479 let clients: Arc<RwLock<HashMap<u64, mpsc::Sender<WsMessage>>>> =
481 Arc::new(RwLock::new(HashMap::new()));
482 let next_client_id = Arc::new(AtomicU64::new(0));
483
484 loop {
485 tokio::select! {
486 result = listener.accept() => {
487 match result {
488 Ok((stream, peer_addr)) => {
489 let client_id = next_client_id.fetch_add(1, Ordering::SeqCst);
490 info!("WebSocket client {} connected: {}", client_id, peer_addr);
491
492 let format = format;
493 let namespace = namespace.clone();
494 let event_tx = event_tx.clone();
495 let clients = clients.clone();
496 let ping_interval = ping_interval_secs;
497
498 let (client_tx, mut client_rx) = mpsc::channel::<WsMessage>(100);
500 clients.write().await.insert(client_id, client_tx);
501
502 tokio::spawn(async move {
503 if let Ok(ws_stream) = accept_async(stream).await {
504 let (mut write, mut read) = ws_stream.split();
505
506 let ping_duration = if ping_interval > 0 {
508 Some(std::time::Duration::from_secs(ping_interval as u64))
509 } else {
510 None
511 };
512 let mut ping_timer = ping_duration.map(tokio::time::interval);
513 let mut awaiting_pong = false;
514
515 loop {
516 tokio::select! {
517 msg = read.next() => {
519 match msg {
520 Some(Ok(ws_msg)) => {
521 match &ws_msg {
522 WsMessage::Pong(_) => {
523 awaiting_pong = false;
524 debug!("Client {} pong received", client_id);
525 }
526 WsMessage::Ping(data) => {
527 if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
528 debug!("Failed to send pong to client {}: {}", client_id, e);
529 break;
530 }
531 }
532 _ => {
533 if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
534 let _ = event_tx.send(BridgeEvent::ToClasp(clasp_msg)).await;
535 }
536 }
537 }
538 }
539 Some(Err(e)) => {
540 debug!("WebSocket client {} error: {}", client_id, e);
541 break;
542 }
543 None => break,
544 }
545 }
546 msg = client_rx.recv() => {
548 match msg {
549 Some(ws_msg) => {
550 if let Err(e) = write.send(ws_msg).await {
551 debug!("Failed to send to client {}: {}", client_id, e);
552 break;
553 }
554 }
555 None => break,
556 }
557 }
558 _ = async {
560 if let Some(ref mut timer) = ping_timer {
561 timer.tick().await
562 } else {
563 std::future::pending::<tokio::time::Instant>().await
564 }
565 } => {
566 if awaiting_pong {
567 warn!("Client {} ping timeout", client_id);
568 break;
569 }
570 if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
571 debug!("Failed to send ping to client {}: {}", client_id, e);
572 break;
573 }
574 awaiting_pong = true;
575 }
576 }
577 }
578 }
579
580 clients.write().await.remove(&client_id);
582 info!("WebSocket client {} disconnected: {}", client_id, peer_addr);
583 });
584 }
585 Err(e) => {
586 error!("WebSocket accept error: {}", e);
587 }
588 }
589 }
590 msg = send_rx.recv() => {
592 if let Some(ws_msg) = msg {
593 let client_list: Vec<_> = clients.read().await.values().cloned().collect();
595 for tx in client_list {
596 let _ = tx.send(ws_msg.clone()).await;
597 }
598 }
599 }
600 _ = shutdown_rx.recv() => {
601 info!("WebSocket server shutting down");
602 break;
603 }
604 }
605 }
606
607 *running.lock() = false;
608 let _ = event_tx
609 .send(BridgeEvent::Disconnected {
610 reason: Some("Server stopped".to_string()),
611 })
612 .await;
613 }
614}
615
616#[async_trait]
617impl Bridge for WebSocketBridge {
618 fn config(&self) -> &BridgeConfig {
619 &self.config
620 }
621
622 async fn start(&mut self) -> Result<mpsc::Receiver<BridgeEvent>> {
623 if *self.running.lock() {
624 return Err(BridgeError::Other("Bridge already running".to_string()));
625 }
626
627 let (event_tx, event_rx) = mpsc::channel(100);
628 let (send_tx, send_rx) = mpsc::channel(100);
629 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
630
631 self.send_tx = Some(send_tx);
632 self.shutdown_tx = Some(shutdown_tx);
633
634 let running = self.running.clone();
635 let ws_config = self.ws_config.clone();
636
637 match ws_config.mode {
638 WsMode::Client => {
639 tokio::spawn(Self::run_client(
640 ws_config.url,
641 ws_config.format,
642 ws_config.namespace,
643 ws_config.auto_reconnect,
644 ws_config.reconnect_delay_secs,
645 ws_config.ping_interval_secs,
646 event_tx,
647 send_rx,
648 shutdown_rx,
649 running,
650 ));
651 }
652 WsMode::Server => {
653 let addr: SocketAddr = ws_config
654 .url
655 .parse()
656 .map_err(|e| BridgeError::Other(format!("Invalid address: {}", e)))?;
657
658 tokio::spawn(Self::run_server(
659 addr,
660 ws_config.format,
661 ws_config.namespace,
662 ws_config.ping_interval_secs,
663 event_tx,
664 send_rx,
665 shutdown_rx,
666 running,
667 ));
668 }
669 }
670
671 info!("WebSocket bridge started in {:?} mode", self.ws_config.mode);
672 Ok(event_rx)
673 }
674
675 async fn stop(&mut self) -> Result<()> {
676 *self.running.lock() = false;
677 if let Some(tx) = self.shutdown_tx.take() {
678 let _ = tx.send(()).await;
679 }
680 self.send_tx = None;
681 info!("WebSocket bridge stopped");
682 Ok(())
683 }
684
685 async fn send(&self, msg: Message) -> Result<()> {
686 let send_tx = self
687 .send_tx
688 .as_ref()
689 .ok_or_else(|| BridgeError::Other("Not connected".to_string()))?;
690
691 if let Some(ws_msg) = Self::message_to_ws(&msg, self.ws_config.format) {
692 send_tx
693 .send(ws_msg)
694 .await
695 .map_err(|e| BridgeError::Other(format!("WebSocket send failed: {}", e)))?;
696 }
697
698 Ok(())
699 }
700
701 fn is_running(&self) -> bool {
702 *self.running.lock()
703 }
704
705 fn namespace(&self) -> &str {
706 &self.ws_config.namespace
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713
714 #[test]
715 fn test_config_default() {
716 let config = WebSocketBridgeConfig::default();
717 assert_eq!(config.mode, WsMode::Client);
718 assert_eq!(config.namespace, "/ws");
719 }
720
721 #[test]
722 fn test_message_formats() {
723 let prefix = "/ws";
724
725 let ws_msg = WsMessage::Text(r#"{"address": "/test", "value": 42}"#.to_string());
727 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
728 assert!(clasp.is_some());
729
730 let ws_msg = WsMessage::Text("hello".to_string());
732 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
733 assert!(clasp.is_some());
734
735 let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
737 let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Raw, prefix);
738 assert!(clasp.is_some());
739 }
740}